[BEAM-5605] Add support for channel splitting to the gRPC read "source" and propagate "split" calls to the downstream receiver (#10501)

* [BEAM-3684] Add support for channel splitting to the gRPC read "source" and propagate "split" calls to the downstream receiver

This code mirrors the logic/implementation within https://github.com/apache/beam/blob/16757ef9a6da4d0ac218c6c4d6b19e2a49ccca45/sdks/python/apache_beam/runners/worker/bundle_processor.py#L206

To be able to propagate the split call to the downstream receiver, I collapsed all the harness FnDataReceiver types into two existing implementations and one new implementations.
The previous hierarchy was:
element counting receiver -> time counting receiver -> multiplexing receiver (possibly the original receiver)

The current implementation combined the element counting, time counting and multiplexing into the MultiplexingMetricTrackingFnDataReceiver while for the singleton case into the MetricTrackingFnDataReceiver.
To propagate splits, a SplittingMetricTrackingFnDataReceiver was created that extends the MetricTrackingFnDataReceiver. Note, like in Python, there is currently no support for splitting as in https://github.com/apache/beam/blob/c167d8ef99b21148bcab7c37538a6ef2f64864c7/sdks/python/apache_beam/runners/worker/operations.py#L133

* fixup! Fix spot/find bugs issues

* fixup! Address comments.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index e93b2ba..65f11d9 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.fn.harness;
 
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.getOnlyElement;
 
 import com.google.auto.service.AutoService;
@@ -24,12 +25,16 @@
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import org.apache.beam.fn.harness.HandlesSplits.SplitResult;
 import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
 import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest.DesiredSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.RemoteGrpcPort;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -47,6 +52,7 @@
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -131,6 +137,11 @@
   private final BeamFnDataClient beamFnDataClient;
   private final Coder<WindowedValue<OutputT>> coder;
 
+  private final Object splittingLock = new Object();
+  // 0-based index of the current element being processed
+  private long index = -1;
+  // 0-based index of the first element to not process, aka the first element of the residual
+  private long stopIndex = Long.MAX_VALUE;
   private InboundDataClient readFuture;
 
   BeamFnDataReadRunner(
@@ -170,7 +181,109 @@
             apiServiceDescriptor,
             LogicalEndpoint.of(processBundleInstructionIdSupplier.get(), pTransformId),
             coder,
-            consumer);
+            this::forwardElementToConsumer);
+  }
+
+  public void forwardElementToConsumer(WindowedValue<OutputT> element) throws Exception {
+    synchronized (splittingLock) {
+      if (index == stopIndex - 1) {
+        return;
+      }
+      index += 1;
+    }
+    consumer.accept(element);
+  }
+
+  public void split(
+      ProcessBundleSplitRequest request, ProcessBundleSplitResponse.Builder response) {
+    DesiredSplit desiredSplit = request.getDesiredSplitsMap().get(pTransformId);
+    if (desiredSplit == null) {
+      return;
+    }
+
+    long totalBufferSize = desiredSplit.getEstimatedInputElements();
+
+    HandlesSplits splittingConsumer = null;
+    if (consumer instanceof HandlesSplits) {
+      splittingConsumer = ((HandlesSplits) consumer);
+    }
+
+    synchronized (splittingLock) {
+      // Since we hold the splittingLock, we guarantee that we will not pass the next element
+      // to the downstream consumer. We still have a race where the downstream consumer may
+      // have yet to see the element or has completed processing the element by the time
+      // we ask it to split (even after we have asked for its progress).
+
+      // If the split request we received was delayed and is less then the known number of elements
+      // then use "index + 1" as the total size. Similarly, if we have already split and the
+      // split request is bounded incorrectly, use the stop index as the upper bound.
+      if (totalBufferSize < index + 1) {
+        totalBufferSize = index + 1;
+      } else if (totalBufferSize > stopIndex) {
+        totalBufferSize = stopIndex;
+      }
+
+      // In the case where we have yet to process an element, set the current element progress to 1.
+      double currentElementProgress = 1;
+
+      // If we have started processing at least one element, attempt to get the downstream
+      // progress defaulting to 0.5 if no progress was able to get fetched.
+      if (index >= 0) {
+        if (splittingConsumer != null) {
+          currentElementProgress = splittingConsumer.getProgress();
+        } else {
+          currentElementProgress = 0.5;
+        }
+      }
+
+      checkArgument(
+          desiredSplit.getAllowedSplitPointsList().isEmpty(),
+          "TODO: BEAM-3836, support split point restrictions.");
+
+      // Now figure out where to split.
+      //
+      // The units here (except for keepOfElementRemainder) are all in terms of number or
+      // (possibly fractional) elements.
+
+      // Compute the amount of "remaining" work that we know of.
+      double remainder = totalBufferSize - index - currentElementProgress;
+      // Compute the number of elements (including fractional elements) that we should "keep".
+      double keep = remainder * desiredSplit.getFractionOfRemainder();
+
+      // If the downstream operator says the progress is less than 1 then the element could be
+      // splittable.
+      if (currentElementProgress < 1) {
+        // See if the amount we need to keep falls within the current element's remainder and if
+        // so, attempt to split it.
+        double keepOfElementRemainder = keep / (1 - currentElementProgress);
+        if (keepOfElementRemainder < 1) {
+          SplitResult splitResult =
+              splittingConsumer != null ? splittingConsumer.trySplit(keepOfElementRemainder) : null;
+          if (splitResult != null) {
+            stopIndex = index + 1;
+            response
+                .addPrimaryRoots(splitResult.getPrimaryRoot())
+                .addResidualRoots(splitResult.getResidualRoot())
+                .addChannelSplitsBuilder()
+                .setLastPrimaryElement(index - 1)
+                .setFirstResidualElement(stopIndex);
+            return;
+          }
+        }
+      }
+
+      // Otherwise, split at the closest element boundary.
+      int newStopIndex =
+          Ints.checkedCast(index + Math.max(1, Math.round(currentElementProgress + keep)));
+      if (newStopIndex < stopIndex) {
+        stopIndex = newStopIndex;
+        response
+            .addChannelSplitsBuilder()
+            .setLastPrimaryElement(stopIndex - 1)
+            .setFirstResidualElement(stopIndex);
+        return;
+      }
+    }
   }
 
   public void blockTillReadFinishes() throws Exception {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
new file mode 100644
index 0000000..a2ac123
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+
+public interface HandlesSplits {
+  SplitResult trySplit(double fractionOfRemainder);
+
+  double getProgress();
+
+  @AutoValue
+  abstract class SplitResult {
+    public static SplitResult of(
+        BeamFnApi.BundleApplication primaryRoot, BeamFnApi.DelayedBundleApplication residualRoot) {
+      return new AutoValue_HandlesSplits_SplitResult(primaryRoot, residualRoot);
+    }
+
+    public abstract BeamFnApi.BundleApplication getPrimaryRoot();
+
+    public abstract BeamFnApi.DelayedBundleApplication getResidualRoot();
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/ElementCountFnDataReceiver.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/ElementCountFnDataReceiver.java
deleted file mode 100644
index f234844..0000000
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/ElementCountFnDataReceiver.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.fn.harness.data;
-
-import java.io.Closeable;
-import java.util.HashMap;
-import org.apache.beam.runners.core.metrics.LabeledMetrics;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
-import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
-import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
-import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
-import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.metrics.Counter;
-import org.apache.beam.sdk.metrics.MetricsContainer;
-import org.apache.beam.sdk.metrics.MetricsEnvironment;
-import org.apache.beam.sdk.util.WindowedValue;
-
-/**
- * A wrapping {@code FnDataReceiver<WindowedValue<T>>} which counts the number of elements consumed
- * by the original {@code FnDataReceiver<WindowedValue<T>>}.
- *
- * @param <T> - The receiving type of the PTransform.
- */
-public class ElementCountFnDataReceiver<T> implements FnDataReceiver<WindowedValue<T>> {
-
-  private FnDataReceiver<WindowedValue<T>> original;
-  private Counter counter;
-  private MetricsContainer unboundMetricContainer;
-
-  public ElementCountFnDataReceiver(
-      FnDataReceiver<WindowedValue<T>> original,
-      String pCollection,
-      MetricsContainerStepMap metricContainerRegistry) {
-    this.original = original;
-    HashMap<String, String> labels = new HashMap<String, String>();
-    labels.put(Labels.PCOLLECTION, pCollection);
-    MonitoringInfoMetricName metricName =
-        MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.ELEMENT_COUNT, labels);
-    this.counter = LabeledMetrics.counter(metricName);
-    // Collect the metric in a metric container which is not bound to the step name.
-    // This is required to count elements from impulse steps, which will produce elements outside
-    // of a pTransform context.
-    this.unboundMetricContainer = metricContainerRegistry.getUnboundContainer();
-  }
-
-  @Override
-  public void accept(WindowedValue<T> input) throws Exception {
-    try (Closeable close = MetricsEnvironment.scopedMetricsContainer(this.unboundMetricContainer)) {
-      // Increment the counter for each window the element occurs in.
-      this.counter.inc(input.getWindows().size());
-      this.original.accept(input);
-    }
-  }
-}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/MultiplexingFnDataReceiver.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/MultiplexingFnDataReceiver.java
deleted file mode 100644
index 65f75b0..0000000
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/MultiplexingFnDataReceiver.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.fn.harness.data;
-
-import java.util.Collection;
-import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-
-/**
- * A {@link FnDataReceiver} which forwards all received inputs to a collection of {@link
- * FnDataReceiver receivers}.
- */
-public class MultiplexingFnDataReceiver<T> implements FnDataReceiver<T> {
-  public static <T> FnDataReceiver<T> forConsumers(Collection<FnDataReceiver<T>> consumers) {
-    if (consumers.size() == 1) {
-      return Iterables.getOnlyElement(consumers);
-    }
-    return new MultiplexingFnDataReceiver<>(consumers);
-  }
-
-  private final Collection<FnDataReceiver<T>> consumers;
-
-  private MultiplexingFnDataReceiver(Collection<FnDataReceiver<T>> consumers) {
-    this.consumers = consumers;
-  }
-
-  @Override
-  public void accept(T input) throws Exception {
-    for (FnDataReceiver<T> consumer : consumers) {
-      consumer.accept(input);
-    }
-  }
-}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 80d270f..3276c46 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -17,23 +17,32 @@
  */
 package org.apache.beam.fn.harness.data;
 
+import com.google.auto.value.AutoValue;
 import java.io.Closeable;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import org.apache.beam.fn.harness.HandlesSplits;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
+import org.apache.beam.runners.core.metrics.LabeledMetrics;
 import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
+import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
 import org.apache.beam.runners.core.metrics.SimpleExecutionState;
 import org.apache.beam.runners.core.metrics.SimpleStateRegistry;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.MetricsContainer;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 
 /**
  * The {@code PCollectionConsumerRegistry} is used to maintain a collection of consuming
@@ -43,8 +52,24 @@
  */
 public class PCollectionConsumerRegistry {
 
-  private ListMultimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers;
-  private Map<String, ElementCountFnDataReceiver> pCollectionIdsToWrappedConsumer;
+  /** Stores metadata about each consumer so that the appropriate metrics tracking can occur. */
+  @AutoValue
+  abstract static class ConsumerAndMetadata {
+    public static ConsumerAndMetadata forConsumer(
+        FnDataReceiver consumer, String pTransformId, SimpleExecutionState state) {
+      return new AutoValue_PCollectionConsumerRegistry_ConsumerAndMetadata(
+          consumer, pTransformId, state);
+    }
+
+    public abstract FnDataReceiver getConsumer();
+
+    public abstract String getPTransformId();
+
+    public abstract SimpleExecutionState getExecutionState();
+  }
+
+  private ListMultimap<String, ConsumerAndMetadata> pCollectionIdsToConsumers;
+  private Map<String, FnDataReceiver> pCollectionIdsToWrappedConsumer;
   private MetricsContainerStepMap metricsContainerRegistry;
   private ExecutionStateTracker stateTracker;
   private SimpleStateRegistry executionStates = new SimpleStateRegistry();
@@ -54,7 +79,7 @@
     this.metricsContainerRegistry = metricsContainerRegistry;
     this.stateTracker = stateTracker;
     this.pCollectionIdsToConsumers = ArrayListMultimap.create();
-    this.pCollectionIdsToWrappedConsumer = new HashMap<String, ElementCountFnDataReceiver>();
+    this.pCollectionIdsToWrappedConsumer = new HashMap<>();
   }
 
   /**
@@ -77,15 +102,13 @@
     // Just save these consumers for now, but package them up later with an
     // ElementCountFnDataReceiver and possibly a MultiplexingFnDataReceiver
     // if there are multiple consumers.
-    ElementCountFnDataReceiver wrappedConsumer =
-        pCollectionIdsToWrappedConsumer.getOrDefault(pCollectionId, null);
-    if (wrappedConsumer != null) {
+    if (pCollectionIdsToWrappedConsumer.containsKey(pCollectionId)) {
       throw new RuntimeException(
           "New consumers for a pCollectionId cannot be register()-d after "
               + "calling getMultiplexingConsumer.");
     }
 
-    HashMap<String, String> labelsMetadata = new HashMap<String, String>();
+    HashMap<String, String> labelsMetadata = new HashMap<>();
     labelsMetadata.put(MonitoringInfoConstants.Labels.PTRANSFORM, pTransformId);
     SimpleExecutionState state =
         new SimpleExecutionState(
@@ -93,20 +116,9 @@
             MonitoringInfoConstants.Urns.PROCESS_BUNDLE_MSECS,
             labelsMetadata);
     executionStates.register(state);
-    // Wrap the consumer with extra logic to set the metric container with the appropriate
-    // PTransform context. This ensures that user metrics obtain the pTransform ID when they are
-    // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
-    // Process Bundle Execution time metric.
-    FnDataReceiver<WindowedValue<T>> wrapAndEnableMetricContainer =
-        (WindowedValue<T> input) -> {
-          MetricsContainerImpl container = metricsContainerRegistry.getContainer(pTransformId);
-          try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(container)) {
-            try (Closeable trackerCloseable = this.stateTracker.enterState(state)) {
-              consumer.accept(input);
-            }
-          }
-        };
-    pCollectionIdsToConsumers.put(pCollectionId, (FnDataReceiver) wrapAndEnableMetricContainer);
+
+    pCollectionIdsToConsumers.put(
+        pCollectionId, ConsumerAndMetadata.forConsumer(consumer, pTransformId, state));
   }
 
   /** Reset the execution states of the registered functions. */
@@ -121,24 +133,28 @@
 
   /**
    * New consumers should not be register()-ed after calling this method. This will cause a
-   * RuntimeException, as this would fail to properly wrap the late-added consumer to the
-   * ElementCountFnDataReceiver.
+   * RuntimeException, as this would fail to properly wrap the late-added consumer to track metrics.
    *
-   * @return A single ElementCountFnDataReceiver which directly wraps all the registered consumers.
+   * @return A {@link FnDataReceiver} which directly wraps all the registered consumers.
    */
   public FnDataReceiver<WindowedValue<?>> getMultiplexingConsumer(String pCollectionId) {
-    ElementCountFnDataReceiver wrappedConsumer =
-        pCollectionIdsToWrappedConsumer.getOrDefault(pCollectionId, null);
-    if (wrappedConsumer == null) {
-      List<FnDataReceiver<WindowedValue<?>>> consumers =
-          pCollectionIdsToConsumers.get(pCollectionId);
-      FnDataReceiver<WindowedValue<?>> consumer =
-          MultiplexingFnDataReceiver.forConsumers(consumers);
-      wrappedConsumer =
-          new ElementCountFnDataReceiver(consumer, pCollectionId, metricsContainerRegistry);
-      pCollectionIdsToWrappedConsumer.put(pCollectionId, wrappedConsumer);
-    }
-    return wrappedConsumer;
+    return pCollectionIdsToWrappedConsumer.computeIfAbsent(
+        pCollectionId,
+        pcId -> {
+          List<ConsumerAndMetadata> consumerAndMetadatas = pCollectionIdsToConsumers.get(pcId);
+          if (consumerAndMetadatas == null) {
+            throw new IllegalArgumentException(
+                String.format("Unknown PCollectionId %s", pCollectionId));
+          } else if (consumerAndMetadatas.size() == 1) {
+            if (consumerAndMetadatas.get(0).getConsumer() instanceof HandlesSplits) {
+              return new SplittingMetricTrackingFnDataReceiver(pcId, consumerAndMetadatas.get(0));
+            }
+            return new MetricTrackingFnDataReceiver(pcId, consumerAndMetadatas.get(0));
+          } else {
+            /* TODO(SDF), Consider supporting splitting each consumer individually. This would never come up in the existing SDF expansion, but might be useful to support fused SDF nodes. This would require dedicated delivery of the split results to each of the consumers separately. */
+            return new MultiplexingMetricTrackingFnDataReceiver(pcId, consumerAndMetadatas);
+          }
+        });
   }
 
   /** @return Execution Time MonitoringInfos based on the tracked start or finish function. */
@@ -146,11 +162,141 @@
     return executionStates.getExecutionTimeMonitoringInfos();
   }
 
+  /** @return the underlying consumers for a pCollectionId, some tests may wish to check this. */
+  @VisibleForTesting
+  public List<FnDataReceiver> getUnderlyingConsumers(String pCollectionId) {
+    return Lists.transform(
+        pCollectionIdsToConsumers.get(pCollectionId), input -> input.getConsumer());
+  }
+
   /**
-   * @return the number of underlying consumers for a pCollectionId, some tests may wish to check
-   *     this.
+   * A wrapping {@code FnDataReceiver<WindowedValue<T>>} which counts the number of elements
+   * consumed by the original {@code FnDataReceiver<WindowedValue<T>> consumer} and sets up metrics
+   * for tracking PTransform processing time.
+   *
+   * @param <T> - The receiving type of the PTransform.
    */
-  public List<FnDataReceiver<WindowedValue<?>>> getUnderlyingConsumers(String pCollectionId) {
-    return pCollectionIdsToConsumers.get(pCollectionId);
+  private class MetricTrackingFnDataReceiver<T> implements FnDataReceiver<WindowedValue<T>> {
+    private final FnDataReceiver<WindowedValue<T>> delegate;
+    private final String pTransformId;
+    private final SimpleExecutionState state;
+    private final Counter counter;
+    private final MetricsContainer unboundMetricContainer;
+
+    public MetricTrackingFnDataReceiver(
+        String pCollectionId, ConsumerAndMetadata consumerAndMetadata) {
+      this.delegate = consumerAndMetadata.getConsumer();
+      this.state = consumerAndMetadata.getExecutionState();
+      this.pTransformId = consumerAndMetadata.getPTransformId();
+      HashMap<String, String> labels = new HashMap<String, String>();
+      labels.put(Labels.PCOLLECTION, pCollectionId);
+      MonitoringInfoMetricName metricName =
+          MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.ELEMENT_COUNT, labels);
+      this.counter = LabeledMetrics.counter(metricName);
+      // Collect the metric in a metric container which is not bound to the step name.
+      // This is required to count elements from impulse steps, which will produce elements outside
+      // of a pTransform context.
+      this.unboundMetricContainer = metricsContainerRegistry.getUnboundContainer();
+    }
+
+    @Override
+    public void accept(WindowedValue<T> input) throws Exception {
+      try (Closeable close =
+          MetricsEnvironment.scopedMetricsContainer(this.unboundMetricContainer)) {
+        // Increment the counter for each window the element occurs in.
+        this.counter.inc(input.getWindows().size());
+
+        // Wrap the consumer with extra logic to set the metric container with the appropriate
+        // PTransform context. This ensures that user metrics obtain the pTransform ID when they are
+        // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
+        // Process Bundle Execution time metric.
+        MetricsContainerImpl container = metricsContainerRegistry.getContainer(pTransformId);
+        try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(container)) {
+          try (Closeable trackerCloseable = stateTracker.enterState(state)) {
+            this.delegate.accept(input);
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * A wrapping {@code FnDataReceiver<WindowedValue<T>>} which counts the number of elements
+   * consumed by the original {@code FnDataReceiver<WindowedValue<T>> consumers} and sets up metrics
+   * for tracking PTransform processing time.
+   *
+   * @param <T> - The receiving type of the PTransform.
+   */
+  private class MultiplexingMetricTrackingFnDataReceiver<T>
+      implements FnDataReceiver<WindowedValue<T>> {
+    private final List<ConsumerAndMetadata> consumerAndMetadatas;
+    private final Counter counter;
+    private final MetricsContainer unboundMetricContainer;
+
+    public MultiplexingMetricTrackingFnDataReceiver(
+        String pCollectionId, List<ConsumerAndMetadata> consumerAndMetadatas) {
+      this.consumerAndMetadatas = consumerAndMetadatas;
+      HashMap<String, String> labels = new HashMap<String, String>();
+      labels.put(Labels.PCOLLECTION, pCollectionId);
+      MonitoringInfoMetricName metricName =
+          MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.ELEMENT_COUNT, labels);
+      this.counter = LabeledMetrics.counter(metricName);
+      // Collect the metric in a metric container which is not bound to the step name.
+      // This is required to count elements from impulse steps, which will produce elements outside
+      // of a pTransform context.
+      this.unboundMetricContainer = metricsContainerRegistry.getUnboundContainer();
+    }
+
+    @Override
+    public void accept(WindowedValue<T> input) throws Exception {
+      try (Closeable close =
+          MetricsEnvironment.scopedMetricsContainer(this.unboundMetricContainer)) {
+        // Increment the counter for each window the element occurs in.
+        this.counter.inc(input.getWindows().size());
+
+        // Wrap the consumer with extra logic to set the metric container with the appropriate
+        // PTransform context. This ensures that user metrics obtain the pTransform ID when they are
+        // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
+        // Process Bundle Execution time metric.
+        for (ConsumerAndMetadata consumerAndMetadata : consumerAndMetadatas) {
+          MetricsContainerImpl container =
+              metricsContainerRegistry.getContainer(consumerAndMetadata.getPTransformId());
+          try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(container)) {
+            try (Closeable trackerCloseable =
+                stateTracker.enterState(consumerAndMetadata.getExecutionState())) {
+              consumerAndMetadata.getConsumer().accept(input);
+            }
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * A wrapping {@code FnDataReceiver<WindowedValue<T>>} which counts the number of elements
+   * consumed by the original {@code FnDataReceiver<WindowedValue<T>> consumer} and forwards split
+   * and progress requests to the original consumer.
+   *
+   * @param <T> - The receiving type of the PTransform.
+   */
+  private class SplittingMetricTrackingFnDataReceiver<T> extends MetricTrackingFnDataReceiver<T>
+      implements HandlesSplits {
+    private final HandlesSplits delegate;
+
+    public SplittingMetricTrackingFnDataReceiver(
+        String pCollection, ConsumerAndMetadata consumerAndMetadata) {
+      super(pCollection, consumerAndMetadata);
+      this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer();
+    }
+
+    @Override
+    public SplitResult trySplit(double fractionOfRemainder) {
+      return delegate.trySplit(fractionOfRemainder);
+    }
+
+    @Override
+    public double getProgress() {
+      return delegate.getProgress();
+    }
   }
 }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java
index 01cf0c7..8fae554 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java
@@ -21,7 +21,6 @@
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
 import org.apache.beam.sdk.coders.Coder;
@@ -96,7 +95,7 @@
    *
    * <p>This method is NOT thread safe. This should only be invoked by a single thread, and is
    * intended for use with a newly constructed QueueingBeamFnDataClient in {@link
-   * ProcessBundleHandler#processBundle(InstructionRequest)}.
+   * ProcessBundleHandler#processBundle}.
    */
   public void drainAndBlock() throws Exception {
     while (true) {
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
index e6e5297..c27400f 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
@@ -184,7 +184,7 @@
             null /* pipelineOptions */,
             null /* beamFnDataClient */,
             null /* beamFnStateClient */,
-            null /* pTransformId */,
+            "ptransform",
             PTransform.newBuilder()
                 .putInputs("in", "input")
                 .putOutputs("out", "output")
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index e8a814c..39e1f9e 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -21,9 +21,11 @@
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyDouble;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -39,12 +41,18 @@
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.fn.harness.HandlesSplits.SplitResult;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
 import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
 import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest.DesiredSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse.ChannelSplit;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
@@ -65,7 +73,6 @@
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
@@ -203,10 +210,8 @@
     when(mockBeamFnDataClient.receive(any(), any(), any(), any()))
         .thenReturn(bundle1Future)
         .thenReturn(bundle2Future);
-    List<WindowedValue<String>> valuesA = new ArrayList<>();
-    List<WindowedValue<String>> valuesB = new ArrayList<>();
-    FnDataReceiver<WindowedValue<String>> consumers =
-        MultiplexingFnDataReceiver.forConsumers(ImmutableList.of(valuesA::add, valuesB::add));
+    List<WindowedValue<String>> values = new ArrayList<>();
+    FnDataReceiver<WindowedValue<String>> consumers = values::add;
     AtomicReference<String> bundleId = new AtomicReference<>("0");
     BeamFnDataReadRunner<String> readRunner =
         new BeamFnDataReadRunner<>(
@@ -245,13 +250,11 @@
 
     readRunner.blockTillReadFinishes();
     future.get();
-    assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
-    assertThat(valuesB, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
+    assertThat(values, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
 
     // Process for bundle id 1
     bundleId.set("1");
-    valuesA.clear();
-    valuesB.clear();
+    values.clear();
     readRunner.registerInputLocation();
 
     verify(mockBeamFnDataClient)
@@ -278,8 +281,7 @@
 
     readRunner.blockTillReadFinishes();
     future.get();
-    assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
-    assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
+    assertThat(values, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
 
     verifyNoMoreInteractions(mockBeamFnDataClient);
   }
@@ -296,4 +298,185 @@
     }
     fail("Expected registrar not found.");
   }
+
+  @Test
+  public void testSplittingWhenNoElementsProcessed() throws Exception {
+    List<WindowedValue<String>> outputValues = new ArrayList<>();
+    BeamFnDataReadRunner<String> readRunner = createReadRunner(outputValues::add);
+
+    ProcessBundleSplitRequest request =
+        ProcessBundleSplitRequest.newBuilder()
+            .putDesiredSplits(
+                "pTransformId",
+                DesiredSplit.newBuilder()
+                    .setEstimatedInputElements(10)
+                    .setFractionOfRemainder(0.5)
+                    .build())
+            .build();
+    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
+    readRunner.split(request, responseBuilder);
+
+    ProcessBundleSplitResponse expected =
+        ProcessBundleSplitResponse.newBuilder()
+            .addChannelSplits(
+                ChannelSplit.newBuilder()
+                    .setLastPrimaryElement(4)
+                    .setFirstResidualElement(5)
+                    .build())
+            .build();
+    assertEquals(expected, responseBuilder.build());
+
+    // Ensure that we process the correct number of elements after splitting.
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("D"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("E"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("F"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("G"));
+    assertThat(
+        outputValues,
+        contains(
+            valueInGlobalWindow("A"),
+            valueInGlobalWindow("B"),
+            valueInGlobalWindow("C"),
+            valueInGlobalWindow("D"),
+            valueInGlobalWindow("E")));
+  }
+
+  @Test
+  public void testSplittingWhenSomeElementsProcessed() throws Exception {
+    List<WindowedValue<String>> outputValues = new ArrayList<>();
+    BeamFnDataReadRunner<String> readRunner = createReadRunner(outputValues::add);
+
+    ProcessBundleSplitRequest request =
+        ProcessBundleSplitRequest.newBuilder()
+            .putDesiredSplits(
+                "pTransformId",
+                DesiredSplit.newBuilder()
+                    .setEstimatedInputElements(10)
+                    .setFractionOfRemainder(0.5)
+                    .build())
+            .build();
+    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
+
+    // Process 2 elements then split
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
+    readRunner.split(request, responseBuilder);
+
+    ProcessBundleSplitResponse expected =
+        ProcessBundleSplitResponse.newBuilder()
+            .addChannelSplits(
+                ChannelSplit.newBuilder()
+                    .setLastPrimaryElement(5)
+                    .setFirstResidualElement(6)
+                    .build())
+            .build();
+    assertEquals(expected, responseBuilder.build());
+
+    // Ensure that we process the correct number of elements after splitting.
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("D"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("E"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("F"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("G"));
+    assertThat(
+        outputValues,
+        contains(
+            valueInGlobalWindow("A"),
+            valueInGlobalWindow("B"),
+            valueInGlobalWindow("C"),
+            valueInGlobalWindow("D"),
+            valueInGlobalWindow("E"),
+            valueInGlobalWindow("F")));
+  }
+
+  @Test
+  public void testSplittingDownstreamReceiver() throws Exception {
+    SplitResult splitResult =
+        SplitResult.of(
+            BundleApplication.newBuilder().setInputId("primary").build(),
+            DelayedBundleApplication.newBuilder()
+                .setApplication(BundleApplication.newBuilder().setInputId("residual").build())
+                .build());
+    SplittingReceiver splittingReceiver = mock(SplittingReceiver.class);
+    when(splittingReceiver.getProgress()).thenReturn(0.3);
+    when(splittingReceiver.trySplit(anyDouble())).thenReturn(splitResult);
+    BeamFnDataReadRunner<String> readRunner = createReadRunner(splittingReceiver);
+
+    ProcessBundleSplitRequest request =
+        ProcessBundleSplitRequest.newBuilder()
+            .putDesiredSplits(
+                "pTransformId",
+                DesiredSplit.newBuilder()
+                    .setEstimatedInputElements(10)
+                    .setFractionOfRemainder(0.05)
+                    .build())
+            .build();
+    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
+
+    // We will be "processing" the 'C' element, aka 2nd index
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
+    readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
+    readRunner.split(request, responseBuilder);
+
+    ProcessBundleSplitResponse expected =
+        ProcessBundleSplitResponse.newBuilder()
+            .addPrimaryRoots(splitResult.getPrimaryRoot())
+            .addResidualRoots(splitResult.getResidualRoot())
+            .addChannelSplits(
+                ChannelSplit.newBuilder()
+                    .setLastPrimaryElement(1)
+                    .setFirstResidualElement(3)
+                    .build())
+            .build();
+    assertEquals(expected, responseBuilder.build());
+  }
+
+  private abstract static class SplittingReceiver
+      implements FnDataReceiver<WindowedValue<String>>, HandlesSplits {}
+
+  private BeamFnDataReadRunner<String> createReadRunner(
+      FnDataReceiver<WindowedValue<String>> consumer) throws Exception {
+    String bundleId = "57";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    String localOutputId = "outputPC";
+    String pTransformId = "pTransformId";
+    consumers.register(localOutputId, pTransformId, consumer);
+    PTransformFunctionRegistry startFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+    PTransformFunctionRegistry finishFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+    RunnerApi.PTransform pTransform =
+        RemoteGrpcPortRead.readFromPort(PORT_SPEC, localOutputId).toPTransform();
+
+    return new BeamFnDataReadRunner.Factory<String>()
+        .createRunnerForPTransform(
+            PipelineOptionsFactory.create(),
+            mockBeamFnDataClient,
+            null /* beamFnStateClient */,
+            pTransformId,
+            pTransform,
+            Suppliers.ofInstance(bundleId)::get,
+            ImmutableMap.of(
+                localOutputId,
+                RunnerApi.PCollection.newBuilder().setCoderId(ELEMENT_CODER_SPEC_ID).build()),
+            COMPONENTS.getCodersMap(),
+            COMPONENTS.getWindowingStrategiesMap(),
+            consumers,
+            startFunctionRegistry,
+            finishFunctionRegistry,
+            teardownFunctions::add,
+            null /* splitListener */);
+  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/ElementCountFnDataReceiverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/ElementCountFnDataReceiverTest.java
deleted file mode 100644
index ace16c6..0000000
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/ElementCountFnDataReceiverTest.java
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.fn.harness.data;
-
-import static junit.framework.TestCase.assertEquals;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.withSettings;
-import static org.powermock.api.mockito.PowerMockito.mockStatic;
-
-import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
-import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
-import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
-import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.metrics.MetricsEnvironment;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-/** Tests for {@link ElementCountFnDataReceiver}. */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(MetricsEnvironment.class)
-public class ElementCountFnDataReceiverTest {
-  /**
-   * Test that the elements are counted, and a MonitoringInfo can be extracted from a
-   * metricsContainer, if it is in scope.
-   *
-   * @throws Exception
-   */
-  @Test
-  public void testCountsElements() throws Exception {
-    final String pCollectionA = "pCollectionA";
-
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-
-    FnDataReceiver<WindowedValue<String>> consumer = mock(FnDataReceiver.class);
-    ElementCountFnDataReceiver<String> wrapperConsumer =
-        new ElementCountFnDataReceiver(consumer, pCollectionA, metricsContainerRegistry);
-    WindowedValue<String> element = WindowedValue.valueInGlobalWindow("elem");
-    int numElements = 20;
-    for (int i = 0; i < numElements; i++) {
-      wrapperConsumer.accept(element);
-    }
-    verify(consumer, times(numElements)).accept(element);
-
-    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
-    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
-    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
-    builder.setInt64Value(numElements);
-    MonitoringInfo expected = builder.build();
-
-    // Clear the timestamp before comparison.
-    MonitoringInfo first = metricsContainerRegistry.getMonitoringInfos().iterator().next();
-    MonitoringInfo result = SimpleMonitoringInfoBuilder.copyAndClearTimestamp(first);
-    assertEquals(expected, result);
-  }
-
-  @Test
-  public void testScopedMetricContainerInvokedUponAccept() throws Exception {
-    mockStatic(MetricsEnvironment.class, withSettings().verboseLogging());
-    final String pCollectionA = "pCollectionA";
-
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-
-    FnDataReceiver<WindowedValue<String>> consumer =
-        mock(FnDataReceiver.class, withSettings().verboseLogging());
-    ElementCountFnDataReceiver<String> wrapperConsumer =
-        new ElementCountFnDataReceiver(consumer, pCollectionA, metricsContainerRegistry);
-    WindowedValue<String> element = WindowedValue.valueInGlobalWindow("elem");
-    wrapperConsumer.accept(element);
-
-    verify(consumer, times(1)).accept(element);
-
-    // Verify that static scopedMetricsContainer is called with unbound container.
-    PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
-    MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getUnboundContainer());
-  }
-}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/MultiplexingFnDataReceiverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/MultiplexingFnDataReceiverTest.java
deleted file mode 100644
index f3e21c2..0000000
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/MultiplexingFnDataReceiverTest.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.fn.harness.data;
-
-import static org.hamcrest.Matchers.contains;
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.junit.Assert.assertThat;
-
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** Tests for {@link MultiplexingFnDataReceiver}. */
-@RunWith(JUnit4.class)
-public class MultiplexingFnDataReceiverTest {
-  @Rule public ExpectedException thrown = ExpectedException.none();
-
-  @Test
-  public void singleConsumer() throws Exception {
-    List<String> consumer = new ArrayList<>();
-    FnDataReceiver<String> multiplexer =
-        MultiplexingFnDataReceiver.forConsumers(
-            ImmutableList.<FnDataReceiver<String>>of(consumer::add));
-
-    multiplexer.accept("foo");
-    multiplexer.accept("bar");
-
-    assertThat(consumer, contains("foo", "bar"));
-  }
-
-  @Test
-  public void singleConsumerException() throws Exception {
-    String message = "my_exception";
-    FnDataReceiver<Integer> multiplexer =
-        MultiplexingFnDataReceiver.forConsumers(
-            ImmutableList.<FnDataReceiver<Integer>>of(
-                (Integer i) -> {
-                  if (i > 1) {
-                    throw new Exception(message);
-                  }
-                }));
-
-    multiplexer.accept(0);
-    multiplexer.accept(1);
-    thrown.expectMessage(message);
-    thrown.expect(Exception.class);
-    multiplexer.accept(2);
-  }
-
-  @Test
-  public void multipleConsumers() throws Exception {
-    List<String> consumer = new ArrayList<>();
-    Set<String> otherConsumer = new HashSet<>();
-    FnDataReceiver<String> multiplexer =
-        MultiplexingFnDataReceiver.forConsumers(
-            ImmutableList.<FnDataReceiver<String>>of(consumer::add, otherConsumer::add));
-
-    multiplexer.accept("foo");
-    multiplexer.accept("bar");
-    multiplexer.accept("foo");
-
-    assertThat(consumer, contains("foo", "bar", "foo"));
-    assertThat(otherConsumer, containsInAnyOrder("foo", "bar"));
-  }
-
-  @Test
-  public void multipleConsumersException() throws Exception {
-    String message = "my_exception";
-    List<Integer> consumer = new ArrayList<>();
-    FnDataReceiver<Integer> multiplexer =
-        MultiplexingFnDataReceiver.forConsumers(
-            ImmutableList.<FnDataReceiver<Integer>>of(
-                consumer::add,
-                (Integer i) -> {
-                  if (i > 1) {
-                    throw new Exception(message);
-                  }
-                }));
-
-    multiplexer.accept(0);
-    multiplexer.accept(1);
-    assertThat(consumer, containsInAnyOrder(0, 1));
-
-    thrown.expectMessage(message);
-    thrown.expect(Exception.class);
-    multiplexer.accept(2);
-  }
-}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index e8b377b..dac1fbe 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -17,18 +17,30 @@
  */
 package org.apache.beam.fn.harness.data;
 
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.withSettings;
 import static org.powermock.api.mockito.PowerMockito.mockStatic;
 
+import org.apache.beam.fn.harness.HandlesSplits;
+import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
+import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -44,6 +56,72 @@
 
   @Rule public ExpectedException expectedException = ExpectedException.none();
 
+  @Test
+  public void singleConsumer() throws Exception {
+    final String pCollectionA = "pCollectionA";
+    final String pTransformIdA = "pTransformIdA";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
+
+    consumers.register(pCollectionA, pTransformIdA, consumerA1);
+
+    FnDataReceiver<WindowedValue<String>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<String>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+
+    WindowedValue<String> element = valueInGlobalWindow("elem");
+    int numElements = 20;
+    for (int i = 0; i < numElements; i++) {
+      wrapperConsumer.accept(element);
+    }
+
+    // Check that the underlying consumers are each invoked per element.
+    verify(consumerA1, times(numElements)).accept(element);
+    assertThat(consumers.keySet(), contains(pCollectionA));
+
+    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+    builder.setInt64Value(numElements);
+    MonitoringInfo expected = builder.build();
+
+    // Clear the timestamp before comparison.
+    MonitoringInfo pCollectionCount =
+        Iterables.find(
+            metricsContainerRegistry.getMonitoringInfos(),
+            monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+    MonitoringInfo result = SimpleMonitoringInfoBuilder.copyAndClearTimestamp(pCollectionCount);
+    assertEquals(expected, result);
+  }
+
+  @Test
+  public void singleConsumerException() throws Exception {
+    final String pCollectionA = "pCollectionA";
+    final String pTransformId = "pTransformId";
+    final String message = "testException";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    FnDataReceiver<WindowedValue<String>> consumer = mock(FnDataReceiver.class);
+
+    consumers.register(pCollectionA, pTransformId, consumer);
+
+    FnDataReceiver<WindowedValue<String>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<String>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+    doThrow(new Exception(message)).when(consumer).accept(any());
+
+    expectedException.expectMessage(message);
+    expectedException.expect(Exception.class);
+    wrapperConsumer.accept(valueInGlobalWindow("elem"));
+  }
+
   /**
    * Test that the counter increments only once when multiple consumers of same pCollection read the
    * same element.
@@ -51,7 +129,54 @@
   @Test
   public void multipleConsumersSamePCollection() throws Exception {
     final String pCollectionA = "pCollectionA";
+    final String pTransformIdA = "pTransformIdA";
+    final String pTransformIdB = "pTransformIdB";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
+    FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class);
+
+    consumers.register(pCollectionA, pTransformIdA, consumerA1);
+    consumers.register(pCollectionA, pTransformIdB, consumerA2);
+
+    FnDataReceiver<WindowedValue<String>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<String>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+
+    WindowedValue<String> element = valueInGlobalWindow("elem");
+    int numElements = 20;
+    for (int i = 0; i < numElements; i++) {
+      wrapperConsumer.accept(element);
+    }
+
+    // Check that the underlying consumers are each invoked per element.
+    verify(consumerA1, times(numElements)).accept(element);
+    verify(consumerA2, times(numElements)).accept(element);
+    assertThat(consumers.keySet(), contains(pCollectionA));
+
+    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+    builder.setInt64Value(numElements);
+    MonitoringInfo expected = builder.build();
+
+    // Clear the timestamp before comparison.
+    MonitoringInfo pCollectionCount =
+        Iterables.find(
+            metricsContainerRegistry.getMonitoringInfos(),
+            monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+    MonitoringInfo result = SimpleMonitoringInfoBuilder.copyAndClearTimestamp(pCollectionCount);
+    assertEquals(expected, result);
+  }
+
+  @Test
+  public void multipleConsumersSamePCollectionException() throws Exception {
+    final String pCollectionA = "pCollectionA";
     final String pTransformId = "pTransformId";
+    final String message = "testException";
 
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
     PCollectionConsumerRegistry consumers =
@@ -66,17 +191,11 @@
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
             (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+    doThrow(new Exception(message)).when(consumerA2).accept(any());
 
-    WindowedValue<String> element = WindowedValue.valueInGlobalWindow("elem");
-    int numElements = 20;
-    for (int i = 0; i < numElements; i++) {
-      wrapperConsumer.accept(element);
-    }
-
-    // Check that the underlying consumers are each invoked per element.
-    verify(consumerA1, times(numElements)).accept(element);
-    verify(consumerA2, times(numElements)).accept(element);
-    assertThat(consumers.keySet(), contains(pCollectionA));
+    expectedException.expectMessage(message);
+    expectedException.expect(Exception.class);
+    wrapperConsumer.accept(valueInGlobalWindow("elem"));
   }
 
   @Test
@@ -118,7 +237,7 @@
         (FnDataReceiver<WindowedValue<String>>)
             (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
 
-    WindowedValue<String> element = WindowedValue.valueInGlobalWindow("elem");
+    WindowedValue<String> element = valueInGlobalWindow("elem");
     wrapperConsumer.accept(element);
 
     // Verify that static scopedMetricsContainer is called with pTransformA's container.
@@ -129,4 +248,61 @@
     PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
     MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getContainer("pTransformB"));
   }
+
+  @Test
+  public void testScopedMetricContainerInvokedUponAccept() throws Exception {
+    mockStatic(MetricsEnvironment.class, withSettings().verboseLogging());
+    final String pCollectionA = "pCollectionA";
+    final String pTransformIdA = "pTransformIdA";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    FnDataReceiver<WindowedValue<String>> consumer =
+        mock(FnDataReceiver.class, withSettings().verboseLogging());
+
+    consumers.register(pCollectionA, pTransformIdA, consumer);
+
+    FnDataReceiver<WindowedValue<String>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<String>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+
+    WindowedValue<String> element = WindowedValue.valueInGlobalWindow("elem");
+    wrapperConsumer.accept(element);
+
+    verify(consumer, times(1)).accept(element);
+
+    // Verify that static scopedMetricsContainer is called with unbound container.
+    PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
+    MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getUnboundContainer());
+  }
+
+  @Test
+  public void testHandlesSplitsPassedToOriginalConsumer() throws Exception {
+    final String pCollectionA = "pCollectionA";
+    final String pTransformIdA = "pTransformIdA";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    SplittingReceiver consumerA1 = mock(SplittingReceiver.class);
+
+    consumers.register(pCollectionA, pTransformIdA, consumerA1);
+
+    FnDataReceiver<WindowedValue<String>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<String>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+
+    assertTrue(wrapperConsumer instanceof HandlesSplits);
+
+    ((HandlesSplits) wrapperConsumer).getProgress();
+    verify(consumerA1).getProgress();
+
+    ((HandlesSplits) wrapperConsumer).trySplit(0.3);
+    verify(consumerA1).trySplit(0.3);
+  }
+
+  private abstract static class SplittingReceiver<T> implements FnDataReceiver<T>, HandlesSplits {}
 }