Merge pull request #15498 from apache/BEAM-12873

diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 12b7df2..17b59ec 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -794,7 +794,8 @@
 
     final int numIters = 2000;
     for (int i = 0; i < numIters; ++i) {
-      server.addWorkToOffer(makeInput(i, 0, "key", DEFAULT_SHARDING_KEY));
+      server.addWorkToOffer(
+          makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), "key", DEFAULT_SHARDING_KEY));
     }
 
     Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
@@ -829,7 +830,8 @@
 
     final int numIters = 2000;
     for (int i = 0; i < numIters; ++i) {
-      server.addWorkToOffer(makeInput(i, 0, "key", DEFAULT_SHARDING_KEY));
+      server.addWorkToOffer(
+          makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), "key", DEFAULT_SHARDING_KEY));
     }
 
     Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
new file mode 100644
index 0000000..5496d8b
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statecache implements the state caching feature described by the
+// Beam Fn API
+//
+// The Beam State API and the intended caching behavior are described here:
+// https://docs.google.com/document/d/1BOozW0bzBuz4oHJEuZNDOHdzaV5Y56ix58Ozrqm2jFg/edit#heading=h.7ghoih5aig5m
+package statecache
+
+import (
+	"sync"
+
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+	fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+type token string
+
+// SideInputCache stores a cache of reusable inputs for the purposes of
+// eliminating redundant calls to the runner during execution of ParDos
+// using side inputs.
+//
+// A SideInputCache should be initialized when the SDK harness is initialized,
+// creating storage for side input caching. On each ProcessBundleRequest,
+// the cache will process the list of tokens for cacheable side inputs and
+// be queried when side inputs are requested in bundle execution. Once a
+// new bundle request comes in the valid tokens will be updated and the cache
+// will be re-used. In the event that the cache reaches capacity, a random,
+// currently invalid cached object will be evicted.
+type SideInputCache struct {
+	capacity    int
+	mu          sync.Mutex
+	cache       map[token]exec.ReusableInput
+	idsToTokens map[string]token
+	validTokens map[token]int8 // Maps tokens to active bundle counts
+	metrics     CacheMetrics
+}
+
+type CacheMetrics struct {
+	Hits           int64
+	Misses         int64
+	Evictions      int64
+	InUseEvictions int64
+}
+
+// Init makes the cache map and the map of IDs to cache tokens for the
+// SideInputCache. Should only be called once. Returns an error for
+// non-positive capacities.
+func (c *SideInputCache) Init(cap int) error {
+	if cap <= 0 {
+		return errors.Errorf("capacity must be a positive integer, got %v", cap)
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.cache = make(map[token]exec.ReusableInput, cap)
+	c.idsToTokens = make(map[string]token)
+	c.validTokens = make(map[token]int8)
+	c.capacity = cap
+	return nil
+}
+
+// SetValidTokens clears the list of valid tokens then sets new ones, also updating the mapping of
+// transform and side input IDs to cache tokens in the process. Should be called at the start of every
+// new ProcessBundleRequest. If the runner does not support caching, the passed cache token values
+// should be empty and all get/set requests will silently be no-ops.
+func (c *SideInputCache) SetValidTokens(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	for _, tok := range cacheTokens {
+		// User State caching is currently not supported, so these tokens are ignored
+		if tok.GetUserState() != nil {
+			continue
+		}
+		s := tok.GetSideInput()
+		transformID := s.GetTransformId()
+		sideInputID := s.GetSideInputId()
+		t := token(tok.GetToken())
+		c.setValidToken(transformID, sideInputID, t)
+	}
+}
+
+// setValidToken adds a new valid token for a request into the SideInputCache struct
+// by mapping the transform ID and side input ID pairing to the cache token.
+func (c *SideInputCache) setValidToken(transformID, sideInputID string, tok token) {
+	idKey := transformID + sideInputID
+	c.idsToTokens[idKey] = tok
+	count, ok := c.validTokens[tok]
+	if !ok {
+		c.validTokens[tok] = 1
+	} else {
+		c.validTokens[tok] = count + 1
+	}
+}
+
+// CompleteBundle takes the cache tokens passed to set the valid tokens and decrements their
+// usage count for the purposes of maintaining a valid count of whether or not a value is
+// still in use. Should be called once ProcessBundle has completed.
+func (c *SideInputCache) CompleteBundle(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	for _, tok := range cacheTokens {
+		// User State caching is currently not supported, so these tokens are ignored
+		if tok.GetUserState() != nil {
+			continue
+		}
+		t := token(tok.GetToken())
+		c.decrementTokenCount(t)
+	}
+}
+
+// decrementTokenCount decrements the validTokens entry for
+// a given token by 1. Should only be called when completing
+// a bundle.
+func (c *SideInputCache) decrementTokenCount(tok token) {
+	count := c.validTokens[tok]
+	if count == 1 {
+		delete(c.validTokens, tok)
+	} else {
+		c.validTokens[tok] = count - 1
+	}
+}
+
+func (c *SideInputCache) makeAndValidateToken(transformID, sideInputID string) (token, bool) {
+	idKey := transformID + sideInputID
+	// Check if it's a known token
+	tok, ok := c.idsToTokens[idKey]
+	if !ok {
+		return "", false
+	}
+	return tok, c.isValid(tok)
+}
+
+// QueryCache takes a transform ID and side input ID and checking if a corresponding side
+// input has been cached. A query having a bad token (e.g. one that doesn't make a known
+// token or one that makes a known but currently invalid token) is treated the same as a
+// cache miss.
+func (c *SideInputCache) QueryCache(transformID, sideInputID string) exec.ReusableInput {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+	if !ok {
+		return nil
+	}
+	// Check to see if cached
+	input, ok := c.cache[tok]
+	if !ok {
+		c.metrics.Misses++
+		return nil
+	}
+
+	c.metrics.Hits++
+	return input
+}
+
+// SetCache allows a user to place a ReusableInput materialized from the reader into the SideInputCache
+// with its corresponding transform ID and side input ID. If the IDs do not pair with a known, valid token
+// then we silently do not cache the input, as this is an indication that the runner is treating that input
+// as uncacheable.
+func (c *SideInputCache) SetCache(transformID, sideInputID string, input exec.ReusableInput) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+	if !ok {
+		return
+	}
+	if len(c.cache) >= c.capacity {
+		c.evictElement()
+	}
+	c.cache[tok] = input
+}
+
+func (c *SideInputCache) isValid(tok token) bool {
+	count, ok := c.validTokens[tok]
+	// If the token is not known or not in use, return false
+	return ok && count > 0
+}
+
+// evictElement randomly evicts a ReusableInput that is not currently valid from the cache.
+// It should only be called by a goroutine that obtained the lock in SetCache.
+func (c *SideInputCache) evictElement() {
+	deleted := false
+	// Select a key from the cache at random
+	for k := range c.cache {
+		// Do not evict an element if it's currently valid
+		if !c.isValid(k) {
+			delete(c.cache, k)
+			c.metrics.Evictions++
+			deleted = true
+			break
+		}
+	}
+	// Nothing is deleted if every side input is still valid. Clear
+	// out a random entry and record the in-use eviction
+	if !deleted {
+		for k := range c.cache {
+			delete(c.cache, k)
+			c.metrics.InUseEvictions++
+			break
+		}
+	}
+}
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
new file mode 100644
index 0000000..b9970c3
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statecache
+
+import (
+	"testing"
+
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+	fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+// TestReusableInput implements the ReusableInput interface for the purposes
+// of testing.
+type TestReusableInput struct {
+	transformID string
+	sideInputID string
+	value       interface{}
+}
+
+func makeTestReusableInput(transformID, sideInputID string, value interface{}) exec.ReusableInput {
+	return &TestReusableInput{transformID: transformID, sideInputID: sideInputID, value: value}
+}
+
+// Init is a ReusableInput interface method, this is a no-op.
+func (r *TestReusableInput) Init() error {
+	return nil
+}
+
+// Value returns the stored value in the TestReusableInput.
+func (r *TestReusableInput) Value() interface{} {
+	return r.value
+}
+
+// Reset clears the value in the TestReusableInput.
+func (r *TestReusableInput) Reset() error {
+	r.value = nil
+	return nil
+}
+
+func TestInit(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(5)
+	if err != nil {
+		t.Errorf("SideInputCache failed but should have succeeded, got %v", err)
+	}
+}
+
+func TestInit_Bad(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(0)
+	if err == nil {
+		t.Error("SideInputCache init succeeded but should have failed")
+	}
+}
+
+func TestQueryCache_EmptyCase(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+	output := s.QueryCache("side1", "transform1")
+	if output != nil {
+		t.Errorf("Cache hit when it should have missed, got %v", output)
+	}
+}
+
+func TestSetCache_UncacheableCase(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+	input := makeTestReusableInput("t1", "s1", 10)
+	s.SetCache("t1", "s1", input)
+	output := s.QueryCache("t1", "s1")
+	if output != nil {
+		t.Errorf("Cache hit when should have missed, got %v", output)
+	}
+}
+
+func TestSetCache_CacheableCase(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+	transID := "t1"
+	sideID := "s1"
+	tok := token("tok1")
+	s.setValidToken(transID, sideID, tok)
+	input := makeTestReusableInput(transID, sideID, 10)
+	s.SetCache(transID, sideID, input)
+	output := s.QueryCache(transID, sideID)
+	if output == nil {
+		t.Fatalf("call to query cache missed when should have hit")
+	}
+	val, ok := output.Value().(int)
+	if !ok {
+		t.Errorf("failed to convert value to integer, got %v", output.Value())
+	}
+	if val != 10 {
+		t.Errorf("element mismatch, expected 10, got %v", val)
+	}
+}
+
+func makeRequest(transformID, sideInputID string, t token) fnpb.ProcessBundleRequest_CacheToken {
+	var tok fnpb.ProcessBundleRequest_CacheToken
+	var wrap fnpb.ProcessBundleRequest_CacheToken_SideInput_
+	var side fnpb.ProcessBundleRequest_CacheToken_SideInput
+	side.TransformId = transformID
+	side.SideInputId = sideInputID
+	wrap.SideInput = &side
+	tok.Type = &wrap
+	tok.Token = []byte(t)
+	return tok
+}
+
+func TestSetValidTokens(t *testing.T) {
+	inputs := []struct {
+		transformID string
+		sideInputID string
+		tok         token
+	}{
+		{
+			"t1",
+			"s1",
+			"tok1",
+		},
+		{
+			"t2",
+			"s2",
+			"tok2",
+		},
+		{
+			"t3",
+			"s3",
+			"tok3",
+		},
+	}
+
+	var s SideInputCache
+	err := s.Init(3)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	var tokens []fnpb.ProcessBundleRequest_CacheToken
+	for _, input := range inputs {
+		t := makeRequest(input.transformID, input.sideInputID, input.tok)
+		tokens = append(tokens, t)
+	}
+
+	s.SetValidTokens(tokens...)
+	if len(s.idsToTokens) != len(inputs) {
+		t.Errorf("Missing tokens, expected %v, got %v", len(inputs), len(s.idsToTokens))
+	}
+
+	for i, input := range inputs {
+		// Check that the token is in the valid list
+		if !s.isValid(input.tok) {
+			t.Errorf("error in input %v, token %v is not valid", i, input.tok)
+		}
+		// Check that the mapping of IDs to tokens is correct
+		mapped := s.idsToTokens[input.transformID+input.sideInputID]
+		if mapped != input.tok {
+			t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tok, mapped)
+		}
+	}
+}
+
+func TestSetValidTokens_ClearingBetween(t *testing.T) {
+	inputs := []struct {
+		transformID string
+		sideInputID string
+		tk          token
+	}{
+		{
+			"t1",
+			"s1",
+			"tok1",
+		},
+		{
+			"t2",
+			"s2",
+			"tok2",
+		},
+		{
+			"t3",
+			"s3",
+			"tok3",
+		},
+	}
+
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	for i, input := range inputs {
+		tok := makeRequest(input.transformID, input.sideInputID, input.tk)
+
+		s.SetValidTokens(tok)
+
+		// Check that the token is in the valid list
+		if !s.isValid(input.tk) {
+			t.Errorf("error in input %v, token %v is not valid", i, input.tk)
+		}
+		// Check that the mapping of IDs to tokens is correct
+		mapped := s.idsToTokens[input.transformID+input.sideInputID]
+		if mapped != input.tk {
+			t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tk, mapped)
+		}
+
+		s.CompleteBundle(tok)
+	}
+
+	for k, _ := range s.validTokens {
+		if s.validTokens[k] != 0 {
+			t.Errorf("token count mismatch for token %v, expected 0, got %v", k, s.validTokens[k])
+		}
+	}
+}
+
+func TestSetCache_Eviction(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	tokOne := makeRequest("t1", "s1", "tok1")
+	inOne := makeTestReusableInput("t1", "s1", 10)
+	s.SetValidTokens(tokOne)
+	s.SetCache("t1", "s1", inOne)
+	// Mark bundle as complete, drop count for tokOne to 0
+	s.CompleteBundle(tokOne)
+
+	tokTwo := makeRequest("t2", "s2", "tok2")
+	inTwo := makeTestReusableInput("t2", "s2", 20)
+	s.SetValidTokens(tokTwo)
+	s.SetCache("t2", "s2", inTwo)
+
+	if len(s.cache) != 1 {
+		t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+	}
+	if s.metrics.Evictions != 1 {
+		t.Errorf("number evictions incorrect, expected 1, got %v", s.metrics.Evictions)
+	}
+}
+
+func TestSetCache_EvictionFailure(t *testing.T) {
+	var s SideInputCache
+	err := s.Init(1)
+	if err != nil {
+		t.Fatalf("cache init failed, got %v", err)
+	}
+
+	tokOne := makeRequest("t1", "s1", "tok1")
+	inOne := makeTestReusableInput("t1", "s1", 10)
+
+	tokTwo := makeRequest("t2", "s2", "tok2")
+	inTwo := makeTestReusableInput("t2", "s2", 20)
+
+	s.SetValidTokens(tokOne, tokTwo)
+	s.SetCache("t1", "s1", inOne)
+	// Should fail to evict because the first token is still valid
+	s.SetCache("t2", "s2", inTwo)
+	// Cache should not exceed size 1
+	if len(s.cache) != 1 {
+		t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+	}
+	if s.metrics.InUseEvictions != 1 {
+		t.Errorf("number of failed evicition calls incorrect, expected 1, got %v", s.metrics.InUseEvictions)
+	}
+}
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
index 511f839..f4ab8bb 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
@@ -156,16 +156,20 @@
   }
 
   /**
-   * An adapter which converts an {@link InputStream} to an {@link Iterator} of {@code T} values
-   * using the specified {@link Coder}.
+   * An adapter which converts an {@link InputStream} to a {@link PrefetchableIterator} of {@code T}
+   * values using the specified {@link Coder}.
    *
    * <p>Note that this adapter follows the Beam Fn API specification for forcing values that decode
    * consuming zero bytes to consuming exactly one byte.
    *
    * <p>Note that access to the underlying {@link InputStream} is lazy and will only be invoked on
-   * first access to {@link #next()} or {@link #hasNext()}.
+   * first access to {@link #next}, {@link #hasNext}, {@link #isReady}, and {@link #prefetch}.
+   *
+   * <p>Note that {@link #isReady} and {@link #prefetch} rely on non-empty {@link ByteString}s being
+   * returned via the underlying {@link PrefetchableIterator} otherwise the {@link #prefetch} will
+   * seemingly make zero progress yet will actually advance through the empty pages.
    */
-  public static class DataStreamDecoder<T> implements Iterator<T> {
+  public static class DataStreamDecoder<T> implements PrefetchableIterator<T> {
 
     private enum State {
       READ_REQUIRED,
@@ -173,13 +177,13 @@
       EOF
     }
 
-    private final Iterator<ByteString> inputByteStrings;
+    private final PrefetchableIterator<ByteString> inputByteStrings;
     private final Inbound inbound;
     private final Coder<T> coder;
     private State currentState;
     private T next;
 
-    public DataStreamDecoder(Coder<T> coder, Iterator<ByteString> inputStream) {
+    public DataStreamDecoder(Coder<T> coder, PrefetchableIterator<ByteString> inputStream) {
       this.currentState = State.READ_REQUIRED;
       this.coder = coder;
       this.inputByteStrings = inputStream;
@@ -187,6 +191,31 @@
     }
 
     @Override
+    public boolean isReady() {
+      switch (currentState) {
+        case EOF:
+          return true;
+        case READ_REQUIRED:
+          try {
+            return inbound.isReady();
+          } catch (IOException e) {
+            throw new RuntimeException(e);
+          }
+        case HAS_NEXT:
+          return true;
+        default:
+          throw new IllegalStateException(String.format("Unknown state %s", currentState));
+      }
+    }
+
+    @Override
+    public void prefetch() {
+      if (!isReady()) {
+        inputByteStrings.prefetch();
+      }
+    }
+
+    @Override
     public boolean hasNext() {
       switch (currentState) {
         case EOF:
@@ -232,8 +261,8 @@
     private static final InputStream EMPTY_STREAM = ByteString.EMPTY.newInput();
 
     /**
-     * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the first
-     * {@link Iterator} on first access of this input stream.
+     * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the {@link
+     * Iterator} on first access of this input stream.
      *
      * <p>Closing this input stream has no effect.
      */
@@ -245,6 +274,22 @@
         this.currentStream = EMPTY_STREAM;
       }
 
+      public boolean isReady() throws IOException {
+        // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
+        // minus the number of bytes that have been read so far and can be reliably used to tell
+        // us whether we are at the end of the stream.
+        while (currentStream.available() == 0) {
+          if (!inputByteStrings.isReady()) {
+            return false;
+          }
+          if (!inputByteStrings.hasNext()) {
+            return true;
+          }
+          currentStream = inputByteStrings.next().newInput();
+        }
+        return true;
+      }
+
       public boolean isEof() throws IOException {
         // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
         // minus the number of bytes that have been read so far and can be reliably used to tell
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
index 9dd5ee4..a8b48e8 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
@@ -23,6 +23,7 @@
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assume.assumeTrue;
 
 import java.io.IOException;
@@ -106,7 +107,7 @@
     }
 
     @Test
-    public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception {
+    public void testNonEmptyInputStreamWithZeroLengthEncoding() throws Exception {
       CountingOutputStream countingOutputStream =
           new CountingOutputStream(ByteStreams.nullOutputStream());
       GlobalWindow.Coder.INSTANCE.encode(GlobalWindow.INSTANCE, countingOutputStream);
@@ -115,6 +116,55 @@
       testDecoderWith(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, GlobalWindow.INSTANCE);
     }
 
+    @Test
+    public void testPrefetch() throws Exception {
+      List<ByteString> encodings = new ArrayList<>();
+      {
+        ByteString.Output encoding = ByteString.newOutput();
+        StringUtf8Coder.of().encode("A", encoding);
+        StringUtf8Coder.of().encode("BC", encoding);
+        encodings.add(encoding.toByteString());
+      }
+      encodings.add(ByteString.EMPTY);
+      {
+        ByteString.Output encoding = ByteString.newOutput();
+        StringUtf8Coder.of().encode("DEF", encoding);
+        StringUtf8Coder.of().encode("GHIJ", encoding);
+        encodings.add(encoding.toByteString());
+      }
+
+      PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<ByteString> iterator =
+          new PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<>(encodings.iterator());
+      PrefetchableIterator<String> decoder =
+          new DataStreamDecoder<>(StringUtf8Coder.of(), iterator);
+      assertFalse(decoder.isReady());
+      decoder.prefetch();
+      assertTrue(decoder.isReady());
+      assertEquals(1, iterator.getNumPrefetchCalls());
+
+      decoder.next();
+      // Now we will have moved off of the empty byte array that we start with so prefetch will
+      // do nothing since we are ready
+      assertTrue(decoder.isReady());
+      decoder.prefetch();
+      assertEquals(1, iterator.getNumPrefetchCalls());
+
+      decoder.next();
+      // Now we are at the end of the first ByteString so we expect a prefetch to pass through
+      assertFalse(decoder.isReady());
+      decoder.prefetch();
+      assertEquals(2, iterator.getNumPrefetchCalls());
+      // We also expect the decoder to not be ready since the next byte string is empty which
+      // would require us to move to the next page. This typically wouldn't happen in practice
+      // though because we expect non empty pages.
+      assertFalse(decoder.isReady());
+
+      // Prefetching will allow us to move to the third ByteString
+      decoder.prefetch();
+      assertEquals(3, iterator.getNumPrefetchCalls());
+      assertTrue(decoder.isReady());
+    }
+
     private <T> void testDecoderWith(Coder<T> coder, T... expected) throws IOException {
       ByteString.Output output = ByteString.newOutput();
       for (T value : expected) {
@@ -131,7 +181,9 @@
     }
 
     private <T> void testDecoderWith(Coder<T> coder, T[] expected, List<ByteString> encoded) {
-      Iterator<T> decoder = new DataStreamDecoder<>(coder, encoded.iterator());
+      Iterator<T> decoder =
+          new DataStreamDecoder<>(
+              coder, PrefetchableIterators.maybePrefetchable(encoded.iterator()));
 
       Object[] actual = Iterators.toArray(decoder, Object.class);
       assertArrayEquals(expected, actual);
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
index 6131634..9ada175 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
@@ -120,10 +120,14 @@
         "F");
   }
 
-  private static class NeverReady implements PrefetchableIterator<String> {
-    PrefetchableIterator<String> delegate = PrefetchableIterators.fromArray("A", "B");
+  public static class NeverReady<T> implements PrefetchableIterator<T> {
+    private final Iterator<T> delegate;
     int prefetchCalled;
 
+    public NeverReady(Iterator<T> delegate) {
+      this.delegate = delegate;
+    }
+
     @Override
     public boolean isReady() {
       return false;
@@ -140,74 +144,117 @@
     }
 
     @Override
-    public String next() {
+    public T next() {
       return delegate.next();
     }
+
+    public int getNumPrefetchCalls() {
+      return prefetchCalled;
+    }
   }
 
-  private static class ReadyAfterPrefetch extends NeverReady {
+  public static class ReadyAfterPrefetch<T> extends NeverReady<T> {
+
+    public ReadyAfterPrefetch(Iterator<T> delegate) {
+      super(delegate);
+    }
+
     @Override
     public boolean isReady() {
       return prefetchCalled > 0;
     }
   }
 
+  public static class ReadyAfterPrefetchUntilNext<T> extends ReadyAfterPrefetch<T> {
+    boolean advancedSincePrefetch;
+
+    public ReadyAfterPrefetchUntilNext(Iterator<T> delegate) {
+      super(delegate);
+    }
+
+    @Override
+    public boolean isReady() {
+      return !advancedSincePrefetch && super.isReady();
+    }
+
+    @Override
+    public void prefetch() {
+      advancedSincePrefetch = false;
+      super.prefetch();
+    }
+
+    @Override
+    public T next() {
+      advancedSincePrefetch = true;
+      return super.next();
+    }
+
+    @Override
+    public boolean hasNext() {
+      advancedSincePrefetch = true;
+      return super.hasNext();
+    }
+  }
+
   @Test
   public void testConcatIsReadyAdvancesToNextIteratorWhenAble() {
-    NeverReady readyAfterPrefetch1 = new NeverReady();
-    ReadyAfterPrefetch readyAfterPrefetch2 = new ReadyAfterPrefetch();
-    ReadyAfterPrefetch readyAfterPrefetch3 = new ReadyAfterPrefetch();
+    NeverReady<String> readyAfterPrefetch1 =
+        new NeverReady<>(PrefetchableIterators.fromArray("A", "B"));
+    ReadyAfterPrefetch<String> readyAfterPrefetch2 =
+        new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
+    ReadyAfterPrefetch<String> readyAfterPrefetch3 =
+        new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
 
     PrefetchableIterator<String> iterator =
         PrefetchableIterators.concat(readyAfterPrefetch1, readyAfterPrefetch2, readyAfterPrefetch3);
 
     // Expect no prefetches yet
-    assertEquals(0, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(0, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
 
     // We expect to attempt to prefetch for the first time.
     iterator.prefetch();
-    assertEquals(1, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(1, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // We expect to attempt to prefetch again since we aren't ready.
     iterator.prefetch();
-    assertEquals(2, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(2, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // The current iterator is done but is never ready so we can't advance to the next one and
     // expect another prefetch to go to the current iterator.
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // Now that we know the last iterator is done and have advanced to the next one we expect
     // prefetch to go through
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // The last iterator is done so we should be able to prefetch the next one before advancing
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
 
     // The current iterator is ready so no additional prefetch is necessary
     iterator.prefetch();
-    assertEquals(3, readyAfterPrefetch1.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch2.prefetchCalled);
-    assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+    assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+    assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
     iterator.next();
   }
 
diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle
index 3c859ae..6337cd4 100644
--- a/sdks/java/harness/build.gradle
+++ b/sdks/java/harness/build.gradle
@@ -72,6 +72,7 @@
   testCompile library.java.mockito_core
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
   testCompile project(":runners:core-construction-java")
+  testCompile project(path: ":sdks:java:fn-execution", configuration: "testRuntime")
   shadowTestRuntimeClasspath library.java.slf4j_jdk14
   jmhCompile project(path: ":sdks:java:harness", configuration: "shadowTest")
   jmhRuntime library.java.slf4j_jdk14
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index 5ddf0ae..777036a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -26,6 +26,8 @@
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 
@@ -49,7 +51,7 @@
   private final BeamFnStateClient beamFnStateClient;
   private final StateRequest request;
   private final Coder<T> valueCoder;
-  private Iterable<T> oldValues;
+  private PrefetchableIterable<T> oldValues;
   private ArrayList<T> newValues;
   private boolean isClosed;
 
@@ -80,19 +82,19 @@
     this.newValues = new ArrayList<>();
   }
 
-  public Iterable<T> get() {
+  public PrefetchableIterable<T> get() {
     checkState(
         !isClosed,
         "Bag user state is no longer usable because it is closed for %s",
         request.getStateKey());
     if (oldValues == null) {
       // If we were cleared we should disregard old values.
-      return Iterables.limit(Collections.unmodifiableList(newValues), newValues.size());
+      return PrefetchableIterables.limit(Collections.unmodifiableList(newValues), newValues.size());
     } else if (newValues.isEmpty()) {
       // If we have no new values then just return the old values.
       return oldValues;
     }
-    return Iterables.concat(
+    return PrefetchableIterables.concat(
         oldValues, Iterables.limit(Collections.unmodifiableList(newValues), newValues.size()));
   }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 2f2789e..5a931c5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -38,7 +38,6 @@
 import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.OrderedListState;
 import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
 import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.StateBinder;
 import org.apache.beam.sdk.state.StateContext;
@@ -264,7 +263,7 @@
 
                   @Override
                   public ValueState<T> readLater() {
-                    // TODO(BEAM-12802): Support prefetching.
+                    impl.get().iterator().prefetch();
                     return this;
                   }
                 };
@@ -310,7 +309,7 @@
 
                   @Override
                   public BagState<T> readLater() {
-                    // TODO(BEAM-12802): Support prefetching.
+                    impl.get().iterator().prefetch();
                     return this;
                   }
 
@@ -391,6 +390,7 @@
 
                   @Override
                   public CombiningState<ElementT, AccumT, ResultT> readLater() {
+                    impl.get().iterator().prefetch();
                     return this;
                   }
 
@@ -412,7 +412,18 @@
 
                   @Override
                   public ReadableState<Boolean> isEmpty() {
-                    return ReadableStates.immediate(!impl.get().iterator().hasNext());
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public @Nullable Boolean read() {
+                        return !impl.get().iterator().hasNext();
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        impl.get().iterator().prefetch();
+                        return this;
+                      }
+                    };
                   }
 
                   @Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index cfc76cf..7828f93 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -21,6 +21,9 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.NoSuchElementException;
+import java.util.Objects;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
@@ -28,30 +31,42 @@
  * Converts an iterator to an iterable lazily loading values from the underlying iterator and
  * caching them to support reiteration.
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
-class LazyCachingIteratorToIterable<T> implements Iterable<T> {
+class LazyCachingIteratorToIterable<T> implements PrefetchableIterable<T> {
   private final List<T> cachedElements;
-  private final Iterator<T> iterator;
+  private final PrefetchableIterator<T> iterator;
 
-  public LazyCachingIteratorToIterable(Iterator<T> iterator) {
+  public LazyCachingIteratorToIterable(PrefetchableIterator<T> iterator) {
     this.cachedElements = new ArrayList<>();
     this.iterator = iterator;
   }
 
   @Override
-  public Iterator<T> iterator() {
+  public PrefetchableIterator<T> iterator() {
     return new CachingIterator();
   }
 
   /** An {@link Iterator} which adds and fetched values into the cached elements list. */
-  private class CachingIterator implements Iterator<T> {
+  private class CachingIterator implements PrefetchableIterator<T> {
     private int position = 0;
 
     private CachingIterator() {}
 
     @Override
+    public boolean isReady() {
+      if (position < cachedElements.size()) {
+        return true;
+      }
+      return iterator.isReady();
+    }
+
+    @Override
+    public void prefetch() {
+      if (!isReady()) {
+        iterator.prefetch();
+      }
+    }
+
+    @Override
     public boolean hasNext() {
       // The order of the short circuit is important below.
       return position < cachedElements.size() || iterator.hasNext();
@@ -76,7 +91,7 @@
 
   @Override
   public int hashCode() {
-    return iterator.hasNext() ? iterator.next().hashCode() : -1789023489;
+    return iterator.hasNext() ? Objects.hashCode(iterator.next()) : -1789023489;
   }
 
   @Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 22be306..1026ba5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -17,20 +17,21 @@
  */
 package org.apache.beam.fn.harness.state;
 
-import java.util.Collections;
 import java.util.Iterator;
 import java.util.NoSuchElementException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.function.Supplier;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.fn.stream.DataStreams;
+import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 
 /**
  * Adapters which convert a a logical series of chunks using continuation tokens over the Beam Fn
@@ -54,7 +55,7 @@
    *     only) chunk of a state stream. This state request will be populated with a continuation
    *     token to request further chunks of the stream if required.
    */
-  public static Iterator<ByteString> readAllStartingFrom(
+  public static PrefetchableIterator<ByteString> readAllStartingFrom(
       BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
     return new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk);
   }
@@ -74,94 +75,142 @@
    *     token to request further chunks of the stream if required.
    * @param valueCoder A coder for decoding the state stream.
    */
-  public static <T> Iterable<T> readAllAndDecodeStartingFrom(
+  public static <T> PrefetchableIterable<T> readAllAndDecodeStartingFrom(
       BeamFnStateClient beamFnStateClient,
       StateRequest stateRequestForFirstChunk,
       Coder<T> valueCoder) {
-    FirstPageAndRemainder firstPageAndRemainder =
-        new FirstPageAndRemainder(beamFnStateClient, stateRequestForFirstChunk);
-    return Iterables.concat(
-        new LazyCachingIteratorToIterable<T>(
-            new DataStreams.DataStreamDecoder<>(
-                valueCoder, new LazySingletonIterator<>(firstPageAndRemainder::firstPage))),
-        () -> new DataStreams.DataStreamDecoder<>(valueCoder, firstPageAndRemainder.remainder()));
-  }
-
-  /** A iterable that contains a single element, provided by a Supplier which is invoked lazily. */
-  static class LazySingletonIterator<T> implements Iterator<T> {
-
-    private final Supplier<T> supplier;
-    private boolean hasNext;
-
-    private LazySingletonIterator(Supplier<T> supplier) {
-      this.supplier = supplier;
-      hasNext = true;
-    }
-
-    @Override
-    public boolean hasNext() {
-      return hasNext;
-    }
-
-    @Override
-    public T next() {
-      hasNext = false;
-      return supplier.get();
-    }
+    return new FirstPageAndRemainder<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder);
   }
 
   /**
-   * An helper class that (lazily) gives the first page of a paginated state request separately from
+   * A helper class that (lazily) gives the first page of a paginated state request separately from
    * all the remaining pages.
    */
-  static class FirstPageAndRemainder {
+  @VisibleForTesting
+  static class FirstPageAndRemainder<T> implements PrefetchableIterable<T> {
     private final BeamFnStateClient beamFnStateClient;
     private final StateRequest stateRequestForFirstChunk;
-    private ByteString firstPage = null;
+    private final Coder<T> valueCoder;
+    private LazyCachingIteratorToIterable<T> firstPage;
+    private CompletableFuture<StateResponse> firstPageResponseFuture;
     private ByteString continuationToken;
 
-    private FirstPageAndRemainder(
-        BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
+    FirstPageAndRemainder(
+        BeamFnStateClient beamFnStateClient,
+        StateRequest stateRequestForFirstChunk,
+        Coder<T> valueCoder) {
       this.beamFnStateClient = beamFnStateClient;
       this.stateRequestForFirstChunk = stateRequestForFirstChunk;
+      this.valueCoder = valueCoder;
     }
 
-    public ByteString firstPage() {
-      if (firstPage == null) {
-        CompletableFuture<StateResponse> stateResponseFuture = new CompletableFuture<>();
+    @Override
+    public PrefetchableIterator<T> iterator() {
+      return new PrefetchableIterator<T>() {
+        PrefetchableIterator<T> delegate;
+
+        private void ensureDelegateExists() {
+          if (delegate == null) {
+            // Fetch the first page if necessary
+            prefetchFirstPage();
+            if (firstPage == null) {
+              StateResponse stateResponse;
+              try {
+                stateResponse = firstPageResponseFuture.get();
+              } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                throw new IllegalStateException(e);
+              } catch (ExecutionException e) {
+                if (e.getCause() == null) {
+                  throw new IllegalStateException(e);
+                }
+                Throwables.throwIfUnchecked(e.getCause());
+                throw new IllegalStateException(e.getCause());
+              }
+              continuationToken = stateResponse.getGet().getContinuationToken();
+              firstPage =
+                  new LazyCachingIteratorToIterable<>(
+                      new DataStreamDecoder<>(
+                          valueCoder,
+                          PrefetchableIterators.fromArray(stateResponse.getGet().getData())));
+            }
+
+            if (ByteString.EMPTY.equals((continuationToken))) {
+              delegate = firstPage.iterator();
+            } else {
+              delegate =
+                  PrefetchableIterators.concat(
+                      firstPage.iterator(),
+                      new DataStreamDecoder<>(
+                          valueCoder,
+                          new LazyBlockingStateFetchingIterator(
+                              beamFnStateClient,
+                              stateRequestForFirstChunk
+                                  .toBuilder()
+                                  .setGet(
+                                      StateGetRequest.newBuilder()
+                                          .setContinuationToken(continuationToken))
+                                  .build())));
+            }
+          }
+        }
+
+        @Override
+        public boolean isReady() {
+          if (delegate == null) {
+            if (firstPageResponseFuture != null) {
+              return firstPageResponseFuture.isDone();
+            }
+            return false;
+          }
+          return delegate.isReady();
+        }
+
+        @Override
+        public void prefetch() {
+          if (firstPageResponseFuture == null) {
+            prefetchFirstPage();
+          } else if (delegate != null && !delegate.isReady()) {
+            delegate.prefetch();
+          }
+        }
+
+        @Override
+        public boolean hasNext() {
+          if (delegate == null) {
+            // Ensure that we prefetch the second page after the first has been accessed.
+            // Prefetching subsequent pages after the first will be handled by the
+            // LazyBlockingStateFetchingIterator
+            ensureDelegateExists();
+            boolean rval = delegate.hasNext();
+            delegate.prefetch();
+            return rval;
+          }
+          return delegate.hasNext();
+        }
+
+        @Override
+        public T next() {
+          if (delegate == null) {
+            // Ensure that we prefetch the second page after the first has been accessed.
+            // Prefetching subsequent pages after the first will be handled by the
+            // LazyBlockingStateFetchingIterator
+            ensureDelegateExists();
+            T rval = delegate.next();
+            delegate.prefetch();
+            return rval;
+          }
+          return delegate.next();
+        }
+      };
+    }
+
+    private void prefetchFirstPage() {
+      if (firstPageResponseFuture == null) {
+        firstPageResponseFuture = new CompletableFuture<>();
         beamFnStateClient.handle(
             stateRequestForFirstChunk.toBuilder().setGet(stateRequestForFirstChunk.getGet()),
-            stateResponseFuture);
-        StateResponse stateResponse;
-        try {
-          stateResponse = stateResponseFuture.get();
-        } catch (InterruptedException e) {
-          Thread.currentThread().interrupt();
-          throw new IllegalStateException(e);
-        } catch (ExecutionException e) {
-          if (e.getCause() == null) {
-            throw new IllegalStateException(e);
-          }
-          Throwables.throwIfUnchecked(e.getCause());
-          throw new IllegalStateException(e.getCause());
-        }
-        continuationToken = stateResponse.getGet().getContinuationToken();
-        firstPage = stateResponse.getGet().getData();
-      }
-      return firstPage;
-    }
-
-    public Iterator<ByteString> remainder() {
-      firstPage();
-      if (ByteString.EMPTY.equals(continuationToken)) {
-        return Collections.emptyIterator();
-      } else {
-        return new LazyBlockingStateFetchingIterator(
-            beamFnStateClient,
-            stateRequestForFirstChunk
-                .toBuilder()
-                .setGet(StateGetRequest.newBuilder().setContinuationToken(continuationToken))
-                .build());
+            firstPageResponseFuture);
       }
     }
   }
@@ -169,10 +218,11 @@
   /**
    * An {@link Iterator} which fetches {@link ByteString} chunks using the State API.
    *
-   * <p>This iterator will only request a chunk on first access. Subsiquently it eagerly pre-fetches
-   * one future chunks at a time.
+   * <p>This iterator will only request a chunk on first access. Subsequently it eagerly pre-fetches
+   * one future chunk at a time.
    */
-  static class LazyBlockingStateFetchingIterator implements Iterator<ByteString> {
+  @VisibleForTesting
+  static class LazyBlockingStateFetchingIterator implements PrefetchableIterator<ByteString> {
 
     private enum State {
       READ_REQUIRED,
@@ -195,8 +245,17 @@
       this.continuationToken = stateRequestForFirstChunk.getGet().getContinuationToken();
     }
 
-    private void prefetch() {
-      if (prefetchedResponse == null && currentState == State.READ_REQUIRED) {
+    @Override
+    public boolean isReady() {
+      if (prefetchedResponse == null) {
+        return currentState != State.READ_REQUIRED;
+      }
+      return prefetchedResponse.isDone();
+    }
+
+    @Override
+    public void prefetch() {
+      if (currentState == State.READ_REQUIRED && prefetchedResponse == null) {
         prefetchedResponse = new CompletableFuture<>();
         beamFnStateClient.handle(
             stateRequestForFirstChunk
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
index 7597128..0914b01 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
@@ -25,8 +25,11 @@
 
 import java.util.Iterator;
 import java.util.NoSuchElementException;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
+import org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetch;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -40,7 +43,8 @@
 
   @Test
   public void testEmptyIterator() {
-    Iterable<Object> iterable = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+    Iterable<Object> iterable =
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.emptyIterator());
     assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
     // iterate multiple times
     assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
@@ -52,7 +56,7 @@
   @Test
   public void testInterleavedIteration() {
     Iterable<String> iterable =
-        new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
 
     Iterator<String> iterator1 = iterable.iterator();
     assertTrue(iterator1.hasNext());
@@ -77,14 +81,45 @@
 
   @Test
   public void testEqualsAndHashCode() {
-    Iterable<String> iterA = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
-    Iterable<String> iterB = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
-    Iterable<String> iterC = new LazyCachingIteratorToIterable<>(Iterators.forArray());
-    Iterable<String> iterD = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+    Iterable<String> iterA =
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+    Iterable<String> iterB =
+        new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+    Iterable<String> iterC = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
+    Iterable<String> iterD = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
     assertEquals(iterA, iterB);
     assertEquals(iterC, iterD);
     assertNotEquals(iterA, iterC);
     assertEquals(iterA.hashCode(), iterB.hashCode());
     assertEquals(iterC.hashCode(), iterD.hashCode());
   }
+
+  @Test
+  public void testPrefetch() {
+    ReadyAfterPrefetch<String> underlying =
+        new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B", "C"));
+    PrefetchableIterable<String> iterable = new LazyCachingIteratorToIterable<>(underlying);
+    PrefetchableIterator<String> iterator1 = iterable.iterator();
+    PrefetchableIterator<String> iterator2 = iterable.iterator();
+
+    // Check that the lazy iterable doesn't do any prefetch/access on instantiation
+    assertFalse(underlying.isReady());
+    assertFalse(iterator1.isReady());
+    assertFalse(iterator2.isReady());
+
+    // Check that if both iterators prefetch there is only one prefetch for the underlying iterator
+    // iterator.
+    iterator1.prefetch();
+    assertEquals(1, underlying.getNumPrefetchCalls());
+    iterator2.prefetch();
+    assertEquals(1, underlying.getNumPrefetchCalls());
+
+    // Check that if that one iterator has advanced, the second doesn't perform any prefetch since
+    // the element is now cached.
+    iterator1.next();
+    iterator1.next();
+    iterator2.next();
+    iterator2.prefetch();
+    assertEquals(1, underlying.getNumPrefetchCalls());
+  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
index fc729cc..384d2df 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
@@ -19,12 +19,16 @@
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Iterator;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
+import org.apache.beam.fn.harness.state.StateFetchingIterators.FirstPageAndRemainder;
 import org.apache.beam.fn.harness.state.StateFetchingIterators.LazyBlockingStateFetchingIterator;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
@@ -32,16 +36,56 @@
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
 import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Tests for {@link StateFetchingIterators}. */
+@RunWith(Enclosed.class)
 public class StateFetchingIteratorsTest {
+
+  private static BeamFnStateClient fakeStateClient(
+      AtomicInteger callCount, ByteString... expected) {
+    return (requestBuilder, response) -> {
+      callCount.incrementAndGet();
+      if (expected.length == 0) {
+        response.complete(
+            StateResponse.newBuilder()
+                .setId(requestBuilder.getId())
+                .setGet(StateGetResponse.newBuilder())
+                .build());
+        return;
+      }
+
+      ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
+
+      int requestedPosition = 0; // Default position is 0
+      if (!ByteString.EMPTY.equals(continuationToken)) {
+        requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
+      }
+
+      // Compute the new continuation token
+      ByteString newContinuationToken = ByteString.EMPTY;
+      if (requestedPosition != expected.length - 1) {
+        newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
+      }
+      response.complete(
+          StateResponse.newBuilder()
+              .setId(requestBuilder.getId())
+              .setGet(
+                  StateGetResponse.newBuilder()
+                      .setData(expected[requestedPosition])
+                      .setContinuationToken(newContinuationToken))
+              .build());
+    };
+  }
+
   /** Tests for {@link StateFetchingIterators.LazyBlockingStateFetchingIterator}. */
   @RunWith(JUnit4.class)
   public static class LazyBlockingStateFetchingIteratorTest {
@@ -77,49 +121,55 @@
           ByteString.EMPTY);
     }
 
-    private BeamFnStateClient fakeStateClient(AtomicInteger callCount, ByteString... expected) {
-      return (requestBuilder, response) -> {
-        callCount.incrementAndGet();
-        if (expected.length == 0) {
-          response.complete(
-              StateResponse.newBuilder()
-                  .setId(requestBuilder.getId())
-                  .setGet(StateGetResponse.newBuilder())
-                  .build());
-          return;
-        }
-
-        ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
-
-        int requestedPosition = 0; // Default position is 0
-        if (!ByteString.EMPTY.equals(continuationToken)) {
-          requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
-        }
-
-        // Compute the new continuation token
-        ByteString newContinuationToken = ByteString.EMPTY;
-        if (requestedPosition != expected.length - 1) {
-          newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
-        }
-        response.complete(
-            StateResponse.newBuilder()
-                .setId(requestBuilder.getId())
-                .setGet(
-                    StateGetResponse.newBuilder()
-                        .setData(expected[requestedPosition])
-                        .setContinuationToken(newContinuationToken))
-                .build());
-      };
+    @Test
+    public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception {
+      AtomicInteger callCount = new AtomicInteger();
+      BeamFnStateClient fakeStateClient =
+          new BeamFnStateClient() {
+            @Override
+            public void handle(
+                StateRequest.Builder requestBuilder, CompletableFuture<StateResponse> response) {
+              callCount.incrementAndGet();
+            }
+          };
+      PrefetchableIterator<ByteString> byteStrings =
+          new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
+      assertEquals(0, callCount.get());
+      byteStrings.prefetch();
+      assertEquals(1, callCount.get()); // first prefetch
+      byteStrings.prefetch();
+      assertEquals(1, callCount.get()); // subsequent is ignored
     }
 
     private void testFetch(ByteString... expected) {
       AtomicInteger callCount = new AtomicInteger();
       BeamFnStateClient fakeStateClient = fakeStateClient(callCount, expected);
-      Iterator<ByteString> byteStrings =
+      PrefetchableIterator<ByteString> byteStrings =
           new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
       assertEquals(0, callCount.get()); // Ensure it's fully lazy.
-      assertArrayEquals(expected, Iterators.toArray(byteStrings, Object.class));
+      assertFalse(byteStrings.isReady());
+
+      // Prefetch every second element in the iterator capturing the results
+      List<ByteString> results = new ArrayList<>();
+      for (int i = 0; i < expected.length; ++i) {
+        if (i % 2 == 0) {
+          // Ensure that prefetch performs the call
+          byteStrings.prefetch();
+          assertEquals(i + 1, callCount.get());
+          assertTrue(byteStrings.isReady());
+        }
+        assertTrue(byteStrings.hasNext());
+        results.add(byteStrings.next());
+      }
+      assertFalse(byteStrings.hasNext());
+      assertTrue(byteStrings.isReady());
+
+      assertEquals(Arrays.asList(expected), results);
     }
+  }
+
+  @RunWith(JUnit4.class)
+  public static class FirstPageAndRemainderTest {
 
     @Test
     public void testEmptyValues() throws Exception {
@@ -133,7 +183,7 @@
 
     @Test
     public void testManyValues() throws Exception {
-      testFetchValues(VarIntCoder.of(), 11, 37, 389, 5077);
+      testFetchValues(VarIntCoder.of(), 1, 22, 333, 4444, 55555, 666666);
     }
 
     private <T> void testFetchValues(Coder<T> coder, T... expected) {
@@ -153,35 +203,42 @@
       AtomicInteger callCount = new AtomicInteger();
       BeamFnStateClient fakeStateClient =
           fakeStateClient(callCount, Iterables.toArray(byteStrings, ByteString.class));
-      Iterable<T> values =
-          StateFetchingIterators.readAllAndDecodeStartingFrom(
-              fakeStateClient, StateRequest.getDefaultInstance(), coder);
+      PrefetchableIterable<T> values =
+          new FirstPageAndRemainder<>(fakeStateClient, StateRequest.getDefaultInstance(), coder);
 
       // Ensure it's fully lazy.
       assertEquals(0, callCount.get());
-      Iterator<T> valuesIter = values.iterator();
+      PrefetchableIterator<T> valuesIter = values.iterator();
+      assertFalse(valuesIter.isReady());
       assertEquals(0, callCount.get());
 
-      // No more is read than necessary.
-      if (valuesIter.hasNext()) {
-        valuesIter.next();
-      }
+      // Ensure that the first page result is cached across multiple iterators and subsequent
+      // iterators are ready and prefetch does nothing
+      valuesIter.prefetch();
+      assertTrue(valuesIter.isReady());
       assertEquals(1, callCount.get());
 
-      // The first page is cached.
-      Iterator<T> valuesIter2 = values.iterator();
-      assertEquals(1, callCount.get());
-      if (valuesIter2.hasNext()) {
-        valuesIter2.next();
-      }
+      PrefetchableIterator<T> valuesIter2 = values.iterator();
+      assertTrue(valuesIter2.isReady());
+      valuesIter2.prefetch();
       assertEquals(1, callCount.get());
 
-      if (valuesIter.hasNext()) {
-        valuesIter.next();
-        // Subsequent pages are pre-fetched, so after accessing the second page,
-        // the third should be requested.
-        assertEquals(3, callCount.get());
+      // Prefetch every second element in the iterator capturing the results
+      List<T> results = new ArrayList<>();
+      for (int i = 0; i < expected.length; ++i) {
+        if (i % 2 == 1) {
+          // Ensure that prefetch performs the call
+          valuesIter2.prefetch();
+          assertTrue(valuesIter2.isReady());
+          // Note that this is i+2 because we expect to prefetch the page after the current one
+          // We also have to bound it to the max number of pages
+          assertEquals(Math.min(i + 2, expected.length), callCount.get());
+        }
+        assertTrue(valuesIter2.hasNext());
+        results.add(valuesIter2.next());
       }
+      assertFalse(valuesIter2.hasNext());
+      assertTrue(valuesIter2.isReady());
 
       // The contents agree.
       assertArrayEquals(expected, Iterables.toArray(values, Object.class));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
index dd5202e..74d6636 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
@@ -1218,6 +1218,10 @@
         List<Cursor> cursors = new ArrayList<>(partitionQueryResponse.getPartitionsList());
         cursors.sort(CURSOR_REFERENCE_VALUE_COMPARATOR);
         final int size = cursors.size();
+        if (size == 0) {
+          emit(c, dbRoot, structuredQuery.toBuilder());
+          return;
+        }
         final int lastIdx = size - 1;
         for (int i = 0; i < size; i++) {
           Cursor curr = cursors.get(i);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
index 0c9bbf1..1f29883 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
@@ -99,6 +99,41 @@
     assertEquals(expected, allValues);
   }
 
+  @Test
+  public void endToEnd_emptyCursors() throws Exception {
+    // First page of the response
+    PartitionQueryRequest request1 =
+        PartitionQueryRequest.newBuilder()
+            .setParent(String.format("projects/%s/databases/(default)/document", projectId))
+            .build();
+    PartitionQueryResponse response1 = PartitionQueryResponse.newBuilder().build();
+    when(callable.call(request1)).thenReturn(pagedResponse1);
+    when(page1.getResponse()).thenReturn(response1);
+    when(pagedResponse1.iteratePages()).thenReturn(ImmutableList.of(page1));
+
+    when(stub.partitionQueryPagedCallable()).thenReturn(callable);
+
+    when(ff.getFirestoreStub(any())).thenReturn(stub);
+    RpcQosOptions options = RpcQosOptions.defaultOptions();
+    when(ff.getRpcQos(any()))
+        .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options));
+
+    ArgumentCaptor<PartitionQueryPair> responses =
+        ArgumentCaptor.forClass(PartitionQueryPair.class);
+
+    doNothing().when(processContext).output(responses.capture());
+
+    when(processContext.element()).thenReturn(request1);
+
+    PartitionQueryFn fn = new PartitionQueryFn(clock, ff, options);
+
+    runFunction(fn);
+
+    List<PartitionQueryPair> expected = newArrayList(new PartitionQueryPair(request1, response1));
+    List<PartitionQueryPair> allValues = responses.getAllValues();
+    assertEquals(expected, allValues);
+  }
+
   @Override
   public void resumeFromLastReadValue() throws Exception {
     when(ff.getFirestoreStub(any())).thenReturn(stub);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
index 25ed63c..ed789da 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
@@ -121,6 +121,39 @@
     assertEquals(expectedQueries, actualQueries);
   }
 
+  @Test
+  public void ensureCursorPairingWorks_emptyCursorsInResponse() {
+    StructuredQuery query =
+        StructuredQuery.newBuilder()
+            .addFrom(
+                CollectionSelector.newBuilder()
+                    .setAllDescendants(true)
+                    .setCollectionId("c1")
+                    .build())
+            .build();
+
+    List<StructuredQuery> expectedQueries = newArrayList(query);
+
+    PartitionQueryPair partitionQueryPair =
+        new PartitionQueryPair(
+            PartitionQueryRequest.newBuilder().setStructuredQuery(query).build(),
+            PartitionQueryResponse.newBuilder().build());
+
+    ArgumentCaptor<RunQueryRequest> captor = ArgumentCaptor.forClass(RunQueryRequest.class);
+    when(processContext.element()).thenReturn(partitionQueryPair);
+    doNothing().when(processContext).output(captor.capture());
+
+    PartitionQueryResponseToRunQueryRequest fn = new PartitionQueryResponseToRunQueryRequest();
+    fn.processElement(processContext);
+
+    List<StructuredQuery> actualQueries =
+        captor.getAllValues().stream()
+            .map(RunQueryRequest::getStructuredQuery)
+            .collect(Collectors.toList());
+
+    assertEquals(expectedQueries, actualQueries);
+  }
+
   private static Cursor referenceValueCursor(String referenceValue) {
     return Cursor.newBuilder()
         .addValues(Value.newBuilder().setReferenceValue(referenceValue).build())
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 1c02f71..98d2faa 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -185,9 +185,12 @@
 @ptransform.PTransform.register_urn(TEST_COMPK_URN, None)
 class CombinePerKeyTransform(ptransform.PTransform):
   def expand(self, pcoll):
-    return pcoll \
-           | beam.CombinePerKey(sum).with_output_types(
-               typing.Tuple[str, int])
+    output = pcoll \
+           | beam.CombinePerKey(sum)
+    # TODO: Use `with_output_types` instead of explicitly
+    #  assigning to `.element_type` after fixing BEAM-12872
+    output.element_type = beam.typehints.Tuple[str, int]
+    return output
 
   def to_runner_api_parameter(self, unused_context):
     return TEST_COMPK_URN, None
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index dfe45c4..41ad3df 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -21,6 +21,7 @@
 
 import copy
 import heapq
+import itertools
 import operator
 import random
 from typing import Any
@@ -606,16 +607,25 @@
 
 
 class _TupleCombineFnBase(core.CombineFn):
-  def __init__(self, *combiners):
+  def __init__(self, *combiners, merge_accumulators_batch_size=None):
     self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners]
     self._named_combiners = combiners
+    # If the `merge_accumulators_batch_size` value is not specified, we chose a
+    # bounded default that is inversely proportional to the number of
+    # accumulators in merged tuples.
+    num_combiners = max(1, len(combiners))
+    self._merge_accumulators_batch_size = (
+        merge_accumulators_batch_size or max(10, 1000 // num_combiners))
 
   def display_data(self):
     combiners = [
         c.__name__ if hasattr(c, '__name__') else c.__class__.__name__
         for c in self._named_combiners
     ]
-    return {'combiners': str(combiners)}
+    return {
+        'combiners': str(combiners),
+        'merge_accumulators_batch_size': self._merge_accumulators_batch_size
+    }
 
   def setup(self, *args, **kwargs):
     for c in self._combiners:
@@ -625,10 +635,22 @@
     return [c.create_accumulator(*args, **kwargs) for c in self._combiners]
 
   def merge_accumulators(self, accumulators, *args, **kwargs):
-    return [
-        c.merge_accumulators(a, *args, **kwargs) for c,
-        a in zip(self._combiners, zip(*accumulators))
-    ]
+    # Make sure that `accumulators` is an iterator (so that the position is
+    # remembered).
+    accumulators = iter(accumulators)
+    result = next(accumulators)
+    while True:
+      # Load accumulators into memory and merge in batches to decrease peak
+      # memory usage.
+      accumulators_batch = [result] + list(
+          itertools.islice(accumulators, self._merge_accumulators_batch_size))
+      if len(accumulators_batch) == 1:
+        break
+      result = [
+          c.merge_accumulators(a, *args, **kwargs) for c,
+          a in zip(self._combiners, zip(*accumulators_batch))
+      ]
+    return result
 
   def compact(self, accumulator, *args, **kwargs):
     return [
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index d826287..7e0e835 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -249,7 +249,8 @@
     dd = DisplayData.create_from(transform)
     expected_items = [
         DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
-        DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']")
+        DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
+        DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
     ]
     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
 
@@ -358,6 +359,49 @@
                   max).with_common_input()).without_defaults())
       assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
 
+  def test_empty_tuple_combine_fn(self):
+    with TestPipeline() as p:
+      result = (
+          p
+          | Create([(), (), ()])
+          | beam.CombineGlobally(combine.TupleCombineFn()))
+      assert_that(result, equal_to([()]))
+
+  def test_tuple_combine_fn_batched_merge(self):
+    num_combine_fns = 10
+    max_num_accumulators_in_memory = 30
+    # Maximum number of accumulator tuples in memory - 1 for the merge result.
+    merge_accumulators_batch_size = (
+        max_num_accumulators_in_memory // num_combine_fns - 1)
+    num_accumulator_tuples_to_merge = 20
+
+    class CountedAccumulator:
+      count = 0
+      oom = False
+
+      def __init__(self):
+        if CountedAccumulator.count > max_num_accumulators_in_memory:
+          CountedAccumulator.oom = True
+        else:
+          CountedAccumulator.count += 1
+
+    class CountedAccumulatorCombineFn(beam.CombineFn):
+      def create_accumulator(self):
+        return CountedAccumulator()
+
+      def merge_accumulators(self, accumulators):
+        CountedAccumulator.count += 1
+        for _ in accumulators:
+          CountedAccumulator.count -= 1
+
+    combine_fn = combine.TupleCombineFn(
+        *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
+        merge_accumulators_batch_size=merge_accumulators_batch_size)
+    combine_fn.merge_accumulators(
+        combine_fn.create_accumulator()
+        for _ in range(num_accumulator_tuples_to_merge))
+    assert not CountedAccumulator.oom
+
   def test_to_list_and_to_dict1(self):
     with TestPipeline() as pipeline:
       the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]