Adds support for SDF in ULR and the Java SDK.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index db25581..b2f8195 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -28,6 +28,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -93,7 +94,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) throws IOException {
 
       BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
           .setPrimitiveTransformReference(pTransformId)
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
index 34d704d..a4f6e54 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
@@ -27,6 +27,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -84,7 +85,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) throws IOException {
       BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
           .setPrimitiveTransformReference(pTransformId)
           .setName(getOnlyElement(pTransform.getInputsMap().keySet()))
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
index 5df178f..185fc11 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
@@ -28,6 +28,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -81,7 +82,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) {
 
       ImmutableList.Builder<FnDataReceiver<WindowedValue<?>>> consumers = ImmutableList.builder();
       for (String pCollectionId : pTransform.getOutputsMap().values()) {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
new file mode 100644
index 0000000..9cbe34c
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Sets;
+import java.io.IOException;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.harness.state.SideInputSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/** A {@link PTransformRunnerFactory} for transforms invoking a {@link DoFn}. */
+abstract class DoFnPTransformRunnerFactory<
+        TransformInputT,
+        FnInputT,
+        OutputT,
+        RunnerT extends DoFnPTransformRunnerFactory.DoFnPTransformRunner<TransformInputT>>
+    implements PTransformRunnerFactory<RunnerT> {
+  interface DoFnPTransformRunner<T> {
+    void startBundle() throws Exception;
+
+    void processElement(WindowedValue<T> input) throws Exception;
+
+    void finishBundle() throws Exception;
+  }
+
+  @Override
+  public final RunnerT createRunnerForPTransform(
+      PipelineOptions pipelineOptions,
+      BeamFnDataClient beamFnDataClient,
+      BeamFnStateClient beamFnStateClient,
+      String ptransformId,
+      PTransform pTransform,
+      Supplier<String> processBundleInstructionId,
+      Map<String, PCollection> pCollections,
+      Map<String, RunnerApi.Coder> coders,
+      Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+      Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+      Consumer<ThrowingRunnable> addStartFunction,
+      Consumer<ThrowingRunnable> addFinishFunction,
+      BundleSplitListener splitListener) {
+    Context<FnInputT, OutputT> context =
+        new Context<>(
+            pipelineOptions,
+            beamFnStateClient,
+            ptransformId,
+            pTransform,
+            processBundleInstructionId,
+            pCollections,
+            coders,
+            windowingStrategies,
+            pCollectionIdsToConsumers,
+            splitListener);
+
+    RunnerT runner = createRunner(context);
+
+    // Register the appropriate handlers.
+    addStartFunction.accept(runner::startBundle);
+    Iterable<String> mainInput =
+        Sets.difference(
+            pTransform.getInputsMap().keySet(), context.parDoPayload.getSideInputsMap().keySet());
+    for (String localInputName : mainInput) {
+      pCollectionIdsToConsumers.put(
+          pTransform.getInputsOrThrow(localInputName),
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<TransformInputT>>) runner::processElement);
+    }
+    addFinishFunction.accept(runner::finishBundle);
+    return runner;
+  }
+
+  abstract RunnerT createRunner(Context<FnInputT, OutputT> context);
+
+  static class Context<InputT, OutputT> {
+    final PipelineOptions pipelineOptions;
+    final BeamFnStateClient beamFnStateClient;
+    final String ptransformId;
+    final PTransform pTransform;
+    final Supplier<String> processBundleInstructionId;
+    final RehydratedComponents rehydratedComponents;
+    final DoFn<InputT, OutputT> doFn;
+    final TupleTag<OutputT> mainOutputTag;
+    final Coder<?> inputCoder;
+    final Coder<?> keyCoder;
+    final Coder<? extends BoundedWindow> windowCoder;
+    final WindowingStrategy<InputT, ?> windowingStrategy;
+    final Map<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap;
+    final ParDoPayload parDoPayload;
+    final ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer;
+    final BundleSplitListener splitListener;
+
+    Context(
+        PipelineOptions pipelineOptions,
+        BeamFnStateClient beamFnStateClient,
+        String ptransformId,
+        PTransform pTransform,
+        Supplier<String> processBundleInstructionId,
+        Map<String, PCollection> pCollections,
+        Map<String, RunnerApi.Coder> coders,
+        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+        Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+        BundleSplitListener splitListener) {
+      this.pipelineOptions = pipelineOptions;
+      this.beamFnStateClient = beamFnStateClient;
+      this.ptransformId = ptransformId;
+      this.pTransform = pTransform;
+      this.processBundleInstructionId = processBundleInstructionId;
+      ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMapBuilder =
+          ImmutableMap.builder();
+      try {
+        rehydratedComponents =
+            RehydratedComponents.forComponents(
+                RunnerApi.Components.newBuilder()
+                    .putAllCoders(coders)
+                    .putAllWindowingStrategies(windowingStrategies)
+                    .build());
+        parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
+        doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
+        mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
+        String mainInputTag =
+            Iterables.getOnlyElement(
+                Sets.difference(
+                    pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
+        PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
+        inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
+        if (inputCoder instanceof KvCoder
+            // TODO: Stop passing windowed value coders within PCollections.
+            || (inputCoder instanceof WindowedValue.WindowedValueCoder
+                && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) {
+          this.keyCoder =
+              inputCoder instanceof WindowedValueCoder
+                  ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
+                  : ((KvCoder) inputCoder).getKeyCoder();
+        } else {
+          this.keyCoder = null;
+        }
+
+        windowingStrategy =
+            (WindowingStrategy)
+                rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());
+        windowCoder = windowingStrategy.getWindowFn().windowCoder();
+
+        // Build the map from tag id to side input specification
+        for (Map.Entry<String, RunnerApi.SideInput> entry :
+            parDoPayload.getSideInputsMap().entrySet()) {
+          String sideInputTag = entry.getKey();
+          RunnerApi.SideInput sideInput = entry.getValue();
+          checkArgument(
+              Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+                  sideInput.getAccessPattern().getUrn()),
+              "This SDK is only capable of dealing with %s materializations "
+                  + "but was asked to handle %s for PCollectionView with tag %s.",
+              Materializations.MULTIMAP_MATERIALIZATION_URN,
+              sideInput.getAccessPattern().getUrn(),
+              sideInputTag);
+
+          PCollection sideInputPCollection =
+              pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
+          WindowingStrategy sideInputWindowingStrategy =
+              rehydratedComponents.getWindowingStrategy(
+                  sideInputPCollection.getWindowingStrategyId());
+          tagToSideInputSpecMapBuilder.put(
+              new TupleTag<>(entry.getKey()),
+              SideInputSpec.create(
+                  rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
+                  sideInputWindowingStrategy.getWindowFn().windowCoder(),
+                  PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
+                  PCollectionViewTranslation.windowMappingFnFromProto(
+                      entry.getValue().getWindowMappingFn())));
+        }
+      } catch (IOException exn) {
+        throw new IllegalArgumentException("Malformed ParDoPayload", exn);
+      }
+
+      ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
+          tagToConsumerBuilder = ImmutableListMultimap.builder();
+      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
+        tagToConsumerBuilder.putAll(
+            new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
+      }
+      tagToConsumer = tagToConsumerBuilder.build();
+      tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build();
+      this.splitListener = splitListener;
+    }
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
index 52d9435..6ac860e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
@@ -27,6 +27,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -69,7 +70,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction)
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener)
         throws IOException {
 
       // Give each input a MultiplexingFnDataReceiver to all outputs of the flatten.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 9778aac..4cfb5d9 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -17,75 +17,28 @@
  */
 package org.apache.beam.fn.harness;
 
-import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.service.AutoService;
-import com.google.auto.value.AutoValue;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableListMultimap;
 import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.ListMultimap;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.Sets;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.InvalidProtocolBufferException;
-import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import java.util.function.Supplier;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.state.BagUserState;
-import org.apache.beam.fn.harness.state.BeamFnStateClient;
-import org.apache.beam.fn.harness.state.MultimapSideInput;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
-import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
 import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.core.construction.RehydratedComponents;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.fn.function.ThrowingRunnable;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.state.BagState;
-import org.apache.beam.sdk.state.CombiningState;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
-import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.StateBinder;
-import org.apache.beam.sdk.state.StateContext;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.state.Timer;
-import org.apache.beam.sdk.state.ValueState;
-import org.apache.beam.sdk.state.WatermarkHoldState;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
 import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
-import org.apache.beam.sdk.transforms.DoFn.StartBundleContext;
 import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
-import org.apache.beam.sdk.transforms.Materializations;
-import org.apache.beam.sdk.transforms.ViewFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -94,399 +47,145 @@
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
-import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
-import org.apache.beam.sdk.util.CombineFnUtil;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
-import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
 import org.joda.time.Instant;
 
 /**
- * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers
- * of abstraction caused by StateInternals/TimerInternals since they model state and timer
- * concepts differently.
+ * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers of
+ * abstraction caused by StateInternals/TimerInternals since they model state and timer concepts
+ * differently.
  */
-public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
-
-  private final ProcessBundleContext processContext;
-  private final FinishBundleContext finishBundleContext;
-  private StartBundleContext startBundleContext;
-
-  /**
-   * A registrar which provides a factory to handle Java {@link DoFn}s.
-   */
+public class FnApiDoFnRunner<InputT, OutputT>
+    implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<InputT> {
+  /** A registrar which provides a factory to handle Java {@link DoFn}s. */
   @AutoService(PTransformRunnerFactory.Registrar.class)
-  public static class Registrar implements
-      PTransformRunnerFactory.Registrar {
-
+  public static class Registrar implements PTransformRunnerFactory.Registrar {
     @Override
     public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
-      return ImmutableMap.of(
-          PTransformTranslation.PAR_DO_TRANSFORM_URN, new NewFactory(),
-          ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory());
+      return ImmutableMap.of(PTransformTranslation.PAR_DO_TRANSFORM_URN, new Factory());
     }
   }
 
-  /** A factory for {@link FnApiDoFnRunner}. */
   static class Factory<InputT, OutputT>
-      implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
+      extends DoFnPTransformRunnerFactory<
+          InputT, InputT, OutputT, FnApiDoFnRunner<InputT, OutputT>> {
     @Override
-    public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
-        PipelineOptions pipelineOptions,
-        BeamFnDataClient beamFnDataClient,
-        BeamFnStateClient beamFnStateClient,
-        String ptransformId,
-        PTransform pTransform,
-        Supplier<String> processBundleInstructionId,
-        Map<String, PCollection> pCollections,
-        Map<String, RunnerApi.Coder> coders,
-        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
-        Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
-        Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
-
-      // For every output PCollection, create a map from output name to Consumer
-      ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
-          tagToOutputMapBuilder = ImmutableListMultimap.builder();
-      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
-        tagToOutputMapBuilder.putAll(
-            new TupleTag<>(entry.getKey()),
-            pCollectionIdsToConsumers.get(entry.getValue()));
-      }
-      ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap =
-          tagToOutputMapBuilder.build();
-
-      // Get the DoFnInfo from the serialized blob.
-      ByteString serializedFn = pTransform.getSpec().getPayload();
-      @SuppressWarnings({"unchecked", "rawtypes"})
-      DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray(
-          serializedFn.toByteArray(), "DoFnInfo");
-
-      @SuppressWarnings({"unchecked", "rawtypes"})
-      DoFnRunner<InputT, OutputT> runner =
-          new FnApiDoFnRunner<>(
-              pipelineOptions,
-              beamFnStateClient,
-              ptransformId,
-              processBundleInstructionId,
-              doFnInfo.getDoFn(),
-              doFnInfo.getInputCoder(),
-              (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
-                  (Collection) tagToOutputMap.get(doFnInfo.getMainOutput()),
-              tagToOutputMap,
-              ImmutableMap.of(),
-              doFnInfo.getWindowingStrategy());
-
-      registerHandlers(
-          runner,
-          pTransform,
-          ImmutableSet.of(),
-          addStartFunction,
-          addFinishFunction,
-          pCollectionIdsToConsumers);
-      return runner;
+    public FnApiDoFnRunner<InputT, OutputT> createRunner(Context<InputT, OutputT> context) {
+      return new FnApiDoFnRunner<>(context);
     }
   }
 
-  static class NewFactory<InputT, OutputT>
-      implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
-    @Override
-    public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
-        PipelineOptions pipelineOptions,
-        BeamFnDataClient beamFnDataClient,
-        BeamFnStateClient beamFnStateClient,
-        String ptransformId,
-        RunnerApi.PTransform pTransform,
-        Supplier<String> processBundleInstructionId,
-        Map<String, RunnerApi.PCollection> pCollections,
-        Map<String, RunnerApi.Coder> coders,
-        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
-        Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
-        Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
-
-      DoFn<InputT, OutputT> doFn;
-      TupleTag<OutputT> mainOutputTag;
-      Coder<InputT> inputCoder;
-      WindowingStrategy<InputT, ?> windowingStrategy;
-
-      ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap =
-          ImmutableMap.builder();
-      ParDoPayload parDoPayload;
-      try {
-        RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(
-            RunnerApi.Components.newBuilder()
-                .putAllCoders(coders).putAllWindowingStrategies(windowingStrategies).build());
-        parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
-        doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
-        mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
-        String mainInputTag = Iterables.getOnlyElement(Sets.difference(
-            pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
-        RunnerApi.PCollection mainInput =
-            pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
-        inputCoder = (Coder<InputT>) rehydratedComponents.getCoder(
-            mainInput.getCoderId());
-        windowingStrategy = (WindowingStrategy) rehydratedComponents.getWindowingStrategy(
-            mainInput.getWindowingStrategyId());
-
-        // Build the map from tag id to side input specification
-        for (Map.Entry<String, RunnerApi.SideInput> entry
-            : parDoPayload.getSideInputsMap().entrySet()) {
-          String sideInputTag = entry.getKey();
-          RunnerApi.SideInput sideInput = entry.getValue();
-          checkArgument(
-              Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
-                  sideInput.getAccessPattern().getUrn()),
-              "This SDK is only capable of dealing with %s materializations "
-                  + "but was asked to handle %s for PCollectionView with tag %s.",
-              Materializations.MULTIMAP_MATERIALIZATION_URN,
-              sideInput.getAccessPattern().getUrn(),
-              sideInputTag);
-
-          RunnerApi.PCollection sideInputPCollection =
-              pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
-          WindowingStrategy sideInputWindowingStrategy =
-              rehydratedComponents.getWindowingStrategy(
-                  sideInputPCollection.getWindowingStrategyId());
-          tagToSideInputSpecMap.put(
-              new TupleTag<>(entry.getKey()),
-              SideInputSpec.create(
-                  rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
-                  sideInputWindowingStrategy.getWindowFn().windowCoder(),
-                  PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
-                  PCollectionViewTranslation.windowMappingFnFromProto(
-                      entry.getValue().getWindowMappingFn())));
-        }
-      } catch (InvalidProtocolBufferException exn) {
-        throw new IllegalArgumentException("Malformed ParDoPayload", exn);
-      } catch (IOException exn) {
-        throw new IllegalArgumentException("Malformed ParDoPayload", exn);
-      }
-
-      ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
-          tagToConsumerBuilder = ImmutableListMultimap.builder();
-      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
-        tagToConsumerBuilder.putAll(
-            new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
-      }
-      ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer =
-          tagToConsumerBuilder.build();
-
-      @SuppressWarnings({"unchecked", "rawtypes"})
-      DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>(
-          pipelineOptions,
-          beamFnStateClient,
-          ptransformId,
-          processBundleInstructionId,
-          doFn,
-          inputCoder,
-          (Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
-              tagToConsumer.get(mainOutputTag),
-          tagToConsumer,
-          tagToSideInputSpecMap.build(),
-          windowingStrategy);
-      registerHandlers(
-          runner,
-          pTransform,
-          parDoPayload.getSideInputsMap().keySet(),
-          addStartFunction,
-          addFinishFunction,
-          pCollectionIdsToConsumers);
-      return runner;
-    }
-  }
-
-  private static <InputT, OutputT> void registerHandlers(
-      DoFnRunner<InputT, OutputT> runner,
-      RunnerApi.PTransform pTransform,
-      Set<String> sideInputLocalNames,
-      Consumer<ThrowingRunnable> addStartFunction,
-      Consumer<ThrowingRunnable> addFinishFunction,
-      Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers) {
-    // Register the appropriate handlers.
-    addStartFunction.accept(runner::startBundle);
-    for (String localInputName
-        : Sets.difference(pTransform.getInputsMap().keySet(), sideInputLocalNames)) {
-      pCollectionIdsToConsumers.put(
-          pTransform.getInputsOrThrow(localInputName),
-          (FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>) runner::processElement);
-    }
-    addFinishFunction.accept(runner::finishBundle);
-  }
-
   //////////////////////////////////////////////////////////////////////////////////////////////////
 
-  private final PipelineOptions pipelineOptions;
-  private final BeamFnStateClient beamFnStateClient;
-  private final String ptransformId;
-  private final Supplier<String> processBundleInstructionId;
-  private final DoFn<InputT, OutputT> doFn;
-  private final Coder<InputT> inputCoder;
+  private final Context<InputT, OutputT> context;
   private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
-  private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap;
-  private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
-  private final Map<StateKey, Object> stateKeyObjectCache;
-  private final WindowingStrategy windowingStrategy;
+  private FnApiStateAccessor stateAccessor;
   private final DoFnSignature doFnSignature;
   private final DoFnInvoker<InputT, OutputT> doFnInvoker;
-  private final StateBinder stateBinder;
-  private final Collection<ThrowingRunnable> stateFinalizers;
 
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and is null otherwise.
-   */
+  private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+  private final ProcessBundleContext processContext;
+  private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+  /** Only valid during {@link #processElement}, null otherwise. */
   private WindowedValue<InputT> currentElement;
 
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and is null otherwise.
-   */
   private BoundedWindow currentWindow;
 
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and only when processing a {@link KV} and is null otherwise.
-   */
-  private ByteString encodedCurrentKey;
-
-  /**
-   * The lifetime of this member is only valid during {@link #processElement}
-   * and is null otherwise.
-   */
-  private ByteString encodedCurrentWindow;
-
-  FnApiDoFnRunner(
-      PipelineOptions pipelineOptions,
-      BeamFnStateClient beamFnStateClient,
-      String ptransformId,
-      Supplier<String> processBundleInstructionId,
-      DoFn<InputT, OutputT> doFn,
-      Coder<InputT> inputCoder,
-      Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
-      Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
-      Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
-      WindowingStrategy windowingStrategy) {
-    this.pipelineOptions = pipelineOptions;
-    this.beamFnStateClient = beamFnStateClient;
-    this.ptransformId = ptransformId;
-    this.processBundleInstructionId = processBundleInstructionId;
-    this.doFn = doFn;
-    this.inputCoder = inputCoder;
-    this.mainOutputConsumers = mainOutputConsumers;
-    this.outputMap = outputMap;
-    this.sideInputSpecMap = sideInputSpecMap;
-    this.stateKeyObjectCache = new HashMap<>();
-    this.windowingStrategy = windowingStrategy;
-    this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
-    this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
+  FnApiDoFnRunner(Context<InputT, OutputT> context) {
+    this.context = context;
+    this.mainOutputConsumers =
+        (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+            (Collection) context.tagToConsumer.get(context.mainOutputTag);
+    this.doFnSignature = DoFnSignatures.signatureForDoFn(context.doFn);
+    this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
     this.doFnInvoker.invokeSetup();
-    this.stateBinder = new BeamFnStateBinder();
-    this.stateFinalizers = new ArrayList<>();
 
-    this.startBundleContext = doFn.new StartBundleContext() {
-      @Override
-      public PipelineOptions getPipelineOptions() {
-        return pipelineOptions;
-      }
-    };
+    this.startBundleContext =
+        this.context.doFn.new StartBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
+        };
     this.processContext = new ProcessBundleContext();
-    finishBundleContext = doFn.new FinishBundleContext() {
-      @Override
-      public PipelineOptions getPipelineOptions() {
-        return pipelineOptions;
-      }
+    finishBundleContext =
+        this.context.doFn.new FinishBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
 
-      @Override
-      public void output(OutputT output, Instant timestamp, BoundedWindow window) {
-        outputTo(mainOutputConsumers,
-            WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
-      }
+          @Override
+          public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+            outputTo(
+                mainOutputConsumers,
+                WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+          }
 
-      @Override
-      public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
-        Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
-        if (consumers == null) {
-          throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
-        }
-        outputTo(consumers,
-            WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
-      }
-    };
+          @Override
+          public <T> void output(
+              TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+            Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+                (Collection) context.tagToConsumer.get(tag);
+            if (consumers == null) {
+              throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
+            }
+            outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+          }
+        };
   }
 
   @Override
   public void startBundle() {
+    this.stateAccessor =
+        new FnApiStateAccessor(
+            context.pipelineOptions,
+            context.ptransformId,
+            context.processBundleInstructionId,
+            context.tagToSideInputSpecMap,
+            context.beamFnStateClient,
+            context.keyCoder,
+            (Coder<BoundedWindow>) context.windowCoder);
+
     doFnInvoker.invokeStartBundle(startBundleContext);
   }
 
   @Override
   public void processElement(WindowedValue<InputT> elem) {
     currentElement = elem;
+    stateAccessor.setCurrentElement(elem);
     try {
       Iterator<BoundedWindow> windowIterator =
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
       while (windowIterator.hasNext()) {
         currentWindow = windowIterator.next();
+        stateAccessor.setCurrentWindow(currentWindow);
         doFnInvoker.invokeProcessElement(processContext);
       }
     } finally {
       currentElement = null;
       currentWindow = null;
-      encodedCurrentKey = null;
-      encodedCurrentWindow = null;
+      stateAccessor.setCurrentElement(null);
+      stateAccessor.setCurrentWindow(null);
     }
   }
 
   @Override
-  public void onTimer(
-      String timerId,
-      BoundedWindow window,
-      Instant timestamp,
-      TimeDomain timeDomain) {
-    throw new UnsupportedOperationException("TODO: Add support for timers");
-  }
-
-  @Override
   public void finishBundle() {
     doFnInvoker.invokeFinishBundle(finishBundleContext);
 
-    // Persist all dirty state cells
-    try {
-      for (ThrowingRunnable runnable : stateFinalizers) {
-        runnable.run();
-      }
-    } catch (InterruptedException e) {
-      Thread.currentThread().interrupt();
-      throw new IllegalStateException(e);
-    } catch (Exception e) {
-      throw new IllegalStateException(e);
-    }
-
     // TODO: Support caching state data across bundle boundaries.
-    stateKeyObjectCache.clear();
+    this.stateAccessor.finalizeState();
+    this.stateAccessor = null;
   }
 
-  @Override
-  public DoFn<InputT, OutputT> getFn() {
-    return doFnInvoker.getFn();
-  }
-
-  /**
-   * Outputs the given element to the specified set of consumers wrapping any exceptions.
-   */
+  /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
   private <T> void outputTo(
-      Collection<FnDataReceiver<WindowedValue<T>>> consumers,
-      WindowedValue<T> output) {
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
     try {
       for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
         consumer.accept(output);
@@ -499,12 +198,11 @@
   /**
    * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}.
    */
-  private class ProcessBundleContext
-      extends DoFn<InputT, OutputT>.ProcessContext
+  private class ProcessBundleContext extends DoFn<InputT, OutputT>.ProcessContext
       implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
 
     private ProcessBundleContext() {
-      doFn.super();
+      context.doFn.super();
     }
 
     @Override
@@ -577,11 +275,11 @@
       checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
       StateSpec<?> spec;
       try {
-        spec = (StateSpec<?>) stateDeclaration.field().get(doFn);
+        spec = (StateSpec<?>) stateDeclaration.field().get(context.doFn);
       } catch (IllegalAccessException e) {
         throw new RuntimeException(e);
       }
-      return spec.bind(stateId, stateBinder);
+      return spec.bind(stateId, stateAccessor);
     }
 
     @Override
@@ -591,60 +289,51 @@
 
     @Override
     public PipelineOptions getPipelineOptions() {
-      return pipelineOptions;
+      return context.pipelineOptions;
     }
 
     @Override
     public PipelineOptions pipelineOptions() {
-      return pipelineOptions;
+      return context.pipelineOptions;
     }
 
     @Override
     public void output(OutputT output) {
-      outputTo(mainOutputConsumers,
+      outputTo(
+          mainOutputConsumers,
           WindowedValue.of(
-              output,
-              currentElement.getTimestamp(),
-              currentWindow,
-              currentElement.getPane()));
+              output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
     }
 
     @Override
     public void outputWithTimestamp(OutputT output, Instant timestamp) {
-      outputTo(mainOutputConsumers,
-          WindowedValue.of(
-              output,
-              timestamp,
-              currentWindow,
-              currentElement.getPane()));
+      outputTo(
+          mainOutputConsumers,
+          WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
     }
 
     @Override
     public <T> void output(TupleTag<T> tag, T output) {
-      Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+          (Collection) context.tagToConsumer.get(tag);
       if (consumers == null) {
         throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
       }
-      outputTo(consumers,
+      outputTo(
+          consumers,
           WindowedValue.of(
-              output,
-              currentElement.getTimestamp(),
-              currentWindow,
-              currentElement.getPane()));
+              output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
     }
 
     @Override
     public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
-      Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+          (Collection) context.tagToConsumer.get(tag);
       if (consumers == null) {
         throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
       }
-      outputTo(consumers,
-          WindowedValue.of(
-              output,
-              timestamp,
-              currentWindow,
-              currentElement.getPane()));
+      outputTo(
+          consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
     }
 
     @Override
@@ -654,7 +343,7 @@
 
     @Override
     public <T> T sideInput(PCollectionView<T> view) {
-      return (T) bindSideInputView(view.getTagInternal());
+      return stateAccessor.get(view, currentWindow);
     }
 
     @Override
@@ -672,359 +361,4 @@
       throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
     }
   }
-
-  /**
-   * A {@link StateBinder} that uses the Beam Fn State API to read and write user state.
-   *
-   * <p>TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that
-   * {@link #bindWatermark} should never be implemented.
-   */
-  private class BeamFnStateBinder implements StateBinder {
-    @Override
-    public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
-      return (ValueState<T>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          new Function<StateKey, Object>() {
-            @Override
-            public Object apply(StateKey key) {
-              return new ValueState<T>() {
-                private final BagUserState<T> impl = createBagUserState(id, coder);
-
-                @Override
-                public void clear() {
-                  impl.clear();
-                }
-
-                @Override
-                public void write(T input) {
-                  impl.clear();
-                  impl.append(input);
-                }
-
-                @Override
-                public T read() {
-                  Iterator<T> value = impl.get().iterator();
-                  if (value.hasNext()) {
-                    return value.next();
-                  } else {
-                    return null;
-                  }
-                }
-
-                @Override
-                public ValueState<T> readLater() {
-                  // TODO: Support prefetching.
-                  return this;
-                }
-              };
-            }
-          });
-    }
-
-    @Override
-    public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
-      return (BagState<T>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          new Function<StateKey, Object>() {
-            @Override
-            public Object apply(StateKey key) {
-              return new BagState<T>() {
-                private final BagUserState<T> impl = createBagUserState(id, elemCoder);
-
-                @Override
-                public void add(T value) {
-                  impl.append(value);
-                }
-
-                @Override
-                public ReadableState<Boolean> isEmpty() {
-                  return ReadableStates.immediate(!impl.get().iterator().hasNext());
-                }
-
-                @Override
-                public Iterable<T> read() {
-                  return impl.get();
-                }
-
-                @Override
-                public BagState<T> readLater() {
-                  // TODO: Support prefetching.
-                  return this;
-                }
-
-                @Override
-                public void clear() {
-                  impl.clear();
-                }
-              };
-            }
-          });
-    }
-
-    @Override
-    public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
-      throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
-    }
-
-    @Override
-    public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(String id,
-        StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder,
-        Coder<ValueT> mapValueCoder) {
-      throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
-    }
-
-    @Override
-    public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
-        String id,
-        StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec, Coder<AccumT> accumCoder,
-        CombineFn<ElementT, AccumT, ResultT> combineFn) {
-      return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          new Function<StateKey, Object>() {
-            @Override
-            public Object apply(StateKey key) {
-              // TODO: Support squashing accumulators depending on whether we know of all
-              // remote accumulators and local accumulators or just local accumulators.
-              return new CombiningState<ElementT, AccumT, ResultT>() {
-                private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
-
-                @Override
-                public AccumT getAccum() {
-                  Iterator<AccumT> iterator = impl.get().iterator();
-                  if (iterator.hasNext()) {
-                    return iterator.next();
-                  }
-                  return combineFn.createAccumulator();
-                }
-
-                @Override
-                public void addAccum(AccumT accum) {
-                  Iterator<AccumT> iterator = impl.get().iterator();
-
-                  // Only merge if there was a prior value
-                  if (iterator.hasNext()) {
-                    accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
-                    // Since there was a prior value, we need to clear.
-                    impl.clear();
-                  }
-
-                  impl.append(accum);
-                }
-
-                @Override
-                public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
-                  return combineFn.mergeAccumulators(accumulators);
-                }
-
-                @Override
-                public CombiningState<ElementT, AccumT, ResultT> readLater() {
-                  return this;
-                }
-
-                @Override
-                public ResultT read() {
-                  Iterator<AccumT> iterator = impl.get().iterator();
-                  if (iterator.hasNext()) {
-                    return combineFn.extractOutput(iterator.next());
-                  }
-                  return combineFn.defaultValue();
-                }
-
-                @Override
-                public void add(ElementT value) {
-                  AccumT newAccumulator = combineFn.addInput(getAccum(), value);
-                  impl.clear();
-                  impl.append(newAccumulator);
-                }
-
-                @Override
-                public ReadableState<Boolean> isEmpty() {
-                  return ReadableStates.immediate(!impl.get().iterator().hasNext());
-                }
-
-                @Override
-                public void clear() {
-                  impl.clear();
-                }
-              };
-            }
-          });
-    }
-
-    @Override
-    public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT>
-    bindCombiningWithContext(
-        String id,
-        StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
-        Coder<AccumT> accumCoder,
-        CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
-      return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
-          createBagUserStateKey(id),
-          key -> bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn,
-              new StateContext<BoundedWindow>() {
-                @Override
-                public PipelineOptions getPipelineOptions() {
-                  return pipelineOptions;
-                }
-
-                @Override
-                public <T> T sideInput(PCollectionView<T> view) {
-                  return (T) bindSideInputView(view.getTagInternal());
-                }
-
-                @Override
-                public BoundedWindow window() {
-                  return currentWindow;
-                }
-              })));
-    }
-
-    /**
-     * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing
-     * and is waiting on resolution of BEAM-2535.
-     */
-    @Override
-    @Deprecated
-    public WatermarkHoldState bindWatermark(String id, StateSpec<WatermarkHoldState> spec,
-        TimestampCombiner timestampCombiner) {
-      throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
-    }
-
-    private <T> BagUserState<T> createBagUserState(
-        String stateId, Coder<T> valueCoder) {
-      BagUserState<T> rval = new BagUserState<>(
-          beamFnStateClient,
-          processBundleInstructionId.get(),
-          ptransformId,
-          stateId,
-          encodedCurrentWindow,
-          encodedCurrentKey,
-          valueCoder);
-      stateFinalizers.add(rval::asyncClose);
-      return rval;
-    }
-  }
-
-  private StateKey createBagUserStateKey(String stateId) {
-    cacheEncodedKeyAndWindowForKeyedContext();
-    StateKey.Builder builder = StateKey.newBuilder();
-    builder.getBagUserStateBuilder()
-        .setWindow(encodedCurrentWindow)
-        .setKey(encodedCurrentKey)
-        .setPtransformId(ptransformId)
-        .setUserStateId(stateId);
-    return builder.build();
-  }
-
-  /**
-   * Memoizes an encoded key and window for the current element being processed saving on the
-   * encoding cost of the key and window across multiple state cells for the lifetime of
-   * {@link #processElement}.
-   *
-   * <p>This should only be called during {@link #processElement}.
-   */
-  private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
-    if (encodedCurrentKey == null) {
-      checkState(currentElement.getValue() instanceof KV,
-          "Accessing state in unkeyed context. Current element is not a KV: %s.",
-          currentElement);
-      checkState(
-          // TODO: Stop passing windowed value coders within PCollections.
-          inputCoder instanceof KvCoder
-              || (inputCoder instanceof WindowedValueCoder
-              && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder)),
-          "Accessing state in unkeyed context. Keyed coder expected but found %s.",
-          inputCoder);
-
-      ByteString.Output encodedKeyOut = ByteString.newOutput();
-
-      Coder<K> keyCoder = inputCoder instanceof WindowedValueCoder
-          ? ((KvCoder<K, ?>) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
-          : ((KvCoder<K, ?>) inputCoder).getKeyCoder();
-      try {
-        keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
-      } catch (IOException e) {
-        throw new IllegalStateException(e);
-      }
-      encodedCurrentKey = encodedKeyOut.toByteString();
-    }
-
-    if (encodedCurrentWindow == null) {
-      ByteString.Output encodedWindowOut = ByteString.newOutput();
-      try {
-        windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut);
-      } catch (IOException e) {
-        throw new IllegalStateException(e);
-      }
-      encodedCurrentWindow = encodedWindowOut.toByteString();
-    }
-  }
-
-  /**
-   * A specification for side inputs containing a value {@link Coder},
-   * the window {@link Coder}, {@link ViewFn}, and the {@link WindowMappingFn}.
-   * @param <W>
-   */
-  @AutoValue
-  abstract static class SideInputSpec<W extends BoundedWindow> {
-    static <W extends BoundedWindow> SideInputSpec create(
-        Coder<?> coder,
-        Coder<W> windowCoder,
-        ViewFn<?, ?> viewFn,
-        WindowMappingFn<W> windowMappingFn) {
-      return new AutoValue_FnApiDoFnRunner_SideInputSpec<>(
-          coder, windowCoder, viewFn, windowMappingFn);
-    }
-
-    abstract Coder<?> getCoder();
-
-    abstract Coder<W> getWindowCoder();
-
-    abstract ViewFn<?, ?> getViewFn();
-
-    abstract WindowMappingFn<W> getWindowMappingFn();
-  }
-
-  private <K, V> Object bindSideInputView(TupleTag<?> view) {
-    SideInputSpec sideInputSpec = sideInputSpecMap.get(view);
-    checkArgument(sideInputSpec != null,
-        "Attempting to access unknown side input %s.",
-        view);
-    KvCoder<K, V> kvCoder = (KvCoder) sideInputSpec.getCoder();
-
-    ByteString.Output encodedWindowOut = ByteString.newOutput();
-    try {
-      sideInputSpec.getWindowCoder().encode(
-          sideInputSpec.getWindowMappingFn().getSideInputWindow(currentWindow), encodedWindowOut);
-    } catch (IOException e) {
-      throw new IllegalStateException(e);
-    }
-    ByteString encodedWindow = encodedWindowOut.toByteString();
-
-    StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
-    cacheKeyBuilder.getMultimapSideInputBuilder()
-        .setPtransformId(ptransformId)
-        .setSideInputId(view.getId())
-        .setWindow(encodedWindow);
-    return stateKeyObjectCache.computeIfAbsent(
-        cacheKeyBuilder.build(),
-        key -> sideInputSpec.getViewFn().apply(createMultimapSideInput(
-            view.getId(), encodedWindow, kvCoder.getKeyCoder(), kvCoder.getValueCoder())));
-  }
-
-  private <K, V> MultimapSideInput<K, V> createMultimapSideInput(
-      String sideInputId,
-      ByteString encodedWindow,
-      Coder<K> keyCoder,
-      Coder<V> valueCoder) {
-
-    return new MultimapSideInput<>(
-        beamFnStateClient,
-        processBundleInstructionId.get(),
-        ptransformId,
-        sideInputId,
-        encodedWindow,
-        keyCoder,
-        valueCoder);
-  }
 }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
index 7819726..e2c61de 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
@@ -27,6 +27,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -109,7 +110,8 @@
         Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction)
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener)
         throws IOException {
 
       Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers =
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
index c130d4d..efd4a38 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
@@ -22,6 +22,7 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -37,7 +38,6 @@
  * A factory able to instantiate an appropriate handler for a given PTransform.
  */
 public interface PTransformRunnerFactory<T> {
-
   /**
    * Creates and returns a handler for a given PTransform. Note that the handler must support
    * processing multiple bundles. The handler will be discarded if an error is thrown during element
@@ -60,6 +60,7 @@
    *     registered within this multimap.
    * @param addStartFunction A consumer to register a start bundle handler with.
    * @param addFinishFunction A consumer to register a finish bundle handler with.
+   * @param splitListener A listener to be invoked when the PTransform splits itself.
    */
   T createRunnerForPTransform(
       PipelineOptions pipelineOptions,
@@ -73,7 +74,8 @@
       Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
       Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
       Consumer<ThrowingRunnable> addStartFunction,
-      Consumer<ThrowingRunnable> addFinishFunction)
+      Consumer<ThrowingRunnable> addFinishFunction,
+      BundleSplitListener splitListener)
       throws IOException;
 
   /**
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
new file mode 100644
index 0000000..c8cac40
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+import org.apache.beam.runners.core.OutputAndTimeBoundedSplittableProcessElementInvoker;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.SplittableProcessElementInvoker;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/** Runs the {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} transform. */
+public class SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>
+    implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<KV<InputT, RestrictionT>> {
+  /** A registrar which provides a factory to handle Java {@link DoFn}s. */
+  @AutoService(PTransformRunnerFactory.Registrar.class)
+  public static class Registrar implements PTransformRunnerFactory.Registrar {
+    @Override
+    public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+      return ImmutableMap.of(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN, new Factory());
+    }
+  }
+
+  static class Factory<InputT, RestrictionT, OutputT>
+      extends DoFnPTransformRunnerFactory<
+          KV<InputT, RestrictionT>,
+          InputT,
+          OutputT,
+          SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>> {
+
+    @Override
+    SplittableProcessElementsRunner<InputT, RestrictionT, OutputT> createRunner(
+        Context<InputT, OutputT> context) {
+      Coder<WindowedValue<KV<InputT, RestrictionT>>> windowedCoder =
+          FullWindowedValueCoder.of(
+              (Coder<KV<InputT, RestrictionT>>) context.inputCoder, context.windowCoder);
+
+      return new SplittableProcessElementsRunner<>(
+          context,
+          windowedCoder,
+          (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+              (Collection) context.tagToConsumer.get(context.mainOutputTag),
+          Iterables.getOnlyElement(context.pTransform.getInputsMap().keySet()));
+    }
+  }
+
+  //////////////////////////////////////////////////////////////////////////////////////////////////
+
+  private final Context<InputT, OutputT> context;
+  private final String mainInputId;
+  private final Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder;
+  private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
+  private final DoFnInvoker<InputT, OutputT> doFnInvoker;
+  private final ScheduledExecutorService executor;
+
+  private FnApiStateAccessor stateAccessor;
+
+  private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+  private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+  SplittableProcessElementsRunner(
+      Context<InputT, OutputT> context,
+      Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder,
+      Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
+      String mainInputId) {
+    this.context = context;
+    this.mainInputId = mainInputId;
+    this.inputCoder = inputCoder;
+    this.mainOutputConsumers = mainOutputConsumers;
+    this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
+    this.doFnInvoker.invokeSetup();
+    this.executor = Executors.newSingleThreadScheduledExecutor();
+
+    this.startBundleContext =
+        context.doFn.new StartBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
+        };
+    this.finishBundleContext =
+        context.doFn.new FinishBundleContext() {
+          @Override
+          public PipelineOptions getPipelineOptions() {
+            return context.pipelineOptions;
+          }
+
+          @Override
+          public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public <T> void output(
+              TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+            throw new UnsupportedOperationException();
+          }
+        };
+  }
+
+  @Override
+  public void startBundle() {
+    this.stateAccessor =
+        new FnApiStateAccessor(
+            context.pipelineOptions,
+            context.ptransformId,
+            context.processBundleInstructionId,
+            context.tagToSideInputSpecMap,
+            context.beamFnStateClient,
+            context.keyCoder,
+            (Coder<BoundedWindow>) context.windowCoder);
+    doFnInvoker.invokeStartBundle(startBundleContext);
+  }
+
+  @Override
+  public void processElement(WindowedValue<KV<InputT, RestrictionT>> elem) {
+    processElementTyped(elem);
+  }
+
+  private <PositionT, TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
+      void processElementTyped(WindowedValue<KV<InputT, RestrictionT>> elem) {
+    checkArgument(
+        elem.getWindows().size() == 1,
+        "SPLITTABLE_PROCESS_ELEMENTS expects its input to be in 1 window, but got %s windows",
+        elem.getWindows().size());
+    this.stateAccessor.setCurrentWindow(elem.getWindows().iterator().next());
+    WindowedValue<InputT> element = elem.withValue(elem.getValue().getKey());
+    TrackerT tracker = doFnInvoker.invokeNewTracker(elem.getValue().getValue());
+    OutputAndTimeBoundedSplittableProcessElementInvoker<
+            InputT, OutputT, RestrictionT, PositionT, TrackerT>
+        processElementInvoker =
+            new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
+                context.doFn,
+                context.pipelineOptions,
+                new OutputWindowedValue<OutputT>() {
+                  @Override
+                  public void outputWindowedValue(
+                      OutputT output,
+                      Instant timestamp,
+                      Collection<? extends BoundedWindow> windows,
+                      PaneInfo pane) {
+                    outputTo(
+                        mainOutputConsumers, WindowedValue.of(output, timestamp, windows, pane));
+                  }
+
+                  @Override
+                  public <AdditionalOutputT> void outputWindowedValue(
+                      TupleTag<AdditionalOutputT> tag,
+                      AdditionalOutputT output,
+                      Instant timestamp,
+                      Collection<? extends BoundedWindow> windows,
+                      PaneInfo pane) {
+                    Collection<FnDataReceiver<WindowedValue<AdditionalOutputT>>> consumers =
+                        (Collection) context.tagToConsumer.get(tag);
+                    if (consumers == null) {
+                      throw new IllegalArgumentException(
+                          String.format("Unknown output tag %s", tag));
+                    }
+                    outputTo(consumers, WindowedValue.of(output, timestamp, windows, pane));
+                  }
+                },
+                stateAccessor,
+                executor,
+                10000,
+                Duration.standardSeconds(10));
+
+    SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, TrackerT>.Result result =
+        processElementInvoker.invokeProcessElement(doFnInvoker, element, tracker);
+    if (result.getContinuation().shouldResume()) {
+      WindowedValue<KV<InputT, RestrictionT>> primary =
+          element.withValue(KV.of(element.getValue(), tracker.currentRestriction()));
+      WindowedValue<KV<InputT, RestrictionT>> residual =
+          element.withValue(KV.of(element.getValue(), result.getResidualRestriction()));
+      ByteString.Output primaryBytes = ByteString.newOutput();
+      ByteString.Output residualBytes = ByteString.newOutput();
+      try {
+        inputCoder.encode(primary, primaryBytes);
+        inputCoder.encode(residual, residualBytes);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      Application primaryApplication =
+          Application.newBuilder()
+              .setPtransformId(context.ptransformId)
+              .setInputId(mainInputId)
+              .setElement(primaryBytes.toByteString())
+              .build();
+      Application residualApplication =
+          Application.newBuilder()
+              .setPtransformId(context.ptransformId)
+              .setInputId(mainInputId)
+              .setElement(residualBytes.toByteString())
+              .build();
+      context.splitListener.split(
+          ImmutableList.of(primaryApplication),
+          ImmutableList.of(
+              DelayedApplication.newBuilder()
+                  .setApplication(residualApplication)
+                  .setDelaySec(0.001 * result.getContinuation().resumeDelay().getMillis())
+                  .build()));
+    }
+  }
+
+  @Override
+  public void finishBundle() {
+    doFnInvoker.invokeFinishBundle(finishBundleContext);
+  }
+
+  /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
+  private <T> void outputTo(
+      Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
+    try {
+      for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
+        consumer.accept(output);
+      }
+    } catch (Throwable t) {
+      throw UserCodeException.wrap(t);
+    }
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
new file mode 100644
index 0000000..db73447
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.control;
+
+import java.util.List;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+
+/**
+ * Listens to splits happening to a single bundle. See <a
+ * href="https://s.apache.org/beam-breaking-fusion">Breaking the Fusion Barrier</a> for a
+ * discussion of the design.
+ */
+public interface BundleSplitListener {
+  /**
+   * Signals that the current bundle should be split into the given set of primary and residual
+   * roots.
+   *
+   * <p>Primary roots are the new decomposition of the bundle's work into transform applications
+   * that have happened or will happen as part of this bundle (modulo future splits). Residual roots
+   * are a decomposition of work that has been given away by the bundle, so the runner must delegate
+   * it for someone else to execute.
+   */
+  void split(List<Application> primaryRoots, List<DelayedApplication> residualRoots);
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 64d713e..ee907a8 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -46,8 +46,12 @@
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
@@ -144,7 +148,8 @@
       SetMultimap<String, String> pCollectionIdsToConsumingPTransforms,
       ListMultimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
       Consumer<ThrowingRunnable> addStartFunction,
-      Consumer<ThrowingRunnable> addFinishFunction)
+      Consumer<ThrowingRunnable> addFinishFunction,
+      BundleSplitListener splitListener)
       throws IOException {
 
     // Recursively ensure that all consumers of the output PCollection have been created.
@@ -166,7 +171,8 @@
             pCollectionIdsToConsumingPTransforms,
             pCollectionIdsToConsumers,
             addStartFunction,
-            addFinishFunction);
+            addFinishFunction,
+            splitListener);
       }
     }
 
@@ -196,15 +202,12 @@
             processBundleDescriptor.getWindowingStrategiesMap(),
             pCollectionIdsToConsumers,
             addStartFunction,
-            addFinishFunction);
+            addFinishFunction,
+            splitListener);
   }
 
   public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request)
       throws Exception {
-    BeamFnApi.InstructionResponse.Builder response =
-        BeamFnApi.InstructionResponse.newBuilder()
-            .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
-
     String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference();
     BeamFnApi.ProcessBundleDescriptor bundleDescriptor =
         (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId);
@@ -223,6 +226,8 @@
       }
     }
 
+    ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder();
+
     // Instantiate a State API call handler depending on whether a State Api service descriptor
     // was specified.
     try (HandleStateCallsForBundle beamFnStateClient =
@@ -230,6 +235,23 @@
         ? new BlockTillStateCallsFinish(beamFnStateGrpcClientCache.forApiServiceDescriptor(
             bundleDescriptor.getStateApiServiceDescriptor()))
         : new FailAllStateCallsForBundle(request.getProcessBundle())) {
+      Multimap<String, Application> allPrimaries = ArrayListMultimap.create();
+      Multimap<String, DelayedApplication> allResiduals = ArrayListMultimap.create();
+      BundleSplitListener splitListener =
+          (List<Application> primaries, List<DelayedApplication> residuals) -> {
+            // Reset primaries and accumulate residuals.
+            Multimap<String, Application> newPrimaries = ArrayListMultimap.create();
+            for (Application primary : primaries) {
+              newPrimaries.put(primary.getPtransformId(), primary);
+            }
+            allPrimaries.clear();
+            allPrimaries.putAll(newPrimaries);
+
+            for (DelayedApplication residual : residuals) {
+              allResiduals.put(residual.getApplication().getPtransformId(), residual);
+            }
+          };
+
       // Create a BeamFnStateClient
       for (Map.Entry<String, RunnerApi.PTransform> entry
           : bundleDescriptor.getTransformsMap().entrySet()) {
@@ -251,7 +273,8 @@
             pCollectionIdsToConsumingPTransforms,
             pCollectionIdsToConsumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            splitListener);
       }
 
       // Already in reverse topological order so we don't need to do anything.
@@ -265,9 +288,16 @@
         LOG.debug("Finishing function {}", finishFunction);
         finishFunction.run();
       }
+      if (!allPrimaries.isEmpty()) {
+        response.setSplit(
+            BundleSplit.newBuilder()
+                .addAllPrimaryRoots(allPrimaries.values())
+                .addAllResidualRoots(allResiduals.values())
+                .build());
+      }
     }
 
-    return response;
+    return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response);
   }
 
   /**
@@ -355,7 +385,8 @@
         Map<String, WindowingStrategy> windowingStrategies,
         Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
         Consumer<ThrowingRunnable> addStartFunction,
-        Consumer<ThrowingRunnable> addFinishFunction) {
+        Consumer<ThrowingRunnable> addFinishFunction,
+        BundleSplitListener splitListener) {
       String message =
           String.format(
               "No factory registered for %s, known factories %s",
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
new file mode 100644
index 0000000..e7a043b
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.MapState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import org.apache.beam.sdk.state.SetState;
+import org.apache.beam.sdk.state.StateBinder;
+import org.apache.beam.sdk.state.StateContext;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CombineFnUtil;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** Provides access to side inputs and state via a {@link BeamFnStateClient}. */
+public class FnApiStateAccessor implements SideInputReader, StateBinder {
+  private final PipelineOptions pipelineOptions;
+  private final Map<StateKey, Object> stateKeyObjectCache;
+  private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
+  private final BeamFnStateClient beamFnStateClient;
+  private final String ptransformId;
+  private final Supplier<String> processBundleInstructionId;
+  private final Collection<ThrowingRunnable> stateFinalizers;
+
+  private final Coder<?> keyCoder;
+  private final Coder<BoundedWindow> windowCoder;
+
+  private WindowedValue<?> currentElement;
+  private BoundedWindow currentWindow;
+  private ByteString encodedCurrentKey;
+  private ByteString encodedCurrentWindow;
+
+  public FnApiStateAccessor(
+      PipelineOptions pipelineOptions,
+      String ptransformId,
+      Supplier<String> processBundleInstructionId,
+      Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
+      BeamFnStateClient beamFnStateClient,
+      Coder<?> keyCoder,
+      Coder<BoundedWindow> windowCoder) {
+    this.pipelineOptions = pipelineOptions;
+    this.stateKeyObjectCache = Maps.newHashMap();
+    this.sideInputSpecMap = sideInputSpecMap;
+    this.beamFnStateClient = beamFnStateClient;
+    this.ptransformId = ptransformId;
+    this.processBundleInstructionId = processBundleInstructionId;
+    this.stateFinalizers = new ArrayList<>();
+
+    this.keyCoder = keyCoder;
+    this.windowCoder = windowCoder;
+  }
+
+  public void setCurrentElement(WindowedValue<?> currentElement) {
+    this.currentElement = currentElement;
+    this.encodedCurrentKey = null;
+  }
+
+  public void setCurrentWindow(BoundedWindow currentWindow) {
+    this.currentWindow = currentWindow;
+    this.encodedCurrentWindow = null;
+  }
+
+  @Override
+  @Nullable
+  public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+    TupleTag<?> tag = view.getTagInternal();
+
+    SideInputSpec sideInputSpec = sideInputSpecMap.get(tag);
+    checkArgument(sideInputSpec != null, "Attempting to access unknown side input %s.", view);
+    KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
+
+    ByteString.Output encodedWindowOut = ByteString.newOutput();
+    try {
+      sideInputSpec
+          .getWindowCoder()
+          .encode(sideInputSpec.getWindowMappingFn().getSideInputWindow(window), encodedWindowOut);
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
+    }
+    ByteString encodedWindow = encodedWindowOut.toByteString();
+
+    StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
+    cacheKeyBuilder
+        .getMultimapSideInputBuilder()
+        .setPtransformId(ptransformId)
+        .setSideInputId(tag.getId())
+        .setWindow(encodedWindow);
+    return (T)
+        stateKeyObjectCache.computeIfAbsent(
+            cacheKeyBuilder.build(),
+            key ->
+                sideInputSpec
+                    .getViewFn()
+                    .apply(
+                        new MultimapSideInput<>(
+                            beamFnStateClient,
+                            processBundleInstructionId.get(),
+                            ptransformId,
+                            tag.getId(),
+                            encodedWindow,
+                            kvCoder.getKeyCoder(),
+                            kvCoder.getValueCoder())));
+  }
+
+  @Override
+  public <T> boolean contains(PCollectionView<T> view) {
+    return sideInputSpecMap.containsKey(view.getTagInternal());
+  }
+
+  @Override
+  public boolean isEmpty() {
+    return sideInputSpecMap.isEmpty();
+  }
+
+  @Override
+  public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
+    return (ValueState<T>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                return new ValueState<T>() {
+                  private final BagUserState<T> impl = createBagUserState(id, coder);
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+
+                  @Override
+                  public void write(T input) {
+                    impl.clear();
+                    impl.append(input);
+                  }
+
+                  @Override
+                  public T read() {
+                    Iterator<T> value = impl.get().iterator();
+                    if (value.hasNext()) {
+                      return value.next();
+                    } else {
+                      return null;
+                    }
+                  }
+
+                  @Override
+                  public ValueState<T> readLater() {
+                    // TODO: Support prefetching.
+                    return this;
+                  }
+                };
+              }
+            });
+  }
+
+  @Override
+  public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
+    return (BagState<T>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                return new BagState<T>() {
+                  private final BagUserState<T> impl = createBagUserState(id, elemCoder);
+
+                  @Override
+                  public void add(T value) {
+                    impl.append(value);
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return ReadableStates.immediate(!impl.get().iterator().hasNext());
+                  }
+
+                  @Override
+                  public Iterable<T> read() {
+                    return impl.get();
+                  }
+
+                  @Override
+                  public BagState<T> readLater() {
+                    // TODO: Support prefetching.
+                    return this;
+                  }
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+                };
+              }
+            });
+  }
+
+  @Override
+  public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
+    throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+  }
+
+  @Override
+  public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+      String id,
+      StateSpec<MapState<KeyT, ValueT>> spec,
+      Coder<KeyT> mapKeyCoder,
+      Coder<ValueT> mapValueCoder) {
+    throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+  }
+
+  @Override
+  public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
+      String id,
+      StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+      Coder<AccumT> accumCoder,
+      CombineFn<ElementT, AccumT, ResultT> combineFn) {
+    return (CombiningState<ElementT, AccumT, ResultT>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                // TODO: Support squashing accumulators depending on whether we know of all
+                // remote accumulators and local accumulators or just local accumulators.
+                return new CombiningState<ElementT, AccumT, ResultT>() {
+                  private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
+
+                  @Override
+                  public AccumT getAccum() {
+                    Iterator<AccumT> iterator = impl.get().iterator();
+                    if (iterator.hasNext()) {
+                      return iterator.next();
+                    }
+                    return combineFn.createAccumulator();
+                  }
+
+                  @Override
+                  public void addAccum(AccumT accum) {
+                    Iterator<AccumT> iterator = impl.get().iterator();
+
+                    // Only merge if there was a prior value
+                    if (iterator.hasNext()) {
+                      accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+                      // Since there was a prior value, we need to clear.
+                      impl.clear();
+                    }
+
+                    impl.append(accum);
+                  }
+
+                  @Override
+                  public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+                    return combineFn.mergeAccumulators(accumulators);
+                  }
+
+                  @Override
+                  public CombiningState<ElementT, AccumT, ResultT> readLater() {
+                    return this;
+                  }
+
+                  @Override
+                  public ResultT read() {
+                    Iterator<AccumT> iterator = impl.get().iterator();
+                    if (iterator.hasNext()) {
+                      return combineFn.extractOutput(iterator.next());
+                    }
+                    return combineFn.defaultValue();
+                  }
+
+                  @Override
+                  public void add(ElementT value) {
+                    AccumT newAccumulator = combineFn.addInput(getAccum(), value);
+                    impl.clear();
+                    impl.append(newAccumulator);
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return ReadableStates.immediate(!impl.get().iterator().hasNext());
+                  }
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+                };
+              }
+            });
+  }
+
+  @Override
+  public <ElementT, AccumT, ResultT>
+      CombiningState<ElementT, AccumT, ResultT> bindCombiningWithContext(
+          String id,
+          StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+          Coder<AccumT> accumCoder,
+          CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
+    return (CombiningState<ElementT, AccumT, ResultT>)
+        stateKeyObjectCache.computeIfAbsent(
+            createBagUserStateKey(id),
+            key ->
+                bindCombining(
+                    id,
+                    spec,
+                    accumCoder,
+                    CombineFnUtil.bindContext(
+                        combineFn,
+                        new StateContext<BoundedWindow>() {
+                          @Override
+                          public PipelineOptions getPipelineOptions() {
+                            return pipelineOptions;
+                          }
+
+                          @Override
+                          public <T> T sideInput(PCollectionView<T> view) {
+                            return get(view, currentWindow);
+                          }
+
+                          @Override
+                          public BoundedWindow window() {
+                            return currentWindow;
+                          }
+                        })));
+  }
+
+  /**
+   * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing and is
+   *     waiting on resolution of BEAM-2535.
+   */
+  @Override
+  @Deprecated
+  public WatermarkHoldState bindWatermark(
+      String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
+    throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
+  }
+
+  private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> valueCoder) {
+    BagUserState<T> rval =
+        new BagUserState<>(
+            beamFnStateClient,
+            processBundleInstructionId.get(),
+            ptransformId,
+            stateId,
+            encodedCurrentWindow,
+            encodedCurrentKey,
+            valueCoder);
+    stateFinalizers.add(rval::asyncClose);
+    return rval;
+  }
+
+  private StateKey createBagUserStateKey(String stateId) {
+    cacheEncodedKeyAndWindowForKeyedContext();
+    StateKey.Builder builder = StateKey.newBuilder();
+    builder
+        .getBagUserStateBuilder()
+        .setWindow(encodedCurrentWindow)
+        .setKey(encodedCurrentKey)
+        .setPtransformId(ptransformId)
+        .setUserStateId(stateId);
+    return builder.build();
+  }
+
+  /**
+   * Memoizes an encoded key and window for the current element being processed saving on the
+   * encoding cost of the key and window across multiple state cells for the lifetime of {@link
+   * DoFn.ProcessElement}
+   *
+   * <p>This should only be called during {@link DoFn.ProcessElement}.
+   */
+  private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
+    if (encodedCurrentKey == null) {
+      checkState(
+          currentElement.getValue() instanceof KV,
+          "Accessing state in unkeyed context. Current element is not a KV: %s.",
+          currentElement);
+      checkState(keyCoder != null, "Accessing state in unkeyed context, no key coder available");
+
+      ByteString.Output encodedKeyOut = ByteString.newOutput();
+      try {
+        ((Coder) keyCoder).encode(((KV<?, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
+      } catch (IOException e) {
+        throw new IllegalStateException(e);
+      }
+      encodedCurrentKey = encodedKeyOut.toByteString();
+    }
+
+    if (encodedCurrentWindow == null) {
+      ByteString.Output encodedWindowOut = ByteString.newOutput();
+      try {
+        windowCoder.encode(currentWindow, encodedWindowOut);
+      } catch (IOException e) {
+        throw new IllegalStateException(e);
+      }
+      encodedCurrentWindow = encodedWindowOut.toByteString();
+    }
+  }
+
+  public void finalizeState() {
+    // Persist all dirty state cells
+    try {
+      for (ThrowingRunnable runnable : stateFinalizers) {
+        runnable.run();
+      }
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new IllegalStateException(e);
+    } catch (Exception e) {
+      throw new IllegalStateException(e);
+    }
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
new file mode 100644
index 0000000..8aec3a1
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
+
+/**
+ * A specification for side inputs containing a value {@link Coder}, the window {@link Coder},
+ * {@link ViewFn}, and the {@link WindowMappingFn}.
+ *
+ * @param <W>
+ */
+@AutoValue
+public abstract class SideInputSpec<W extends BoundedWindow> {
+  public static <W extends BoundedWindow> SideInputSpec create(
+      Coder<?> coder,
+      Coder<W> windowCoder,
+      ViewFn<?, ?> viewFn,
+      WindowMappingFn<W> windowMappingFn) {
+    return new AutoValue_SideInputSpec<>(
+        coder, windowCoder, viewFn, windowMappingFn);
+  }
+
+  abstract Coder<?> getCoder();
+
+  abstract Coder<W> getWindowCoder();
+
+  abstract ViewFn<?, ?> getViewFn();
+
+  abstract WindowMappingFn<W> getWindowMappingFn();
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
index d72753b..9871c37 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
@@ -201,7 +201,8 @@
             null /* windowingStrategies */,
             receivers,
             null /* addStartFunction */,
-            null /* addFinishFunction */);
+            null, /* addFinishFunction */
+            null /* splitListener */);
 
     WindowedValue<Integer> value =
         WindowedValue.of(
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index f138677..4b4c64f 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -152,7 +152,8 @@
         COMPONENTS.getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     verifyZeroInteractions(mockBeamFnDataClient);
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
index 12540c9..a70aedd 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
@@ -141,7 +141,8 @@
         COMPONENTS.getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     verifyZeroInteractions(mockBeamFnDataClient);
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
index 044de10..e438bd5 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
@@ -151,7 +151,8 @@
         Collections.emptyMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     // This is testing a deprecated way of running sources and should be removed
     // once all source definitions are instead propagated along the input edge.
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
index 9b1bd75..3b58517 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
@@ -87,7 +87,8 @@
         Collections.emptyMap(),
         consumers,
         null /* addStartFunction */,
-        null /* addFinishFunction */);
+        null, /* addFinishFunction */
+        null /* splitListener */);
 
     mainOutputValues.clear();
     assertThat(consumers.keySet(), containsInAnyOrder(
@@ -149,7 +150,8 @@
             Collections.emptyMap(),
             consumers,
             null /* addStartFunction */,
-            null /* addFinishFunction */);
+            null, /* addFinishFunction */
+            null /* splitListener */);
 
     mainOutputValues.clear();
     assertThat(consumers.keySet(), containsInAnyOrder("inputATarget", "mainOutputTarget"));
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 85aa564..ee04cf8 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -18,8 +18,6 @@
 
 package org.apache.beam.fn.harness;
 
-import static com.google.common.base.Preconditions.checkState;
-import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -27,29 +25,23 @@
 import static org.hamcrest.Matchers.hasSize;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
 
 import com.google.common.base.Suppliers;
 import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Multimap;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
+import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
-import java.util.ServiceLoader;
-import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
 import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
 import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -60,8 +52,6 @@
 import org.apache.beam.sdk.state.StateSpecs;
 import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
-import org.apache.beam.sdk.transforms.CombineWithContext.Context;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
@@ -73,15 +63,13 @@
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.CoderUtils;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.hamcrest.collection.IsMapContaining;
+import org.apache.beam.sdk.values.TupleTagList;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Test;
@@ -90,137 +78,10 @@
 
 /** Tests for {@link FnApiDoFnRunner}. */
 @RunWith(JUnit4.class)
-public class FnApiDoFnRunnerTest {
+public class FnApiDoFnRunnerTest implements Serializable {
 
   public static final String TEST_PTRANSFORM_ID = "pTransformId";
 
-  private static class TestDoFn extends DoFn<String, String> {
-    private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
-    private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
-
-    private enum State {
-      NOT_SET_UP,
-      OUTSIDE_BUNDLE,
-      INSIDE_BUNDLE,
-    }
-
-    private State state = State.NOT_SET_UP;
-
-    private BoundedWindow window;
-
-    @Setup
-    public void setUp() {
-      checkState(State.NOT_SET_UP.equals(state), "Unexpected state: %s", state);
-      state = State.OUTSIDE_BUNDLE;
-    }
-
-    // No testing for TearDown - it's currently not supported by FnHarness.
-
-    @StartBundle
-    public void startBundle() {
-      checkState(State.OUTSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
-      state = State.INSIDE_BUNDLE;
-    }
-
-    @ProcessElement
-    public void processElement(ProcessContext context, BoundedWindow window) {
-      checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
-      context.output("MainOutput" + context.element());
-      context.output(additionalOutput, "AdditionalOutput" + context.element());
-      this.window = window;
-    }
-
-    @FinishBundle
-    public void finishBundle(FinishBundleContext context) {
-      checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
-      state = State.OUTSIDE_BUNDLE;
-      if (window != null) {
-        context.output("FinishBundle", window.maxTimestamp(), window);
-        window = null;
-      }
-    }
-  }
-
-  /**
-   * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
-   * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
-   * are directed to the correct consumers.
-   */
-  @Test
-  public void testCreatingAndProcessingDoFn() throws Exception {
-    String pTransformId = "pTransformId";
-
-    DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
-        new TestDoFn(),
-        WindowingStrategy.globalDefault(),
-        ImmutableList.of(),
-        StringUtf8Coder.of(),
-        TestDoFn.mainOutput);
-    RunnerApi.FunctionSpec functionSpec =
-        RunnerApi.FunctionSpec.newBuilder()
-            .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
-            .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
-            .build();
-    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
-        .setSpec(functionSpec)
-        .putInputs("inputA", "inputATarget")
-        .putInputs("inputB", "inputBTarget")
-        .putOutputs(TestDoFn.mainOutput.getId(), "mainOutputTarget")
-        .putOutputs(TestDoFn.additionalOutput.getId(), "additionalOutputTarget")
-        .build();
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
-    Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
-    consumers.put("mainOutputTarget",
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    consumers.put("additionalOutputTarget",
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
-    List<ThrowingRunnable> startFunctions = new ArrayList<>();
-    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
-        PipelineOptionsFactory.create(),
-        null /* beamFnDataClient */,
-        null /* beamFnStateClient */,
-        pTransformId,
-        pTransform,
-        Suppliers.ofInstance("57L")::get,
-        Collections.emptyMap(),
-        Collections.emptyMap(),
-        Collections.emptyMap(),
-        consumers,
-        startFunctions::add,
-        finishFunctions::add);
-
-    Iterables.getOnlyElement(startFunctions).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(
-        "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
-
-    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
-    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
-    Iterables.getOnlyElement(consumers.get("inputBTarget")).accept(valueInGlobalWindow("B"));
-    assertThat(mainOutputValues, contains(
-        valueInGlobalWindow("MainOutputA1"),
-        valueInGlobalWindow("MainOutputA2"),
-        valueInGlobalWindow("MainOutputB")));
-    assertThat(additionalOutputValues, contains(
-        valueInGlobalWindow("AdditionalOutputA1"),
-        valueInGlobalWindow("AdditionalOutputA2"),
-        valueInGlobalWindow("AdditionalOutputB")));
-    mainOutputValues.clear();
-    additionalOutputValues.clear();
-
-    Iterables.getOnlyElement(finishFunctions).run();
-    assertThat(
-        mainOutputValues,
-        contains(
-            timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
-    mainOutputValues.clear();
-  }
-
   private static class ConcatCombineFn extends CombineFn<String, String, String> {
     @Override
     public String createAccumulator() {
@@ -247,33 +108,6 @@
     }
   }
 
-  private static class ConcatCombineFnWithContext
-      extends CombineFnWithContext<String, String, String> {
-    @Override
-    public String createAccumulator(Context c) {
-      return "";
-    }
-
-    @Override
-    public String addInput(String accumulator, String input, Context c) {
-      return accumulator.concat(input);
-    }
-
-    @Override
-    public String mergeAccumulators(Iterable<String> accumulators, Context c) {
-      StringBuilder builder = new StringBuilder();
-      for (String value : accumulators) {
-        builder.append(value);
-      }
-      return builder.toString();
-    }
-
-    @Override
-    public String extractOutput(String accumulator, Context c) {
-      return accumulator;
-    }
-  }
-
   private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
     private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
     private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
@@ -287,17 +121,12 @@
     @StateId("combine")
     private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
         StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
-    @StateId("combineWithContext")
-    private final StateSpec<CombiningState<String, String, String>> combiningWithContextStateSpec =
-        StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext());
 
     @ProcessElement
     public void processElement(ProcessContext context,
         @StateId("value") ValueState<String> valueState,
         @StateId("bag") BagState<String> bagState,
-        @StateId("combine") CombiningState<String, String, String> combiningState,
-        @StateId("combineWithContext")
-            CombiningState<String, String, String> combiningWithContextState) {
+        @StateId("combine") CombiningState<String, String, String> combiningState) {
       context.output("value:" + valueState.read());
       valueState.write(context.element().getValue());
 
@@ -306,41 +135,34 @@
 
       context.output("combine:" + combiningState.read());
       combiningState.add(context.element().getValue());
-
-      context.output("combineWithContext:" + combiningWithContextState.read());
-      combiningWithContextState.add(context.element().getValue());
     }
   }
 
   @Test
   public void testUsingUserState() throws Exception {
-    DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
-        new TestStatefulDoFn(),
-        WindowingStrategy.globalDefault(),
-        ImmutableList.of(),
-        KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
-        new TupleTag<>("mainOutput"));
-    RunnerApi.FunctionSpec functionSpec =
-        RunnerApi.FunctionSpec.newBuilder()
-            .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
-            .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
-            .build();
-    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
-        .setSpec(functionSpec)
-        .putInputs("input", "inputTarget")
-        .putOutputs("mainOutput", "mainOutputTarget")
-        .build();
+    Pipeline p = Pipeline.create();
+    PCollection<KV<String, String>> valuePCollection =
+        p.apply(Create.of(KV.of("unused", "unused")));
+    PCollection<String> outputPCollection =
+        valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(new TestStatefulDoFn()));
+
+    SdkComponents sdkComponents = SdkComponents.create();
+    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+    String outputPCollectionId =
+        sdkComponents.registerPCollection(outputPCollection);
+    RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
+        pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
 
     FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of(
         bagUserStateKey("value", "X"), encode("X0"),
         bagUserStateKey("bag", "X"), encode("X0"),
-        bagUserStateKey("combine", "X"), encode("X0"),
-        bagUserStateKey("combineWithContext", "X"), encode("X0")
+        bagUserStateKey("combine", "X"), encode("X0")
     ));
 
     List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
     Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
-    consumers.put("mainOutputTarget",
+    consumers.put(outputPCollectionId,
         (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
     List<ThrowingRunnable> startFunctions = new ArrayList<>();
     List<ThrowingRunnable> finishFunctions = new ArrayList<>();
@@ -352,22 +174,23 @@
         TEST_PTRANSFORM_ID,
         pTransform,
         Suppliers.ofInstance("57L")::get,
-        Collections.emptyMap(),
-        Collections.emptyMap(),
-        Collections.emptyMap(),
+        pProto.getComponents().getPcollectionsMap(),
+        pProto.getComponents().getCodersMap(),
+        pProto.getComponents().getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     Iterables.getOnlyElement(startFunctions).run();
     mainOutputValues.clear();
 
-    assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget"));
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
 
     // Ensure that bag user state that is initially empty or populated works.
     // Ensure that the key order does not matter when we traverse over KV pairs.
     FnDataReceiver<WindowedValue<?>> mainInput =
-        Iterables.getOnlyElement(consumers.get("inputTarget"));
+        Iterables.getOnlyElement(consumers.get(inputPCollectionId));
     mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
     mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
     mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
@@ -376,19 +199,15 @@
         valueInGlobalWindow("value:X0"),
         valueInGlobalWindow("bag:[X0]"),
         valueInGlobalWindow("combine:X0"),
-        valueInGlobalWindow("combineWithContext:X0"),
         valueInGlobalWindow("value:null"),
         valueInGlobalWindow("bag:[]"),
         valueInGlobalWindow("combine:"),
-        valueInGlobalWindow("combineWithContext:"),
         valueInGlobalWindow("value:X1"),
         valueInGlobalWindow("bag:[X0, X1]"),
         valueInGlobalWindow("combine:X0X1"),
-        valueInGlobalWindow("combineWithContext:X0X1"),
         valueInGlobalWindow("value:Y1"),
         valueInGlobalWindow("bag:[Y1]"),
-        valueInGlobalWindow("combine:Y1"),
-        valueInGlobalWindow("combineWithContext:Y1")));
+        valueInGlobalWindow("combine:Y1")));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctions).run();
@@ -399,11 +218,9 @@
             .put(bagUserStateKey("value", "X"), encode("X2"))
             .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
             .put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
-            .put(bagUserStateKey("combineWithContext", "X"), encode("X0X1X2"))
             .put(bagUserStateKey("value", "Y"), encode("Y2"))
             .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
             .put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
-            .put(bagUserStateKey("combineWithContext", "Y"), encode("Y1Y2"))
             .build(),
         fakeClient.getData());
     mainOutputValues.clear();
@@ -425,13 +242,17 @@
     private final PCollectionView<String> defaultSingletonSideInput;
     private final PCollectionView<String> singletonSideInput;
     private final PCollectionView<Iterable<String>> iterableSideInput;
+    private final TupleTag<String> additionalOutput;
+
     private TestSideInputDoFn(
         PCollectionView<String> defaultSingletonSideInput,
         PCollectionView<String> singletonSideInput,
-        PCollectionView<Iterable<String>> iterableSideInput) {
+        PCollectionView<Iterable<String>> iterableSideInput,
+        TupleTag<String> additionalOutput) {
       this.defaultSingletonSideInput = defaultSingletonSideInput;
       this.singletonSideInput = singletonSideInput;
       this.iterableSideInput = iterableSideInput;
+      this.additionalOutput = additionalOutput;
     }
 
     @ProcessElement
@@ -441,6 +262,7 @@
       for (String sideInputValue : context.sideInput(iterableSideInput)) {
         context.output(context.element() + ":" + sideInputValue);
       }
+      context.output(additionalOutput, context.element() + ":additional");
     }
   }
 
@@ -453,21 +275,31 @@
     PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
     PCollectionView<Iterable<String>> iterableSideInputView =
         valuePCollection.apply(View.asIterable());
-    PCollection<String> outputPCollection = valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
-        new TestSideInputDoFn(
-            defaultSingletonSideInputView,
-            singletonSideInputView,
-            iterableSideInputView))
-        .withSideInputs(
-            defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView));
+    TupleTag<String> mainOutput = new TupleTag<String>("main") {};
+    TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
+    PCollectionTuple outputPCollection =
+        valuePCollection.apply(
+            TEST_PTRANSFORM_ID,
+            ParDo.of(
+                    new TestSideInputDoFn(
+                        defaultSingletonSideInputView,
+                        singletonSideInputView,
+                        iterableSideInputView,
+                        additionalOutput))
+                .withSideInputs(
+                    defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView)
+                .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
 
     SdkComponents sdkComponents = SdkComponents.create();
     RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
     String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+    String outputPCollectionId =
+        sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
+    String additionalPCollectionId =
+        sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
 
-    RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
-        pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+    RunnerApi.PTransform pTransform =
+        pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID);
 
     ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
         multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
@@ -478,13 +310,16 @@
     FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
 
     List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
     Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
-    consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+    consumers.put(outputPCollectionId,
         (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+    consumers.put(additionalPCollectionId,
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
     List<ThrowingRunnable> startFunctions = new ArrayList<>();
     List<ThrowingRunnable> finishFunctions = new ArrayList<>();
 
-    new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
         PipelineOptionsFactory.create(),
         null /* beamFnDataClient */,
         fakeClient,
@@ -496,12 +331,15 @@
         pProto.getComponents().getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     Iterables.getOnlyElement(startFunctions).run();
     mainOutputValues.clear();
 
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+    assertThat(
+        consumers.keySet(),
+        containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
 
     // Ensure that bag user state that is initially empty or populated works.
     // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
@@ -520,6 +358,9 @@
         valueInGlobalWindow("Y:iterableValue1"),
         valueInGlobalWindow("Y:iterableValue2"),
         valueInGlobalWindow("Y:iterableValue3")));
+    assertThat(
+        additionalOutputValues,
+        contains(valueInGlobalWindow("X:additional"), valueInGlobalWindow("Y:additional")));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctions).run();
@@ -589,7 +430,7 @@
     List<ThrowingRunnable> startFunctions = new ArrayList<>();
     List<ThrowingRunnable> finishFunctions = new ArrayList<>();
 
-    new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+    new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
         PipelineOptionsFactory.create(),
         null /* beamFnDataClient */,
         fakeClient,
@@ -601,7 +442,8 @@
         pProto.getComponents().getWindowingStrategiesMap(),
         consumers,
         startFunctions::add,
-        finishFunctions::add);
+        finishFunctions::add,
+        null /* splitListener */);
 
     Iterables.getOnlyElement(startFunctions).run();
     mainOutputValues.clear();
@@ -658,17 +500,4 @@
     }
     return out.toByteString();
   }
-
-  @Test
-  public void testRegistration() {
-    for (Registrar registrar :
-        ServiceLoader.load(Registrar.class)) {
-      if (registrar instanceof FnApiDoFnRunner.Registrar) {
-        assertThat(registrar.getPTransformRunnerFactories(),
-            IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN));
-        return;
-      }
-    }
-    fail("Expected registrar not found.");
-  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
index 906bdeb..93ca808 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
@@ -82,7 +82,8 @@
             Collections.emptyMap(),
             consumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            null /* splitListener */);
 
     assertThat(startFunctions, empty());
     assertThat(finishFunctions, empty());
@@ -116,7 +117,8 @@
             Collections.emptyMap(),
             consumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            null /* splitListener */);
 
     assertThat(startFunctions, empty());
     assertThat(finishFunctions, empty());
@@ -150,7 +152,8 @@
             Collections.emptyMap(),
             consumers,
             startFunctions::add,
-            finishFunctions::add);
+            finishFunctions::add,
+            null /* splitListener */);
 
     assertThat(startFunctions, empty());
     assertThat(finishFunctions, empty());
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 48d56fd..eafef99 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -50,6 +50,7 @@
 import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.function.ThrowingConsumer;
 import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -115,7 +116,8 @@
             windowingStrategies,
             pCollectionIdsToConsumers,
             addStartFunction,
-            addFinishFunction) -> {
+            addFinishFunction,
+            splitListener) -> {
           assertThat(processBundleInstructionId.get(), equalTo("999L"));
 
           transformsProcessed.add(pTransform);
@@ -176,7 +178,8 @@
                     windowingStrategies,
                     pCollectionIdsToConsumers,
                     addStartFunction,
-                    addFinishFunction) -> {
+                    addFinishFunction,
+                    splitListener) -> {
                   thrown.expect(IllegalStateException.class);
                   thrown.expectMessage("TestException");
                   throw new IllegalStateException("TestException");
@@ -217,7 +220,8 @@
                         windowingStrategies,
                         pCollectionIdsToConsumers,
                         addStartFunction,
-                        addFinishFunction) -> {
+                        addFinishFunction,
+                        splitListener) -> {
                       thrown.expect(IllegalStateException.class);
                       thrown.expectMessage("TestException");
                       addStartFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -261,7 +265,8 @@
                         windowingStrategies,
                         pCollectionIdsToConsumers,
                         addStartFunction,
-                        addFinishFunction) -> {
+                        addFinishFunction,
+                        splitListener) -> {
                       thrown.expect(IllegalStateException.class);
                       thrown.expectMessage("TestException");
                       addFinishFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -335,10 +340,11 @@
               Supplier<String> processBundleInstructionId,
               Map<String, PCollection> pCollections,
               Map<String, Coder> coders,
-              Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+              Map<String, WindowingStrategy> windowingStrategies,
               Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
               Consumer<ThrowingRunnable> addStartFunction,
-              Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+              Consumer<ThrowingRunnable> addFinishFunction,
+              BundleSplitListener splitListener) throws IOException {
             addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
             return null;
           }
@@ -386,10 +392,11 @@
               Supplier<String> processBundleInstructionId,
               Map<String, PCollection> pCollections,
               Map<String, Coder> coders,
-              Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+              Map<String, WindowingStrategy> windowingStrategies,
               Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
               Consumer<ThrowingRunnable> addStartFunction,
-              Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+              Consumer<ThrowingRunnable> addFinishFunction,
+              BundleSplitListener splitListener) throws IOException {
             addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
             return null;
           }