Merge pull request #15498 from apache/BEAM-12873
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 12b7df2..17b59ec 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -794,7 +794,8 @@
final int numIters = 2000;
for (int i = 0; i < numIters; ++i) {
- server.addWorkToOffer(makeInput(i, 0, "key", DEFAULT_SHARDING_KEY));
+ server.addWorkToOffer(
+ makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), "key", DEFAULT_SHARDING_KEY));
}
Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
@@ -829,7 +830,8 @@
final int numIters = 2000;
for (int i = 0; i < numIters; ++i) {
- server.addWorkToOffer(makeInput(i, 0, "key", DEFAULT_SHARDING_KEY));
+ server.addWorkToOffer(
+ makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), "key", DEFAULT_SHARDING_KEY));
}
Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
new file mode 100644
index 0000000..5496d8b
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statecache implements the state caching feature described by the
+// Beam Fn API
+//
+// The Beam State API and the intended caching behavior are described here:
+// https://docs.google.com/document/d/1BOozW0bzBuz4oHJEuZNDOHdzaV5Y56ix58Ozrqm2jFg/edit#heading=h.7ghoih5aig5m
+package statecache
+
+import (
+ "sync"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+ fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+type token string
+
+// SideInputCache stores a cache of reusable inputs for the purposes of
+// eliminating redundant calls to the runner during execution of ParDos
+// using side inputs.
+//
+// A SideInputCache should be initialized when the SDK harness is initialized,
+// creating storage for side input caching. On each ProcessBundleRequest,
+// the cache will process the list of tokens for cacheable side inputs and
+// be queried when side inputs are requested in bundle execution. Once a
+// new bundle request comes in the valid tokens will be updated and the cache
+// will be re-used. In the event that the cache reaches capacity, a random,
+// currently invalid cached object will be evicted.
+type SideInputCache struct {
+ capacity int
+ mu sync.Mutex
+ cache map[token]exec.ReusableInput
+ idsToTokens map[string]token
+ validTokens map[token]int8 // Maps tokens to active bundle counts
+ metrics CacheMetrics
+}
+
+type CacheMetrics struct {
+ Hits int64
+ Misses int64
+ Evictions int64
+ InUseEvictions int64
+}
+
+// Init makes the cache map and the map of IDs to cache tokens for the
+// SideInputCache. Should only be called once. Returns an error for
+// non-positive capacities.
+func (c *SideInputCache) Init(cap int) error {
+ if cap <= 0 {
+ return errors.Errorf("capacity must be a positive integer, got %v", cap)
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.cache = make(map[token]exec.ReusableInput, cap)
+ c.idsToTokens = make(map[string]token)
+ c.validTokens = make(map[token]int8)
+ c.capacity = cap
+ return nil
+}
+
+// SetValidTokens clears the list of valid tokens then sets new ones, also updating the mapping of
+// transform and side input IDs to cache tokens in the process. Should be called at the start of every
+// new ProcessBundleRequest. If the runner does not support caching, the passed cache token values
+// should be empty and all get/set requests will silently be no-ops.
+func (c *SideInputCache) SetValidTokens(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for _, tok := range cacheTokens {
+ // User State caching is currently not supported, so these tokens are ignored
+ if tok.GetUserState() != nil {
+ continue
+ }
+ s := tok.GetSideInput()
+ transformID := s.GetTransformId()
+ sideInputID := s.GetSideInputId()
+ t := token(tok.GetToken())
+ c.setValidToken(transformID, sideInputID, t)
+ }
+}
+
+// setValidToken adds a new valid token for a request into the SideInputCache struct
+// by mapping the transform ID and side input ID pairing to the cache token.
+func (c *SideInputCache) setValidToken(transformID, sideInputID string, tok token) {
+ idKey := transformID + sideInputID
+ c.idsToTokens[idKey] = tok
+ count, ok := c.validTokens[tok]
+ if !ok {
+ c.validTokens[tok] = 1
+ } else {
+ c.validTokens[tok] = count + 1
+ }
+}
+
+// CompleteBundle takes the cache tokens passed to set the valid tokens and decrements their
+// usage count for the purposes of maintaining a valid count of whether or not a value is
+// still in use. Should be called once ProcessBundle has completed.
+func (c *SideInputCache) CompleteBundle(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for _, tok := range cacheTokens {
+ // User State caching is currently not supported, so these tokens are ignored
+ if tok.GetUserState() != nil {
+ continue
+ }
+ t := token(tok.GetToken())
+ c.decrementTokenCount(t)
+ }
+}
+
+// decrementTokenCount decrements the validTokens entry for
+// a given token by 1. Should only be called when completing
+// a bundle.
+func (c *SideInputCache) decrementTokenCount(tok token) {
+ count := c.validTokens[tok]
+ if count == 1 {
+ delete(c.validTokens, tok)
+ } else {
+ c.validTokens[tok] = count - 1
+ }
+}
+
+func (c *SideInputCache) makeAndValidateToken(transformID, sideInputID string) (token, bool) {
+ idKey := transformID + sideInputID
+ // Check if it's a known token
+ tok, ok := c.idsToTokens[idKey]
+ if !ok {
+ return "", false
+ }
+ return tok, c.isValid(tok)
+}
+
+// QueryCache takes a transform ID and side input ID and checking if a corresponding side
+// input has been cached. A query having a bad token (e.g. one that doesn't make a known
+// token or one that makes a known but currently invalid token) is treated the same as a
+// cache miss.
+func (c *SideInputCache) QueryCache(transformID, sideInputID string) exec.ReusableInput {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+ if !ok {
+ return nil
+ }
+ // Check to see if cached
+ input, ok := c.cache[tok]
+ if !ok {
+ c.metrics.Misses++
+ return nil
+ }
+
+ c.metrics.Hits++
+ return input
+}
+
+// SetCache allows a user to place a ReusableInput materialized from the reader into the SideInputCache
+// with its corresponding transform ID and side input ID. If the IDs do not pair with a known, valid token
+// then we silently do not cache the input, as this is an indication that the runner is treating that input
+// as uncacheable.
+func (c *SideInputCache) SetCache(transformID, sideInputID string, input exec.ReusableInput) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+ if !ok {
+ return
+ }
+ if len(c.cache) >= c.capacity {
+ c.evictElement()
+ }
+ c.cache[tok] = input
+}
+
+func (c *SideInputCache) isValid(tok token) bool {
+ count, ok := c.validTokens[tok]
+ // If the token is not known or not in use, return false
+ return ok && count > 0
+}
+
+// evictElement randomly evicts a ReusableInput that is not currently valid from the cache.
+// It should only be called by a goroutine that obtained the lock in SetCache.
+func (c *SideInputCache) evictElement() {
+ deleted := false
+ // Select a key from the cache at random
+ for k := range c.cache {
+ // Do not evict an element if it's currently valid
+ if !c.isValid(k) {
+ delete(c.cache, k)
+ c.metrics.Evictions++
+ deleted = true
+ break
+ }
+ }
+ // Nothing is deleted if every side input is still valid. Clear
+ // out a random entry and record the in-use eviction
+ if !deleted {
+ for k := range c.cache {
+ delete(c.cache, k)
+ c.metrics.InUseEvictions++
+ break
+ }
+ }
+}
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
new file mode 100644
index 0000000..b9970c3
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statecache
+
+import (
+ "testing"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+ fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+// TestReusableInput implements the ReusableInput interface for the purposes
+// of testing.
+type TestReusableInput struct {
+ transformID string
+ sideInputID string
+ value interface{}
+}
+
+func makeTestReusableInput(transformID, sideInputID string, value interface{}) exec.ReusableInput {
+ return &TestReusableInput{transformID: transformID, sideInputID: sideInputID, value: value}
+}
+
+// Init is a ReusableInput interface method, this is a no-op.
+func (r *TestReusableInput) Init() error {
+ return nil
+}
+
+// Value returns the stored value in the TestReusableInput.
+func (r *TestReusableInput) Value() interface{} {
+ return r.value
+}
+
+// Reset clears the value in the TestReusableInput.
+func (r *TestReusableInput) Reset() error {
+ r.value = nil
+ return nil
+}
+
+func TestInit(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(5)
+ if err != nil {
+ t.Errorf("SideInputCache failed but should have succeeded, got %v", err)
+ }
+}
+
+func TestInit_Bad(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(0)
+ if err == nil {
+ t.Error("SideInputCache init succeeded but should have failed")
+ }
+}
+
+func TestQueryCache_EmptyCase(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+ output := s.QueryCache("side1", "transform1")
+ if output != nil {
+ t.Errorf("Cache hit when it should have missed, got %v", output)
+ }
+}
+
+func TestSetCache_UncacheableCase(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+ input := makeTestReusableInput("t1", "s1", 10)
+ s.SetCache("t1", "s1", input)
+ output := s.QueryCache("t1", "s1")
+ if output != nil {
+ t.Errorf("Cache hit when should have missed, got %v", output)
+ }
+}
+
+func TestSetCache_CacheableCase(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+ transID := "t1"
+ sideID := "s1"
+ tok := token("tok1")
+ s.setValidToken(transID, sideID, tok)
+ input := makeTestReusableInput(transID, sideID, 10)
+ s.SetCache(transID, sideID, input)
+ output := s.QueryCache(transID, sideID)
+ if output == nil {
+ t.Fatalf("call to query cache missed when should have hit")
+ }
+ val, ok := output.Value().(int)
+ if !ok {
+ t.Errorf("failed to convert value to integer, got %v", output.Value())
+ }
+ if val != 10 {
+ t.Errorf("element mismatch, expected 10, got %v", val)
+ }
+}
+
+func makeRequest(transformID, sideInputID string, t token) fnpb.ProcessBundleRequest_CacheToken {
+ var tok fnpb.ProcessBundleRequest_CacheToken
+ var wrap fnpb.ProcessBundleRequest_CacheToken_SideInput_
+ var side fnpb.ProcessBundleRequest_CacheToken_SideInput
+ side.TransformId = transformID
+ side.SideInputId = sideInputID
+ wrap.SideInput = &side
+ tok.Type = &wrap
+ tok.Token = []byte(t)
+ return tok
+}
+
+func TestSetValidTokens(t *testing.T) {
+ inputs := []struct {
+ transformID string
+ sideInputID string
+ tok token
+ }{
+ {
+ "t1",
+ "s1",
+ "tok1",
+ },
+ {
+ "t2",
+ "s2",
+ "tok2",
+ },
+ {
+ "t3",
+ "s3",
+ "tok3",
+ },
+ }
+
+ var s SideInputCache
+ err := s.Init(3)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ var tokens []fnpb.ProcessBundleRequest_CacheToken
+ for _, input := range inputs {
+ t := makeRequest(input.transformID, input.sideInputID, input.tok)
+ tokens = append(tokens, t)
+ }
+
+ s.SetValidTokens(tokens...)
+ if len(s.idsToTokens) != len(inputs) {
+ t.Errorf("Missing tokens, expected %v, got %v", len(inputs), len(s.idsToTokens))
+ }
+
+ for i, input := range inputs {
+ // Check that the token is in the valid list
+ if !s.isValid(input.tok) {
+ t.Errorf("error in input %v, token %v is not valid", i, input.tok)
+ }
+ // Check that the mapping of IDs to tokens is correct
+ mapped := s.idsToTokens[input.transformID+input.sideInputID]
+ if mapped != input.tok {
+ t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tok, mapped)
+ }
+ }
+}
+
+func TestSetValidTokens_ClearingBetween(t *testing.T) {
+ inputs := []struct {
+ transformID string
+ sideInputID string
+ tk token
+ }{
+ {
+ "t1",
+ "s1",
+ "tok1",
+ },
+ {
+ "t2",
+ "s2",
+ "tok2",
+ },
+ {
+ "t3",
+ "s3",
+ "tok3",
+ },
+ }
+
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ for i, input := range inputs {
+ tok := makeRequest(input.transformID, input.sideInputID, input.tk)
+
+ s.SetValidTokens(tok)
+
+ // Check that the token is in the valid list
+ if !s.isValid(input.tk) {
+ t.Errorf("error in input %v, token %v is not valid", i, input.tk)
+ }
+ // Check that the mapping of IDs to tokens is correct
+ mapped := s.idsToTokens[input.transformID+input.sideInputID]
+ if mapped != input.tk {
+ t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tk, mapped)
+ }
+
+ s.CompleteBundle(tok)
+ }
+
+ for k, _ := range s.validTokens {
+ if s.validTokens[k] != 0 {
+ t.Errorf("token count mismatch for token %v, expected 0, got %v", k, s.validTokens[k])
+ }
+ }
+}
+
+func TestSetCache_Eviction(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ tokOne := makeRequest("t1", "s1", "tok1")
+ inOne := makeTestReusableInput("t1", "s1", 10)
+ s.SetValidTokens(tokOne)
+ s.SetCache("t1", "s1", inOne)
+ // Mark bundle as complete, drop count for tokOne to 0
+ s.CompleteBundle(tokOne)
+
+ tokTwo := makeRequest("t2", "s2", "tok2")
+ inTwo := makeTestReusableInput("t2", "s2", 20)
+ s.SetValidTokens(tokTwo)
+ s.SetCache("t2", "s2", inTwo)
+
+ if len(s.cache) != 1 {
+ t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+ }
+ if s.metrics.Evictions != 1 {
+ t.Errorf("number evictions incorrect, expected 1, got %v", s.metrics.Evictions)
+ }
+}
+
+func TestSetCache_EvictionFailure(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ tokOne := makeRequest("t1", "s1", "tok1")
+ inOne := makeTestReusableInput("t1", "s1", 10)
+
+ tokTwo := makeRequest("t2", "s2", "tok2")
+ inTwo := makeTestReusableInput("t2", "s2", 20)
+
+ s.SetValidTokens(tokOne, tokTwo)
+ s.SetCache("t1", "s1", inOne)
+ // Should fail to evict because the first token is still valid
+ s.SetCache("t2", "s2", inTwo)
+ // Cache should not exceed size 1
+ if len(s.cache) != 1 {
+ t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+ }
+ if s.metrics.InUseEvictions != 1 {
+ t.Errorf("number of failed evicition calls incorrect, expected 1, got %v", s.metrics.InUseEvictions)
+ }
+}
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
index 511f839..f4ab8bb 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
@@ -156,16 +156,20 @@
}
/**
- * An adapter which converts an {@link InputStream} to an {@link Iterator} of {@code T} values
- * using the specified {@link Coder}.
+ * An adapter which converts an {@link InputStream} to a {@link PrefetchableIterator} of {@code T}
+ * values using the specified {@link Coder}.
*
* <p>Note that this adapter follows the Beam Fn API specification for forcing values that decode
* consuming zero bytes to consuming exactly one byte.
*
* <p>Note that access to the underlying {@link InputStream} is lazy and will only be invoked on
- * first access to {@link #next()} or {@link #hasNext()}.
+ * first access to {@link #next}, {@link #hasNext}, {@link #isReady}, and {@link #prefetch}.
+ *
+ * <p>Note that {@link #isReady} and {@link #prefetch} rely on non-empty {@link ByteString}s being
+ * returned via the underlying {@link PrefetchableIterator} otherwise the {@link #prefetch} will
+ * seemingly make zero progress yet will actually advance through the empty pages.
*/
- public static class DataStreamDecoder<T> implements Iterator<T> {
+ public static class DataStreamDecoder<T> implements PrefetchableIterator<T> {
private enum State {
READ_REQUIRED,
@@ -173,13 +177,13 @@
EOF
}
- private final Iterator<ByteString> inputByteStrings;
+ private final PrefetchableIterator<ByteString> inputByteStrings;
private final Inbound inbound;
private final Coder<T> coder;
private State currentState;
private T next;
- public DataStreamDecoder(Coder<T> coder, Iterator<ByteString> inputStream) {
+ public DataStreamDecoder(Coder<T> coder, PrefetchableIterator<ByteString> inputStream) {
this.currentState = State.READ_REQUIRED;
this.coder = coder;
this.inputByteStrings = inputStream;
@@ -187,6 +191,31 @@
}
@Override
+ public boolean isReady() {
+ switch (currentState) {
+ case EOF:
+ return true;
+ case READ_REQUIRED:
+ try {
+ return inbound.isReady();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ case HAS_NEXT:
+ return true;
+ default:
+ throw new IllegalStateException(String.format("Unknown state %s", currentState));
+ }
+ }
+
+ @Override
+ public void prefetch() {
+ if (!isReady()) {
+ inputByteStrings.prefetch();
+ }
+ }
+
+ @Override
public boolean hasNext() {
switch (currentState) {
case EOF:
@@ -232,8 +261,8 @@
private static final InputStream EMPTY_STREAM = ByteString.EMPTY.newInput();
/**
- * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the first
- * {@link Iterator} on first access of this input stream.
+ * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the {@link
+ * Iterator} on first access of this input stream.
*
* <p>Closing this input stream has no effect.
*/
@@ -245,6 +274,22 @@
this.currentStream = EMPTY_STREAM;
}
+ public boolean isReady() throws IOException {
+ // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
+ // minus the number of bytes that have been read so far and can be reliably used to tell
+ // us whether we are at the end of the stream.
+ while (currentStream.available() == 0) {
+ if (!inputByteStrings.isReady()) {
+ return false;
+ }
+ if (!inputByteStrings.hasNext()) {
+ return true;
+ }
+ currentStream = inputByteStrings.next().newInput();
+ }
+ return true;
+ }
+
public boolean isEof() throws IOException {
// Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
// minus the number of bytes that have been read so far and can be reliably used to tell
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
index 9dd5ee4..a8b48e8 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
@@ -23,6 +23,7 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
import java.io.IOException;
@@ -106,7 +107,7 @@
}
@Test
- public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception {
+ public void testNonEmptyInputStreamWithZeroLengthEncoding() throws Exception {
CountingOutputStream countingOutputStream =
new CountingOutputStream(ByteStreams.nullOutputStream());
GlobalWindow.Coder.INSTANCE.encode(GlobalWindow.INSTANCE, countingOutputStream);
@@ -115,6 +116,55 @@
testDecoderWith(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, GlobalWindow.INSTANCE);
}
+ @Test
+ public void testPrefetch() throws Exception {
+ List<ByteString> encodings = new ArrayList<>();
+ {
+ ByteString.Output encoding = ByteString.newOutput();
+ StringUtf8Coder.of().encode("A", encoding);
+ StringUtf8Coder.of().encode("BC", encoding);
+ encodings.add(encoding.toByteString());
+ }
+ encodings.add(ByteString.EMPTY);
+ {
+ ByteString.Output encoding = ByteString.newOutput();
+ StringUtf8Coder.of().encode("DEF", encoding);
+ StringUtf8Coder.of().encode("GHIJ", encoding);
+ encodings.add(encoding.toByteString());
+ }
+
+ PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<ByteString> iterator =
+ new PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<>(encodings.iterator());
+ PrefetchableIterator<String> decoder =
+ new DataStreamDecoder<>(StringUtf8Coder.of(), iterator);
+ assertFalse(decoder.isReady());
+ decoder.prefetch();
+ assertTrue(decoder.isReady());
+ assertEquals(1, iterator.getNumPrefetchCalls());
+
+ decoder.next();
+ // Now we will have moved off of the empty byte array that we start with so prefetch will
+ // do nothing since we are ready
+ assertTrue(decoder.isReady());
+ decoder.prefetch();
+ assertEquals(1, iterator.getNumPrefetchCalls());
+
+ decoder.next();
+ // Now we are at the end of the first ByteString so we expect a prefetch to pass through
+ assertFalse(decoder.isReady());
+ decoder.prefetch();
+ assertEquals(2, iterator.getNumPrefetchCalls());
+ // We also expect the decoder to not be ready since the next byte string is empty which
+ // would require us to move to the next page. This typically wouldn't happen in practice
+ // though because we expect non empty pages.
+ assertFalse(decoder.isReady());
+
+ // Prefetching will allow us to move to the third ByteString
+ decoder.prefetch();
+ assertEquals(3, iterator.getNumPrefetchCalls());
+ assertTrue(decoder.isReady());
+ }
+
private <T> void testDecoderWith(Coder<T> coder, T... expected) throws IOException {
ByteString.Output output = ByteString.newOutput();
for (T value : expected) {
@@ -131,7 +181,9 @@
}
private <T> void testDecoderWith(Coder<T> coder, T[] expected, List<ByteString> encoded) {
- Iterator<T> decoder = new DataStreamDecoder<>(coder, encoded.iterator());
+ Iterator<T> decoder =
+ new DataStreamDecoder<>(
+ coder, PrefetchableIterators.maybePrefetchable(encoded.iterator()));
Object[] actual = Iterators.toArray(decoder, Object.class);
assertArrayEquals(expected, actual);
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
index 6131634..9ada175 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
@@ -120,10 +120,14 @@
"F");
}
- private static class NeverReady implements PrefetchableIterator<String> {
- PrefetchableIterator<String> delegate = PrefetchableIterators.fromArray("A", "B");
+ public static class NeverReady<T> implements PrefetchableIterator<T> {
+ private final Iterator<T> delegate;
int prefetchCalled;
+ public NeverReady(Iterator<T> delegate) {
+ this.delegate = delegate;
+ }
+
@Override
public boolean isReady() {
return false;
@@ -140,74 +144,117 @@
}
@Override
- public String next() {
+ public T next() {
return delegate.next();
}
+
+ public int getNumPrefetchCalls() {
+ return prefetchCalled;
+ }
}
- private static class ReadyAfterPrefetch extends NeverReady {
+ public static class ReadyAfterPrefetch<T> extends NeverReady<T> {
+
+ public ReadyAfterPrefetch(Iterator<T> delegate) {
+ super(delegate);
+ }
+
@Override
public boolean isReady() {
return prefetchCalled > 0;
}
}
+ public static class ReadyAfterPrefetchUntilNext<T> extends ReadyAfterPrefetch<T> {
+ boolean advancedSincePrefetch;
+
+ public ReadyAfterPrefetchUntilNext(Iterator<T> delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public boolean isReady() {
+ return !advancedSincePrefetch && super.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ advancedSincePrefetch = false;
+ super.prefetch();
+ }
+
+ @Override
+ public T next() {
+ advancedSincePrefetch = true;
+ return super.next();
+ }
+
+ @Override
+ public boolean hasNext() {
+ advancedSincePrefetch = true;
+ return super.hasNext();
+ }
+ }
+
@Test
public void testConcatIsReadyAdvancesToNextIteratorWhenAble() {
- NeverReady readyAfterPrefetch1 = new NeverReady();
- ReadyAfterPrefetch readyAfterPrefetch2 = new ReadyAfterPrefetch();
- ReadyAfterPrefetch readyAfterPrefetch3 = new ReadyAfterPrefetch();
+ NeverReady<String> readyAfterPrefetch1 =
+ new NeverReady<>(PrefetchableIterators.fromArray("A", "B"));
+ ReadyAfterPrefetch<String> readyAfterPrefetch2 =
+ new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
+ ReadyAfterPrefetch<String> readyAfterPrefetch3 =
+ new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
PrefetchableIterator<String> iterator =
PrefetchableIterators.concat(readyAfterPrefetch1, readyAfterPrefetch2, readyAfterPrefetch3);
// Expect no prefetches yet
- assertEquals(0, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(0, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
// We expect to attempt to prefetch for the first time.
iterator.prefetch();
- assertEquals(1, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(1, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// We expect to attempt to prefetch again since we aren't ready.
iterator.prefetch();
- assertEquals(2, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(2, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// The current iterator is done but is never ready so we can't advance to the next one and
// expect another prefetch to go to the current iterator.
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// Now that we know the last iterator is done and have advanced to the next one we expect
// prefetch to go through
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(1, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// The last iterator is done so we should be able to prefetch the next one before advancing
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(1, readyAfterPrefetch2.prefetchCalled);
- assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// The current iterator is ready so no additional prefetch is necessary
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(1, readyAfterPrefetch2.prefetchCalled);
- assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
}
diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle
index 3c859ae..6337cd4 100644
--- a/sdks/java/harness/build.gradle
+++ b/sdks/java/harness/build.gradle
@@ -72,6 +72,7 @@
testCompile library.java.mockito_core
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
testCompile project(":runners:core-construction-java")
+ testCompile project(path: ":sdks:java:fn-execution", configuration: "testRuntime")
shadowTestRuntimeClasspath library.java.slf4j_jdk14
jmhCompile project(path: ":sdks:java:harness", configuration: "shadowTest")
jmhRuntime library.java.slf4j_jdk14
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index 5ddf0ae..777036a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -26,6 +26,8 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -49,7 +51,7 @@
private final BeamFnStateClient beamFnStateClient;
private final StateRequest request;
private final Coder<T> valueCoder;
- private Iterable<T> oldValues;
+ private PrefetchableIterable<T> oldValues;
private ArrayList<T> newValues;
private boolean isClosed;
@@ -80,19 +82,19 @@
this.newValues = new ArrayList<>();
}
- public Iterable<T> get() {
+ public PrefetchableIterable<T> get() {
checkState(
!isClosed,
"Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
if (oldValues == null) {
// If we were cleared we should disregard old values.
- return Iterables.limit(Collections.unmodifiableList(newValues), newValues.size());
+ return PrefetchableIterables.limit(Collections.unmodifiableList(newValues), newValues.size());
} else if (newValues.isEmpty()) {
// If we have no new values then just return the old values.
return oldValues;
}
- return Iterables.concat(
+ return PrefetchableIterables.concat(
oldValues, Iterables.limit(Collections.unmodifiableList(newValues), newValues.size()));
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 2f2789e..5a931c5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -38,7 +38,6 @@
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.StateBinder;
import org.apache.beam.sdk.state.StateContext;
@@ -264,7 +263,7 @@
@Override
public ValueState<T> readLater() {
- // TODO(BEAM-12802): Support prefetching.
+ impl.get().iterator().prefetch();
return this;
}
};
@@ -310,7 +309,7 @@
@Override
public BagState<T> readLater() {
- // TODO(BEAM-12802): Support prefetching.
+ impl.get().iterator().prefetch();
return this;
}
@@ -391,6 +390,7 @@
@Override
public CombiningState<ElementT, AccumT, ResultT> readLater() {
+ impl.get().iterator().prefetch();
return this;
}
@@ -412,7 +412,18 @@
@Override
public ReadableState<Boolean> isEmpty() {
- return ReadableStates.immediate(!impl.get().iterator().hasNext());
+ return new ReadableState<Boolean>() {
+ @Override
+ public @Nullable Boolean read() {
+ return !impl.get().iterator().hasNext();
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ impl.get().iterator().prefetch();
+ return this;
+ }
+ };
}
@Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index cfc76cf..7828f93 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -21,6 +21,9 @@
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
+import java.util.Objects;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -28,30 +31,42 @@
* Converts an iterator to an iterable lazily loading values from the underlying iterator and
* caching them to support reiteration.
*/
-@SuppressWarnings({
- "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
-class LazyCachingIteratorToIterable<T> implements Iterable<T> {
+class LazyCachingIteratorToIterable<T> implements PrefetchableIterable<T> {
private final List<T> cachedElements;
- private final Iterator<T> iterator;
+ private final PrefetchableIterator<T> iterator;
- public LazyCachingIteratorToIterable(Iterator<T> iterator) {
+ public LazyCachingIteratorToIterable(PrefetchableIterator<T> iterator) {
this.cachedElements = new ArrayList<>();
this.iterator = iterator;
}
@Override
- public Iterator<T> iterator() {
+ public PrefetchableIterator<T> iterator() {
return new CachingIterator();
}
/** An {@link Iterator} which adds and fetched values into the cached elements list. */
- private class CachingIterator implements Iterator<T> {
+ private class CachingIterator implements PrefetchableIterator<T> {
private int position = 0;
private CachingIterator() {}
@Override
+ public boolean isReady() {
+ if (position < cachedElements.size()) {
+ return true;
+ }
+ return iterator.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ if (!isReady()) {
+ iterator.prefetch();
+ }
+ }
+
+ @Override
public boolean hasNext() {
// The order of the short circuit is important below.
return position < cachedElements.size() || iterator.hasNext();
@@ -76,7 +91,7 @@
@Override
public int hashCode() {
- return iterator.hasNext() ? iterator.next().hashCode() : -1789023489;
+ return iterator.hasNext() ? Objects.hashCode(iterator.next()) : -1789023489;
}
@Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 22be306..1026ba5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -17,20 +17,21 @@
*/
package org.apache.beam.fn.harness.state;
-import java.util.Collections;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
-import java.util.function.Supplier;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.fn.stream.DataStreams;
+import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
/**
* Adapters which convert a a logical series of chunks using continuation tokens over the Beam Fn
@@ -54,7 +55,7 @@
* only) chunk of a state stream. This state request will be populated with a continuation
* token to request further chunks of the stream if required.
*/
- public static Iterator<ByteString> readAllStartingFrom(
+ public static PrefetchableIterator<ByteString> readAllStartingFrom(
BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
return new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk);
}
@@ -74,94 +75,142 @@
* token to request further chunks of the stream if required.
* @param valueCoder A coder for decoding the state stream.
*/
- public static <T> Iterable<T> readAllAndDecodeStartingFrom(
+ public static <T> PrefetchableIterable<T> readAllAndDecodeStartingFrom(
BeamFnStateClient beamFnStateClient,
StateRequest stateRequestForFirstChunk,
Coder<T> valueCoder) {
- FirstPageAndRemainder firstPageAndRemainder =
- new FirstPageAndRemainder(beamFnStateClient, stateRequestForFirstChunk);
- return Iterables.concat(
- new LazyCachingIteratorToIterable<T>(
- new DataStreams.DataStreamDecoder<>(
- valueCoder, new LazySingletonIterator<>(firstPageAndRemainder::firstPage))),
- () -> new DataStreams.DataStreamDecoder<>(valueCoder, firstPageAndRemainder.remainder()));
- }
-
- /** A iterable that contains a single element, provided by a Supplier which is invoked lazily. */
- static class LazySingletonIterator<T> implements Iterator<T> {
-
- private final Supplier<T> supplier;
- private boolean hasNext;
-
- private LazySingletonIterator(Supplier<T> supplier) {
- this.supplier = supplier;
- hasNext = true;
- }
-
- @Override
- public boolean hasNext() {
- return hasNext;
- }
-
- @Override
- public T next() {
- hasNext = false;
- return supplier.get();
- }
+ return new FirstPageAndRemainder<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder);
}
/**
- * An helper class that (lazily) gives the first page of a paginated state request separately from
+ * A helper class that (lazily) gives the first page of a paginated state request separately from
* all the remaining pages.
*/
- static class FirstPageAndRemainder {
+ @VisibleForTesting
+ static class FirstPageAndRemainder<T> implements PrefetchableIterable<T> {
private final BeamFnStateClient beamFnStateClient;
private final StateRequest stateRequestForFirstChunk;
- private ByteString firstPage = null;
+ private final Coder<T> valueCoder;
+ private LazyCachingIteratorToIterable<T> firstPage;
+ private CompletableFuture<StateResponse> firstPageResponseFuture;
private ByteString continuationToken;
- private FirstPageAndRemainder(
- BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
+ FirstPageAndRemainder(
+ BeamFnStateClient beamFnStateClient,
+ StateRequest stateRequestForFirstChunk,
+ Coder<T> valueCoder) {
this.beamFnStateClient = beamFnStateClient;
this.stateRequestForFirstChunk = stateRequestForFirstChunk;
+ this.valueCoder = valueCoder;
}
- public ByteString firstPage() {
- if (firstPage == null) {
- CompletableFuture<StateResponse> stateResponseFuture = new CompletableFuture<>();
+ @Override
+ public PrefetchableIterator<T> iterator() {
+ return new PrefetchableIterator<T>() {
+ PrefetchableIterator<T> delegate;
+
+ private void ensureDelegateExists() {
+ if (delegate == null) {
+ // Fetch the first page if necessary
+ prefetchFirstPage();
+ if (firstPage == null) {
+ StateResponse stateResponse;
+ try {
+ stateResponse = firstPageResponseFuture.get();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException(e);
+ } catch (ExecutionException e) {
+ if (e.getCause() == null) {
+ throw new IllegalStateException(e);
+ }
+ Throwables.throwIfUnchecked(e.getCause());
+ throw new IllegalStateException(e.getCause());
+ }
+ continuationToken = stateResponse.getGet().getContinuationToken();
+ firstPage =
+ new LazyCachingIteratorToIterable<>(
+ new DataStreamDecoder<>(
+ valueCoder,
+ PrefetchableIterators.fromArray(stateResponse.getGet().getData())));
+ }
+
+ if (ByteString.EMPTY.equals((continuationToken))) {
+ delegate = firstPage.iterator();
+ } else {
+ delegate =
+ PrefetchableIterators.concat(
+ firstPage.iterator(),
+ new DataStreamDecoder<>(
+ valueCoder,
+ new LazyBlockingStateFetchingIterator(
+ beamFnStateClient,
+ stateRequestForFirstChunk
+ .toBuilder()
+ .setGet(
+ StateGetRequest.newBuilder()
+ .setContinuationToken(continuationToken))
+ .build())));
+ }
+ }
+ }
+
+ @Override
+ public boolean isReady() {
+ if (delegate == null) {
+ if (firstPageResponseFuture != null) {
+ return firstPageResponseFuture.isDone();
+ }
+ return false;
+ }
+ return delegate.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ if (firstPageResponseFuture == null) {
+ prefetchFirstPage();
+ } else if (delegate != null && !delegate.isReady()) {
+ delegate.prefetch();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (delegate == null) {
+ // Ensure that we prefetch the second page after the first has been accessed.
+ // Prefetching subsequent pages after the first will be handled by the
+ // LazyBlockingStateFetchingIterator
+ ensureDelegateExists();
+ boolean rval = delegate.hasNext();
+ delegate.prefetch();
+ return rval;
+ }
+ return delegate.hasNext();
+ }
+
+ @Override
+ public T next() {
+ if (delegate == null) {
+ // Ensure that we prefetch the second page after the first has been accessed.
+ // Prefetching subsequent pages after the first will be handled by the
+ // LazyBlockingStateFetchingIterator
+ ensureDelegateExists();
+ T rval = delegate.next();
+ delegate.prefetch();
+ return rval;
+ }
+ return delegate.next();
+ }
+ };
+ }
+
+ private void prefetchFirstPage() {
+ if (firstPageResponseFuture == null) {
+ firstPageResponseFuture = new CompletableFuture<>();
beamFnStateClient.handle(
stateRequestForFirstChunk.toBuilder().setGet(stateRequestForFirstChunk.getGet()),
- stateResponseFuture);
- StateResponse stateResponse;
- try {
- stateResponse = stateResponseFuture.get();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new IllegalStateException(e);
- } catch (ExecutionException e) {
- if (e.getCause() == null) {
- throw new IllegalStateException(e);
- }
- Throwables.throwIfUnchecked(e.getCause());
- throw new IllegalStateException(e.getCause());
- }
- continuationToken = stateResponse.getGet().getContinuationToken();
- firstPage = stateResponse.getGet().getData();
- }
- return firstPage;
- }
-
- public Iterator<ByteString> remainder() {
- firstPage();
- if (ByteString.EMPTY.equals(continuationToken)) {
- return Collections.emptyIterator();
- } else {
- return new LazyBlockingStateFetchingIterator(
- beamFnStateClient,
- stateRequestForFirstChunk
- .toBuilder()
- .setGet(StateGetRequest.newBuilder().setContinuationToken(continuationToken))
- .build());
+ firstPageResponseFuture);
}
}
}
@@ -169,10 +218,11 @@
/**
* An {@link Iterator} which fetches {@link ByteString} chunks using the State API.
*
- * <p>This iterator will only request a chunk on first access. Subsiquently it eagerly pre-fetches
- * one future chunks at a time.
+ * <p>This iterator will only request a chunk on first access. Subsequently it eagerly pre-fetches
+ * one future chunk at a time.
*/
- static class LazyBlockingStateFetchingIterator implements Iterator<ByteString> {
+ @VisibleForTesting
+ static class LazyBlockingStateFetchingIterator implements PrefetchableIterator<ByteString> {
private enum State {
READ_REQUIRED,
@@ -195,8 +245,17 @@
this.continuationToken = stateRequestForFirstChunk.getGet().getContinuationToken();
}
- private void prefetch() {
- if (prefetchedResponse == null && currentState == State.READ_REQUIRED) {
+ @Override
+ public boolean isReady() {
+ if (prefetchedResponse == null) {
+ return currentState != State.READ_REQUIRED;
+ }
+ return prefetchedResponse.isDone();
+ }
+
+ @Override
+ public void prefetch() {
+ if (currentState == State.READ_REQUIRED && prefetchedResponse == null) {
prefetchedResponse = new CompletableFuture<>();
beamFnStateClient.handle(
stateRequestForFirstChunk
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
index 7597128..0914b01 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
@@ -25,8 +25,11 @@
import java.util.Iterator;
import java.util.NoSuchElementException;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
+import org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetch;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -40,7 +43,8 @@
@Test
public void testEmptyIterator() {
- Iterable<Object> iterable = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+ Iterable<Object> iterable =
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.emptyIterator());
assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
// iterate multiple times
assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
@@ -52,7 +56,7 @@
@Test
public void testInterleavedIteration() {
Iterable<String> iterable =
- new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
Iterator<String> iterator1 = iterable.iterator();
assertTrue(iterator1.hasNext());
@@ -77,14 +81,45 @@
@Test
public void testEqualsAndHashCode() {
- Iterable<String> iterA = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
- Iterable<String> iterB = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
- Iterable<String> iterC = new LazyCachingIteratorToIterable<>(Iterators.forArray());
- Iterable<String> iterD = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+ Iterable<String> iterA =
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+ Iterable<String> iterB =
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+ Iterable<String> iterC = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
+ Iterable<String> iterD = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
assertEquals(iterA, iterB);
assertEquals(iterC, iterD);
assertNotEquals(iterA, iterC);
assertEquals(iterA.hashCode(), iterB.hashCode());
assertEquals(iterC.hashCode(), iterD.hashCode());
}
+
+ @Test
+ public void testPrefetch() {
+ ReadyAfterPrefetch<String> underlying =
+ new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B", "C"));
+ PrefetchableIterable<String> iterable = new LazyCachingIteratorToIterable<>(underlying);
+ PrefetchableIterator<String> iterator1 = iterable.iterator();
+ PrefetchableIterator<String> iterator2 = iterable.iterator();
+
+ // Check that the lazy iterable doesn't do any prefetch/access on instantiation
+ assertFalse(underlying.isReady());
+ assertFalse(iterator1.isReady());
+ assertFalse(iterator2.isReady());
+
+ // Check that if both iterators prefetch there is only one prefetch for the underlying iterator
+ // iterator.
+ iterator1.prefetch();
+ assertEquals(1, underlying.getNumPrefetchCalls());
+ iterator2.prefetch();
+ assertEquals(1, underlying.getNumPrefetchCalls());
+
+ // Check that if that one iterator has advanced, the second doesn't perform any prefetch since
+ // the element is now cached.
+ iterator1.next();
+ iterator1.next();
+ iterator2.next();
+ iterator2.prefetch();
+ assertEquals(1, underlying.getNumPrefetchCalls());
+ }
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
index fc729cc..384d2df 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
@@ -19,12 +19,16 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Iterator;
import java.util.List;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
+import org.apache.beam.fn.harness.state.StateFetchingIterators.FirstPageAndRemainder;
import org.apache.beam.fn.harness.state.StateFetchingIterators.LazyBlockingStateFetchingIterator;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
@@ -32,16 +36,56 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link StateFetchingIterators}. */
+@RunWith(Enclosed.class)
public class StateFetchingIteratorsTest {
+
+ private static BeamFnStateClient fakeStateClient(
+ AtomicInteger callCount, ByteString... expected) {
+ return (requestBuilder, response) -> {
+ callCount.incrementAndGet();
+ if (expected.length == 0) {
+ response.complete(
+ StateResponse.newBuilder()
+ .setId(requestBuilder.getId())
+ .setGet(StateGetResponse.newBuilder())
+ .build());
+ return;
+ }
+
+ ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
+
+ int requestedPosition = 0; // Default position is 0
+ if (!ByteString.EMPTY.equals(continuationToken)) {
+ requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
+ }
+
+ // Compute the new continuation token
+ ByteString newContinuationToken = ByteString.EMPTY;
+ if (requestedPosition != expected.length - 1) {
+ newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
+ }
+ response.complete(
+ StateResponse.newBuilder()
+ .setId(requestBuilder.getId())
+ .setGet(
+ StateGetResponse.newBuilder()
+ .setData(expected[requestedPosition])
+ .setContinuationToken(newContinuationToken))
+ .build());
+ };
+ }
+
/** Tests for {@link StateFetchingIterators.LazyBlockingStateFetchingIterator}. */
@RunWith(JUnit4.class)
public static class LazyBlockingStateFetchingIteratorTest {
@@ -77,49 +121,55 @@
ByteString.EMPTY);
}
- private BeamFnStateClient fakeStateClient(AtomicInteger callCount, ByteString... expected) {
- return (requestBuilder, response) -> {
- callCount.incrementAndGet();
- if (expected.length == 0) {
- response.complete(
- StateResponse.newBuilder()
- .setId(requestBuilder.getId())
- .setGet(StateGetResponse.newBuilder())
- .build());
- return;
- }
-
- ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
-
- int requestedPosition = 0; // Default position is 0
- if (!ByteString.EMPTY.equals(continuationToken)) {
- requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
- }
-
- // Compute the new continuation token
- ByteString newContinuationToken = ByteString.EMPTY;
- if (requestedPosition != expected.length - 1) {
- newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
- }
- response.complete(
- StateResponse.newBuilder()
- .setId(requestBuilder.getId())
- .setGet(
- StateGetResponse.newBuilder()
- .setData(expected[requestedPosition])
- .setContinuationToken(newContinuationToken))
- .build());
- };
+ @Test
+ public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception {
+ AtomicInteger callCount = new AtomicInteger();
+ BeamFnStateClient fakeStateClient =
+ new BeamFnStateClient() {
+ @Override
+ public void handle(
+ StateRequest.Builder requestBuilder, CompletableFuture<StateResponse> response) {
+ callCount.incrementAndGet();
+ }
+ };
+ PrefetchableIterator<ByteString> byteStrings =
+ new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
+ assertEquals(0, callCount.get());
+ byteStrings.prefetch();
+ assertEquals(1, callCount.get()); // first prefetch
+ byteStrings.prefetch();
+ assertEquals(1, callCount.get()); // subsequent is ignored
}
private void testFetch(ByteString... expected) {
AtomicInteger callCount = new AtomicInteger();
BeamFnStateClient fakeStateClient = fakeStateClient(callCount, expected);
- Iterator<ByteString> byteStrings =
+ PrefetchableIterator<ByteString> byteStrings =
new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
assertEquals(0, callCount.get()); // Ensure it's fully lazy.
- assertArrayEquals(expected, Iterators.toArray(byteStrings, Object.class));
+ assertFalse(byteStrings.isReady());
+
+ // Prefetch every second element in the iterator capturing the results
+ List<ByteString> results = new ArrayList<>();
+ for (int i = 0; i < expected.length; ++i) {
+ if (i % 2 == 0) {
+ // Ensure that prefetch performs the call
+ byteStrings.prefetch();
+ assertEquals(i + 1, callCount.get());
+ assertTrue(byteStrings.isReady());
+ }
+ assertTrue(byteStrings.hasNext());
+ results.add(byteStrings.next());
+ }
+ assertFalse(byteStrings.hasNext());
+ assertTrue(byteStrings.isReady());
+
+ assertEquals(Arrays.asList(expected), results);
}
+ }
+
+ @RunWith(JUnit4.class)
+ public static class FirstPageAndRemainderTest {
@Test
public void testEmptyValues() throws Exception {
@@ -133,7 +183,7 @@
@Test
public void testManyValues() throws Exception {
- testFetchValues(VarIntCoder.of(), 11, 37, 389, 5077);
+ testFetchValues(VarIntCoder.of(), 1, 22, 333, 4444, 55555, 666666);
}
private <T> void testFetchValues(Coder<T> coder, T... expected) {
@@ -153,35 +203,42 @@
AtomicInteger callCount = new AtomicInteger();
BeamFnStateClient fakeStateClient =
fakeStateClient(callCount, Iterables.toArray(byteStrings, ByteString.class));
- Iterable<T> values =
- StateFetchingIterators.readAllAndDecodeStartingFrom(
- fakeStateClient, StateRequest.getDefaultInstance(), coder);
+ PrefetchableIterable<T> values =
+ new FirstPageAndRemainder<>(fakeStateClient, StateRequest.getDefaultInstance(), coder);
// Ensure it's fully lazy.
assertEquals(0, callCount.get());
- Iterator<T> valuesIter = values.iterator();
+ PrefetchableIterator<T> valuesIter = values.iterator();
+ assertFalse(valuesIter.isReady());
assertEquals(0, callCount.get());
- // No more is read than necessary.
- if (valuesIter.hasNext()) {
- valuesIter.next();
- }
+ // Ensure that the first page result is cached across multiple iterators and subsequent
+ // iterators are ready and prefetch does nothing
+ valuesIter.prefetch();
+ assertTrue(valuesIter.isReady());
assertEquals(1, callCount.get());
- // The first page is cached.
- Iterator<T> valuesIter2 = values.iterator();
- assertEquals(1, callCount.get());
- if (valuesIter2.hasNext()) {
- valuesIter2.next();
- }
+ PrefetchableIterator<T> valuesIter2 = values.iterator();
+ assertTrue(valuesIter2.isReady());
+ valuesIter2.prefetch();
assertEquals(1, callCount.get());
- if (valuesIter.hasNext()) {
- valuesIter.next();
- // Subsequent pages are pre-fetched, so after accessing the second page,
- // the third should be requested.
- assertEquals(3, callCount.get());
+ // Prefetch every second element in the iterator capturing the results
+ List<T> results = new ArrayList<>();
+ for (int i = 0; i < expected.length; ++i) {
+ if (i % 2 == 1) {
+ // Ensure that prefetch performs the call
+ valuesIter2.prefetch();
+ assertTrue(valuesIter2.isReady());
+ // Note that this is i+2 because we expect to prefetch the page after the current one
+ // We also have to bound it to the max number of pages
+ assertEquals(Math.min(i + 2, expected.length), callCount.get());
+ }
+ assertTrue(valuesIter2.hasNext());
+ results.add(valuesIter2.next());
}
+ assertFalse(valuesIter2.hasNext());
+ assertTrue(valuesIter2.isReady());
// The contents agree.
assertArrayEquals(expected, Iterables.toArray(values, Object.class));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
index dd5202e..74d6636 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
@@ -1218,6 +1218,10 @@
List<Cursor> cursors = new ArrayList<>(partitionQueryResponse.getPartitionsList());
cursors.sort(CURSOR_REFERENCE_VALUE_COMPARATOR);
final int size = cursors.size();
+ if (size == 0) {
+ emit(c, dbRoot, structuredQuery.toBuilder());
+ return;
+ }
final int lastIdx = size - 1;
for (int i = 0; i < size; i++) {
Cursor curr = cursors.get(i);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
index 0c9bbf1..1f29883 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
@@ -99,6 +99,41 @@
assertEquals(expected, allValues);
}
+ @Test
+ public void endToEnd_emptyCursors() throws Exception {
+ // First page of the response
+ PartitionQueryRequest request1 =
+ PartitionQueryRequest.newBuilder()
+ .setParent(String.format("projects/%s/databases/(default)/document", projectId))
+ .build();
+ PartitionQueryResponse response1 = PartitionQueryResponse.newBuilder().build();
+ when(callable.call(request1)).thenReturn(pagedResponse1);
+ when(page1.getResponse()).thenReturn(response1);
+ when(pagedResponse1.iteratePages()).thenReturn(ImmutableList.of(page1));
+
+ when(stub.partitionQueryPagedCallable()).thenReturn(callable);
+
+ when(ff.getFirestoreStub(any())).thenReturn(stub);
+ RpcQosOptions options = RpcQosOptions.defaultOptions();
+ when(ff.getRpcQos(any()))
+ .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options));
+
+ ArgumentCaptor<PartitionQueryPair> responses =
+ ArgumentCaptor.forClass(PartitionQueryPair.class);
+
+ doNothing().when(processContext).output(responses.capture());
+
+ when(processContext.element()).thenReturn(request1);
+
+ PartitionQueryFn fn = new PartitionQueryFn(clock, ff, options);
+
+ runFunction(fn);
+
+ List<PartitionQueryPair> expected = newArrayList(new PartitionQueryPair(request1, response1));
+ List<PartitionQueryPair> allValues = responses.getAllValues();
+ assertEquals(expected, allValues);
+ }
+
@Override
public void resumeFromLastReadValue() throws Exception {
when(ff.getFirestoreStub(any())).thenReturn(stub);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
index 25ed63c..ed789da 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
@@ -121,6 +121,39 @@
assertEquals(expectedQueries, actualQueries);
}
+ @Test
+ public void ensureCursorPairingWorks_emptyCursorsInResponse() {
+ StructuredQuery query =
+ StructuredQuery.newBuilder()
+ .addFrom(
+ CollectionSelector.newBuilder()
+ .setAllDescendants(true)
+ .setCollectionId("c1")
+ .build())
+ .build();
+
+ List<StructuredQuery> expectedQueries = newArrayList(query);
+
+ PartitionQueryPair partitionQueryPair =
+ new PartitionQueryPair(
+ PartitionQueryRequest.newBuilder().setStructuredQuery(query).build(),
+ PartitionQueryResponse.newBuilder().build());
+
+ ArgumentCaptor<RunQueryRequest> captor = ArgumentCaptor.forClass(RunQueryRequest.class);
+ when(processContext.element()).thenReturn(partitionQueryPair);
+ doNothing().when(processContext).output(captor.capture());
+
+ PartitionQueryResponseToRunQueryRequest fn = new PartitionQueryResponseToRunQueryRequest();
+ fn.processElement(processContext);
+
+ List<StructuredQuery> actualQueries =
+ captor.getAllValues().stream()
+ .map(RunQueryRequest::getStructuredQuery)
+ .collect(Collectors.toList());
+
+ assertEquals(expectedQueries, actualQueries);
+ }
+
private static Cursor referenceValueCursor(String referenceValue) {
return Cursor.newBuilder()
.addValues(Value.newBuilder().setReferenceValue(referenceValue).build())
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 1c02f71..98d2faa 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -185,9 +185,12 @@
@ptransform.PTransform.register_urn(TEST_COMPK_URN, None)
class CombinePerKeyTransform(ptransform.PTransform):
def expand(self, pcoll):
- return pcoll \
- | beam.CombinePerKey(sum).with_output_types(
- typing.Tuple[str, int])
+ output = pcoll \
+ | beam.CombinePerKey(sum)
+ # TODO: Use `with_output_types` instead of explicitly
+ # assigning to `.element_type` after fixing BEAM-12872
+ output.element_type = beam.typehints.Tuple[str, int]
+ return output
def to_runner_api_parameter(self, unused_context):
return TEST_COMPK_URN, None
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index dfe45c4..41ad3df 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -21,6 +21,7 @@
import copy
import heapq
+import itertools
import operator
import random
from typing import Any
@@ -606,16 +607,25 @@
class _TupleCombineFnBase(core.CombineFn):
- def __init__(self, *combiners):
+ def __init__(self, *combiners, merge_accumulators_batch_size=None):
self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners]
self._named_combiners = combiners
+ # If the `merge_accumulators_batch_size` value is not specified, we chose a
+ # bounded default that is inversely proportional to the number of
+ # accumulators in merged tuples.
+ num_combiners = max(1, len(combiners))
+ self._merge_accumulators_batch_size = (
+ merge_accumulators_batch_size or max(10, 1000 // num_combiners))
def display_data(self):
combiners = [
c.__name__ if hasattr(c, '__name__') else c.__class__.__name__
for c in self._named_combiners
]
- return {'combiners': str(combiners)}
+ return {
+ 'combiners': str(combiners),
+ 'merge_accumulators_batch_size': self._merge_accumulators_batch_size
+ }
def setup(self, *args, **kwargs):
for c in self._combiners:
@@ -625,10 +635,22 @@
return [c.create_accumulator(*args, **kwargs) for c in self._combiners]
def merge_accumulators(self, accumulators, *args, **kwargs):
- return [
- c.merge_accumulators(a, *args, **kwargs) for c,
- a in zip(self._combiners, zip(*accumulators))
- ]
+ # Make sure that `accumulators` is an iterator (so that the position is
+ # remembered).
+ accumulators = iter(accumulators)
+ result = next(accumulators)
+ while True:
+ # Load accumulators into memory and merge in batches to decrease peak
+ # memory usage.
+ accumulators_batch = [result] + list(
+ itertools.islice(accumulators, self._merge_accumulators_batch_size))
+ if len(accumulators_batch) == 1:
+ break
+ result = [
+ c.merge_accumulators(a, *args, **kwargs) for c,
+ a in zip(self._combiners, zip(*accumulators_batch))
+ ]
+ return result
def compact(self, accumulator, *args, **kwargs):
return [
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index d826287..7e0e835 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -249,7 +249,8 @@
dd = DisplayData.create_from(transform)
expected_items = [
DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
- DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']")
+ DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
+ DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
@@ -358,6 +359,49 @@
max).with_common_input()).without_defaults())
assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
+ def test_empty_tuple_combine_fn(self):
+ with TestPipeline() as p:
+ result = (
+ p
+ | Create([(), (), ()])
+ | beam.CombineGlobally(combine.TupleCombineFn()))
+ assert_that(result, equal_to([()]))
+
+ def test_tuple_combine_fn_batched_merge(self):
+ num_combine_fns = 10
+ max_num_accumulators_in_memory = 30
+ # Maximum number of accumulator tuples in memory - 1 for the merge result.
+ merge_accumulators_batch_size = (
+ max_num_accumulators_in_memory // num_combine_fns - 1)
+ num_accumulator_tuples_to_merge = 20
+
+ class CountedAccumulator:
+ count = 0
+ oom = False
+
+ def __init__(self):
+ if CountedAccumulator.count > max_num_accumulators_in_memory:
+ CountedAccumulator.oom = True
+ else:
+ CountedAccumulator.count += 1
+
+ class CountedAccumulatorCombineFn(beam.CombineFn):
+ def create_accumulator(self):
+ return CountedAccumulator()
+
+ def merge_accumulators(self, accumulators):
+ CountedAccumulator.count += 1
+ for _ in accumulators:
+ CountedAccumulator.count -= 1
+
+ combine_fn = combine.TupleCombineFn(
+ *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
+ merge_accumulators_batch_size=merge_accumulators_batch_size)
+ combine_fn.merge_accumulators(
+ combine_fn.create_accumulator()
+ for _ in range(num_accumulator_tuples_to_merge))
+ assert not CountedAccumulator.oom
+
def test_to_list_and_to_dict1(self):
with TestPipeline() as pipeline:
the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]