Merge pull request #15492 from rohdesamuel/hotkeyflake

[BEAM-12842] Add timestamp to test work item to deflake
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 4cdf4ae..67fe8b8 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -446,7 +446,7 @@
     def errorprone_version = "2.3.4"
     def google_clients_version = "1.32.1"
     def google_cloud_bigdataoss_version = "2.2.2"
-    def google_cloud_pubsublite_version = "0.13.2"
+    def google_cloud_pubsublite_version = "1.0.4"
     def google_code_gson_version = "2.8.6"
     def google_oauth_clients_version = "1.31.0"
     // Try to keep grpc_version consistent with gRPC version in google_cloud_platform_libraries_bom
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
new file mode 100644
index 0000000..5496d8b
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statecache implements the state caching feature described by the
+// Beam Fn API
+//
+// The Beam State API and the intended caching behavior are described here:
+// https://docs.google.com/document/d/1BOozW0bzBuz4oHJEuZNDOHdzaV5Y56ix58Ozrqm2jFg/edit#heading=h.7ghoih5aig5m
+package statecache
+
+import (
+	"sync"
+
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+	fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+type token string
+
+// SideInputCache stores a cache of reusable inputs for the purposes of
+// eliminating redundant calls to the runner during execution of ParDos
+// using side inputs.
+//
+// A SideInputCache should be initialized when the SDK harness is initialized,
+// creating storage for side input caching. On each ProcessBundleRequest,
+// the cache will process the list of tokens for cacheable side inputs and
+// be queried when side inputs are requested in bundle execution. Once a
+// new bundle request comes in the valid tokens will be updated and the cache
+// will be re-used. In the event that the cache reaches capacity, a random,
+// currently invalid cached object will be evicted.
+type SideInputCache struct {
+	capacity    int
+	mu          sync.Mutex
+	cache       map[token]exec.ReusableInput
+	idsToTokens map[string]token
+	validTokens map[token]int8 // Maps tokens to active bundle counts
+	metrics     CacheMetrics
+}
+
+type CacheMetrics struct {
+	Hits           int64
+	Misses         int64
+	Evictions      int64
+	InUseEvictions int64
+}
+
+// Init makes the cache map and the map of IDs to cache tokens for the
+// SideInputCache. Should only be called once. Returns an error for
+// non-positive capacities.
+func (c *SideInputCache) Init(cap int) error {
+	if cap <= 0 {
+		return errors.Errorf("capacity must be a positive integer, got %v", cap)
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.cache = make(map[token]exec.ReusableInput, cap)
+	c.idsToTokens = make(map[string]token)
+	c.validTokens = make(map[token]int8)
+	c.capacity = cap
+	return nil
+}
+
+// SetValidTokens clears the list of valid tokens then sets new ones, also updating the mapping of
+// transform and side input IDs to cache tokens in the process. Should be called at the start of every
+// new ProcessBundleRequest. If the runner does not support caching, the passed cache token values
+// should be empty and all get/set requests will silently be no-ops.
+func (c *SideInputCache) SetValidTokens(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	for _, tok := range cacheTokens {
+		// User State caching is currently not supported, so these tokens are ignored
+		if tok.GetUserState() != nil {
+			continue
+		}
+		s := tok.GetSideInput()
+		transformID := s.GetTransformId()
+		sideInputID := s.GetSideInputId()
+		t := token(tok.GetToken())
+		c.setValidToken(transformID, sideInputID, t)
+	}
+}
+
+// setValidToken adds a new valid token for a request into the SideInputCache struct
+// by mapping the transform ID and side input ID pairing to the cache token.
+func (c *SideInputCache) setValidToken(transformID, sideInputID string, tok token) {
+	idKey := transformID + sideInputID
+	c.idsToTokens[idKey] = tok
+	count, ok := c.validTokens[tok]
+	if !ok {
+		c.validTokens[tok] = 1
+	} else {
+		c.validTokens[tok] = count + 1
+	}
+}
+
+// CompleteBundle takes the cache tokens passed to set the valid tokens and decrements their
+// usage count for the purposes of maintaining a valid count of whether or not a value is
+// still in use. Should be called once ProcessBundle has completed.
+func (c *SideInputCache) CompleteBundle(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	for _, tok := range cacheTokens {
+		// User State caching is currently not supported, so these tokens are ignored
+		if tok.GetUserState() != nil {
+			continue
+		}
+		t := token(tok.GetToken())
+		c.decrementTokenCount(t)
+	}
+}
+
+// decrementTokenCount decrements the validTokens entry for
+// a given token by 1. Should only be called when completing
+// a bundle.
+func (c *SideInputCache) decrementTokenCount(tok token) {
+	count := c.validTokens[tok]
+	if count == 1 {
+		delete(c.validTokens, tok)
+	} else {
+		c.validTokens[tok] = count - 1
+	}
+}
+
+func (c *SideInputCache) makeAndValidateToken(transformID, sideInputID string) (token, bool) {
+	idKey := transformID + sideInputID
+	// Check if it's a known token
+	tok, ok := c.idsToTokens[idKey]
+	if !ok {
+		return "", false
+	}
+	return tok, c.isValid(tok)
+}
+
+// QueryCache takes a transform ID and side input ID and checking if a corresponding side
+// input has been cached. A query having a bad token (e.g. one that doesn't make a known
+// token or one that makes a known but currently invalid token) is treated the same as a
+// cache miss.
+func (c *SideInputCache) QueryCache(transformID, sideInputID string) exec.ReusableInput {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+	if !ok {
+		return nil
+	}
+	// Check to see if cached
+	input, ok := c.cache[tok]
+	if !ok {
+		c.metrics.Misses++
+		return nil
+	}
+
+	c.metrics.Hits++
+	return input
+}
+
+// SetCache allows a user to place a ReusableInput materialized from the reader into the SideInputCache
+// with its corresponding transform ID and side input ID. If the IDs do not pair with a known, valid token
+// then we silently do not cache the input, as this is an indication that the runner is treating that input
+// as uncacheable.
+func (c *SideInputCache) SetCache(transformID, sideInputID string, input exec.ReusableInput) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+	if !ok {
+		return
+	}
+	if len(c.cache) >= c.capacity {
+		c.evictElement()
+	}
+	c.cache[tok] = input
+}
+
+func (c *SideInputCache) isValid(tok token) bool {
+	count, ok := c.validTokens[tok]
+	// If the token is not known or not in use, return false
+	return ok && count > 0
+}
+
+// evictElement randomly evicts a ReusableInput that is not currently valid from the cache.
+// It should only be called by a goroutine that obtained the lock in SetCache.
+func (c *SideInputCache) evictElement() {
+	deleted := false
+	// Select a key from the cache at random
+	for k := range c.cache {
+		// Do not evict an element if it's currently valid
+		if !c.isValid(k) {
+			delete(c.cache, k)
+			c.metrics.Evictions++
+			deleted = true
+			break
+		}
+	}
+	// Nothing is deleted if every side input is still valid. Clear
+	// out a random entry and record the in-use eviction
+	if !deleted {
+		for k := range c.cache {
+			delete(c.cache, k)
+			c.metrics.InUseEvictions++
+			break
+		}
+	}
+}
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
new file mode 100644
index 0000000..b9970c3
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statecache
+
+import (
+	"testing"
+
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+	fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+// TestReusableInput implements the ReusableInput interface for the purposes
+// of testing.
+type TestReusableInput struct {
+	transformID string
+	sideInputID string
+	value       interface{}
+}
+
+func makeTestReusableInput(transformID, sideInputID string, value interface{}) exec.ReusableInput {
+	return &TestReusableInput{transformID: transformID, sideInputID: sideInputID, value: value}
+}
+
+// Init is a ReusableInput interface method, this is a no-op.
+func (r *TestReusableInput) Init() error {
+	return nil
+}
+
+// Value returns the stored value in the TestReusableInput.
+func (r *TestReusableInput) Value() interface{} {
+	return r.value
+}
+
+// Reset clears the value in the TestReusableInput.
+func (r *TestReusableInput) Reset() error {
+	r.value = nil
+	return nil
+}
+
+func TestInit(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(5)
+	if err != nil {
+		t.Errorf("SideInputCache failed but should have succeeded, got %v", err)
+	}
+}
+
+func TestInit_Bad(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(0)
+	if err == nil {
+		t.Error("SideInputCache init succeeded but should have failed")
+	}
+}
+
+func TestQueryCache_EmptyCase(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+	output := s.QueryCache("side1", "transform1")
+	if output != nil {
+		t.Errorf("Cache hit when it should have missed, got %v", output)
+	}
+}
+
+func TestSetCache_UncacheableCase(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+	input := makeTestReusableInput("t1", "s1", 10)
+	s.SetCache("t1", "s1", input)
+	output := s.QueryCache("t1", "s1")
+	if output != nil {
+		t.Errorf("Cache hit when should have missed, got %v", output)
+	}
+}
+
+func TestSetCache_CacheableCase(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+	transID := "t1"
+	sideID := "s1"
+	tok := token("tok1")
+	s.setValidToken(transID, sideID, tok)
+	input := makeTestReusableInput(transID, sideID, 10)
+	s.SetCache(transID, sideID, input)
+	output := s.QueryCache(transID, sideID)
+	if output == nil {
+		t.Fatalf("call to query cache missed when should have hit")
+	}
+	val, ok := output.Value().(int)
+	if !ok {
+		t.Errorf("failed to convert value to integer, got %v", output.Value())
+	}
+	if val != 10 {
+		t.Errorf("element mismatch, expected 10, got %v", val)
+	}
+}
+
+func makeRequest(transformID, sideInputID string, t token) fnpb.ProcessBundleRequest_CacheToken {
+	var tok fnpb.ProcessBundleRequest_CacheToken
+	var wrap fnpb.ProcessBundleRequest_CacheToken_SideInput_
+	var side fnpb.ProcessBundleRequest_CacheToken_SideInput
+	side.TransformId = transformID
+	side.SideInputId = sideInputID
+	wrap.SideInput = &side
+	tok.Type = &wrap
+	tok.Token = []byte(t)
+	return tok
+}
+
+func TestSetValidTokens(t *testing.T) {
+	inputs := []struct {
+		transformID string
+		sideInputID string
+		tok         token
+	}{
+		{
+			"t1",
+			"s1",
+			"tok1",
+		},
+		{
+			"t2",
+			"s2",
+			"tok2",
+		},
+		{
+			"t3",
+			"s3",
+			"tok3",
+		},
+	}
+
+	var s SideInputCache
+	err := s.Init(3)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	var tokens []fnpb.ProcessBundleRequest_CacheToken
+	for _, input := range inputs {
+		t := makeRequest(input.transformID, input.sideInputID, input.tok)
+		tokens = append(tokens, t)
+	}
+
+	s.SetValidTokens(tokens...)
+	if len(s.idsToTokens) != len(inputs) {
+		t.Errorf("Missing tokens, expected %v, got %v", len(inputs), len(s.idsToTokens))
+	}
+
+	for i, input := range inputs {
+		// Check that the token is in the valid list
+		if !s.isValid(input.tok) {
+			t.Errorf("error in input %v, token %v is not valid", i, input.tok)
+		}
+		// Check that the mapping of IDs to tokens is correct
+		mapped := s.idsToTokens[input.transformID+input.sideInputID]
+		if mapped != input.tok {
+			t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tok, mapped)
+		}
+	}
+}
+
+func TestSetValidTokens_ClearingBetween(t *testing.T) {
+	inputs := []struct {
+		transformID string
+		sideInputID string
+		tk          token
+	}{
+		{
+			"t1",
+			"s1",
+			"tok1",
+		},
+		{
+			"t2",
+			"s2",
+			"tok2",
+		},
+		{
+			"t3",
+			"s3",
+			"tok3",
+		},
+	}
+
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	for i, input := range inputs {
+		tok := makeRequest(input.transformID, input.sideInputID, input.tk)
+
+		s.SetValidTokens(tok)
+
+		// Check that the token is in the valid list
+		if !s.isValid(input.tk) {
+			t.Errorf("error in input %v, token %v is not valid", i, input.tk)
+		}
+		// Check that the mapping of IDs to tokens is correct
+		mapped := s.idsToTokens[input.transformID+input.sideInputID]
+		if mapped != input.tk {
+			t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tk, mapped)
+		}
+
+		s.CompleteBundle(tok)
+	}
+
+	for k, _ := range s.validTokens {
+		if s.validTokens[k] != 0 {
+			t.Errorf("token count mismatch for token %v, expected 0, got %v", k, s.validTokens[k])
+		}
+	}
+}
+
+func TestSetCache_Eviction(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	tokOne := makeRequest("t1", "s1", "tok1")
+	inOne := makeTestReusableInput("t1", "s1", 10)
+	s.SetValidTokens(tokOne)
+	s.SetCache("t1", "s1", inOne)
+	// Mark bundle as complete, drop count for tokOne to 0
+	s.CompleteBundle(tokOne)
+
+	tokTwo := makeRequest("t2", "s2", "tok2")
+	inTwo := makeTestReusableInput("t2", "s2", 20)
+	s.SetValidTokens(tokTwo)
+	s.SetCache("t2", "s2", inTwo)
+
+	if len(s.cache) != 1 {
+		t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+	}
+	if s.metrics.Evictions != 1 {
+		t.Errorf("number evictions incorrect, expected 1, got %v", s.metrics.Evictions)
+	}
+}
+
+func TestSetCache_EvictionFailure(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	tokOne := makeRequest("t1", "s1", "tok1")
+	inOne := makeTestReusableInput("t1", "s1", 10)
+
+	tokTwo := makeRequest("t2", "s2", "tok2")
+	inTwo := makeTestReusableInput("t2", "s2", 20)
+
+	s.SetValidTokens(tokOne, tokTwo)
+	s.SetCache("t1", "s1", inOne)
+	// Should fail to evict because the first token is still valid
+	s.SetCache("t2", "s2", inTwo)
+	// Cache should not exceed size 1
+	if len(s.cache) != 1 {
+		t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+	}
+	if s.metrics.InUseEvictions != 1 {
+		t.Errorf("number of failed evicition calls incorrect, expected 1, got %v", s.metrics.InUseEvictions)
+	}
+}
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
index 511f839..f4ab8bb 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
@@ -156,16 +156,20 @@
   }
 
   /**
-   * An adapter which converts an {@link InputStream} to an {@link Iterator} of {@code T} values
-   * using the specified {@link Coder}.
+   * An adapter which converts an {@link InputStream} to a {@link PrefetchableIterator} of {@code T}
+   * values using the specified {@link Coder}.
    *
    * <p>Note that this adapter follows the Beam Fn API specification for forcing values that decode
    * consuming zero bytes to consuming exactly one byte.
    *
    * <p>Note that access to the underlying {@link InputStream} is lazy and will only be invoked on
-   * first access to {@link #next()} or {@link #hasNext()}.
+   * first access to {@link #next}, {@link #hasNext}, {@link #isReady}, and {@link #prefetch}.
+   *
+   * <p>Note that {@link #isReady} and {@link #prefetch} rely on non-empty {@link ByteString}s being
+   * returned via the underlying {@link PrefetchableIterator} otherwise the {@link #prefetch} will
+   * seemingly make zero progress yet will actually advance through the empty pages.
    */
-  public static class DataStreamDecoder<T> implements Iterator<T> {
+  public static class DataStreamDecoder<T> implements PrefetchableIterator<T> {
 
     private enum State {
       READ_REQUIRED,
@@ -173,13 +177,13 @@
       EOF
     }
 
-    private final Iterator<ByteString> inputByteStrings;
+    private final PrefetchableIterator<ByteString> inputByteStrings;
     private final Inbound inbound;
     private final Coder<T> coder;
     private State currentState;
     private T next;
 
-    public DataStreamDecoder(Coder<T> coder, Iterator<ByteString> inputStream) {
+    public DataStreamDecoder(Coder<T> coder, PrefetchableIterator<ByteString> inputStream) {
       this.currentState = State.READ_REQUIRED;
       this.coder = coder;
       this.inputByteStrings = inputStream;
@@ -187,6 +191,31 @@
     }
 
     @Override
+    public boolean isReady() {
+      switch (currentState) {
+        case EOF:
+          return true;
+        case READ_REQUIRED:
+          try {
+            return inbound.isReady();
+          } catch (IOException e) {
+            throw new RuntimeException(e);
+          }
+        case HAS_NEXT:
+          return true;
+        default:
+          throw new IllegalStateException(String.format("Unknown state %s", currentState));
+      }
+    }
+
+    @Override
+    public void prefetch() {
+      if (!isReady()) {
+        inputByteStrings.prefetch();
+      }
+    }
+
+    @Override
     public boolean hasNext() {
       switch (currentState) {
         case EOF:
@@ -232,8 +261,8 @@
     private static final InputStream EMPTY_STREAM = ByteString.EMPTY.newInput();
 
     /**
-     * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the first
-     * {@link Iterator} on first access of this input stream.
+     * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the {@link
+     * Iterator} on first access of this input stream.
      *
      * <p>Closing this input stream has no effect.
      */
@@ -245,6 +274,22 @@
         this.currentStream = EMPTY_STREAM;
       }
 
+      public boolean isReady() throws IOException {
+        // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
+        // minus the number of bytes that have been read so far and can be reliably used to tell
+        // us whether we are at the end of the stream.
+        while (currentStream.available() == 0) {
+          if (!inputByteStrings.isReady()) {
+            return false;
+          }
+          if (!inputByteStrings.hasNext()) {
+            return true;
+          }
+          currentStream = inputByteStrings.next().newInput();
+        }
+        return true;
+      }
+
       public boolean isEof() throws IOException {
         // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
         // minus the number of bytes that have been read so far and can be reliably used to tell
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
index 9dd5ee4..a8b48e8 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
@@ -23,6 +23,7 @@
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assume.assumeTrue;
 
 import java.io.IOException;
@@ -106,7 +107,7 @@
     }
 
     @Test
-    public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception {
+    public void testNonEmptyInputStreamWithZeroLengthEncoding() throws Exception {
       CountingOutputStream countingOutputStream =
           new CountingOutputStream(ByteStreams.nullOutputStream());
       GlobalWindow.Coder.INSTANCE.encode(GlobalWindow.INSTANCE, countingOutputStream);
@@ -115,6 +116,55 @@
       testDecoderWith(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, GlobalWindow.INSTANCE);
     }
 
+    @Test
+    public void testPrefetch() throws Exception {
+      List<ByteString> encodings = new ArrayList<>();
+      {
+        ByteString.Output encoding = ByteString.newOutput();
+        StringUtf8Coder.of().encode("A", encoding);
+        StringUtf8Coder.of().encode("BC", encoding);
+        encodings.add(encoding.toByteString());
+      }
+      encodings.add(ByteString.EMPTY);
+      {
+        ByteString.Output encoding = ByteString.newOutput();
+        StringUtf8Coder.of().encode("DEF", encoding);
+        StringUtf8Coder.of().encode("GHIJ", encoding);
+        encodings.add(encoding.toByteString());
+      }
+
+      PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<ByteString> iterator =
+          new PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<>(encodings.iterator());
+      PrefetchableIterator<String> decoder =
+          new DataStreamDecoder<>(StringUtf8Coder.of(), iterator);
+      assertFalse(decoder.isReady());
+      decoder.prefetch();
+      assertTrue(decoder.isReady());
+      assertEquals(1, iterator.getNumPrefetchCalls());
+
+      decoder.next();
+      // Now we will have moved off of the empty byte array that we start with so prefetch will
+      // do nothing since we are ready
+      assertTrue(decoder.isReady());
+      decoder.prefetch();
+      assertEquals(1, iterator.getNumPrefetchCalls());
+
+      decoder.next();
+      // Now we are at the end of the first ByteString so we expect a prefetch to pass through
+      assertFalse(decoder.isReady());
+      decoder.prefetch();
+      assertEquals(2, iterator.getNumPrefetchCalls());
+      // We also expect the decoder to not be ready since the next byte string is empty which
+      // would require us to move to the next page. This typically wouldn't happen in practice
+      // though because we expect non empty pages.
+      assertFalse(decoder.isReady());
+
+      // Prefetching will allow us to move to the third ByteString
+      decoder.prefetch();
+      assertEquals(3, iterator.getNumPrefetchCalls());
+      assertTrue(decoder.isReady());
+    }
+
     private <T> void testDecoderWith(Coder<T> coder, T... expected) throws IOException {
       ByteString.Output output = ByteString.newOutput();
       for (T value : expected) {
@@ -131,7 +181,9 @@
     }
 
     private <T> void testDecoderWith(Coder<T> coder, T[] expected, List<ByteString> encoded) {
-      Iterator<T> decoder = new DataStreamDecoder<>(coder, encoded.iterator());
+      Iterator<T> decoder =
+          new DataStreamDecoder<>(
+              coder, PrefetchableIterators.maybePrefetchable(encoded.iterator()));
 
       Object[] actual = Iterators.toArray(decoder, Object.class);
       assertArrayEquals(expected, actual);
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
index 6131634..9ada175 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
@@ -120,10 +120,14 @@
         "F");
   }
 
-  private static class NeverReady implements PrefetchableIterator<String> {
-    PrefetchableIterator<String> delegate = PrefetchableIterators.fromArray("A", "B");
+  public static class NeverReady<T> implements PrefetchableIterator<T> {
+    private final Iterator<T> delegate;
     int prefetchCalled;
 
+    public NeverReady(Iterator<T> delegate) {
+      this.delegate = delegate;
+    }
+
     @Override
     public boolean isReady() {
       return false;
@@ -140,74 +144,117 @@
     }
 
     @Override
-    public String next() {
+    public T next() {
       return delegate.next();
     }
+
+    public int getNumPrefetchCalls() {
+      return prefetchCalled;
+    }
   }
 
-  private static class ReadyAfterPrefetch extends NeverReady {
+  public static class ReadyAfterPrefetch<T> extends NeverReady<T> {
+
+    public ReadyAfterPrefetch(Iterator<T> delegate) {
+      super(delegate);
+    }
+
     @Override
     public boolean isReady() {
       return prefetchCalled > 0;
     }
   }
 
+  public static class ReadyAfterPrefetchUntilNext<T> extends ReadyAfterPrefetch<T> {
+    boolean advancedSincePrefetch;
+
+    public ReadyAfterPrefetchUntilNext(Iterator<T> delegate) {
+      super(delegate);
+    }
+
+    @Override
+    public boolean isReady() {
+      return !advancedSincePrefetch && super.isReady();
+    }
+
+    @Override
+    public void prefetch() {
+      advancedSincePrefetch = false;
+      super.prefetch();
+    }
+
+    @Override
+    public T next() {
+      advancedSincePrefetch = true;
+      return super.next();
+    }
+
+    @Override
+    public boolean hasNext() {
+      advancedSincePrefetch = true;
+      return super.hasNext();
+    }
+  }
+
   @Test
   public void testConcatIsReadyAdvancesToNextIteratorWhenAble() {
-    NeverReady readyAfterPrefetch1 = new NeverReady();
-    ReadyAfterPrefetch readyAfterPrefetch2 = new ReadyAfterPrefetch();
-    ReadyAfterPrefetch readyAfterPrefetch3 = new ReadyAfterPrefetch();
+    NeverReady<String> readyAfterPrefetch1 =
+        new NeverReady<>(PrefetchableIterators.fromArray("A", "B"));
+    ReadyAfterPrefetch<String> readyAfterPrefetch2 =
+        new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
+    ReadyAfterPrefetch<String> readyAfterPrefetch3 =
+        new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
 
     PrefetchableIterator<String> iterator =
         PrefetchableIterators.concat(readyAfterPrefetch1, readyAfterPrefetch2, readyAfterPrefetch3);
 
     // Expect no prefetches yet
-    assertEquals(0, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(0, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
 
     // We expect to attempt to prefetch for the first time.
     iterator.prefetch();
-    assertEquals(1, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(1, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // We expect to attempt to prefetch again since we aren't ready.
     iterator.prefetch();
-    assertEquals(2, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(2, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // The current iterator is done but is never ready so we can't advance to the next one and
     // expect another prefetch to go to the current iterator.
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // Now that we know the last iterator is done and have advanced to the next one we expect
     // prefetch to go through
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // The last iterator is done so we should be able to prefetch the next one before advancing
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // The current iterator is ready so no additional prefetch is necessary
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
   }
 
diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle
index 3c859ae..6337cd4 100644
--- a/sdks/java/harness/build.gradle
+++ b/sdks/java/harness/build.gradle
@@ -72,6 +72,7 @@
   testCompile library.java.mockito_core
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
   testCompile project(":runners:core-construction-java")
+  testCompile project(path: ":sdks:java:fn-execution", configuration: "testRuntime")
   shadowTestRuntimeClasspath library.java.slf4j_jdk14
   jmhCompile project(path: ":sdks:java:harness", configuration: "shadowTest")
   jmhRuntime library.java.slf4j_jdk14
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index 5ddf0ae..777036a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -26,6 +26,8 @@
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 
@@ -49,7 +51,7 @@
   private final BeamFnStateClient beamFnStateClient;
   private final StateRequest request;
   private final Coder<T> valueCoder;
-  private Iterable<T> oldValues;
+  private PrefetchableIterable<T> oldValues;
   private ArrayList<T> newValues;
   private boolean isClosed;
 
@@ -80,19 +82,19 @@
     this.newValues = new ArrayList<>();
   }
 
-  public Iterable<T> get() {
+  public PrefetchableIterable<T> get() {
     checkState(
         !isClosed,
         "Bag user state is no longer usable because it is closed for %s",
         request.getStateKey());
     if (oldValues == null) {
       // If we were cleared we should disregard old values.
-      return Iterables.limit(Collections.unmodifiableList(newValues), newValues.size());
+      return PrefetchableIterables.limit(Collections.unmodifiableList(newValues), newValues.size());
     } else if (newValues.isEmpty()) {
       // If we have no new values then just return the old values.
       return oldValues;
     }
-    return Iterables.concat(
+    return PrefetchableIterables.concat(
         oldValues, Iterables.limit(Collections.unmodifiableList(newValues), newValues.size()));
   }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 2f2789e..5a931c5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -38,7 +38,6 @@
 import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.OrderedListState;
 import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
 import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.StateBinder;
 import org.apache.beam.sdk.state.StateContext;
@@ -264,7 +263,7 @@
 
                   @Override
                   public ValueState<T> readLater() {
-                    // TODO(BEAM-12802): Support prefetching.
+                    impl.get().iterator().prefetch();
                     return this;
                   }
                 };
@@ -310,7 +309,7 @@
 
                   @Override
                   public BagState<T> readLater() {
-                    // TODO(BEAM-12802): Support prefetching.
+                    impl.get().iterator().prefetch();
                     return this;
                   }
 
@@ -391,6 +390,7 @@
 
                   @Override
                   public CombiningState<ElementT, AccumT, ResultT> readLater() {
+                    impl.get().iterator().prefetch();
                     return this;
                   }
 
@@ -412,7 +412,18 @@
 
                   @Override
                   public ReadableState<Boolean> isEmpty() {
-                    return ReadableStates.immediate(!impl.get().iterator().hasNext());
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public @Nullable Boolean read() {
+                        return !impl.get().iterator().hasNext();
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        impl.get().iterator().prefetch();
+                        return this;
+                      }
+                    };
                   }
 
                   @Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index cfc76cf..7828f93 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -21,6 +21,9 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.NoSuchElementException;
+import java.util.Objects;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
@@ -28,30 +31,42 @@
  * Converts an iterator to an iterable lazily loading values from the underlying iterator and
  * caching them to support reiteration.
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
-class LazyCachingIteratorToIterable<T> implements Iterable<T> {
+class LazyCachingIteratorToIterable<T> implements PrefetchableIterable<T> {
   private final List<T> cachedElements;
-  private final Iterator<T> iterator;
+  private final PrefetchableIterator<T> iterator;
 
-  public LazyCachingIteratorToIterable(Iterator<T> iterator) {
+  public LazyCachingIteratorToIterable(PrefetchableIterator<T> iterator) {
     this.cachedElements = new ArrayList<>();
     this.iterator = iterator;
   }
 
   @Override
-  public Iterator<T> iterator() {
+  public PrefetchableIterator<T> iterator() {
     return new CachingIterator();
   }
 
   /** An {@link Iterator} which adds and fetched values into the cached elements list. */
-  private class CachingIterator implements Iterator<T> {
+  private class CachingIterator implements PrefetchableIterator<T> {
     private int position = 0;
 
     private CachingIterator() {}
 
     @Override
+    public boolean isReady() {
+      if (position < cachedElements.size()) {
+        return true;
+      }
+      return iterator.isReady();
+    }
+
+    @Override
+    public void prefetch() {
+      if (!isReady()) {
+        iterator.prefetch();
+      }
+    }
+
+    @Override
     public boolean hasNext() {
       // The order of the short circuit is important below.
       return position < cachedElements.size() || iterator.hasNext();
@@ -76,7 +91,7 @@
 
   @Override
   public int hashCode() {
-    return iterator.hasNext() ? iterator.next().hashCode() : -1789023489;
+    return iterator.hasNext() ? Objects.hashCode(iterator.next()) : -1789023489;
   }
 
   @Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 22be306..1026ba5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -17,20 +17,21 @@
  */
 package org.apache.beam.fn.harness.state;
 
-import java.util.Collections;
 import java.util.Iterator;
 import java.util.NoSuchElementException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.function.Supplier;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.fn.stream.DataStreams;
+import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 
 /**
  * Adapters which convert a a logical series of chunks using continuation tokens over the Beam Fn
@@ -54,7 +55,7 @@
    *     only) chunk of a state stream. This state request will be populated with a continuation
    *     token to request further chunks of the stream if required.
    */
-  public static Iterator<ByteString> readAllStartingFrom(
+  public static PrefetchableIterator<ByteString> readAllStartingFrom(
       BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
     return new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk);
   }
@@ -74,94 +75,142 @@
    *     token to request further chunks of the stream if required.
    * @param valueCoder A coder for decoding the state stream.
    */
-  public static <T> Iterable<T> readAllAndDecodeStartingFrom(
+  public static <T> PrefetchableIterable<T> readAllAndDecodeStartingFrom(
       BeamFnStateClient beamFnStateClient,
       StateRequest stateRequestForFirstChunk,
       Coder<T> valueCoder) {
-    FirstPageAndRemainder firstPageAndRemainder =
-        new FirstPageAndRemainder(beamFnStateClient, stateRequestForFirstChunk);
-    return Iterables.concat(
-        new LazyCachingIteratorToIterable<T>(
-            new DataStreams.DataStreamDecoder<>(
-                valueCoder, new LazySingletonIterator<>(firstPageAndRemainder::firstPage))),
-        () -> new DataStreams.DataStreamDecoder<>(valueCoder, firstPageAndRemainder.remainder()));
-  }
-
-  /** A iterable that contains a single element, provided by a Supplier which is invoked lazily. */
-  static class LazySingletonIterator<T> implements Iterator<T> {
-
-    private final Supplier<T> supplier;
-    private boolean hasNext;
-
-    private LazySingletonIterator(Supplier<T> supplier) {
-      this.supplier = supplier;
-      hasNext = true;
-    }
-
-    @Override
-    public boolean hasNext() {
-      return hasNext;
-    }
-
-    @Override
-    public T next() {
-      hasNext = false;
-      return supplier.get();
-    }
+    return new FirstPageAndRemainder<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder);
   }
 
   /**
-   * An helper class that (lazily) gives the first page of a paginated state request separately from
+   * A helper class that (lazily) gives the first page of a paginated state request separately from
    * all the remaining pages.
    */
-  static class FirstPageAndRemainder {
+  @VisibleForTesting
+  static class FirstPageAndRemainder<T> implements PrefetchableIterable<T> {
     private final BeamFnStateClient beamFnStateClient;
     private final StateRequest stateRequestForFirstChunk;
-    private ByteString firstPage = null;
+    private final Coder<T> valueCoder;
+    private LazyCachingIteratorToIterable<T> firstPage;
+    private CompletableFuture<StateResponse> firstPageResponseFuture;
     private ByteString continuationToken;
 
-    private FirstPageAndRemainder(
-        BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
+    FirstPageAndRemainder(
+        BeamFnStateClient beamFnStateClient,
+        StateRequest stateRequestForFirstChunk,
+        Coder<T> valueCoder) {
       this.beamFnStateClient = beamFnStateClient;
       this.stateRequestForFirstChunk = stateRequestForFirstChunk;
+      this.valueCoder = valueCoder;
     }
 
-    public ByteString firstPage() {
-      if (firstPage == null) {
-        CompletableFuture<StateResponse> stateResponseFuture = new CompletableFuture<>();
+    @Override
+    public PrefetchableIterator<T> iterator() {
+      return new PrefetchableIterator<T>() {
+        PrefetchableIterator<T> delegate;
+
+        private void ensureDelegateExists() {
+          if (delegate == null) {
+            // Fetch the first page if necessary
+            prefetchFirstPage();
+            if (firstPage == null) {
+              StateResponse stateResponse;
+              try {
+                stateResponse = firstPageResponseFuture.get();
+              } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                throw new IllegalStateException(e);
+              } catch (ExecutionException e) {
+                if (e.getCause() == null) {
+                  throw new IllegalStateException(e);
+                }
+                Throwables.throwIfUnchecked(e.getCause());
+                throw new IllegalStateException(e.getCause());
+              }
+              continuationToken = stateResponse.getGet().getContinuationToken();
+              firstPage =
+                  new LazyCachingIteratorToIterable<>(
+                      new DataStreamDecoder<>(
+                          valueCoder,
+                          PrefetchableIterators.fromArray(stateResponse.getGet().getData())));
+            }
+
+            if (ByteString.EMPTY.equals((continuationToken))) {
+              delegate = firstPage.iterator();
+            } else {
+              delegate =
+                  PrefetchableIterators.concat(
+                      firstPage.iterator(),
+                      new DataStreamDecoder<>(
+                          valueCoder,
+                          new LazyBlockingStateFetchingIterator(
+                              beamFnStateClient,
+                              stateRequestForFirstChunk
+                                  .toBuilder()
+                                  .setGet(
+                                      StateGetRequest.newBuilder()
+                                          .setContinuationToken(continuationToken))
+                                  .build())));
+            }
+          }
+        }
+
+        @Override
+        public boolean isReady() {
+          if (delegate == null) {
+            if (firstPageResponseFuture != null) {
+              return firstPageResponseFuture.isDone();
+            }
+            return false;
+          }
+          return delegate.isReady();
+        }
+
+        @Override
+        public void prefetch() {
+          if (firstPageResponseFuture == null) {
+            prefetchFirstPage();
+          } else if (delegate != null && !delegate.isReady()) {
+            delegate.prefetch();
+          }
+        }
+
+        @Override
+        public boolean hasNext() {
+          if (delegate == null) {
+            // Ensure that we prefetch the second page after the first has been accessed.
+            // Prefetching subsequent pages after the first will be handled by the
+            // LazyBlockingStateFetchingIterator
+            ensureDelegateExists();
+            boolean rval = delegate.hasNext();
+            delegate.prefetch();
+            return rval;
+          }
+          return delegate.hasNext();
+        }
+
+        @Override
+        public T next() {
+          if (delegate == null) {
+            // Ensure that we prefetch the second page after the first has been accessed.
+            // Prefetching subsequent pages after the first will be handled by the
+            // LazyBlockingStateFetchingIterator
+            ensureDelegateExists();
+            T rval = delegate.next();
+            delegate.prefetch();
+            return rval;
+          }
+          return delegate.next();
+        }
+      };
+    }
+
+    private void prefetchFirstPage() {
+      if (firstPageResponseFuture == null) {
+        firstPageResponseFuture = new CompletableFuture<>();
         beamFnStateClient.handle(
             stateRequestForFirstChunk.toBuilder().setGet(stateRequestForFirstChunk.getGet()),
-            stateResponseFuture);
-        StateResponse stateResponse;
-        try {
-          stateResponse = stateResponseFuture.get();
-        } catch (InterruptedException e) {
-          Thread.currentThread().interrupt();
-          throw new IllegalStateException(e);
-        } catch (ExecutionException e) {
-          if (e.getCause() == null) {
-            throw new IllegalStateException(e);
-          }
-          Throwables.throwIfUnchecked(e.getCause());
-          throw new IllegalStateException(e.getCause());
-        }
-        continuationToken = stateResponse.getGet().getContinuationToken();
-        firstPage = stateResponse.getGet().getData();
-      }
-      return firstPage;
-    }
-
-    public Iterator<ByteString> remainder() {
-      firstPage();
-      if (ByteString.EMPTY.equals(continuationToken)) {
-        return Collections.emptyIterator();
-      } else {
-        return new LazyBlockingStateFetchingIterator(
-            beamFnStateClient,
-            stateRequestForFirstChunk
-                .toBuilder()
-                .setGet(StateGetRequest.newBuilder().setContinuationToken(continuationToken))
-                .build());
+            firstPageResponseFuture);
       }
     }
   }
@@ -169,10 +218,11 @@
   /**
    * An {@link Iterator} which fetches {@link ByteString} chunks using the State API.
    *
-   * <p>This iterator will only request a chunk on first access. Subsiquently it eagerly pre-fetches
-   * one future chunks at a time.
+   * <p>This iterator will only request a chunk on first access. Subsequently it eagerly pre-fetches
+   * one future chunk at a time.
    */
-  static class LazyBlockingStateFetchingIterator implements Iterator<ByteString> {
+  @VisibleForTesting
+  static class LazyBlockingStateFetchingIterator implements PrefetchableIterator<ByteString> {
 
     private enum State {
       READ_REQUIRED,
@@ -195,8 +245,17 @@
       this.continuationToken = stateRequestForFirstChunk.getGet().getContinuationToken();
     }
 
-    private void prefetch() {
-      if (prefetchedResponse == null && currentState == State.READ_REQUIRED) {
+    @Override
+    public boolean isReady() {
+      if (prefetchedResponse == null) {
+        return currentState != State.READ_REQUIRED;
+      }
+      return prefetchedResponse.isDone();
+    }
+
+    @Override
+    public void prefetch() {
+      if (currentState == State.READ_REQUIRED && prefetchedResponse == null) {
         prefetchedResponse = new CompletableFuture<>();
         beamFnStateClient.handle(
             stateRequestForFirstChunk
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
index 7597128..0914b01 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
@@ -25,8 +25,11 @@
 
 import java.util.Iterator;
 import java.util.NoSuchElementException;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
+import org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetch;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -40,7 +43,8 @@
 
   @Test
   public void testEmptyIterator() {
-    Iterable<Object> iterable = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+    Iterable<Object> iterable =
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.emptyIterator());
     assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
     // iterate multiple times
     assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
@@ -52,7 +56,7 @@
   @Test
   public void testInterleavedIteration() {
     Iterable<String> iterable =
-        new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
 
     Iterator<String> iterator1 = iterable.iterator();
     assertTrue(iterator1.hasNext());
@@ -77,14 +81,45 @@
 
   @Test
   public void testEqualsAndHashCode() {
-    Iterable<String> iterA = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
-    Iterable<String> iterB = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
-    Iterable<String> iterC = new LazyCachingIteratorToIterable<>(Iterators.forArray());
-    Iterable<String> iterD = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+    Iterable<String> iterA =
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+    Iterable<String> iterB =
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+    Iterable<String> iterC = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
+    Iterable<String> iterD = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
     assertEquals(iterA, iterB);
     assertEquals(iterC, iterD);
     assertNotEquals(iterA, iterC);
     assertEquals(iterA.hashCode(), iterB.hashCode());
     assertEquals(iterC.hashCode(), iterD.hashCode());
   }
+
+  @Test
+  public void testPrefetch() {
+    ReadyAfterPrefetch<String> underlying =
+        new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B", "C"));
+    PrefetchableIterable<String> iterable = new LazyCachingIteratorToIterable<>(underlying);
+    PrefetchableIterator<String> iterator1 = iterable.iterator();
+    PrefetchableIterator<String> iterator2 = iterable.iterator();
+
+    // Check that the lazy iterable doesn't do any prefetch/access on instantiation
+    assertFalse(underlying.isReady());
+    assertFalse(iterator1.isReady());
+    assertFalse(iterator2.isReady());
+
+    // Check that if both iterators prefetch there is only one prefetch for the underlying iterator
+    // iterator.
+    iterator1.prefetch();
+    assertEquals(1, underlying.getNumPrefetchCalls());
+    iterator2.prefetch();
+    assertEquals(1, underlying.getNumPrefetchCalls());
+
+    // Check that if that one iterator has advanced, the second doesn't perform any prefetch since
+    // the element is now cached.
+    iterator1.next();
+    iterator1.next();
+    iterator2.next();
+    iterator2.prefetch();
+    assertEquals(1, underlying.getNumPrefetchCalls());
+  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
index fc729cc..384d2df 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
@@ -19,12 +19,16 @@
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Iterator;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
+import org.apache.beam.fn.harness.state.StateFetchingIterators.FirstPageAndRemainder;
 import org.apache.beam.fn.harness.state.StateFetchingIterators.LazyBlockingStateFetchingIterator;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
@@ -32,16 +36,56 @@
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
 import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Tests for {@link StateFetchingIterators}. */
+@RunWith(Enclosed.class)
 public class StateFetchingIteratorsTest {
+
+  private static BeamFnStateClient fakeStateClient(
+      AtomicInteger callCount, ByteString... expected) {
+    return (requestBuilder, response) -> {
+      callCount.incrementAndGet();
+      if (expected.length == 0) {
+        response.complete(
+            StateResponse.newBuilder()
+                .setId(requestBuilder.getId())
+                .setGet(StateGetResponse.newBuilder())
+                .build());
+        return;
+      }
+
+      ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
+
+      int requestedPosition = 0; // Default position is 0
+      if (!ByteString.EMPTY.equals(continuationToken)) {
+        requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
+      }
+
+      // Compute the new continuation token
+      ByteString newContinuationToken = ByteString.EMPTY;
+      if (requestedPosition != expected.length - 1) {
+        newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
+      }
+      response.complete(
+          StateResponse.newBuilder()
+              .setId(requestBuilder.getId())
+              .setGet(
+                  StateGetResponse.newBuilder()
+                      .setData(expected[requestedPosition])
+                      .setContinuationToken(newContinuationToken))
+              .build());
+    };
+  }
+
   /** Tests for {@link StateFetchingIterators.LazyBlockingStateFetchingIterator}. */
   @RunWith(JUnit4.class)
   public static class LazyBlockingStateFetchingIteratorTest {
@@ -77,49 +121,55 @@
           ByteString.EMPTY);
     }
 
-    private BeamFnStateClient fakeStateClient(AtomicInteger callCount, ByteString... expected) {
-      return (requestBuilder, response) -> {
-        callCount.incrementAndGet();
-        if (expected.length == 0) {
-          response.complete(
-              StateResponse.newBuilder()
-                  .setId(requestBuilder.getId())
-                  .setGet(StateGetResponse.newBuilder())
-                  .build());
-          return;
-        }
-
-        ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
-
-        int requestedPosition = 0; // Default position is 0
-        if (!ByteString.EMPTY.equals(continuationToken)) {
-          requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
-        }
-
-        // Compute the new continuation token
-        ByteString newContinuationToken = ByteString.EMPTY;
-        if (requestedPosition != expected.length - 1) {
-          newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
-        }
-        response.complete(
-            StateResponse.newBuilder()
-                .setId(requestBuilder.getId())
-                .setGet(
-                    StateGetResponse.newBuilder()
-                        .setData(expected[requestedPosition])
-                        .setContinuationToken(newContinuationToken))
-                .build());
-      };
+    @Test
+    public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception {
+      AtomicInteger callCount = new AtomicInteger();
+      BeamFnStateClient fakeStateClient =
+          new BeamFnStateClient() {
+            @Override
+            public void handle(
+                StateRequest.Builder requestBuilder, CompletableFuture<StateResponse> response) {
+              callCount.incrementAndGet();
+            }
+          };
+      PrefetchableIterator<ByteString> byteStrings =
+          new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
+      assertEquals(0, callCount.get());
+      byteStrings.prefetch();
+      assertEquals(1, callCount.get()); // first prefetch
+      byteStrings.prefetch();
+      assertEquals(1, callCount.get()); // subsequent is ignored
     }
 
     private void testFetch(ByteString... expected) {
       AtomicInteger callCount = new AtomicInteger();
       BeamFnStateClient fakeStateClient = fakeStateClient(callCount, expected);
-      Iterator<ByteString> byteStrings =
+      PrefetchableIterator<ByteString> byteStrings =
           new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
       assertEquals(0, callCount.get()); // Ensure it's fully lazy.
-      assertArrayEquals(expected, Iterators.toArray(byteStrings, Object.class));
+      assertFalse(byteStrings.isReady());
+
+      // Prefetch every second element in the iterator capturing the results
+      List<ByteString> results = new ArrayList<>();
+      for (int i = 0; i < expected.length; ++i) {
+        if (i % 2 == 0) {
+          // Ensure that prefetch performs the call
+          byteStrings.prefetch();
+          assertEquals(i + 1, callCount.get());
+          assertTrue(byteStrings.isReady());
+        }
+        assertTrue(byteStrings.hasNext());
+        results.add(byteStrings.next());
+      }
+      assertFalse(byteStrings.hasNext());
+      assertTrue(byteStrings.isReady());
+
+      assertEquals(Arrays.asList(expected), results);
     }
+  }
+
+  @RunWith(JUnit4.class)
+  public static class FirstPageAndRemainderTest {
 
     @Test
     public void testEmptyValues() throws Exception {
@@ -133,7 +183,7 @@
 
     @Test
     public void testManyValues() throws Exception {
-      testFetchValues(VarIntCoder.of(), 11, 37, 389, 5077);
+      testFetchValues(VarIntCoder.of(), 1, 22, 333, 4444, 55555, 666666);
     }
 
     private <T> void testFetchValues(Coder<T> coder, T... expected) {
@@ -153,35 +203,42 @@
       AtomicInteger callCount = new AtomicInteger();
       BeamFnStateClient fakeStateClient =
           fakeStateClient(callCount, Iterables.toArray(byteStrings, ByteString.class));
-      Iterable<T> values =
-          StateFetchingIterators.readAllAndDecodeStartingFrom(
-              fakeStateClient, StateRequest.getDefaultInstance(), coder);
+      PrefetchableIterable<T> values =
+          new FirstPageAndRemainder<>(fakeStateClient, StateRequest.getDefaultInstance(), coder);
 
       // Ensure it's fully lazy.
       assertEquals(0, callCount.get());
-      Iterator<T> valuesIter = values.iterator();
+      PrefetchableIterator<T> valuesIter = values.iterator();
+      assertFalse(valuesIter.isReady());
       assertEquals(0, callCount.get());
 
-      // No more is read than necessary.
-      if (valuesIter.hasNext()) {
-        valuesIter.next();
-      }
+      // Ensure that the first page result is cached across multiple iterators and subsequent
+      // iterators are ready and prefetch does nothing
+      valuesIter.prefetch();
+      assertTrue(valuesIter.isReady());
       assertEquals(1, callCount.get());
 
-      // The first page is cached.
-      Iterator<T> valuesIter2 = values.iterator();
-      assertEquals(1, callCount.get());
-      if (valuesIter2.hasNext()) {
-        valuesIter2.next();
-      }
+      PrefetchableIterator<T> valuesIter2 = values.iterator();
+      assertTrue(valuesIter2.isReady());
+      valuesIter2.prefetch();
       assertEquals(1, callCount.get());
 
-      if (valuesIter.hasNext()) {
-        valuesIter.next();
-        // Subsequent pages are pre-fetched, so after accessing the second page,
-        // the third should be requested.
-        assertEquals(3, callCount.get());
+      // Prefetch every second element in the iterator capturing the results
+      List<T> results = new ArrayList<>();
+      for (int i = 0; i < expected.length; ++i) {
+        if (i % 2 == 1) {
+          // Ensure that prefetch performs the call
+          valuesIter2.prefetch();
+          assertTrue(valuesIter2.isReady());
+          // Note that this is i+2 because we expect to prefetch the page after the current one
+          // We also have to bound it to the max number of pages
+          assertEquals(Math.min(i + 2, expected.length), callCount.get());
+        }
+        assertTrue(valuesIter2.hasNext());
+        results.add(valuesIter2.next());
       }
+      assertFalse(valuesIter2.hasNext());
+      assertTrue(valuesIter2.isReady());
 
       // The contents agree.
       assertArrayEquals(expected, Iterables.toArray(values, Object.class));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
index dd5202e..74d6636 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
@@ -1218,6 +1218,10 @@
         List<Cursor> cursors = new ArrayList<>(partitionQueryResponse.getPartitionsList());
         cursors.sort(CURSOR_REFERENCE_VALUE_COMPARATOR);
         final int size = cursors.size();
+        if (size == 0) {
+          emit(c, dbRoot, structuredQuery.toBuilder());
+          return;
+        }
         final int lastIdx = size - 1;
         for (int i = 0; i < size; i++) {
           Cursor curr = cursors.get(i);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
index b0cc681..bf6a288 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
@@ -24,25 +24,36 @@
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 
 /** Common util functions for converting between PubsubMessage proto and {@link PubsubMessage}. */
-public class PubsubMessages {
+public final class PubsubMessages {
+  private PubsubMessages() {}
+
+  public static com.google.pubsub.v1.PubsubMessage toProto(PubsubMessage input) {
+    Map<String, String> attributes = input.getAttributeMap();
+    com.google.pubsub.v1.PubsubMessage.Builder message =
+        com.google.pubsub.v1.PubsubMessage.newBuilder()
+            .setData(ByteString.copyFrom(input.getPayload()));
+    // TODO(BEAM-8085) this should not be null
+    if (attributes != null) {
+      message.putAllAttributes(attributes);
+    }
+    String messageId = input.getMessageId();
+    if (messageId != null) {
+      message.setMessageId(messageId);
+    }
+    return message.build();
+  }
+
+  public static PubsubMessage fromProto(com.google.pubsub.v1.PubsubMessage input) {
+    return new PubsubMessage(
+        input.getData().toByteArray(), input.getAttributesMap(), input.getMessageId());
+  }
+
   // Convert the PubsubMessage to a PubsubMessage proto, then return its serialized representation.
   public static class ParsePayloadAsPubsubMessageProto
       implements SerializableFunction<PubsubMessage, byte[]> {
     @Override
     public byte[] apply(PubsubMessage input) {
-      Map<String, String> attributes = input.getAttributeMap();
-      com.google.pubsub.v1.PubsubMessage.Builder message =
-          com.google.pubsub.v1.PubsubMessage.newBuilder()
-              .setData(ByteString.copyFrom(input.getPayload()));
-      // TODO(BEAM-8085) this should not be null
-      if (attributes != null) {
-        message.putAllAttributes(attributes);
-      }
-      String messageId = input.getMessageId();
-      if (messageId != null) {
-        message.setMessageId(messageId);
-      }
-      return message.build().toByteArray();
+      return toProto(input).toByteArray();
     }
   }
 
@@ -54,8 +65,7 @@
       try {
         com.google.pubsub.v1.PubsubMessage message =
             com.google.pubsub.v1.PubsubMessage.parseFrom(input);
-        return new PubsubMessage(
-            message.getData().toByteArray(), message.getAttributesMap(), message.getMessageId());
+        return fromProto(message);
       } catch (InvalidProtocolBufferException e) {
         throw new RuntimeException("Could not decode Pubsub message", e);
       }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubChecks.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubChecks.java
deleted file mode 100644
index 6dc1516..0000000
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubChecks.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.gcp.pubsublite;
-
-import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.toCpsPublishTransformer;
-
-import com.google.cloud.pubsublite.Message;
-import com.google.cloud.pubsublite.proto.PubSubMessage;
-import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.TypeDescriptor;
-
-/**
- * A class providing a conversion validity check between Cloud Pub/Sub and Pub/Sub Lite message
- * types.
- */
-public final class CloudPubsubChecks {
-  private CloudPubsubChecks() {}
-
-  /**
-   * Ensure that all messages that pass through can be converted to Cloud Pub/Sub messages using the
-   * standard transformation methods in the client library.
-   *
-   * <p>Will fail the pipeline if a message has multiple attributes per key.
-   */
-  public static PTransform<PCollection<? extends PubSubMessage>, PCollection<PubSubMessage>>
-      ensureUsableAsCloudPubsub() {
-    return MapElements.into(TypeDescriptor.of(PubSubMessage.class))
-        .via(
-            message -> {
-              Object unused = toCpsPublishTransformer().transform(Message.fromProto(message));
-              return message;
-            });
-  }
-}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubTransforms.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubTransforms.java
new file mode 100644
index 0000000..1140c11
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubTransforms.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.fromCpsPublishTransformer;
+import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.toCpsPublishTransformer;
+import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.toCpsSubscribeTransformer;
+
+import com.google.cloud.pubsublite.Message;
+import com.google.cloud.pubsublite.cloudpubsub.KeyExtractor;
+import com.google.cloud.pubsublite.proto.PubSubMessage;
+import com.google.cloud.pubsublite.proto.SequencedMessage;
+import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage;
+import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessages;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+
+/** A class providing transforms between Cloud Pub/Sub and Pub/Sub Lite message types. */
+public final class CloudPubsubTransforms {
+  private CloudPubsubTransforms() {}
+  /**
+   * Ensure that all messages that pass through can be converted to Cloud Pub/Sub messages using the
+   * standard transformation methods in the client library.
+   *
+   * <p>Will fail the pipeline if a message has multiple attributes per key.
+   */
+  public static PTransform<PCollection<PubSubMessage>, PCollection<PubSubMessage>>
+      ensureUsableAsCloudPubsub() {
+    return new PTransform<PCollection<PubSubMessage>, PCollection<PubSubMessage>>() {
+      @Override
+      public PCollection<PubSubMessage> expand(PCollection<PubSubMessage> input) {
+        return input.apply(
+            MapElements.into(TypeDescriptor.of(PubSubMessage.class))
+                .via(
+                    message -> {
+                      Object unused =
+                          toCpsPublishTransformer().transform(Message.fromProto(message));
+                      return message;
+                    }));
+      }
+    };
+  }
+
+  /**
+   * Transform messages read from Pub/Sub Lite to their equivalent Cloud Pub/Sub Message that would
+   * have been read from PubsubIO.
+   *
+   * <p>Will fail the pipeline if a message has multiple attributes per map key.
+   */
+  public static PTransform<PCollection<SequencedMessage>, PCollection<PubsubMessage>>
+      toCloudPubsubMessages() {
+    return new PTransform<PCollection<SequencedMessage>, PCollection<PubsubMessage>>() {
+      @Override
+      public PCollection<PubsubMessage> expand(PCollection<SequencedMessage> input) {
+        return input.apply(
+            MapElements.into(TypeDescriptor.of(PubsubMessage.class))
+                .via(
+                    message ->
+                        PubsubMessages.fromProto(
+                            toCpsSubscribeTransformer()
+                                .transform(
+                                    com.google.cloud.pubsublite.SequencedMessage.fromProto(
+                                        message)))));
+      }
+    };
+  }
+
+  /**
+   * Transform messages publishable using PubsubIO to their equivalent Pub/Sub Lite publishable
+   * message.
+   */
+  public static PTransform<PCollection<PubsubMessage>, PCollection<PubSubMessage>>
+      fromCloudPubsubMessages() {
+    return new PTransform<PCollection<PubsubMessage>, PCollection<PubSubMessage>>() {
+      @Override
+      public PCollection<PubSubMessage> expand(PCollection<PubsubMessage> input) {
+        return input.apply(
+            MapElements.into(TypeDescriptor.of(PubSubMessage.class))
+                .via(
+                    message ->
+                        fromCpsPublishTransformer(KeyExtractor.DEFAULT)
+                            .transform(PubsubMessages.toProto(message))
+                            .toProto()));
+      }
+    };
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactory.java
new file mode 100644
index 0000000..de0cf43
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactory.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import java.io.Serializable;
+
+/**
+ * A ManagedBacklogReaderFactory produces TopicBacklogReaders and tears down any produced readers
+ * when it is itself closed.
+ *
+ * <p>close() should never be called on produced readers.
+ */
+public interface ManagedBacklogReaderFactory extends AutoCloseable, Serializable {
+  TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition);
+
+  @Override
+  void close();
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactoryImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactoryImpl.java
new file mode 100644
index 0000000..9a337bf
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactoryImpl.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import com.google.api.gax.rpc.ApiException;
+import com.google.cloud.pubsublite.Offset;
+import com.google.cloud.pubsublite.proto.ComputeMessageStatsResponse;
+import java.util.HashMap;
+import java.util.Map;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+public class ManagedBacklogReaderFactoryImpl implements ManagedBacklogReaderFactory {
+  private final SerializableFunction<SubscriptionPartition, TopicBacklogReader> newReader;
+
+  @GuardedBy("this")
+  private final Map<SubscriptionPartition, TopicBacklogReader> readers = new HashMap<>();
+
+  ManagedBacklogReaderFactoryImpl(
+      SerializableFunction<SubscriptionPartition, TopicBacklogReader> newReader) {
+    this.newReader = newReader;
+  }
+
+  private static final class NonCloseableTopicBacklogReader implements TopicBacklogReader {
+    private final TopicBacklogReader underlying;
+
+    NonCloseableTopicBacklogReader(TopicBacklogReader underlying) {
+      this.underlying = underlying;
+    }
+
+    @Override
+    public ComputeMessageStatsResponse computeMessageStats(Offset offset) throws ApiException {
+      return underlying.computeMessageStats(offset);
+    }
+
+    @Override
+    public void close() {
+      throw new IllegalArgumentException(
+          "Cannot call close() on a reader returned from ManagedBacklogReaderFactory.");
+    }
+  }
+
+  @Override
+  public synchronized TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition) {
+    return new NonCloseableTopicBacklogReader(
+        readers.computeIfAbsent(subscriptionPartition, newReader::apply));
+  }
+
+  @Override
+  public synchronized void close() {
+    readers.values().forEach(TopicBacklogReader::close);
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRange.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRange.java
new file mode 100644
index 0000000..b39d87e
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRange.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.coders.DefaultCoder;
+import org.apache.beam.sdk.io.range.OffsetRange;
+
+@AutoValue
+@DefaultCoder(OffsetByteRangeCoder.class)
+abstract class OffsetByteRange {
+  abstract OffsetRange getRange();
+
+  abstract long getByteCount();
+
+  static OffsetByteRange of(OffsetRange range, long byteCount) {
+    return new AutoValue_OffsetByteRange(range, byteCount);
+  }
+
+  static OffsetByteRange of(OffsetRange range) {
+    return of(range, 0);
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeCoder.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeCoder.java
new file mode 100644
index 0000000..076cda1
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeCoder.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderProvider;
+import org.apache.beam.sdk.coders.CoderProviders;
+import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TypeDescriptor;
+
+public class OffsetByteRangeCoder extends AtomicCoder<OffsetByteRange> {
+  private static final Coder<OffsetByteRange> CODER =
+      DelegateCoder.of(
+          KvCoder.of(OffsetRange.Coder.of(), VarLongCoder.of()),
+          OffsetByteRangeCoder::toKv,
+          OffsetByteRangeCoder::fromKv);
+
+  private static KV<OffsetRange, Long> toKv(OffsetByteRange value) {
+    return KV.of(value.getRange(), value.getByteCount());
+  }
+
+  private static OffsetByteRange fromKv(KV<OffsetRange, Long> kv) {
+    return OffsetByteRange.of(kv.getKey(), kv.getValue());
+  }
+
+  @Override
+  public void encode(OffsetByteRange value, OutputStream outStream) throws IOException {
+    CODER.encode(value, outStream);
+  }
+
+  @Override
+  public OffsetByteRange decode(InputStream inStream) throws IOException {
+    return CODER.decode(inStream);
+  }
+
+  public static CoderProvider getCoderProvider() {
+    return CoderProviders.forCoder(
+        TypeDescriptor.of(OffsetByteRange.class), new OffsetByteRangeCoder());
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java
index 608af8f..da9aaaa 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java
@@ -26,8 +26,6 @@
 import java.util.concurrent.TimeUnit;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.io.range.OffsetRange;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
 import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Stopwatch;
 import org.joda.time.Duration;
@@ -44,24 +42,27 @@
  * received. IMPORTANT: minTrackingTime must be strictly smaller than the SDF read timeout when it
  * would return ProcessContinuation.resume().
  */
-class OffsetByteRangeTracker extends RestrictionTracker<OffsetRange, OffsetByteProgress>
-    implements HasProgress {
-  private final TopicBacklogReader backlogReader;
+class OffsetByteRangeTracker extends TrackerWithProgress {
+  private final TopicBacklogReader unownedBacklogReader;
   private final Duration minTrackingTime;
   private final long minBytesReceived;
   private final Stopwatch stopwatch;
-  private OffsetRange range;
+  private OffsetByteRange range;
   private @Nullable Long lastClaimed;
-  private long byteCount = 0;
 
   public OffsetByteRangeTracker(
-      OffsetRange range,
-      TopicBacklogReader backlogReader,
+      OffsetByteRange range,
+      TopicBacklogReader unownedBacklogReader,
       Stopwatch stopwatch,
       Duration minTrackingTime,
       long minBytesReceived) {
-    checkArgument(range.getTo() == Long.MAX_VALUE);
-    this.backlogReader = backlogReader;
+    checkArgument(
+        range.getRange().getTo() == Long.MAX_VALUE,
+        "May only construct OffsetByteRangeTracker with an unbounded range with no progress.");
+    checkArgument(
+        range.getByteCount() == 0L,
+        "May only construct OffsetByteRangeTracker with an unbounded range with no progress.");
+    this.unownedBacklogReader = unownedBacklogReader;
     this.minTrackingTime = minTrackingTime;
     this.minBytesReceived = minBytesReceived;
     this.stopwatch = stopwatch.reset().start();
@@ -69,11 +70,6 @@
   }
 
   @Override
-  public void finalize() {
-    this.backlogReader.close();
-  }
-
-  @Override
   public IsBounded isBounded() {
     return IsBounded.UNBOUNDED;
   }
@@ -87,32 +83,32 @@
         position.lastOffset().value(),
         lastClaimed);
     checkArgument(
-        toClaim >= range.getFrom(),
+        toClaim >= range.getRange().getFrom(),
         "Trying to claim offset %s before start of the range %s",
         toClaim,
         range);
     // split() has already been called, truncating this range. No more offsets may be claimed.
-    if (range.getTo() != Long.MAX_VALUE) {
-      boolean isRangeEmpty = range.getTo() == range.getFrom();
-      boolean isValidClosedRange = nextOffset() == range.getTo();
+    if (range.getRange().getTo() != Long.MAX_VALUE) {
+      boolean isRangeEmpty = range.getRange().getTo() == range.getRange().getFrom();
+      boolean isValidClosedRange = nextOffset() == range.getRange().getTo();
       checkState(
           isRangeEmpty || isValidClosedRange,
           "Violated class precondition: offset range improperly split. Please report a beam bug.");
       return false;
     }
     lastClaimed = toClaim;
-    byteCount += position.batchBytes();
+    range = OffsetByteRange.of(range.getRange(), range.getByteCount() + position.batchBytes());
     return true;
   }
 
   @Override
-  public OffsetRange currentRestriction() {
+  public OffsetByteRange currentRestriction() {
     return range;
   }
 
   private long nextOffset() {
     checkState(lastClaimed == null || lastClaimed < Long.MAX_VALUE);
-    return lastClaimed == null ? currentRestriction().getFrom() : lastClaimed + 1;
+    return lastClaimed == null ? currentRestriction().getRange().getFrom() : lastClaimed + 1;
   }
 
   /**
@@ -124,29 +120,33 @@
     if (duration.isLongerThan(minTrackingTime)) {
       return true;
     }
-    if (byteCount >= minBytesReceived) {
+    if (currentRestriction().getByteCount() >= minBytesReceived) {
       return true;
     }
     return false;
   }
 
   @Override
-  public @Nullable SplitResult<OffsetRange> trySplit(double fractionOfRemainder) {
+  public @Nullable SplitResult<OffsetByteRange> trySplit(double fractionOfRemainder) {
     // Cannot split a bounded range. This should already be completely claimed.
-    if (range.getTo() != Long.MAX_VALUE) {
+    if (range.getRange().getTo() != Long.MAX_VALUE) {
       return null;
     }
     if (!receivedEnough()) {
       return null;
     }
-    range = new OffsetRange(currentRestriction().getFrom(), nextOffset());
-    return SplitResult.of(this.range, new OffsetRange(nextOffset(), Long.MAX_VALUE));
+    range =
+        OffsetByteRange.of(
+            new OffsetRange(currentRestriction().getRange().getFrom(), nextOffset()),
+            range.getByteCount());
+    return SplitResult.of(
+        this.range, OffsetByteRange.of(new OffsetRange(nextOffset(), Long.MAX_VALUE), 0));
   }
 
   @Override
   @SuppressWarnings("unboxing.of.nullable")
   public void checkDone() throws IllegalStateException {
-    if (range.getFrom() == range.getTo()) {
+    if (range.getRange().getFrom() == range.getRange().getTo()) {
       return;
     }
     checkState(
@@ -155,18 +155,18 @@
         range);
     long lastClaimedNotNull = checkNotNull(lastClaimed);
     checkState(
-        lastClaimedNotNull >= range.getTo() - 1,
+        lastClaimedNotNull >= range.getRange().getTo() - 1,
         "Last attempted offset was %s in range %s, claiming work in [%s, %s) was not attempted",
         lastClaimedNotNull,
         range,
         lastClaimedNotNull + 1,
-        range.getTo());
+        range.getRange().getTo());
   }
 
   @Override
   public Progress getProgress() {
     ComputeMessageStatsResponse stats =
-        this.backlogReader.computeMessageStats(Offset.of(nextOffset()));
-    return Progress.from(byteCount, stats.getMessageBytes());
+        this.unownedBacklogReader.computeMessageStats(Offset.of(nextOffset()));
+    return Progress.from(range.getByteCount(), stats.getMessageBytes());
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java
index 623e20c..d7526d8 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java
@@ -27,4 +27,8 @@
   private PerServerPublisherCache() {}
 
   static final PublisherCache PUBLISHER_CACHE = new PublisherCache();
+
+  static {
+    Runtime.getRuntime().addShutdownHook(new Thread(PUBLISHER_CACHE::close));
+  }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java
index a9f7a43..fdf7920 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java
@@ -17,13 +17,12 @@
  */
 package org.apache.beam.sdk.io.gcp.pubsublite;
 
-import static com.google.cloud.pubsublite.internal.ExtractStatus.toCanonical;
-import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static com.google.cloud.pubsublite.internal.wire.ApiServiceUtils.blockingShutdown;
 
 import com.google.cloud.pubsublite.Offset;
+import com.google.cloud.pubsublite.internal.ExtractStatus;
 import com.google.cloud.pubsublite.internal.wire.Committer;
 import com.google.cloud.pubsublite.proto.SequencedMessage;
-import java.util.concurrent.ExecutionException;
 import org.apache.beam.sdk.io.range.OffsetRange;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.SerializableBiFunction;
@@ -35,31 +34,35 @@
 
 class PerSubscriptionPartitionSdf extends DoFn<SubscriptionPartition, SequencedMessage> {
   private final Duration maxSleepTime;
+  private final ManagedBacklogReaderFactory backlogReaderFactory;
   private final SubscriptionPartitionProcessorFactory processorFactory;
   private final SerializableFunction<SubscriptionPartition, InitialOffsetReader>
       offsetReaderFactory;
-  private final SerializableBiFunction<
-          SubscriptionPartition, OffsetRange, RestrictionTracker<OffsetRange, OffsetByteProgress>>
+  private final SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress>
       trackerFactory;
   private final SerializableFunction<SubscriptionPartition, Committer> committerFactory;
 
   PerSubscriptionPartitionSdf(
       Duration maxSleepTime,
+      ManagedBacklogReaderFactory backlogReaderFactory,
       SerializableFunction<SubscriptionPartition, InitialOffsetReader> offsetReaderFactory,
-      SerializableBiFunction<
-              SubscriptionPartition,
-              OffsetRange,
-              RestrictionTracker<OffsetRange, OffsetByteProgress>>
+      SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress>
           trackerFactory,
       SubscriptionPartitionProcessorFactory processorFactory,
       SerializableFunction<SubscriptionPartition, Committer> committerFactory) {
     this.maxSleepTime = maxSleepTime;
+    this.backlogReaderFactory = backlogReaderFactory;
     this.processorFactory = processorFactory;
     this.offsetReaderFactory = offsetReaderFactory;
     this.trackerFactory = trackerFactory;
     this.committerFactory = committerFactory;
   }
 
+  @Teardown
+  public void teardown() {
+    backlogReaderFactory.close();
+  }
+
   @GetInitialWatermarkEstimatorState
   public Instant getInitialWatermarkState() {
     return Instant.EPOCH;
@@ -72,7 +75,7 @@
 
   @ProcessElement
   public ProcessContinuation processElement(
-      RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+      RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
       @Element SubscriptionPartition subscriptionPartition,
       OutputReceiver<SequencedMessage> receiver)
       throws Exception {
@@ -83,38 +86,44 @@
       processor
           .lastClaimed()
           .ifPresent(
-              lastClaimedOffset ->
-              /* TODO(boyuanzz): When default dataflow can use finalizers, undo this.
-              finalizer.afterBundleCommit(
-                  Instant.ofEpochMilli(Long.MAX_VALUE),
-                  () -> */ {
+              lastClaimedOffset -> {
                 Committer committer = committerFactory.apply(subscriptionPartition);
                 committer.startAsync().awaitRunning();
                 // Commit the next-to-deliver offset.
                 try {
                   committer.commitOffset(Offset.of(lastClaimedOffset.value() + 1)).get();
-                } catch (ExecutionException e) {
-                  throw toCanonical(checkArgumentNotNull(e.getCause())).underlying;
                 } catch (Exception e) {
-                  throw toCanonical(e).underlying;
+                  throw ExtractStatus.toCanonical(e).underlying;
                 }
-                committer.stopAsync().awaitTerminated();
+                blockingShutdown(committer);
               });
       return result;
     }
   }
 
   @GetInitialRestriction
-  public OffsetRange getInitialRestriction(@Element SubscriptionPartition subscriptionPartition) {
+  public OffsetByteRange getInitialRestriction(
+      @Element SubscriptionPartition subscriptionPartition) {
     try (InitialOffsetReader reader = offsetReaderFactory.apply(subscriptionPartition)) {
       Offset offset = reader.read();
-      return new OffsetRange(offset.value(), Long.MAX_VALUE /* open interval */);
+      return OffsetByteRange.of(
+          new OffsetRange(offset.value(), Long.MAX_VALUE /* open interval */));
     }
   }
 
   @NewTracker
-  public RestrictionTracker<OffsetRange, OffsetByteProgress> newTracker(
-      @Element SubscriptionPartition subscriptionPartition, @Restriction OffsetRange range) {
-    return trackerFactory.apply(subscriptionPartition, range);
+  public TrackerWithProgress newTracker(
+      @Element SubscriptionPartition subscriptionPartition, @Restriction OffsetByteRange range) {
+    return trackerFactory.apply(backlogReaderFactory.newReader(subscriptionPartition), range);
+  }
+
+  @GetSize
+  public double getSize(
+      @Element SubscriptionPartition subscriptionPartition,
+      @Restriction OffsetByteRange restriction) {
+    if (restriction.getRange().getTo() != Long.MAX_VALUE) {
+      return restriction.getByteCount();
+    }
+    return newTracker(subscriptionPartition, restriction).getProgress().getWorkRemaining();
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java
index f8dc24b..3dbdec6 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java
@@ -23,52 +23,50 @@
 import com.google.api.core.ApiService.State;
 import com.google.api.gax.rpc.ApiException;
 import com.google.cloud.pubsublite.MessageMetadata;
-import com.google.cloud.pubsublite.internal.CloseableMonitor;
 import com.google.cloud.pubsublite.internal.Publisher;
+import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
 import com.google.errorprone.annotations.concurrent.GuardedBy;
 import java.util.HashMap;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 
 /** A map of working publishers by PublisherOptions. */
-class PublisherCache {
-  private final CloseableMonitor monitor = new CloseableMonitor();
-
-  private final Executor listenerExecutor = Executors.newSingleThreadExecutor();
-
-  @GuardedBy("monitor.monitor")
+class PublisherCache implements AutoCloseable {
+  @GuardedBy("this")
   private final HashMap<PublisherOptions, Publisher<MessageMetadata>> livePublishers =
       new HashMap<>();
 
-  Publisher<MessageMetadata> get(PublisherOptions options) throws ApiException {
+  private synchronized void evict(PublisherOptions options) {
+    livePublishers.remove(options);
+  }
+
+  synchronized Publisher<MessageMetadata> get(PublisherOptions options) throws ApiException {
     checkArgument(options.usesCache());
-    try (CloseableMonitor.Hold h = monitor.enter()) {
-      Publisher<MessageMetadata> publisher = livePublishers.get(options);
-      if (publisher != null) {
-        return publisher;
-      }
-      publisher = Publishers.newPublisher(options);
-      livePublishers.put(options, publisher);
-      publisher.addListener(
-          new Listener() {
-            @Override
-            public void failed(State s, Throwable t) {
-              try (CloseableMonitor.Hold h = monitor.enter()) {
-                livePublishers.remove(options);
-              }
-            }
-          },
-          listenerExecutor);
-      publisher.startAsync().awaitRunning();
+    Publisher<MessageMetadata> publisher = livePublishers.get(options);
+    if (publisher != null) {
       return publisher;
     }
+    publisher = Publishers.newPublisher(options);
+    livePublishers.put(options, publisher);
+    publisher.addListener(
+        new Listener() {
+          @Override
+          public void failed(State s, Throwable t) {
+            evict(options);
+          }
+        },
+        SystemExecutors.getFuturesExecutor());
+    publisher.startAsync().awaitRunning();
+    return publisher;
   }
 
   @VisibleForTesting
-  void set(PublisherOptions options, Publisher<MessageMetadata> toCache) {
-    try (CloseableMonitor.Hold h = monitor.enter()) {
-      livePublishers.put(options, toCache);
-    }
+  synchronized void set(PublisherOptions options, Publisher<MessageMetadata> toCache) {
+    livePublishers.put(options, toCache);
+  }
+
+  @Override
+  public synchronized void close() {
+    livePublishers.forEach(((options, publisher) -> publisher.stopAsync()));
+    livePublishers.clear();
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java
index 34012f7..67ea6cf 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java
@@ -17,17 +17,27 @@
  */
 package org.apache.beam.sdk.io.gcp.pubsublite;
 
+import static com.google.cloud.pubsublite.internal.ExtractStatus.toCanonical;
 import static com.google.cloud.pubsublite.internal.UncheckedApiPreconditions.checkArgument;
+import static com.google.cloud.pubsublite.internal.wire.ServiceClients.addDefaultMetadata;
+import static com.google.cloud.pubsublite.internal.wire.ServiceClients.addDefaultSettings;
 
 import com.google.api.gax.rpc.ApiException;
 import com.google.cloud.pubsublite.AdminClient;
 import com.google.cloud.pubsublite.AdminClientSettings;
 import com.google.cloud.pubsublite.MessageMetadata;
-import com.google.cloud.pubsublite.TopicPath;
+import com.google.cloud.pubsublite.Partition;
+import com.google.cloud.pubsublite.cloudpubsub.PublisherSettings;
 import com.google.cloud.pubsublite.internal.Publisher;
 import com.google.cloud.pubsublite.internal.wire.PartitionCountWatchingPublisherSettings;
+import com.google.cloud.pubsublite.internal.wire.PubsubContext;
 import com.google.cloud.pubsublite.internal.wire.PubsubContext.Framework;
+import com.google.cloud.pubsublite.internal.wire.RoutingMetadata;
 import com.google.cloud.pubsublite.internal.wire.SinglePartitionPublisherBuilder;
+import com.google.cloud.pubsublite.v1.AdminServiceClient;
+import com.google.cloud.pubsublite.v1.AdminServiceSettings;
+import com.google.cloud.pubsublite.v1.PublisherServiceClient;
+import com.google.cloud.pubsublite.v1.PublisherServiceSettings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.reflect.TypeToken;
 
 class Publishers {
@@ -35,6 +45,38 @@
 
   private Publishers() {}
 
+  private static AdminClient newAdminClient(PublisherOptions options) throws ApiException {
+    try {
+      return AdminClient.create(
+          AdminClientSettings.newBuilder()
+              .setServiceClient(
+                  AdminServiceClient.create(
+                      addDefaultSettings(
+                          options.topicPath().location().extractRegion(),
+                          AdminServiceSettings.newBuilder())))
+              .setRegion(options.topicPath().location().extractRegion())
+              .build());
+    } catch (Throwable t) {
+      throw toCanonical(t).underlying;
+    }
+  }
+
+  private static PublisherServiceClient newServiceClient(
+      PublisherOptions options, Partition partition) {
+    PublisherServiceSettings.Builder settingsBuilder = PublisherServiceSettings.newBuilder();
+    settingsBuilder =
+        addDefaultMetadata(
+            PubsubContext.of(FRAMEWORK),
+            RoutingMetadata.of(options.topicPath(), partition),
+            settingsBuilder);
+    try {
+      return PublisherServiceClient.create(
+          addDefaultSettings(options.topicPath().location().extractRegion(), settingsBuilder));
+    } catch (Throwable t) {
+      throw toCanonical(t).underlying;
+    }
+  }
+
   @SuppressWarnings("unchecked")
   static Publisher<MessageMetadata> newPublisher(PublisherOptions options) throws ApiException {
     SerializableSupplier<Object> supplier = options.publisherSupplier();
@@ -44,20 +86,18 @@
       checkArgument(token.isSupertypeOf(supplied.getClass()));
       return (Publisher<MessageMetadata>) supplied;
     }
-
-    TopicPath topic = options.topicPath();
-    PartitionCountWatchingPublisherSettings.Builder publisherSettings =
-        PartitionCountWatchingPublisherSettings.newBuilder()
-            .setTopic(topic)
-            .setPublisherFactory(
-                partition ->
-                    SinglePartitionPublisherBuilder.newBuilder()
-                        .setTopic(topic)
-                        .setPartition(partition)
-                        .build())
-            .setAdminClient(
-                AdminClient.create(
-                    AdminClientSettings.newBuilder().setRegion(topic.location().region()).build()));
-    return publisherSettings.build().instantiate();
+    return PartitionCountWatchingPublisherSettings.newBuilder()
+        .setTopic(options.topicPath())
+        .setPublisherFactory(
+            partition ->
+                SinglePartitionPublisherBuilder.newBuilder()
+                    .setTopic(options.topicPath())
+                    .setPartition(partition)
+                    .setServiceClient(newServiceClient(options, partition))
+                    .setBatchingSettings(PublisherSettings.DEFAULT_BATCHING_SETTINGS)
+                    .build())
+        .setAdminClient(newAdminClient(options))
+        .build()
+        .instantiate();
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java
index ca1f2be..b93ac61 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java
@@ -107,7 +107,7 @@
    * }</pre>
    */
   public static PTransform<PCollection<PubSubMessage>, PDone> write(PublisherOptions options) {
-    return new PTransform<PCollection<PubSubMessage>, PDone>("PubsubLiteIO") {
+    return new PTransform<PCollection<PubSubMessage>, PDone>() {
       @Override
       public PDone expand(PCollection<PubSubMessage> input) {
         PubsubLiteSink sink = new PubsubLiteSink(options);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java
index d3acdfa..d0e3afa 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java
@@ -28,16 +28,14 @@
 import com.google.cloud.pubsublite.internal.CheckedApiException;
 import com.google.cloud.pubsublite.internal.ExtractStatus;
 import com.google.cloud.pubsublite.internal.Publisher;
+import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
 import com.google.cloud.pubsublite.proto.PubSubMessage;
 import com.google.errorprone.annotations.concurrent.GuardedBy;
 import java.util.ArrayDeque;
 import java.util.Deque;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
 import java.util.function.Consumer;
 import org.apache.beam.sdk.io.gcp.pubsublite.PublisherOrError.Kind;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
 
 /** A sink which publishes messages to Pub/Sub Lite. */
 @SuppressWarnings({
@@ -56,8 +54,6 @@
   @GuardedBy("this")
   private transient Deque<CheckedApiException> errorsSinceLastFinish;
 
-  private static final Executor executor = Executors.newCachedThreadPool();
-
   PubsubLiteSink(PublisherOptions options) {
     this.options = options;
   }
@@ -89,7 +85,7 @@
             onFailure.accept(t);
           }
         },
-        MoreExecutors.directExecutor());
+        SystemExecutors.getFuturesExecutor());
     if (!options.usesCache()) {
       publisher.startAsync();
     }
@@ -130,7 +126,7 @@
             onFailure.accept(t);
           }
         },
-        executor);
+        SystemExecutors.getFuturesExecutor());
   }
 
   // Intentionally don't flush on bundle finish to allow multi-sink client reuse.
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java
index 9875880..b6a9f5d 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java
@@ -23,6 +23,7 @@
 import com.google.api.gax.rpc.ApiException;
 import com.google.cloud.pubsublite.AdminClient;
 import com.google.cloud.pubsublite.AdminClientSettings;
+import com.google.cloud.pubsublite.Offset;
 import com.google.cloud.pubsublite.Partition;
 import com.google.cloud.pubsublite.TopicPath;
 import com.google.cloud.pubsublite.internal.wire.Committer;
@@ -31,8 +32,6 @@
 import java.util.List;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
-import org.apache.beam.sdk.io.range.OffsetRange;
-import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
@@ -54,10 +53,11 @@
     checkArgument(subscriptionPartition.subscription().equals(options.subscriptionPath()));
   }
 
-  private Subscriber newSubscriber(Partition partition, Consumer<List<SequencedMessage>> consumer) {
+  private Subscriber newSubscriber(
+      Partition partition, Offset initialOffset, Consumer<List<SequencedMessage>> consumer) {
     try {
       return options
-          .getSubscriberFactory(partition)
+          .getSubscriberFactory(partition, initialOffset)
           .newSubscriber(
               messages ->
                   consumer.accept(
@@ -71,23 +71,31 @@
 
   private SubscriptionPartitionProcessor newPartitionProcessor(
       SubscriptionPartition subscriptionPartition,
-      RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+      RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
       OutputReceiver<SequencedMessage> receiver)
       throws ApiException {
     checkSubscription(subscriptionPartition);
     return new SubscriptionPartitionProcessorImpl(
         tracker,
         receiver,
-        consumer -> newSubscriber(subscriptionPartition.partition(), consumer),
+        consumer ->
+            newSubscriber(
+                subscriptionPartition.partition(),
+                Offset.of(tracker.currentRestriction().getRange().getFrom()),
+                consumer),
         options.flowControlSettings());
   }
 
-  private RestrictionTracker<OffsetRange, OffsetByteProgress> newRestrictionTracker(
-      SubscriptionPartition subscriptionPartition, OffsetRange initial) {
+  private TopicBacklogReader newBacklogReader(SubscriptionPartition subscriptionPartition) {
     checkSubscription(subscriptionPartition);
+    return options.getBacklogReader(subscriptionPartition.partition());
+  }
+
+  private TrackerWithProgress newRestrictionTracker(
+      TopicBacklogReader backlogReader, OffsetByteRange initial) {
     return new OffsetByteRangeTracker(
         initial,
-        options.getBacklogReader(subscriptionPartition.partition()),
+        backlogReader,
         Stopwatch.createUnstarted(),
         options.minBundleTimeout(),
         LongMath.saturatedMultiply(options.flowControlSettings().bytesOutstanding(), 10));
@@ -107,7 +115,7 @@
     try (AdminClient admin =
         AdminClient.create(
             AdminClientSettings.newBuilder()
-                .setRegion(options.subscriptionPath().location().region())
+                .setRegion(options.subscriptionPath().location().extractRegion())
                 .build())) {
       return TopicPath.parse(admin.getSubscription(options.subscriptionPath()).get().getTopic());
     } catch (Throwable t) {
@@ -118,25 +126,15 @@
   @Override
   public PCollection<SequencedMessage> expand(PBegin input) {
     PCollection<SubscriptionPartition> subscriptionPartitions;
-    if (options.partitions().isEmpty()) {
-      subscriptionPartitions =
-          input.apply(new SubscriptionPartitionLoader(getTopicPath(), options.subscriptionPath()));
-    } else {
-      subscriptionPartitions =
-          input.apply(
-              Create.of(
-                  options.partitions().stream()
-                      .map(
-                          partition ->
-                              SubscriptionPartition.of(options.subscriptionPath(), partition))
-                      .collect(Collectors.toList())));
-    }
+    subscriptionPartitions =
+        input.apply(new SubscriptionPartitionLoader(getTopicPath(), options.subscriptionPath()));
 
     return subscriptionPartitions.apply(
         ParDo.of(
             new PerSubscriptionPartitionSdf(
                 // Ensure we read for at least 5 seconds more than the bundle timeout.
                 options.minBundleTimeout().plus(Duration.standardSeconds(5)),
+                new ManagedBacklogReaderFactoryImpl(this::newBacklogReader),
                 this::newInitialOffsetReader,
                 this::newRestrictionTracker,
                 this::newPartitionProcessor,
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java
index 0d3afe2..a9625be 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java
@@ -23,6 +23,7 @@
 
 import com.google.api.gax.rpc.ApiException;
 import com.google.auto.value.AutoValue;
+import com.google.cloud.pubsublite.Offset;
 import com.google.cloud.pubsublite.Partition;
 import com.google.cloud.pubsublite.SubscriptionPath;
 import com.google.cloud.pubsublite.cloudpubsub.FlowControlSettings;
@@ -35,13 +36,13 @@
 import com.google.cloud.pubsublite.internal.wire.RoutingMetadata;
 import com.google.cloud.pubsublite.internal.wire.SubscriberBuilder;
 import com.google.cloud.pubsublite.internal.wire.SubscriberFactory;
+import com.google.cloud.pubsublite.proto.Cursor;
+import com.google.cloud.pubsublite.proto.SeekRequest;
 import com.google.cloud.pubsublite.v1.CursorServiceClient;
 import com.google.cloud.pubsublite.v1.CursorServiceSettings;
 import com.google.cloud.pubsublite.v1.SubscriberServiceClient;
 import com.google.cloud.pubsublite.v1.SubscriberServiceSettings;
 import java.io.Serializable;
-import java.util.Set;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 
@@ -69,11 +70,6 @@
   public abstract FlowControlSettings flowControlSettings();
 
   /**
-   * A set of partitions. If empty, continuously poll the set of partitions using an admin client.
-   */
-  public abstract Set<Partition> partitions();
-
-  /**
    * The minimum wall time to pass before allowing bundle closure.
    *
    * <p>Setting this to too small of a value will result in increased compute costs and lower
@@ -108,7 +104,6 @@
   public static Builder newBuilder() {
     Builder builder = new AutoValue_SubscriberOptions.Builder();
     return builder
-        .setPartitions(ImmutableSet.of())
         .setFlowControlSettings(DEFAULT_FLOW_CONTROL)
         .setMinBundleTimeout(MIN_BUNDLE_TIMEOUT);
   }
@@ -119,20 +114,19 @@
       throws ApiException {
     try {
       SubscriberServiceSettings.Builder settingsBuilder = SubscriberServiceSettings.newBuilder();
-
       settingsBuilder =
           addDefaultMetadata(
               PubsubContext.of(FRAMEWORK),
               RoutingMetadata.of(subscriptionPath(), partition),
               settingsBuilder);
       return SubscriberServiceClient.create(
-          addDefaultSettings(subscriptionPath().location().region(), settingsBuilder));
+          addDefaultSettings(subscriptionPath().location().extractRegion(), settingsBuilder));
     } catch (Throwable t) {
       throw toCanonical(t).underlying;
     }
   }
 
-  SubscriberFactory getSubscriberFactory(Partition partition) {
+  SubscriberFactory getSubscriberFactory(Partition partition, Offset initialOffset) {
     SubscriberFactory factory = subscriberFactory();
     if (factory != null) {
       return factory;
@@ -143,6 +137,10 @@
             .setSubscriptionPath(subscriptionPath())
             .setPartition(partition)
             .setServiceClient(newSubscriberServiceClient(partition))
+            .setInitialLocation(
+                SeekRequest.newBuilder()
+                    .setCursor(Cursor.newBuilder().setOffset(initialOffset.value()))
+                    .build())
             .build();
   }
 
@@ -150,7 +148,7 @@
     try {
       return CursorServiceClient.create(
           addDefaultSettings(
-              subscriptionPath().location().region(), CursorServiceSettings.newBuilder()));
+              subscriptionPath().location().extractRegion(), CursorServiceSettings.newBuilder()));
     } catch (Throwable t) {
       throw toCanonical(t).underlying;
     }
@@ -189,7 +187,7 @@
     return new InitialOffsetReaderImpl(
         CursorClient.create(
             CursorClientSettings.newBuilder()
-                .setRegion(subscriptionPath().location().region())
+                .setRegion(subscriptionPath().location().extractRegion())
                 .build()),
         subscriptionPath(),
         partition);
@@ -201,8 +199,6 @@
     public abstract Builder setSubscriptionPath(SubscriptionPath path);
 
     // Optional parameters
-    public abstract Builder setPartitions(Set<Partition> partitions);
-
     public abstract Builder setFlowControlSettings(FlowControlSettings flowControlSettings);
 
     public abstract Builder setMinBundleTimeout(Duration minBundleTimeout);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java
index 866e922..e411d80 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java
@@ -92,7 +92,9 @@
                     })
                 .withPollInterval(pollDuration)
                 .withTerminationPerInput(
-                    terminate ? Watch.Growth.afterIterations(10) : Watch.Growth.never()));
+                    terminate
+                        ? Watch.Growth.afterTotalOf(pollDuration.multipliedBy(10))
+                        : Watch.Growth.never()));
     return partitions.apply(
         MapElements.into(TypeDescriptor.of(SubscriptionPartition.class))
             .via(kv -> SubscriptionPartition.of(subscription, kv.getValue())));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java
index 6bf3623..530c180 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java
@@ -19,7 +19,6 @@
 
 import com.google.cloud.pubsublite.proto.SequencedMessage;
 import java.io.Serializable;
-import org.apache.beam.sdk.io.range.OffsetRange;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 
@@ -28,6 +27,6 @@
 
   SubscriptionPartitionProcessor newProcessor(
       SubscriptionPartition subscriptionPartition,
-      RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+      RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
       OutputReceiver<SequencedMessage> receiver);
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java
index 8d2a137..a086d18 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.io.gcp.pubsublite;
 
+import static com.google.cloud.pubsublite.internal.wire.ApiServiceUtils.blockingShutdown;
+
 import com.google.api.core.ApiService.Listener;
 import com.google.api.core.ApiService.State;
 import com.google.cloud.pubsublite.Offset;
@@ -24,9 +26,8 @@
 import com.google.cloud.pubsublite.internal.CheckedApiException;
 import com.google.cloud.pubsublite.internal.ExtractStatus;
 import com.google.cloud.pubsublite.internal.wire.Subscriber;
-import com.google.cloud.pubsublite.proto.Cursor;
+import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
 import com.google.cloud.pubsublite.proto.FlowControlRequest;
-import com.google.cloud.pubsublite.proto.SeekRequest;
 import com.google.cloud.pubsublite.proto.SequencedMessage;
 import com.google.protobuf.util.Timestamps;
 import java.util.List;
@@ -36,19 +37,17 @@
 import java.util.concurrent.TimeoutException;
 import java.util.function.Consumer;
 import java.util.function.Function;
-import org.apache.beam.sdk.io.range.OffsetRange;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.SettableFuture;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 
 class SubscriptionPartitionProcessorImpl extends Listener
     implements SubscriptionPartitionProcessor {
-  private final RestrictionTracker<OffsetRange, OffsetByteProgress> tracker;
+  private final RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker;
   private final OutputReceiver<SequencedMessage> receiver;
   private final Subscriber subscriber;
   private final SettableFuture<Void> completionFuture = SettableFuture.create();
@@ -57,7 +56,7 @@
 
   @SuppressWarnings("methodref.receiver.bound.invalid")
   SubscriptionPartitionProcessorImpl(
-      RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+      RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
       OutputReceiver<SequencedMessage> receiver,
       Function<Consumer<List<SequencedMessage>>, Subscriber> subscriberFactory,
       FlowControlSettings flowControlSettings) {
@@ -70,23 +69,15 @@
   @Override
   @SuppressWarnings("argument.type.incompatible")
   public void start() throws CheckedApiException {
-    this.subscriber.addListener(this, MoreExecutors.directExecutor());
+    this.subscriber.addListener(this, SystemExecutors.getFuturesExecutor());
     this.subscriber.startAsync();
     this.subscriber.awaitRunning();
     try {
-      this.subscriber
-          .seek(
-              SeekRequest.newBuilder()
-                  .setCursor(Cursor.newBuilder().setOffset(tracker.currentRestriction().getFrom()))
-                  .build())
-          .get();
       this.subscriber.allowFlow(
           FlowControlRequest.newBuilder()
               .setAllowedBytes(flowControlSettings.bytesOutstanding())
               .setAllowedMessages(flowControlSettings.messagesOutstanding())
               .build());
-    } catch (ExecutionException e) {
-      throw ExtractStatus.toCanonical(e.getCause());
     } catch (Throwable t) {
       throw ExtractStatus.toCanonical(t);
     }
@@ -125,7 +116,7 @@
 
   @Override
   public void close() {
-    subscriber.stopAsync().awaitTerminated();
+    blockingShutdown(subscriber);
   }
 
   @Override
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java
index 8c1dd94..79db0f1 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java
@@ -62,7 +62,7 @@
       try (AdminClient adminClient =
           AdminClient.create(
               AdminClientSettings.newBuilder()
-                  .setRegion(subscriptionPath.location().region())
+                  .setRegion(subscriptionPath.location().extractRegion())
                   .build())) {
         return setTopicPath(
             TopicPath.parse(adminClient.getSubscription(subscriptionPath).get().getTopic()));
@@ -81,7 +81,9 @@
 
   TopicBacklogReader instantiate() throws ApiException {
     TopicStatsClientSettings settings =
-        TopicStatsClientSettings.newBuilder().setRegion(topicPath().location().region()).build();
+        TopicStatsClientSettings.newBuilder()
+            .setRegion(topicPath().location().extractRegion())
+            .build();
     TopicBacklogReader impl =
         new TopicBacklogReaderImpl(TopicStatsClient.create(settings), topicPath(), partition());
     return new LimitingTopicBacklogReader(impl, Ticker.systemTicker());
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TrackerWithProgress.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TrackerWithProgress.java
new file mode 100644
index 0000000..7f0d030
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TrackerWithProgress.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
+
+public abstract class TrackerWithProgress
+    extends RestrictionTracker<OffsetByteRange, OffsetByteProgress> implements HasProgress {}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
index 0c9bbf1..1f29883 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
@@ -99,6 +99,41 @@
     assertEquals(expected, allValues);
   }
 
+  @Test
+  public void endToEnd_emptyCursors() throws Exception {
+    // First page of the response
+    PartitionQueryRequest request1 =
+        PartitionQueryRequest.newBuilder()
+            .setParent(String.format("projects/%s/databases/(default)/document", projectId))
+            .build();
+    PartitionQueryResponse response1 = PartitionQueryResponse.newBuilder().build();
+    when(callable.call(request1)).thenReturn(pagedResponse1);
+    when(page1.getResponse()).thenReturn(response1);
+    when(pagedResponse1.iteratePages()).thenReturn(ImmutableList.of(page1));
+
+    when(stub.partitionQueryPagedCallable()).thenReturn(callable);
+
+    when(ff.getFirestoreStub(any())).thenReturn(stub);
+    RpcQosOptions options = RpcQosOptions.defaultOptions();
+    when(ff.getRpcQos(any()))
+        .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options));
+
+    ArgumentCaptor<PartitionQueryPair> responses =
+        ArgumentCaptor.forClass(PartitionQueryPair.class);
+
+    doNothing().when(processContext).output(responses.capture());
+
+    when(processContext.element()).thenReturn(request1);
+
+    PartitionQueryFn fn = new PartitionQueryFn(clock, ff, options);
+
+    runFunction(fn);
+
+    List<PartitionQueryPair> expected = newArrayList(new PartitionQueryPair(request1, response1));
+    List<PartitionQueryPair> allValues = responses.getAllValues();
+    assertEquals(expected, allValues);
+  }
+
   @Override
   public void resumeFromLastReadValue() throws Exception {
     when(ff.getFirestoreStub(any())).thenReturn(stub);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
index 25ed63c..ed789da 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
@@ -121,6 +121,39 @@
     assertEquals(expectedQueries, actualQueries);
   }
 
+  @Test
+  public void ensureCursorPairingWorks_emptyCursorsInResponse() {
+    StructuredQuery query =
+        StructuredQuery.newBuilder()
+            .addFrom(
+                CollectionSelector.newBuilder()
+                    .setAllDescendants(true)
+                    .setCollectionId("c1")
+                    .build())
+            .build();
+
+    List<StructuredQuery> expectedQueries = newArrayList(query);
+
+    PartitionQueryPair partitionQueryPair =
+        new PartitionQueryPair(
+            PartitionQueryRequest.newBuilder().setStructuredQuery(query).build(),
+            PartitionQueryResponse.newBuilder().build());
+
+    ArgumentCaptor<RunQueryRequest> captor = ArgumentCaptor.forClass(RunQueryRequest.class);
+    when(processContext.element()).thenReturn(partitionQueryPair);
+    doNothing().when(processContext).output(captor.capture());
+
+    PartitionQueryResponseToRunQueryRequest fn = new PartitionQueryResponseToRunQueryRequest();
+    fn.processElement(processContext);
+
+    List<StructuredQuery> actualQueries =
+        captor.getAllValues().stream()
+            .map(RunQueryRequest::getStructuredQuery)
+            .collect(Collectors.toList());
+
+    assertEquals(expectedQueries, actualQueries);
+  }
+
   private static Cursor referenceValueCursor(String referenceValue) {
     return Cursor.newBuilder()
         .addValues(Value.newBuilder().setReferenceValue(referenceValue).build())
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java
index f34ebb6..5a31f4f 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java
@@ -49,7 +49,7 @@
   private static final double IGNORED_FRACTION = -10000000.0;
   private static final long MIN_BYTES = 1000;
   private static final OffsetRange RANGE = new OffsetRange(123L, Long.MAX_VALUE);
-  private final TopicBacklogReader reader = mock(TopicBacklogReader.class);
+  private final TopicBacklogReader unownedBacklogReader = mock(TopicBacklogReader.class);
 
   @Spy Ticker ticker;
   private OffsetByteRangeTracker tracker;
@@ -60,14 +60,18 @@
     when(ticker.read()).thenReturn(0L);
     tracker =
         new OffsetByteRangeTracker(
-            RANGE, reader, Stopwatch.createUnstarted(ticker), Duration.millis(500), MIN_BYTES);
+            OffsetByteRange.of(RANGE, 0),
+            unownedBacklogReader,
+            Stopwatch.createUnstarted(ticker),
+            Duration.millis(500),
+            MIN_BYTES);
   }
 
   @Test
   public void progressTracked() {
     assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(123), 10)));
     assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(124), 11)));
-    when(reader.computeMessageStats(Offset.of(125)))
+    when(unownedBacklogReader.computeMessageStats(Offset.of(125)))
         .thenReturn(ComputeMessageStatsResponse.newBuilder().setMessageBytes(1000).build());
     Progress progress = tracker.getProgress();
     assertEquals(21, progress.getWorkCompleted(), .0001);
@@ -76,7 +80,7 @@
 
   @Test
   public void getProgressStatsFailure() {
-    when(reader.computeMessageStats(Offset.of(123)))
+    when(unownedBacklogReader.computeMessageStats(Offset.of(123)))
         .thenThrow(new CheckedApiException(Code.INTERNAL).underlying);
     assertThrows(ApiException.class, tracker::getProgress);
   }
@@ -86,11 +90,15 @@
   public void claimSplitSuccess() {
     assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(1_000), MIN_BYTES)));
     assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(10_000), MIN_BYTES)));
-    SplitResult<OffsetRange> splits = tracker.trySplit(IGNORED_FRACTION);
-    assertEquals(RANGE.getFrom(), splits.getPrimary().getFrom());
-    assertEquals(10_001, splits.getPrimary().getTo());
-    assertEquals(10_001, splits.getResidual().getFrom());
-    assertEquals(Long.MAX_VALUE, splits.getResidual().getTo());
+    SplitResult<OffsetByteRange> splits = tracker.trySplit(IGNORED_FRACTION);
+    OffsetByteRange primary = splits.getPrimary();
+    assertEquals(RANGE.getFrom(), primary.getRange().getFrom());
+    assertEquals(10_001, primary.getRange().getTo());
+    assertEquals(MIN_BYTES * 2, primary.getByteCount());
+    OffsetByteRange residual = splits.getResidual();
+    assertEquals(10_001, residual.getRange().getFrom());
+    assertEquals(Long.MAX_VALUE, residual.getRange().getTo());
+    assertEquals(0, residual.getByteCount());
     assertEquals(splits.getPrimary(), tracker.currentRestriction());
     tracker.checkDone();
     assertNull(tracker.trySplit(IGNORED_FRACTION));
@@ -100,10 +108,10 @@
   @SuppressWarnings({"dereference.of.nullable", "argument.type.incompatible"})
   public void splitWithoutClaimEmpty() {
     when(ticker.read()).thenReturn(100000000000000L);
-    SplitResult<OffsetRange> splits = tracker.trySplit(IGNORED_FRACTION);
-    assertEquals(RANGE.getFrom(), splits.getPrimary().getFrom());
-    assertEquals(RANGE.getFrom(), splits.getPrimary().getTo());
-    assertEquals(RANGE, splits.getResidual());
+    SplitResult<OffsetByteRange> splits = tracker.trySplit(IGNORED_FRACTION);
+    assertEquals(RANGE.getFrom(), splits.getPrimary().getRange().getFrom());
+    assertEquals(RANGE.getFrom(), splits.getPrimary().getRange().getTo());
+    assertEquals(RANGE, splits.getResidual().getRange());
     assertEquals(splits.getPrimary(), tracker.currentRestriction());
     tracker.checkDone();
     assertNull(tracker.trySplit(IGNORED_FRACTION));
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java
index 598037e..0a4e3e7 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java
@@ -28,6 +28,7 @@
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 import static org.mockito.MockitoAnnotations.initMocks;
 
@@ -51,6 +52,8 @@
 import org.apache.beam.sdk.transforms.SerializableBiFunction;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.math.DoubleMath;
 import org.joda.time.Duration;
 import org.junit.Before;
 import org.junit.Test;
@@ -65,22 +68,24 @@
 public class PerSubscriptionPartitionSdfTest {
   private static final Duration MAX_SLEEP_TIME =
       Duration.standardMinutes(10).plus(Duration.millis(10));
-  private static final OffsetRange RESTRICTION = new OffsetRange(1, Long.MAX_VALUE);
+  private static final OffsetByteRange RESTRICTION =
+      OffsetByteRange.of(new OffsetRange(1, Long.MAX_VALUE), 0);
   private static final SubscriptionPartition PARTITION =
       SubscriptionPartition.of(example(SubscriptionPath.class), example(Partition.class));
 
   @Mock SerializableFunction<SubscriptionPartition, InitialOffsetReader> offsetReaderFactory;
 
+  @Mock ManagedBacklogReaderFactory backlogReaderFactory;
+  @Mock TopicBacklogReader backlogReader;
+
   @Mock
-  SerializableBiFunction<
-          SubscriptionPartition, OffsetRange, RestrictionTracker<OffsetRange, OffsetByteProgress>>
-      trackerFactory;
+  SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress> trackerFactory;
 
   @Mock SubscriptionPartitionProcessorFactory processorFactory;
   @Mock SerializableFunction<SubscriptionPartition, Committer> committerFactory;
 
   @Mock InitialOffsetReader initialOffsetReader;
-  @Spy RestrictionTracker<OffsetRange, OffsetByteProgress> tracker;
+  @Spy TrackerWithProgress tracker;
   @Mock OutputReceiver<SequencedMessage> output;
   @Mock SubscriptionPartitionProcessor processor;
 
@@ -98,9 +103,11 @@
     when(trackerFactory.apply(any(), any())).thenReturn(tracker);
     when(committerFactory.apply(any())).thenReturn(committer);
     when(tracker.currentRestriction()).thenReturn(RESTRICTION);
+    when(backlogReaderFactory.newReader(any())).thenReturn(backlogReader);
     sdf =
         new PerSubscriptionPartitionSdf(
             MAX_SLEEP_TIME,
+            backlogReaderFactory,
             offsetReaderFactory,
             trackerFactory,
             processorFactory,
@@ -110,9 +117,10 @@
   @Test
   public void getInitialRestrictionReadSuccess() {
     when(initialOffsetReader.read()).thenReturn(example(Offset.class));
-    OffsetRange range = sdf.getInitialRestriction(PARTITION);
-    assertEquals(example(Offset.class).value(), range.getFrom());
-    assertEquals(Long.MAX_VALUE, range.getTo());
+    OffsetByteRange range = sdf.getInitialRestriction(PARTITION);
+    assertEquals(example(Offset.class).value(), range.getRange().getFrom());
+    assertEquals(Long.MAX_VALUE, range.getRange().getTo());
+    assertEquals(0, range.getByteCount());
     verify(offsetReaderFactory).apply(PARTITION);
   }
 
@@ -125,7 +133,13 @@
   @Test
   public void newTrackerCallsFactory() {
     assertSame(tracker, sdf.newTracker(PARTITION, RESTRICTION));
-    verify(trackerFactory).apply(PARTITION, RESTRICTION);
+    verify(trackerFactory).apply(backlogReader, RESTRICTION);
+  }
+
+  @Test
+  public void tearDownClosesBacklogReaderFactory() {
+    sdf.teardown();
+    verify(backlogReaderFactory).close();
   }
 
   @Test
@@ -159,12 +173,48 @@
     order2.verify(committer).awaitTerminated();
   }
 
+  private static final class NoopManagedBacklogReaderFactory
+      implements ManagedBacklogReaderFactory {
+    @Override
+    public TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition) {
+      return null;
+    }
+
+    @Override
+    public void close() {}
+  }
+
   @Test
   @SuppressWarnings("return.type.incompatible")
   public void dofnIsSerializable() throws Exception {
     ObjectOutputStream output = new ObjectOutputStream(new ByteArrayOutputStream());
     output.writeObject(
         new PerSubscriptionPartitionSdf(
-            MAX_SLEEP_TIME, x -> null, (x, y) -> null, (x, y, z) -> null, (x) -> null));
+            MAX_SLEEP_TIME,
+            new NoopManagedBacklogReaderFactory(),
+            x -> null,
+            (x, y) -> null,
+            (x, y, z) -> null,
+            (x) -> null));
+  }
+
+  @Test
+  public void getProgressUnboundedRangeDelegates() {
+    Progress progress = Progress.from(0, 0.2);
+    when(tracker.getProgress()).thenReturn(progress);
+    assertTrue(
+        DoubleMath.fuzzyEquals(
+            progress.getWorkRemaining(), sdf.getSize(PARTITION, RESTRICTION), .0001));
+    verify(tracker).getProgress();
+  }
+
+  @Test
+  public void getProgressBoundedReturnsBytes() {
+    assertTrue(
+        DoubleMath.fuzzyEquals(
+            123.0,
+            sdf.getSize(PARTITION, OffsetByteRange.of(new OffsetRange(87, 8000), 123)),
+            .0001));
+    verifyNoInteractions(tracker);
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java
new file mode 100644
index 0000000..e242942
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static org.junit.Assert.fail;
+
+import com.google.cloud.pubsublite.AdminClient;
+import com.google.cloud.pubsublite.AdminClientSettings;
+import com.google.cloud.pubsublite.BacklogLocation;
+import com.google.cloud.pubsublite.CloudZone;
+import com.google.cloud.pubsublite.Message;
+import com.google.cloud.pubsublite.ProjectId;
+import com.google.cloud.pubsublite.SubscriptionName;
+import com.google.cloud.pubsublite.SubscriptionPath;
+import com.google.cloud.pubsublite.TopicName;
+import com.google.cloud.pubsublite.TopicPath;
+import com.google.cloud.pubsublite.proto.PubSubMessage;
+import com.google.cloud.pubsublite.proto.SequencedMessage;
+import com.google.cloud.pubsublite.proto.Subscription;
+import com.google.cloud.pubsublite.proto.Subscription.DeliveryConfig.DeliveryRequirement;
+import com.google.cloud.pubsublite.proto.Topic;
+import com.google.cloud.pubsublite.proto.Topic.PartitionConfig.Capacity;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import com.google.protobuf.ByteString;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.StreamingOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.FlatMapElements;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Duration;
+import org.junit.After;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@RunWith(JUnit4.class)
+public class ReadWriteIT {
+  private static final Logger LOG = LoggerFactory.getLogger(ReadWriteIT.class);
+  private static final CloudZone ZONE = CloudZone.parse("us-central1-b");
+  private static final int MESSAGE_COUNT = 90;
+
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+  private static ProjectId getProject(PipelineOptions options) {
+    return ProjectId.of(checkArgumentNotNull(options.as(GcpOptions.class).getProject()));
+  }
+
+  private static String randomName() {
+    return "beam_it_resource_" + ThreadLocalRandom.current().nextLong();
+  }
+
+  private static AdminClient newAdminClient() {
+    return AdminClient.create(AdminClientSettings.newBuilder().setRegion(ZONE.region()).build());
+  }
+
+  private final Deque<Runnable> cleanupActions = new ArrayDeque<>();
+
+  private TopicPath createTopic(ProjectId id) throws Exception {
+    TopicPath toReturn =
+        TopicPath.newBuilder()
+            .setProject(id)
+            .setLocation(ZONE)
+            .setName(TopicName.of(randomName()))
+            .build();
+    Topic.Builder topic = Topic.newBuilder().setName(toReturn.toString());
+    topic
+        .getPartitionConfigBuilder()
+        .setCount(2)
+        .setCapacity(Capacity.newBuilder().setPublishMibPerSec(4).setSubscribeMibPerSec(4));
+    topic.getRetentionConfigBuilder().setPerPartitionBytes(30 * (1L << 30));
+    cleanupActions.addLast(
+        () -> {
+          try (AdminClient client = newAdminClient()) {
+            client.deleteTopic(toReturn).get();
+          } catch (Throwable t) {
+            LOG.error("Failed to clean up topic.", t);
+          }
+        });
+    try (AdminClient client = newAdminClient()) {
+      client.createTopic(topic.build()).get();
+    }
+    return toReturn;
+  }
+
+  private SubscriptionPath createSubscription(TopicPath topic) throws Exception {
+    SubscriptionPath toReturn =
+        SubscriptionPath.newBuilder()
+            .setProject(topic.project())
+            .setLocation(ZONE)
+            .setName(SubscriptionName.of(randomName()))
+            .build();
+    Subscription.Builder subscription = Subscription.newBuilder().setName(toReturn.toString());
+    subscription
+        .getDeliveryConfigBuilder()
+        .setDeliveryRequirement(DeliveryRequirement.DELIVER_IMMEDIATELY);
+    subscription.setTopic(topic.toString());
+    cleanupActions.addLast(
+        () -> {
+          try (AdminClient client = newAdminClient()) {
+            client.deleteSubscription(toReturn).get();
+          } catch (Throwable t) {
+            LOG.error("Failed to clean up subscription.", t);
+          }
+        });
+    try (AdminClient client = newAdminClient()) {
+      client.createSubscription(subscription.build(), BacklogLocation.BEGINNING).get();
+    }
+    return toReturn;
+  }
+
+  @After
+  public void tearDown() {
+    while (!cleanupActions.isEmpty()) {
+      cleanupActions.removeLast().run();
+    }
+  }
+
+  // Workaround for BEAM-12867
+  // TODO(BEAM-12867): Remove this.
+  private static class CustomCreate extends PTransform<PCollection<Void>, PCollection<Integer>> {
+    @Override
+    public PCollection<Integer> expand(PCollection<Void> input) {
+      return input.apply(
+          "createIndexes",
+          FlatMapElements.via(
+              new SimpleFunction<Void, Iterable<Integer>>() {
+                @Override
+                public Iterable<Integer> apply(Void input) {
+                  return IntStream.range(0, MESSAGE_COUNT).boxed().collect(Collectors.toList());
+                }
+              }));
+    }
+  }
+
+  public static void writeMessages(TopicPath topicPath, Pipeline pipeline) {
+    PCollection<Void> trigger = pipeline.apply(Create.of((Void) null));
+    PCollection<Integer> indexes = trigger.apply("createIndexes", new CustomCreate());
+    PCollection<PubSubMessage> messages =
+        indexes.apply(
+            "createMessages",
+            MapElements.via(
+                new SimpleFunction<Integer, PubSubMessage>(
+                    index ->
+                        Message.builder()
+                            .setData(ByteString.copyFromUtf8(index.toString()))
+                            .build()
+                            .toProto()) {}));
+    // Add UUIDs to messages for later deduplication.
+    messages = messages.apply("addUuids", PubsubLiteIO.addUuids());
+    messages.apply(
+        "writeMessages",
+        PubsubLiteIO.write(PublisherOptions.newBuilder().setTopicPath(topicPath).build()));
+  }
+
+  public static PCollection<SequencedMessage> readMessages(
+      SubscriptionPath subscriptionPath, Pipeline pipeline) {
+    PCollection<SequencedMessage> messages =
+        pipeline.apply(
+            "readMessages",
+            PubsubLiteIO.read(
+                SubscriberOptions.newBuilder()
+                    .setSubscriptionPath(subscriptionPath)
+                    // setMinBundleTimeout INTENDED FOR TESTING ONLY
+                    // This sacrifices efficiency to make tests run faster. Do not use this in a
+                    // real pipeline!
+                    .setMinBundleTimeout(Duration.standardSeconds(5))
+                    .build()));
+    // Deduplicate messages based on the uuids added in PubsubLiteIO.addUuids() when writing.
+    return messages.apply(
+        "dedupeMessages", PubsubLiteIO.deduplicate(UuidDeduplicationOptions.newBuilder().build()));
+  }
+
+  // This static out of band communication is needed to retain serializability.
+  @GuardedBy("ReadWriteIT.class")
+  private static final List<SequencedMessage> received = new ArrayList<>();
+
+  private static synchronized void addMessageReceived(SequencedMessage message) {
+    received.add(message);
+  }
+
+  private static synchronized List<SequencedMessage> getTestQuickstartReceived() {
+    return ImmutableList.copyOf(received);
+  }
+
+  private static PTransform<PCollection<? extends SequencedMessage>, PCollection<Void>>
+      collectTestQuickstart() {
+    return MapElements.via(
+        new SimpleFunction<SequencedMessage, Void>() {
+          @Override
+          public Void apply(SequencedMessage input) {
+            addMessageReceived(input);
+            return null;
+          }
+        });
+  }
+
+  @Test
+  public void testReadWrite() throws Exception {
+    pipeline.getOptions().as(StreamingOptions.class).setStreaming(true);
+    pipeline.getOptions().as(TestPipelineOptions.class).setBlockOnRun(false);
+
+    TopicPath topic = createTopic(getProject(pipeline.getOptions()));
+    SubscriptionPath subscription = createSubscription(topic);
+
+    // Publish some messages
+    writeMessages(topic, pipeline);
+
+    // Read some messages. They should be deduplicated by the time we see them, so there should be
+    // exactly numMessages, one for every index in [0,MESSAGE_COUNT).
+    PCollection<SequencedMessage> messages = readMessages(subscription, pipeline);
+    messages.apply("messageReceiver", collectTestQuickstart());
+    pipeline.run();
+    LOG.info("Running!");
+    for (int round = 0; round < 120; ++round) {
+      Thread.sleep(1000);
+      Map<Integer, Integer> receivedCounts = new HashMap<>();
+      for (SequencedMessage message : getTestQuickstartReceived()) {
+        int id = Integer.parseInt(message.getMessage().getData().toStringUtf8());
+        receivedCounts.put(id, receivedCounts.getOrDefault(id, 0) + 1);
+      }
+      LOG.info("Performing comparison round {}.\n", round);
+      boolean done = true;
+      List<Integer> missing = new ArrayList<>();
+      for (int id = 0; id < MESSAGE_COUNT; id++) {
+        int idCount = receivedCounts.getOrDefault(id, 0);
+        if (idCount == 0) {
+          missing.add(id);
+          done = false;
+        }
+        if (idCount > 1) {
+          fail(String.format("Failed to deduplicate message with id %s.", id));
+        }
+      }
+      LOG.info("Still messing messages: {}.\n", missing);
+      if (done) {
+        return;
+      }
+    }
+    fail(
+        String.format(
+            "Failed to receive all messages after 2 minutes. Received %s messages.",
+            getTestQuickstartReceived().size()));
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java
index dbf3b93..3d743758 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java
@@ -31,7 +31,6 @@
 import static org.mockito.Mockito.when;
 import static org.mockito.MockitoAnnotations.initMocks;
 
-import com.google.api.core.ApiFutures;
 import com.google.api.gax.rpc.ApiException;
 import com.google.api.gax.rpc.StatusCode.Code;
 import com.google.cloud.pubsublite.Offset;
@@ -40,7 +39,6 @@
 import com.google.cloud.pubsublite.internal.wire.Subscriber;
 import com.google.cloud.pubsublite.proto.Cursor;
 import com.google.cloud.pubsublite.proto.FlowControlRequest;
-import com.google.cloud.pubsublite.proto.SeekRequest;
 import com.google.cloud.pubsublite.proto.SequencedMessage;
 import com.google.protobuf.util.Timestamps;
 import java.util.List;
@@ -64,7 +62,7 @@
 @RunWith(JUnit4.class)
 @SuppressWarnings("initialization.fields.uninitialized")
 public class SubscriptionPartitionProcessorImplTest {
-  @Spy RestrictionTracker<OffsetRange, OffsetByteProgress> tracker;
+  @Spy RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker;
   @Mock OutputReceiver<SequencedMessage> receiver;
   @Mock Function<Consumer<List<SequencedMessage>>, Subscriber> subscriberFactory;
 
@@ -83,6 +81,10 @@
         .build();
   }
 
+  private OffsetByteRange initialRange() {
+    return OffsetByteRange.of(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
+  }
+
   @Before
   public void setUp() {
     initMocks(this);
@@ -100,18 +102,11 @@
 
   @Test
   public void lifecycle() throws Exception {
-    when(tracker.currentRestriction())
-        .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
-    when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+    when(tracker.currentRestriction()).thenReturn(initialRange());
     processor.start();
     verify(subscriber).startAsync();
     verify(subscriber).awaitRunning();
     verify(subscriber)
-        .seek(
-            SeekRequest.newBuilder()
-                .setCursor(Cursor.newBuilder().setOffset(example(Offset.class).value()))
-                .build());
-    verify(subscriber)
         .allowFlow(
             FlowControlRequest.newBuilder()
                 .setAllowedBytes(DEFAULT_FLOW_CONTROL.bytesOutstanding())
@@ -123,29 +118,15 @@
   }
 
   @Test
-  public void lifecycleSeekThrows() throws Exception {
-    when(tracker.currentRestriction())
-        .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
-    when(subscriber.seek(any()))
-        .thenReturn(ApiFutures.immediateFailedFuture(new CheckedApiException(Code.OUT_OF_RANGE)));
+  public void lifecycleFlowControlThrows() throws Exception {
+    when(tracker.currentRestriction()).thenReturn(initialRange());
     doThrow(new CheckedApiException(Code.OUT_OF_RANGE)).when(subscriber).allowFlow(any());
     assertThrows(CheckedApiException.class, () -> processor.start());
   }
 
   @Test
-  public void lifecycleFlowControlThrows() {
-    when(tracker.currentRestriction())
-        .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
-    when(subscriber.seek(any()))
-        .thenReturn(ApiFutures.immediateFailedFuture(new CheckedApiException(Code.OUT_OF_RANGE)));
-    assertThrows(CheckedApiException.class, () -> processor.start());
-  }
-
-  @Test
   public void lifecycleSubscriberAwaitThrows() throws Exception {
-    when(tracker.currentRestriction())
-        .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
-    when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+    when(tracker.currentRestriction()).thenReturn(initialRange());
     processor.start();
     doThrow(new CheckedApiException(Code.INTERNAL).underlying).when(subscriber).awaitTerminated();
     assertThrows(ApiException.class, () -> processor.close());
@@ -155,21 +136,19 @@
 
   @Test
   public void subscriberFailureFails() throws Exception {
-    when(tracker.currentRestriction())
-        .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
-    when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+    when(tracker.currentRestriction()).thenReturn(initialRange());
     processor.start();
     subscriber.fail(new CheckedApiException(Code.OUT_OF_RANGE));
     ApiException e =
-        assertThrows(ApiException.class, () -> processor.waitForCompletion(Duration.ZERO));
+        assertThrows(
+            // Longer wait is needed due to listener asynchrony.
+            ApiException.class, () -> processor.waitForCompletion(Duration.standardSeconds(1)));
     assertEquals(Code.OUT_OF_RANGE, e.getStatusCode().getCode());
   }
 
   @Test
   public void allowFlowFailureFails() throws Exception {
-    when(tracker.currentRestriction())
-        .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
-    when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+    when(tracker.currentRestriction()).thenReturn(initialRange());
     processor.start();
     when(tracker.tryClaim(any())).thenReturn(true);
     doThrow(new CheckedApiException(Code.OUT_OF_RANGE)).when(subscriber).allowFlow(any());
diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle
index 69cd88f..3f9d0c7 100644
--- a/sdks/java/io/jms/build.gradle
+++ b/sdks/java/io/jms/build.gradle
@@ -36,6 +36,7 @@
   testCompile library.java.activemq_kahadb_store
   testCompile library.java.activemq_client
   testCompile library.java.junit
+  testCompile library.java.mockito_core
   testRuntimeOnly library.java.slf4j_jdk14
   testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java
new file mode 100644
index 0000000..0e023d1
--- /dev/null
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jms;
+
+import java.io.Serializable;
+import org.apache.beam.sdk.io.UnboundedSource;
+
+/**
+ * Enables users to specify their own `JMS` backlog reporters enabling {@link JmsIO} to report
+ * {@link UnboundedSource.UnboundedReader#getTotalBacklogBytes()}.
+ */
+public interface AutoScaler extends Serializable {
+
+  /** The {@link AutoScaler} is started when the {@link JmsIO.UnboundedJmsReader} is started. */
+  void start();
+
+  /**
+   * Returns the size of the backlog of unread data in the underlying data source represented by all
+   * splits of this source.
+   */
+  long getTotalBacklogBytes();
+
+  /** The {@link AutoScaler} is stopped when the {@link JmsIO.UnboundedJmsReader} is closed. */
+  void stop();
+}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java
new file mode 100644
index 0000000..2b05cf6
--- /dev/null
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jms;
+
+import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
+
+/**
+ * Default implementation of {@link AutoScaler}. Returns {@link
+ * org.apache.beam.sdk.io.UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} as the default value.
+ */
+public class DefaultAutoscaler implements AutoScaler {
+  @Override
+  public void start() {}
+
+  @Override
+  public long getTotalBacklogBytes() {
+    return BACKLOG_UNKNOWN;
+  }
+
+  @Override
+  public void stop() {}
+}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
index 4999e10..9fa4492 100644
--- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
@@ -196,6 +196,8 @@
 
     abstract @Nullable Coder<T> getCoder();
 
+    abstract @Nullable AutoScaler getAutoScaler();
+
     abstract Builder<T> builder();
 
     @AutoValue.Builder
@@ -218,6 +220,8 @@
 
       abstract Builder<T> setCoder(Coder<T> coder);
 
+      abstract Builder<T> setAutoScaler(AutoScaler autoScaler);
+
       abstract Read<T> build();
     }
 
@@ -344,6 +348,14 @@
       return builder().setCoder(coder).build();
     }
 
+    /**
+     * Sets the {@link AutoScaler} to use for reporting backlog during the execution of this source.
+     */
+    public Read<T> withAutoScaler(AutoScaler autoScaler) {
+      checkArgument(autoScaler != null, "autoScaler can not be null");
+      return builder().setAutoScaler(autoScaler).build();
+    }
+
     @Override
     public PCollection<T> expand(PBegin input) {
       checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required");
@@ -447,6 +459,7 @@
     private Connection connection;
     private Session session;
     private MessageConsumer consumer;
+    private AutoScaler autoScaler;
 
     private T currentMessage;
     private Instant currentTimestamp;
@@ -474,6 +487,12 @@
         }
         connection.start();
         this.connection = connection;
+        if (spec.getAutoScaler() == null) {
+          this.autoScaler = new DefaultAutoscaler();
+        } else {
+          this.autoScaler = spec.getAutoScaler();
+        }
+        this.autoScaler.start();
       } catch (Exception e) {
         throw new IOException("Error connecting to JMS", e);
       }
@@ -545,6 +564,11 @@
     }
 
     @Override
+    public long getTotalBacklogBytes() {
+      return this.autoScaler.getTotalBacklogBytes();
+    }
+
+    @Override
     public UnboundedSource<T, ?> getCurrentSource() {
       return source;
     }
@@ -565,6 +589,10 @@
           connection.close();
           connection = null;
         }
+        if (autoScaler != null) {
+          autoScaler.stop();
+          autoScaler = null;
+        }
       } catch (Exception e) {
         throw new IOException(e);
       }
diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
index c335f8a..a9f3c3f 100644
--- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
+++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
@@ -17,12 +17,17 @@
  */
 package org.apache.beam.sdk.io.jms;
 
+import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsString;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import java.io.IOException;
 import java.lang.reflect.Proxy;
@@ -421,6 +426,50 @@
     CoderProperties.coderDecodeEncodeEqual(coder, jmsCheckpointMark);
   }
 
+  @Test
+  public void testDefaultAutoscaler() throws IOException {
+    JmsIO.Read spec =
+        JmsIO.read()
+            .withConnectionFactory(connectionFactory)
+            .withUsername(USERNAME)
+            .withPassword(PASSWORD)
+            .withQueue(QUEUE);
+    JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+
+    // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
+    reader.start();
+    assertEquals(BACKLOG_UNKNOWN, reader.getSplitBacklogBytes());
+    assertEquals(BACKLOG_UNKNOWN, reader.getTotalBacklogBytes());
+    reader.close();
+  }
+
+  @Test
+  public void testCustomAutoscaler() throws IOException {
+    long excpectedTotalBacklogBytes = 1111L;
+
+    AutoScaler autoScaler = mock(DefaultAutoscaler.class);
+    when(autoScaler.getTotalBacklogBytes()).thenReturn(excpectedTotalBacklogBytes);
+    JmsIO.Read spec =
+        JmsIO.read()
+            .withConnectionFactory(connectionFactory)
+            .withUsername(USERNAME)
+            .withPassword(PASSWORD)
+            .withQueue(QUEUE)
+            .withAutoScaler(autoScaler);
+
+    JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+
+    // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
+    reader.start();
+    verify(autoScaler, times(1)).start();
+    assertEquals(excpectedTotalBacklogBytes, reader.getTotalBacklogBytes());
+    verify(autoScaler, times(1)).getTotalBacklogBytes();
+    reader.close();
+    verify(autoScaler, times(1)).stop();
+  }
+
   private int count(String queue) throws Exception {
     Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD);
     connection.start();
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 1c02f71..98d2faa 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -185,9 +185,12 @@
 @ptransform.PTransform.register_urn(TEST_COMPK_URN, None)
 class CombinePerKeyTransform(ptransform.PTransform):
   def expand(self, pcoll):
-    return pcoll \
-           | beam.CombinePerKey(sum).with_output_types(
-               typing.Tuple[str, int])
+    output = pcoll \
+           | beam.CombinePerKey(sum)
+    # TODO: Use `with_output_types` instead of explicitly
+    #  assigning to `.element_type` after fixing BEAM-12872
+    output.element_type = beam.typehints.Tuple[str, int]
+    return output
 
   def to_runner_api_parameter(self, unused_context):
     return TEST_COMPK_URN, None
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
index bebde98..9b0d6ff 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
@@ -63,12 +63,16 @@
       if not os.path.exists(self._executable_jar):
         parsed = urllib.parse.urlparse(self._executable_jar)
         if not parsed.scheme:
+          try:
+            flink_version = self.flink_version()
+          except Exception:
+            flink_version = '$FLINK_VERSION'
           raise ValueError(
               'Unable to parse jar URL "%s". If using a full URL, make sure '
               'the scheme is specified. If using a local file path, make sure '
               'the file exists; you may have to first build the job server '
               'using `./gradlew runners:flink:%s:job-server:shadowJar`.' %
-              (self._executable_jar, self._flink_version))
+              (self._executable_jar, flink_version))
       url = self._executable_jar
     else:
       url = job_server.JavaJarJobServer.path_to_beam_jar(
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py
index 281e2d9..1294f46 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py
@@ -192,6 +192,36 @@
     self.assertEqual(
         options_proto['beam:option:unknown_option_foo:v1'], 'some_value')
 
+  @requests_mock.mock()
+  def test_bad_url_flink_version(self, http_mock):
+    http_mock.get('http://flink/v1/config', json={'flink-version': '1.2.3.4'})
+    options = pipeline_options.FlinkRunnerOptions()
+    options.flink_job_server_jar = "bad url"
+    job_server = flink_uber_jar_job_server.FlinkUberJarJobServer(
+        'http://flink', options)
+    with self.assertRaises(ValueError) as context:
+      job_server.executable_jar()
+    self.assertEqual(
+        'Unable to parse jar URL "bad url". If using a full URL, make sure '
+        'the scheme is specified. If using a local file path, make sure '
+        'the file exists; you may have to first build the job server '
+        'using `./gradlew runners:flink:1.2:job-server:shadowJar`.',
+        str(context.exception))
+
+  def test_bad_url_placeholder_version(self):
+    options = pipeline_options.FlinkRunnerOptions()
+    options.flink_job_server_jar = "bad url"
+    job_server = flink_uber_jar_job_server.FlinkUberJarJobServer(
+        'http://example.com/bad', options)
+    with self.assertRaises(ValueError) as context:
+      job_server.executable_jar()
+    self.assertEqual(
+        'Unable to parse jar URL "bad url". If using a full URL, make sure '
+        'the scheme is specified. If using a local file path, make sure '
+        'the file exists; you may have to first build the job server '
+        'using `./gradlew runners:flink:$FLINK_VERSION:job-server:shadowJar`.',
+        str(context.exception))
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index bcedd86..41ad3df 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -21,6 +21,7 @@
 
 import copy
 import heapq
+import itertools
 import operator
 import random
 from typing import Any
@@ -168,7 +169,8 @@
   """Combiners for obtaining extremal elements."""
 
   # pylint: disable=no-self-argument
-
+  @with_input_types(T)
+  @with_output_types(List[T])
   class Of(CombinerWithoutDefaults):
     """Obtain a list of the compare-most N elements in a PCollection.
 
@@ -202,10 +204,12 @@
         # This is a more efficient global algorithm.
         top_per_bundle = pcoll | core.ParDo(
             _TopPerBundle(self._n, self._key, self._reverse))
-        # If pcoll is empty, we can't guerentee that top_per_bundle
+        # If pcoll is empty, we can't guarantee that top_per_bundle
         # won't be empty, so inject at least one empty accumulator
-        # so that downstream is guerenteed to produce non-empty output.
-        empty_bundle = pcoll.pipeline | core.Create([(None, [])])
+        # so that downstream is guaranteed to produce non-empty output.
+        empty_bundle = (
+            pcoll.pipeline | core.Create([(None, [])]).with_output_types(
+                top_per_bundle.element_type))
         return ((top_per_bundle, empty_bundle) | core.Flatten()
                 | core.GroupByKey()
                 | core.ParDo(
@@ -219,6 +223,8 @@
               TopCombineFn(self._n, self._key,
                            self._reverse)).without_defaults()
 
+  @with_input_types(Tuple[K, V])
+  @with_output_types(Tuple[K, List[V]])
   class PerKey(ptransform.PTransform):
     """Identifies the compare-most N elements associated with each key.
 
@@ -524,6 +530,8 @@
 
   # pylint: disable=no-self-argument
 
+  @with_input_types(T)
+  @with_output_types(List[T])
   class FixedSizeGlobally(CombinerWithoutDefaults):
     """Sample n elements from the input PCollection without replacement."""
     def __init__(self, n):
@@ -543,6 +551,8 @@
     def default_label(self):
       return 'FixedSizeGlobally(%d)' % self._n
 
+  @with_input_types(Tuple[K, V])
+  @with_output_types(Tuple[K, List[V]])
   class FixedSizePerKey(ptransform.PTransform):
     """Sample n elements associated with each key without replacement."""
     def __init__(self, n):
@@ -597,16 +607,25 @@
 
 
 class _TupleCombineFnBase(core.CombineFn):
-  def __init__(self, *combiners):
+  def __init__(self, *combiners, merge_accumulators_batch_size=None):
     self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners]
     self._named_combiners = combiners
+    # If the `merge_accumulators_batch_size` value is not specified, we chose a
+    # bounded default that is inversely proportional to the number of
+    # accumulators in merged tuples.
+    num_combiners = max(1, len(combiners))
+    self._merge_accumulators_batch_size = (
+        merge_accumulators_batch_size or max(10, 1000 // num_combiners))
 
   def display_data(self):
     combiners = [
         c.__name__ if hasattr(c, '__name__') else c.__class__.__name__
         for c in self._named_combiners
     ]
-    return {'combiners': str(combiners)}
+    return {
+        'combiners': str(combiners),
+        'merge_accumulators_batch_size': self._merge_accumulators_batch_size
+    }
 
   def setup(self, *args, **kwargs):
     for c in self._combiners:
@@ -616,10 +635,22 @@
     return [c.create_accumulator(*args, **kwargs) for c in self._combiners]
 
   def merge_accumulators(self, accumulators, *args, **kwargs):
-    return [
-        c.merge_accumulators(a, *args, **kwargs) for c,
-        a in zip(self._combiners, zip(*accumulators))
-    ]
+    # Make sure that `accumulators` is an iterator (so that the position is
+    # remembered).
+    accumulators = iter(accumulators)
+    result = next(accumulators)
+    while True:
+      # Load accumulators into memory and merge in batches to decrease peak
+      # memory usage.
+      accumulators_batch = [result] + list(
+          itertools.islice(accumulators, self._merge_accumulators_batch_size))
+      if len(accumulators_batch) == 1:
+        break
+      result = [
+          c.merge_accumulators(a, *args, **kwargs) for c,
+          a in zip(self._combiners, zip(*accumulators_batch))
+      ]
+    return result
 
   def compact(self, accumulator, *args, **kwargs):
     return [
@@ -670,6 +701,8 @@
     ]
 
 
+@with_input_types(T)
+@with_output_types(List[T])
 class ToList(CombinerWithoutDefaults):
   """A global CombineFn that condenses a PCollection into a single list."""
   def expand(self, pcoll):
@@ -698,6 +731,8 @@
     return accumulator
 
 
+@with_input_types(Tuple[K, V])
+@with_output_types(Dict[K, V])
 class ToDict(CombinerWithoutDefaults):
   """A global CombineFn that condenses a PCollection into a single dict.
 
@@ -735,6 +770,8 @@
     return accumulator
 
 
+@with_input_types(T)
+@with_output_types(Set[T])
 class ToSet(CombinerWithoutDefaults):
   """A global CombineFn that condenses a PCollection into a set."""
   def expand(self, pcoll):
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index d826287..7e0e835 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -249,7 +249,8 @@
     dd = DisplayData.create_from(transform)
     expected_items = [
         DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
-        DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']")
+        DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
+        DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
     ]
     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
 
@@ -358,6 +359,49 @@
                   max).with_common_input()).without_defaults())
       assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
 
+  def test_empty_tuple_combine_fn(self):
+    with TestPipeline() as p:
+      result = (
+          p
+          | Create([(), (), ()])
+          | beam.CombineGlobally(combine.TupleCombineFn()))
+      assert_that(result, equal_to([()]))
+
+  def test_tuple_combine_fn_batched_merge(self):
+    num_combine_fns = 10
+    max_num_accumulators_in_memory = 30
+    # Maximum number of accumulator tuples in memory - 1 for the merge result.
+    merge_accumulators_batch_size = (
+        max_num_accumulators_in_memory // num_combine_fns - 1)
+    num_accumulator_tuples_to_merge = 20
+
+    class CountedAccumulator:
+      count = 0
+      oom = False
+
+      def __init__(self):
+        if CountedAccumulator.count > max_num_accumulators_in_memory:
+          CountedAccumulator.oom = True
+        else:
+          CountedAccumulator.count += 1
+
+    class CountedAccumulatorCombineFn(beam.CombineFn):
+      def create_accumulator(self):
+        return CountedAccumulator()
+
+      def merge_accumulators(self, accumulators):
+        CountedAccumulator.count += 1
+        for _ in accumulators:
+          CountedAccumulator.count -= 1
+
+    combine_fn = combine.TupleCombineFn(
+        *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
+        merge_accumulators_batch_size=merge_accumulators_batch_size)
+    combine_fn.merge_accumulators(
+        combine_fn.create_accumulator()
+        for _ in range(num_accumulator_tuples_to_merge))
+    assert not CountedAccumulator.oom
+
   def test_to_list_and_to_dict1(self):
     with TestPipeline() as pipeline:
       the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index ec001e2..9ce33f5 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -2333,9 +2333,8 @@
 
     self.assertStartswith(
         e.exception.args[0],
-        "Type hint violation for 'CombinePerKey(TopCombineFn)': "
-        "requires Tuple[TypeVariable[K], TypeVariable[T]] "
-        "but got {} for element".format(int))
+        "Input type hint violation at TopMod: expected Tuple[TypeVariable[K], "
+        "TypeVariable[V]], got {}".format(int))
 
   def test_per_key_pipeline_checking_satisfied(self):
     d = (
@@ -2492,10 +2491,8 @@
 
     self.assertStartswith(
         e.exception.args[0],
-        "Type hint violation for 'CombinePerKey': "
-        "requires "
-        "Tuple[TypeVariable[K], Tuple[TypeVariable[K], TypeVariable[V]]] "
-        "but got Tuple[None, int] for element")
+        "Input type hint violation at ToDict: expected Tuple[TypeVariable[K], "
+        "TypeVariable[V]], got {}".format(int))
 
   def test_to_dict_pipeline_check_satisfied(self):
     d = (