Merge pull request #15495 [BEAM-12689] fixed broken Python tab on HCatalog IO page
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 4cdf4ae..67fe8b8 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -446,7 +446,7 @@
def errorprone_version = "2.3.4"
def google_clients_version = "1.32.1"
def google_cloud_bigdataoss_version = "2.2.2"
- def google_cloud_pubsublite_version = "0.13.2"
+ def google_cloud_pubsublite_version = "1.0.4"
def google_code_gson_version = "2.8.6"
def google_oauth_clients_version = "1.31.0"
// Try to keep grpc_version consistent with gRPC version in google_cloud_platform_libraries_bom
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index d532f53..95693e3 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -709,6 +709,15 @@
bytes window = 3;
}
+ // Represents a request for an unordered set of values associated with a
+ // specified user key and window for a PTransform. See
+ // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
+ // details.
+ //
+ // The response data stream will be a concatenation of all V's associated
+ // with the specified user key and window.
+ // See https://s.apache.org/beam-fn-api-send-and-receive-data for further
+ // details.
message BagUserState {
// (Required) The id of the PTransform containing user state.
string transform_id = 1;
@@ -721,6 +730,53 @@
bytes key = 4;
}
+ // Represents a request for the keys of a multimap associated with a specified
+ // user key and window for a PTransform. See
+ // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
+ // details.
+ //
+ // Can only be used to perform StateGetRequests and StateClearRequests on the
+ // user state.
+ //
+ // The response data stream will be a concatenation of all K's associated
+ // with the specified user key and window.
+ // See https://s.apache.org/beam-fn-api-send-and-receive-data for further
+ // details.
+ message MultimapKeysUserState {
+ // (Required) The id of the PTransform containing user state.
+ string transform_id = 1;
+ // (Required) The id of the user state.
+ string user_state_id = 2;
+ // (Required) The window encoded in a nested context.
+ bytes window = 3;
+ // (Required) The key of the currently executing element encoded in a
+ // nested context.
+ bytes key = 4;
+ }
+
+ // Represents a request for the values of the map key associated with a
+ // specified user key and window for a PTransform. See
+ // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
+ // details.
+ //
+ // The response data stream will be a concatenation of all V's associated
+ // with the specified map key, user key, and window.
+ // See https://s.apache.org/beam-fn-api-send-and-receive-data for further
+ // details.
+ message MultimapUserState {
+ // (Required) The id of the PTransform containing user state.
+ string transform_id = 1;
+ // (Required) The id of the user state.
+ string user_state_id = 2;
+ // (Required) The window encoded in a nested context.
+ bytes window = 3;
+ // (Required) The key of the currently executing element encoded in a
+ // nested context.
+ bytes key = 4;
+ // (Required) The map key encoded in a nested context.
+ bytes map_key = 5;
+ }
+
// (Required) One of the following state keys must be set.
oneof type {
Runner runner = 1;
@@ -728,7 +784,8 @@
BagUserState bag_user_state = 3;
IterableSideInput iterable_side_input = 4;
MultimapKeysSideInput multimap_keys_side_input = 5;
- // TODO: represent a state key for user map state
+ MultimapKeysUserState multimap_keys_user_state = 6;
+ MultimapUserState multimap_user_state = 7;
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java
index cb54c60..19670f4 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java
@@ -75,6 +75,8 @@
"org.apache.beam.sdk.extensions.gcp.util.RetryHttpRequestInitializer$LoggingHttpBackOffHandler";
protected static final String BIGQUERY_STREAMING_INSERT_THROTTLE_TIME_NAMESPACE =
"org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl";
+ protected static final String BIGQUERY_READ_THROTTLE_TIME_NAMESPACE =
+ "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$StorageClientImpl";
protected static final String THROTTLE_TIME_COUNTER_NAME = "throttling-msecs";
private BatchModeExecutionContext(
@@ -555,6 +557,13 @@
totalThrottleMsecs += bigqueryStreamingInsertThrottleTime.getCumulative();
}
+ CounterCell bigqueryReadThrottleTime =
+ container.tryGetCounter(
+ MetricName.named(BIGQUERY_READ_THROTTLE_TIME_NAMESPACE, THROTTLE_TIME_COUNTER_NAME));
+ if (bigqueryReadThrottleTime != null) {
+ totalThrottleMsecs += bigqueryReadThrottleTime.getCumulative();
+ }
+
CounterCell throttlingMsecs =
container.tryGetCounter(DataflowSystemMetrics.THROTTLING_MSECS_METRIC_NAME);
if (throttlingMsecs != null) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 12b7df2..17b59ec 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -794,7 +794,8 @@
final int numIters = 2000;
for (int i = 0; i < numIters; ++i) {
- server.addWorkToOffer(makeInput(i, 0, "key", DEFAULT_SHARDING_KEY));
+ server.addWorkToOffer(
+ makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), "key", DEFAULT_SHARDING_KEY));
}
Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
@@ -829,7 +830,8 @@
final int numIters = 2000;
for (int i = 0; i < numIters; ++i) {
- server.addWorkToOffer(makeInput(i, 0, "key", DEFAULT_SHARDING_KEY));
+ server.addWorkToOffer(
+ makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), "key", DEFAULT_SHARDING_KEY));
}
Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
diff --git a/sdks/go/examples/cookbook/combine/combine.go b/sdks/go/examples/cookbook/combine/combine.go
index 205f515..33023b2 100644
--- a/sdks/go/examples/cookbook/combine/combine.go
+++ b/sdks/go/examples/cookbook/combine/combine.go
@@ -34,7 +34,8 @@
input = flag.String("input", "publicdata:samples.shakespeare", "Shakespeare plays BQ table.")
output = flag.String("output", "", "Output BQ table.")
- minLength = flag.Int("min_length", 9, "Minimum word length")
+ minLength = flag.Int("min_length", 9, "Minimum word length")
+ small_words = beam.NewCounter("extract", "small_words")
)
func init() {
@@ -67,11 +68,12 @@
MinLength int `json:"min_length"`
}
-func (f *extractFn) ProcessElement(row WordRow, emit func(string, string)) {
+func (f *extractFn) ProcessElement(ctx context.Context, row WordRow, emit func(string, string)) {
if len(row.Word) >= f.MinLength {
emit(row.Word, row.Corpus)
+ } else {
+ small_words.Inc(ctx, 1)
}
- // TODO(herohde) 7/14/2017: increment counter for "small words"
}
// TODO(herohde) 7/14/2017: the choice of a string (instead of []string) for the
diff --git a/sdks/go/examples/wordcount/wordcount.go b/sdks/go/examples/wordcount/wordcount.go
index b2eb67d..08d9cb5 100644
--- a/sdks/go/examples/wordcount/wordcount.go
+++ b/sdks/go/examples/wordcount/wordcount.go
@@ -106,23 +106,33 @@
// done automatically by the starcgen code generator, or it can be done manually
// by calling beam.RegisterFunction in an init() call.
func init() {
- beam.RegisterFunction(extractFn)
beam.RegisterFunction(formatFn)
}
var (
- wordRE = regexp.MustCompile(`[a-zA-Z]+('[a-z])?`)
- empty = beam.NewCounter("extract", "emptyLines")
- lineLen = beam.NewDistribution("extract", "lineLenDistro")
+ wordRE = regexp.MustCompile(`[a-zA-Z]+('[a-z])?`)
+ empty = beam.NewCounter("extract", "emptyLines")
+ small_word_length = flag.Int("small_word_length", 9, "small_word_length")
+ small_words = beam.NewCounter("extract", "small_words")
+ lineLen = beam.NewDistribution("extract", "lineLenDistro")
)
-// extractFn is a DoFn that emits the words in a given line.
-func extractFn(ctx context.Context, line string, emit func(string)) {
+// extractFn is a DoFn that emits the words in a given line and keeps a count for small words.
+type extractFn struct {
+ SmallWordLength int `json:"min_length"`
+}
+
+func (f *extractFn) ProcessElement(ctx context.Context, line string, emit func(string)) {
lineLen.Update(ctx, int64(len(line)))
if len(strings.TrimSpace(line)) == 0 {
empty.Inc(ctx, 1)
}
for _, word := range wordRE.FindAllString(line, -1) {
+ // increment the counter for small words if length of words is
+ // less than small_word_length
+ if len(word) < f.SmallWordLength {
+ small_words.Inc(ctx, 1)
+ }
emit(word)
}
}
@@ -150,7 +160,7 @@
s = s.Scope("CountWords")
// Convert lines of text into individual words.
- col := beam.ParDo(s, extractFn, lines)
+ col := beam.ParDo(s, &extractFn{SmallWordLength: *small_word_length}, lines)
// Count the number of times each word occurs.
return stats.Count(s, col)
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
new file mode 100644
index 0000000..5496d8b
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statecache implements the state caching feature described by the
+// Beam Fn API
+//
+// The Beam State API and the intended caching behavior are described here:
+// https://docs.google.com/document/d/1BOozW0bzBuz4oHJEuZNDOHdzaV5Y56ix58Ozrqm2jFg/edit#heading=h.7ghoih5aig5m
+package statecache
+
+import (
+ "sync"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+ fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+type token string
+
+// SideInputCache stores a cache of reusable inputs for the purposes of
+// eliminating redundant calls to the runner during execution of ParDos
+// using side inputs.
+//
+// A SideInputCache should be initialized when the SDK harness is initialized,
+// creating storage for side input caching. On each ProcessBundleRequest,
+// the cache will process the list of tokens for cacheable side inputs and
+// be queried when side inputs are requested in bundle execution. Once a
+// new bundle request comes in the valid tokens will be updated and the cache
+// will be re-used. In the event that the cache reaches capacity, a random,
+// currently invalid cached object will be evicted.
+type SideInputCache struct {
+ capacity int
+ mu sync.Mutex
+ cache map[token]exec.ReusableInput
+ idsToTokens map[string]token
+ validTokens map[token]int8 // Maps tokens to active bundle counts
+ metrics CacheMetrics
+}
+
+type CacheMetrics struct {
+ Hits int64
+ Misses int64
+ Evictions int64
+ InUseEvictions int64
+}
+
+// Init makes the cache map and the map of IDs to cache tokens for the
+// SideInputCache. Should only be called once. Returns an error for
+// non-positive capacities.
+func (c *SideInputCache) Init(cap int) error {
+ if cap <= 0 {
+ return errors.Errorf("capacity must be a positive integer, got %v", cap)
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.cache = make(map[token]exec.ReusableInput, cap)
+ c.idsToTokens = make(map[string]token)
+ c.validTokens = make(map[token]int8)
+ c.capacity = cap
+ return nil
+}
+
+// SetValidTokens clears the list of valid tokens then sets new ones, also updating the mapping of
+// transform and side input IDs to cache tokens in the process. Should be called at the start of every
+// new ProcessBundleRequest. If the runner does not support caching, the passed cache token values
+// should be empty and all get/set requests will silently be no-ops.
+func (c *SideInputCache) SetValidTokens(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for _, tok := range cacheTokens {
+ // User State caching is currently not supported, so these tokens are ignored
+ if tok.GetUserState() != nil {
+ continue
+ }
+ s := tok.GetSideInput()
+ transformID := s.GetTransformId()
+ sideInputID := s.GetSideInputId()
+ t := token(tok.GetToken())
+ c.setValidToken(transformID, sideInputID, t)
+ }
+}
+
+// setValidToken adds a new valid token for a request into the SideInputCache struct
+// by mapping the transform ID and side input ID pairing to the cache token.
+func (c *SideInputCache) setValidToken(transformID, sideInputID string, tok token) {
+ idKey := transformID + sideInputID
+ c.idsToTokens[idKey] = tok
+ count, ok := c.validTokens[tok]
+ if !ok {
+ c.validTokens[tok] = 1
+ } else {
+ c.validTokens[tok] = count + 1
+ }
+}
+
+// CompleteBundle takes the cache tokens passed to set the valid tokens and decrements their
+// usage count for the purposes of maintaining a valid count of whether or not a value is
+// still in use. Should be called once ProcessBundle has completed.
+func (c *SideInputCache) CompleteBundle(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for _, tok := range cacheTokens {
+ // User State caching is currently not supported, so these tokens are ignored
+ if tok.GetUserState() != nil {
+ continue
+ }
+ t := token(tok.GetToken())
+ c.decrementTokenCount(t)
+ }
+}
+
+// decrementTokenCount decrements the validTokens entry for
+// a given token by 1. Should only be called when completing
+// a bundle.
+func (c *SideInputCache) decrementTokenCount(tok token) {
+ count := c.validTokens[tok]
+ if count == 1 {
+ delete(c.validTokens, tok)
+ } else {
+ c.validTokens[tok] = count - 1
+ }
+}
+
+func (c *SideInputCache) makeAndValidateToken(transformID, sideInputID string) (token, bool) {
+ idKey := transformID + sideInputID
+ // Check if it's a known token
+ tok, ok := c.idsToTokens[idKey]
+ if !ok {
+ return "", false
+ }
+ return tok, c.isValid(tok)
+}
+
+// QueryCache takes a transform ID and side input ID and checking if a corresponding side
+// input has been cached. A query having a bad token (e.g. one that doesn't make a known
+// token or one that makes a known but currently invalid token) is treated the same as a
+// cache miss.
+func (c *SideInputCache) QueryCache(transformID, sideInputID string) exec.ReusableInput {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+ if !ok {
+ return nil
+ }
+ // Check to see if cached
+ input, ok := c.cache[tok]
+ if !ok {
+ c.metrics.Misses++
+ return nil
+ }
+
+ c.metrics.Hits++
+ return input
+}
+
+// SetCache allows a user to place a ReusableInput materialized from the reader into the SideInputCache
+// with its corresponding transform ID and side input ID. If the IDs do not pair with a known, valid token
+// then we silently do not cache the input, as this is an indication that the runner is treating that input
+// as uncacheable.
+func (c *SideInputCache) SetCache(transformID, sideInputID string, input exec.ReusableInput) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ tok, ok := c.makeAndValidateToken(transformID, sideInputID)
+ if !ok {
+ return
+ }
+ if len(c.cache) >= c.capacity {
+ c.evictElement()
+ }
+ c.cache[tok] = input
+}
+
+func (c *SideInputCache) isValid(tok token) bool {
+ count, ok := c.validTokens[tok]
+ // If the token is not known or not in use, return false
+ return ok && count > 0
+}
+
+// evictElement randomly evicts a ReusableInput that is not currently valid from the cache.
+// It should only be called by a goroutine that obtained the lock in SetCache.
+func (c *SideInputCache) evictElement() {
+ deleted := false
+ // Select a key from the cache at random
+ for k := range c.cache {
+ // Do not evict an element if it's currently valid
+ if !c.isValid(k) {
+ delete(c.cache, k)
+ c.metrics.Evictions++
+ deleted = true
+ break
+ }
+ }
+ // Nothing is deleted if every side input is still valid. Clear
+ // out a random entry and record the in-use eviction
+ if !deleted {
+ for k := range c.cache {
+ delete(c.cache, k)
+ c.metrics.InUseEvictions++
+ break
+ }
+ }
+}
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
new file mode 100644
index 0000000..b9970c3
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statecache
+
+import (
+ "testing"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+ fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
+)
+
+// TestReusableInput implements the ReusableInput interface for the purposes
+// of testing.
+type TestReusableInput struct {
+ transformID string
+ sideInputID string
+ value interface{}
+}
+
+func makeTestReusableInput(transformID, sideInputID string, value interface{}) exec.ReusableInput {
+ return &TestReusableInput{transformID: transformID, sideInputID: sideInputID, value: value}
+}
+
+// Init is a ReusableInput interface method, this is a no-op.
+func (r *TestReusableInput) Init() error {
+ return nil
+}
+
+// Value returns the stored value in the TestReusableInput.
+func (r *TestReusableInput) Value() interface{} {
+ return r.value
+}
+
+// Reset clears the value in the TestReusableInput.
+func (r *TestReusableInput) Reset() error {
+ r.value = nil
+ return nil
+}
+
+func TestInit(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(5)
+ if err != nil {
+ t.Errorf("SideInputCache failed but should have succeeded, got %v", err)
+ }
+}
+
+func TestInit_Bad(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(0)
+ if err == nil {
+ t.Error("SideInputCache init succeeded but should have failed")
+ }
+}
+
+func TestQueryCache_EmptyCase(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+ output := s.QueryCache("side1", "transform1")
+ if output != nil {
+ t.Errorf("Cache hit when it should have missed, got %v", output)
+ }
+}
+
+func TestSetCache_UncacheableCase(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+ input := makeTestReusableInput("t1", "s1", 10)
+ s.SetCache("t1", "s1", input)
+ output := s.QueryCache("t1", "s1")
+ if output != nil {
+ t.Errorf("Cache hit when should have missed, got %v", output)
+ }
+}
+
+func TestSetCache_CacheableCase(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+ transID := "t1"
+ sideID := "s1"
+ tok := token("tok1")
+ s.setValidToken(transID, sideID, tok)
+ input := makeTestReusableInput(transID, sideID, 10)
+ s.SetCache(transID, sideID, input)
+ output := s.QueryCache(transID, sideID)
+ if output == nil {
+ t.Fatalf("call to query cache missed when should have hit")
+ }
+ val, ok := output.Value().(int)
+ if !ok {
+ t.Errorf("failed to convert value to integer, got %v", output.Value())
+ }
+ if val != 10 {
+ t.Errorf("element mismatch, expected 10, got %v", val)
+ }
+}
+
+func makeRequest(transformID, sideInputID string, t token) fnpb.ProcessBundleRequest_CacheToken {
+ var tok fnpb.ProcessBundleRequest_CacheToken
+ var wrap fnpb.ProcessBundleRequest_CacheToken_SideInput_
+ var side fnpb.ProcessBundleRequest_CacheToken_SideInput
+ side.TransformId = transformID
+ side.SideInputId = sideInputID
+ wrap.SideInput = &side
+ tok.Type = &wrap
+ tok.Token = []byte(t)
+ return tok
+}
+
+func TestSetValidTokens(t *testing.T) {
+ inputs := []struct {
+ transformID string
+ sideInputID string
+ tok token
+ }{
+ {
+ "t1",
+ "s1",
+ "tok1",
+ },
+ {
+ "t2",
+ "s2",
+ "tok2",
+ },
+ {
+ "t3",
+ "s3",
+ "tok3",
+ },
+ }
+
+ var s SideInputCache
+ err := s.Init(3)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ var tokens []fnpb.ProcessBundleRequest_CacheToken
+ for _, input := range inputs {
+ t := makeRequest(input.transformID, input.sideInputID, input.tok)
+ tokens = append(tokens, t)
+ }
+
+ s.SetValidTokens(tokens...)
+ if len(s.idsToTokens) != len(inputs) {
+ t.Errorf("Missing tokens, expected %v, got %v", len(inputs), len(s.idsToTokens))
+ }
+
+ for i, input := range inputs {
+ // Check that the token is in the valid list
+ if !s.isValid(input.tok) {
+ t.Errorf("error in input %v, token %v is not valid", i, input.tok)
+ }
+ // Check that the mapping of IDs to tokens is correct
+ mapped := s.idsToTokens[input.transformID+input.sideInputID]
+ if mapped != input.tok {
+ t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tok, mapped)
+ }
+ }
+}
+
+func TestSetValidTokens_ClearingBetween(t *testing.T) {
+ inputs := []struct {
+ transformID string
+ sideInputID string
+ tk token
+ }{
+ {
+ "t1",
+ "s1",
+ "tok1",
+ },
+ {
+ "t2",
+ "s2",
+ "tok2",
+ },
+ {
+ "t3",
+ "s3",
+ "tok3",
+ },
+ }
+
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ for i, input := range inputs {
+ tok := makeRequest(input.transformID, input.sideInputID, input.tk)
+
+ s.SetValidTokens(tok)
+
+ // Check that the token is in the valid list
+ if !s.isValid(input.tk) {
+ t.Errorf("error in input %v, token %v is not valid", i, input.tk)
+ }
+ // Check that the mapping of IDs to tokens is correct
+ mapped := s.idsToTokens[input.transformID+input.sideInputID]
+ if mapped != input.tk {
+ t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tk, mapped)
+ }
+
+ s.CompleteBundle(tok)
+ }
+
+ for k, _ := range s.validTokens {
+ if s.validTokens[k] != 0 {
+ t.Errorf("token count mismatch for token %v, expected 0, got %v", k, s.validTokens[k])
+ }
+ }
+}
+
+func TestSetCache_Eviction(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ tokOne := makeRequest("t1", "s1", "tok1")
+ inOne := makeTestReusableInput("t1", "s1", 10)
+ s.SetValidTokens(tokOne)
+ s.SetCache("t1", "s1", inOne)
+ // Mark bundle as complete, drop count for tokOne to 0
+ s.CompleteBundle(tokOne)
+
+ tokTwo := makeRequest("t2", "s2", "tok2")
+ inTwo := makeTestReusableInput("t2", "s2", 20)
+ s.SetValidTokens(tokTwo)
+ s.SetCache("t2", "s2", inTwo)
+
+ if len(s.cache) != 1 {
+ t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+ }
+ if s.metrics.Evictions != 1 {
+ t.Errorf("number evictions incorrect, expected 1, got %v", s.metrics.Evictions)
+ }
+}
+
+func TestSetCache_EvictionFailure(t *testing.T) {
+ var s SideInputCache
+ err := s.Init(1)
+ if err != nil {
+ t.Fatalf("cache init failed, got %v", err)
+ }
+
+ tokOne := makeRequest("t1", "s1", "tok1")
+ inOne := makeTestReusableInput("t1", "s1", 10)
+
+ tokTwo := makeRequest("t2", "s2", "tok2")
+ inTwo := makeTestReusableInput("t2", "s2", 20)
+
+ s.SetValidTokens(tokOne, tokTwo)
+ s.SetCache("t1", "s1", inOne)
+ // Should fail to evict because the first token is still valid
+ s.SetCache("t2", "s2", inTwo)
+ // Cache should not exceed size 1
+ if len(s.cache) != 1 {
+ t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache))
+ }
+ if s.metrics.InUseEvictions != 1 {
+ t.Errorf("number of failed evicition calls incorrect, expected 1, got %v", s.metrics.InUseEvictions)
+ }
+}
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
index 511f839..f4ab8bb 100644
--- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java
@@ -156,16 +156,20 @@
}
/**
- * An adapter which converts an {@link InputStream} to an {@link Iterator} of {@code T} values
- * using the specified {@link Coder}.
+ * An adapter which converts an {@link InputStream} to a {@link PrefetchableIterator} of {@code T}
+ * values using the specified {@link Coder}.
*
* <p>Note that this adapter follows the Beam Fn API specification for forcing values that decode
* consuming zero bytes to consuming exactly one byte.
*
* <p>Note that access to the underlying {@link InputStream} is lazy and will only be invoked on
- * first access to {@link #next()} or {@link #hasNext()}.
+ * first access to {@link #next}, {@link #hasNext}, {@link #isReady}, and {@link #prefetch}.
+ *
+ * <p>Note that {@link #isReady} and {@link #prefetch} rely on non-empty {@link ByteString}s being
+ * returned via the underlying {@link PrefetchableIterator} otherwise the {@link #prefetch} will
+ * seemingly make zero progress yet will actually advance through the empty pages.
*/
- public static class DataStreamDecoder<T> implements Iterator<T> {
+ public static class DataStreamDecoder<T> implements PrefetchableIterator<T> {
private enum State {
READ_REQUIRED,
@@ -173,13 +177,13 @@
EOF
}
- private final Iterator<ByteString> inputByteStrings;
+ private final PrefetchableIterator<ByteString> inputByteStrings;
private final Inbound inbound;
private final Coder<T> coder;
private State currentState;
private T next;
- public DataStreamDecoder(Coder<T> coder, Iterator<ByteString> inputStream) {
+ public DataStreamDecoder(Coder<T> coder, PrefetchableIterator<ByteString> inputStream) {
this.currentState = State.READ_REQUIRED;
this.coder = coder;
this.inputByteStrings = inputStream;
@@ -187,6 +191,31 @@
}
@Override
+ public boolean isReady() {
+ switch (currentState) {
+ case EOF:
+ return true;
+ case READ_REQUIRED:
+ try {
+ return inbound.isReady();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ case HAS_NEXT:
+ return true;
+ default:
+ throw new IllegalStateException(String.format("Unknown state %s", currentState));
+ }
+ }
+
+ @Override
+ public void prefetch() {
+ if (!isReady()) {
+ inputByteStrings.prefetch();
+ }
+ }
+
+ @Override
public boolean hasNext() {
switch (currentState) {
case EOF:
@@ -232,8 +261,8 @@
private static final InputStream EMPTY_STREAM = ByteString.EMPTY.newInput();
/**
- * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the first
- * {@link Iterator} on first access of this input stream.
+ * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the {@link
+ * Iterator} on first access of this input stream.
*
* <p>Closing this input stream has no effect.
*/
@@ -245,6 +274,22 @@
this.currentStream = EMPTY_STREAM;
}
+ public boolean isReady() throws IOException {
+ // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
+ // minus the number of bytes that have been read so far and can be reliably used to tell
+ // us whether we are at the end of the stream.
+ while (currentStream.available() == 0) {
+ if (!inputByteStrings.isReady()) {
+ return false;
+ }
+ if (!inputByteStrings.hasNext()) {
+ return true;
+ }
+ currentStream = inputByteStrings.next().newInput();
+ }
+ return true;
+ }
+
public boolean isEof() throws IOException {
// Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
// minus the number of bytes that have been read so far and can be reliably used to tell
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
index 9dd5ee4..a8b48e8 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java
@@ -23,6 +23,7 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
import java.io.IOException;
@@ -106,7 +107,7 @@
}
@Test
- public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception {
+ public void testNonEmptyInputStreamWithZeroLengthEncoding() throws Exception {
CountingOutputStream countingOutputStream =
new CountingOutputStream(ByteStreams.nullOutputStream());
GlobalWindow.Coder.INSTANCE.encode(GlobalWindow.INSTANCE, countingOutputStream);
@@ -115,6 +116,55 @@
testDecoderWith(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, GlobalWindow.INSTANCE);
}
+ @Test
+ public void testPrefetch() throws Exception {
+ List<ByteString> encodings = new ArrayList<>();
+ {
+ ByteString.Output encoding = ByteString.newOutput();
+ StringUtf8Coder.of().encode("A", encoding);
+ StringUtf8Coder.of().encode("BC", encoding);
+ encodings.add(encoding.toByteString());
+ }
+ encodings.add(ByteString.EMPTY);
+ {
+ ByteString.Output encoding = ByteString.newOutput();
+ StringUtf8Coder.of().encode("DEF", encoding);
+ StringUtf8Coder.of().encode("GHIJ", encoding);
+ encodings.add(encoding.toByteString());
+ }
+
+ PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<ByteString> iterator =
+ new PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<>(encodings.iterator());
+ PrefetchableIterator<String> decoder =
+ new DataStreamDecoder<>(StringUtf8Coder.of(), iterator);
+ assertFalse(decoder.isReady());
+ decoder.prefetch();
+ assertTrue(decoder.isReady());
+ assertEquals(1, iterator.getNumPrefetchCalls());
+
+ decoder.next();
+ // Now we will have moved off of the empty byte array that we start with so prefetch will
+ // do nothing since we are ready
+ assertTrue(decoder.isReady());
+ decoder.prefetch();
+ assertEquals(1, iterator.getNumPrefetchCalls());
+
+ decoder.next();
+ // Now we are at the end of the first ByteString so we expect a prefetch to pass through
+ assertFalse(decoder.isReady());
+ decoder.prefetch();
+ assertEquals(2, iterator.getNumPrefetchCalls());
+ // We also expect the decoder to not be ready since the next byte string is empty which
+ // would require us to move to the next page. This typically wouldn't happen in practice
+ // though because we expect non empty pages.
+ assertFalse(decoder.isReady());
+
+ // Prefetching will allow us to move to the third ByteString
+ decoder.prefetch();
+ assertEquals(3, iterator.getNumPrefetchCalls());
+ assertTrue(decoder.isReady());
+ }
+
private <T> void testDecoderWith(Coder<T> coder, T... expected) throws IOException {
ByteString.Output output = ByteString.newOutput();
for (T value : expected) {
@@ -131,7 +181,9 @@
}
private <T> void testDecoderWith(Coder<T> coder, T[] expected, List<ByteString> encoded) {
- Iterator<T> decoder = new DataStreamDecoder<>(coder, encoded.iterator());
+ Iterator<T> decoder =
+ new DataStreamDecoder<>(
+ coder, PrefetchableIterators.maybePrefetchable(encoded.iterator()));
Object[] actual = Iterators.toArray(decoder, Object.class);
assertArrayEquals(expected, actual);
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
index 6131634..9ada175 100644
--- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java
@@ -120,10 +120,14 @@
"F");
}
- private static class NeverReady implements PrefetchableIterator<String> {
- PrefetchableIterator<String> delegate = PrefetchableIterators.fromArray("A", "B");
+ public static class NeverReady<T> implements PrefetchableIterator<T> {
+ private final Iterator<T> delegate;
int prefetchCalled;
+ public NeverReady(Iterator<T> delegate) {
+ this.delegate = delegate;
+ }
+
@Override
public boolean isReady() {
return false;
@@ -140,74 +144,117 @@
}
@Override
- public String next() {
+ public T next() {
return delegate.next();
}
+
+ public int getNumPrefetchCalls() {
+ return prefetchCalled;
+ }
}
- private static class ReadyAfterPrefetch extends NeverReady {
+ public static class ReadyAfterPrefetch<T> extends NeverReady<T> {
+
+ public ReadyAfterPrefetch(Iterator<T> delegate) {
+ super(delegate);
+ }
+
@Override
public boolean isReady() {
return prefetchCalled > 0;
}
}
+ public static class ReadyAfterPrefetchUntilNext<T> extends ReadyAfterPrefetch<T> {
+ boolean advancedSincePrefetch;
+
+ public ReadyAfterPrefetchUntilNext(Iterator<T> delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public boolean isReady() {
+ return !advancedSincePrefetch && super.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ advancedSincePrefetch = false;
+ super.prefetch();
+ }
+
+ @Override
+ public T next() {
+ advancedSincePrefetch = true;
+ return super.next();
+ }
+
+ @Override
+ public boolean hasNext() {
+ advancedSincePrefetch = true;
+ return super.hasNext();
+ }
+ }
+
@Test
public void testConcatIsReadyAdvancesToNextIteratorWhenAble() {
- NeverReady readyAfterPrefetch1 = new NeverReady();
- ReadyAfterPrefetch readyAfterPrefetch2 = new ReadyAfterPrefetch();
- ReadyAfterPrefetch readyAfterPrefetch3 = new ReadyAfterPrefetch();
+ NeverReady<String> readyAfterPrefetch1 =
+ new NeverReady<>(PrefetchableIterators.fromArray("A", "B"));
+ ReadyAfterPrefetch<String> readyAfterPrefetch2 =
+ new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
+ ReadyAfterPrefetch<String> readyAfterPrefetch3 =
+ new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
PrefetchableIterator<String> iterator =
PrefetchableIterators.concat(readyAfterPrefetch1, readyAfterPrefetch2, readyAfterPrefetch3);
// Expect no prefetches yet
- assertEquals(0, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(0, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
// We expect to attempt to prefetch for the first time.
iterator.prefetch();
- assertEquals(1, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(1, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// We expect to attempt to prefetch again since we aren't ready.
iterator.prefetch();
- assertEquals(2, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(2, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// The current iterator is done but is never ready so we can't advance to the next one and
// expect another prefetch to go to the current iterator.
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(0, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// Now that we know the last iterator is done and have advanced to the next one we expect
// prefetch to go through
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(1, readyAfterPrefetch2.prefetchCalled);
- assertEquals(0, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// The last iterator is done so we should be able to prefetch the next one before advancing
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(1, readyAfterPrefetch2.prefetchCalled);
- assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
// The current iterator is ready so no additional prefetch is necessary
iterator.prefetch();
- assertEquals(3, readyAfterPrefetch1.prefetchCalled);
- assertEquals(1, readyAfterPrefetch2.prefetchCalled);
- assertEquals(1, readyAfterPrefetch3.prefetchCalled);
+ assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
+ assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
}
diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle
index 3c859ae..6337cd4 100644
--- a/sdks/java/harness/build.gradle
+++ b/sdks/java/harness/build.gradle
@@ -72,6 +72,7 @@
testCompile library.java.mockito_core
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
testCompile project(":runners:core-construction-java")
+ testCompile project(path: ":sdks:java:fn-execution", configuration: "testRuntime")
shadowTestRuntimeClasspath library.java.slf4j_jdk14
jmhCompile project(path: ":sdks:java:harness", configuration: "shadowTest")
jmhRuntime library.java.slf4j_jdk14
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index 5ddf0ae..777036a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -26,6 +26,8 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -49,7 +51,7 @@
private final BeamFnStateClient beamFnStateClient;
private final StateRequest request;
private final Coder<T> valueCoder;
- private Iterable<T> oldValues;
+ private PrefetchableIterable<T> oldValues;
private ArrayList<T> newValues;
private boolean isClosed;
@@ -80,19 +82,19 @@
this.newValues = new ArrayList<>();
}
- public Iterable<T> get() {
+ public PrefetchableIterable<T> get() {
checkState(
!isClosed,
"Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
if (oldValues == null) {
// If we were cleared we should disregard old values.
- return Iterables.limit(Collections.unmodifiableList(newValues), newValues.size());
+ return PrefetchableIterables.limit(Collections.unmodifiableList(newValues), newValues.size());
} else if (newValues.isEmpty()) {
// If we have no new values then just return the old values.
return oldValues;
}
- return Iterables.concat(
+ return PrefetchableIterables.concat(
oldValues, Iterables.limit(Collections.unmodifiableList(newValues), newValues.size()));
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 2f2789e..5a931c5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -38,7 +38,6 @@
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.StateBinder;
import org.apache.beam.sdk.state.StateContext;
@@ -264,7 +263,7 @@
@Override
public ValueState<T> readLater() {
- // TODO(BEAM-12802): Support prefetching.
+ impl.get().iterator().prefetch();
return this;
}
};
@@ -310,7 +309,7 @@
@Override
public BagState<T> readLater() {
- // TODO(BEAM-12802): Support prefetching.
+ impl.get().iterator().prefetch();
return this;
}
@@ -391,6 +390,7 @@
@Override
public CombiningState<ElementT, AccumT, ResultT> readLater() {
+ impl.get().iterator().prefetch();
return this;
}
@@ -412,7 +412,18 @@
@Override
public ReadableState<Boolean> isEmpty() {
- return ReadableStates.immediate(!impl.get().iterator().hasNext());
+ return new ReadableState<Boolean>() {
+ @Override
+ public @Nullable Boolean read() {
+ return !impl.get().iterator().hasNext();
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ impl.get().iterator().prefetch();
+ return this;
+ }
+ };
}
@Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index cfc76cf..7828f93 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -21,6 +21,9 @@
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
+import java.util.Objects;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -28,30 +31,42 @@
* Converts an iterator to an iterable lazily loading values from the underlying iterator and
* caching them to support reiteration.
*/
-@SuppressWarnings({
- "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
-class LazyCachingIteratorToIterable<T> implements Iterable<T> {
+class LazyCachingIteratorToIterable<T> implements PrefetchableIterable<T> {
private final List<T> cachedElements;
- private final Iterator<T> iterator;
+ private final PrefetchableIterator<T> iterator;
- public LazyCachingIteratorToIterable(Iterator<T> iterator) {
+ public LazyCachingIteratorToIterable(PrefetchableIterator<T> iterator) {
this.cachedElements = new ArrayList<>();
this.iterator = iterator;
}
@Override
- public Iterator<T> iterator() {
+ public PrefetchableIterator<T> iterator() {
return new CachingIterator();
}
/** An {@link Iterator} which adds and fetched values into the cached elements list. */
- private class CachingIterator implements Iterator<T> {
+ private class CachingIterator implements PrefetchableIterator<T> {
private int position = 0;
private CachingIterator() {}
@Override
+ public boolean isReady() {
+ if (position < cachedElements.size()) {
+ return true;
+ }
+ return iterator.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ if (!isReady()) {
+ iterator.prefetch();
+ }
+ }
+
+ @Override
public boolean hasNext() {
// The order of the short circuit is important below.
return position < cachedElements.size() || iterator.hasNext();
@@ -76,7 +91,7 @@
@Override
public int hashCode() {
- return iterator.hasNext() ? iterator.next().hashCode() : -1789023489;
+ return iterator.hasNext() ? Objects.hashCode(iterator.next()) : -1789023489;
}
@Override
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 22be306..1026ba5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -17,20 +17,21 @@
*/
package org.apache.beam.fn.harness.state;
-import java.util.Collections;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
-import java.util.function.Supplier;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.fn.stream.DataStreams;
+import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
/**
* Adapters which convert a a logical series of chunks using continuation tokens over the Beam Fn
@@ -54,7 +55,7 @@
* only) chunk of a state stream. This state request will be populated with a continuation
* token to request further chunks of the stream if required.
*/
- public static Iterator<ByteString> readAllStartingFrom(
+ public static PrefetchableIterator<ByteString> readAllStartingFrom(
BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
return new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk);
}
@@ -74,94 +75,142 @@
* token to request further chunks of the stream if required.
* @param valueCoder A coder for decoding the state stream.
*/
- public static <T> Iterable<T> readAllAndDecodeStartingFrom(
+ public static <T> PrefetchableIterable<T> readAllAndDecodeStartingFrom(
BeamFnStateClient beamFnStateClient,
StateRequest stateRequestForFirstChunk,
Coder<T> valueCoder) {
- FirstPageAndRemainder firstPageAndRemainder =
- new FirstPageAndRemainder(beamFnStateClient, stateRequestForFirstChunk);
- return Iterables.concat(
- new LazyCachingIteratorToIterable<T>(
- new DataStreams.DataStreamDecoder<>(
- valueCoder, new LazySingletonIterator<>(firstPageAndRemainder::firstPage))),
- () -> new DataStreams.DataStreamDecoder<>(valueCoder, firstPageAndRemainder.remainder()));
- }
-
- /** A iterable that contains a single element, provided by a Supplier which is invoked lazily. */
- static class LazySingletonIterator<T> implements Iterator<T> {
-
- private final Supplier<T> supplier;
- private boolean hasNext;
-
- private LazySingletonIterator(Supplier<T> supplier) {
- this.supplier = supplier;
- hasNext = true;
- }
-
- @Override
- public boolean hasNext() {
- return hasNext;
- }
-
- @Override
- public T next() {
- hasNext = false;
- return supplier.get();
- }
+ return new FirstPageAndRemainder<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder);
}
/**
- * An helper class that (lazily) gives the first page of a paginated state request separately from
+ * A helper class that (lazily) gives the first page of a paginated state request separately from
* all the remaining pages.
*/
- static class FirstPageAndRemainder {
+ @VisibleForTesting
+ static class FirstPageAndRemainder<T> implements PrefetchableIterable<T> {
private final BeamFnStateClient beamFnStateClient;
private final StateRequest stateRequestForFirstChunk;
- private ByteString firstPage = null;
+ private final Coder<T> valueCoder;
+ private LazyCachingIteratorToIterable<T> firstPage;
+ private CompletableFuture<StateResponse> firstPageResponseFuture;
private ByteString continuationToken;
- private FirstPageAndRemainder(
- BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) {
+ FirstPageAndRemainder(
+ BeamFnStateClient beamFnStateClient,
+ StateRequest stateRequestForFirstChunk,
+ Coder<T> valueCoder) {
this.beamFnStateClient = beamFnStateClient;
this.stateRequestForFirstChunk = stateRequestForFirstChunk;
+ this.valueCoder = valueCoder;
}
- public ByteString firstPage() {
- if (firstPage == null) {
- CompletableFuture<StateResponse> stateResponseFuture = new CompletableFuture<>();
+ @Override
+ public PrefetchableIterator<T> iterator() {
+ return new PrefetchableIterator<T>() {
+ PrefetchableIterator<T> delegate;
+
+ private void ensureDelegateExists() {
+ if (delegate == null) {
+ // Fetch the first page if necessary
+ prefetchFirstPage();
+ if (firstPage == null) {
+ StateResponse stateResponse;
+ try {
+ stateResponse = firstPageResponseFuture.get();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException(e);
+ } catch (ExecutionException e) {
+ if (e.getCause() == null) {
+ throw new IllegalStateException(e);
+ }
+ Throwables.throwIfUnchecked(e.getCause());
+ throw new IllegalStateException(e.getCause());
+ }
+ continuationToken = stateResponse.getGet().getContinuationToken();
+ firstPage =
+ new LazyCachingIteratorToIterable<>(
+ new DataStreamDecoder<>(
+ valueCoder,
+ PrefetchableIterators.fromArray(stateResponse.getGet().getData())));
+ }
+
+ if (ByteString.EMPTY.equals((continuationToken))) {
+ delegate = firstPage.iterator();
+ } else {
+ delegate =
+ PrefetchableIterators.concat(
+ firstPage.iterator(),
+ new DataStreamDecoder<>(
+ valueCoder,
+ new LazyBlockingStateFetchingIterator(
+ beamFnStateClient,
+ stateRequestForFirstChunk
+ .toBuilder()
+ .setGet(
+ StateGetRequest.newBuilder()
+ .setContinuationToken(continuationToken))
+ .build())));
+ }
+ }
+ }
+
+ @Override
+ public boolean isReady() {
+ if (delegate == null) {
+ if (firstPageResponseFuture != null) {
+ return firstPageResponseFuture.isDone();
+ }
+ return false;
+ }
+ return delegate.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ if (firstPageResponseFuture == null) {
+ prefetchFirstPage();
+ } else if (delegate != null && !delegate.isReady()) {
+ delegate.prefetch();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (delegate == null) {
+ // Ensure that we prefetch the second page after the first has been accessed.
+ // Prefetching subsequent pages after the first will be handled by the
+ // LazyBlockingStateFetchingIterator
+ ensureDelegateExists();
+ boolean rval = delegate.hasNext();
+ delegate.prefetch();
+ return rval;
+ }
+ return delegate.hasNext();
+ }
+
+ @Override
+ public T next() {
+ if (delegate == null) {
+ // Ensure that we prefetch the second page after the first has been accessed.
+ // Prefetching subsequent pages after the first will be handled by the
+ // LazyBlockingStateFetchingIterator
+ ensureDelegateExists();
+ T rval = delegate.next();
+ delegate.prefetch();
+ return rval;
+ }
+ return delegate.next();
+ }
+ };
+ }
+
+ private void prefetchFirstPage() {
+ if (firstPageResponseFuture == null) {
+ firstPageResponseFuture = new CompletableFuture<>();
beamFnStateClient.handle(
stateRequestForFirstChunk.toBuilder().setGet(stateRequestForFirstChunk.getGet()),
- stateResponseFuture);
- StateResponse stateResponse;
- try {
- stateResponse = stateResponseFuture.get();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new IllegalStateException(e);
- } catch (ExecutionException e) {
- if (e.getCause() == null) {
- throw new IllegalStateException(e);
- }
- Throwables.throwIfUnchecked(e.getCause());
- throw new IllegalStateException(e.getCause());
- }
- continuationToken = stateResponse.getGet().getContinuationToken();
- firstPage = stateResponse.getGet().getData();
- }
- return firstPage;
- }
-
- public Iterator<ByteString> remainder() {
- firstPage();
- if (ByteString.EMPTY.equals(continuationToken)) {
- return Collections.emptyIterator();
- } else {
- return new LazyBlockingStateFetchingIterator(
- beamFnStateClient,
- stateRequestForFirstChunk
- .toBuilder()
- .setGet(StateGetRequest.newBuilder().setContinuationToken(continuationToken))
- .build());
+ firstPageResponseFuture);
}
}
}
@@ -169,10 +218,11 @@
/**
* An {@link Iterator} which fetches {@link ByteString} chunks using the State API.
*
- * <p>This iterator will only request a chunk on first access. Subsiquently it eagerly pre-fetches
- * one future chunks at a time.
+ * <p>This iterator will only request a chunk on first access. Subsequently it eagerly pre-fetches
+ * one future chunk at a time.
*/
- static class LazyBlockingStateFetchingIterator implements Iterator<ByteString> {
+ @VisibleForTesting
+ static class LazyBlockingStateFetchingIterator implements PrefetchableIterator<ByteString> {
private enum State {
READ_REQUIRED,
@@ -195,8 +245,17 @@
this.continuationToken = stateRequestForFirstChunk.getGet().getContinuationToken();
}
- private void prefetch() {
- if (prefetchedResponse == null && currentState == State.READ_REQUIRED) {
+ @Override
+ public boolean isReady() {
+ if (prefetchedResponse == null) {
+ return currentState != State.READ_REQUIRED;
+ }
+ return prefetchedResponse.isDone();
+ }
+
+ @Override
+ public void prefetch() {
+ if (currentState == State.READ_REQUIRED && prefetchedResponse == null) {
prefetchedResponse = new CompletableFuture<>();
beamFnStateClient.handle(
stateRequestForFirstChunk
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
index 7597128..0914b01 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
@@ -25,8 +25,11 @@
import java.util.Iterator;
import java.util.NoSuchElementException;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
+import org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetch;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -40,7 +43,8 @@
@Test
public void testEmptyIterator() {
- Iterable<Object> iterable = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+ Iterable<Object> iterable =
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.emptyIterator());
assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
// iterate multiple times
assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class));
@@ -52,7 +56,7 @@
@Test
public void testInterleavedIteration() {
Iterable<String> iterable =
- new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
Iterator<String> iterator1 = iterable.iterator();
assertTrue(iterator1.hasNext());
@@ -77,14 +81,45 @@
@Test
public void testEqualsAndHashCode() {
- Iterable<String> iterA = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
- Iterable<String> iterB = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
- Iterable<String> iterC = new LazyCachingIteratorToIterable<>(Iterators.forArray());
- Iterable<String> iterD = new LazyCachingIteratorToIterable<>(Iterators.forArray());
+ Iterable<String> iterA =
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+ Iterable<String> iterB =
+ new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C"));
+ Iterable<String> iterC = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
+ Iterable<String> iterD = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray());
assertEquals(iterA, iterB);
assertEquals(iterC, iterD);
assertNotEquals(iterA, iterC);
assertEquals(iterA.hashCode(), iterB.hashCode());
assertEquals(iterC.hashCode(), iterD.hashCode());
}
+
+ @Test
+ public void testPrefetch() {
+ ReadyAfterPrefetch<String> underlying =
+ new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B", "C"));
+ PrefetchableIterable<String> iterable = new LazyCachingIteratorToIterable<>(underlying);
+ PrefetchableIterator<String> iterator1 = iterable.iterator();
+ PrefetchableIterator<String> iterator2 = iterable.iterator();
+
+ // Check that the lazy iterable doesn't do any prefetch/access on instantiation
+ assertFalse(underlying.isReady());
+ assertFalse(iterator1.isReady());
+ assertFalse(iterator2.isReady());
+
+ // Check that if both iterators prefetch there is only one prefetch for the underlying iterator
+ // iterator.
+ iterator1.prefetch();
+ assertEquals(1, underlying.getNumPrefetchCalls());
+ iterator2.prefetch();
+ assertEquals(1, underlying.getNumPrefetchCalls());
+
+ // Check that if that one iterator has advanced, the second doesn't perform any prefetch since
+ // the element is now cached.
+ iterator1.next();
+ iterator1.next();
+ iterator2.next();
+ iterator2.prefetch();
+ assertEquals(1, underlying.getNumPrefetchCalls());
+ }
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
index fc729cc..384d2df 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
@@ -19,12 +19,16 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Iterator;
import java.util.List;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
+import org.apache.beam.fn.harness.state.StateFetchingIterators.FirstPageAndRemainder;
import org.apache.beam.fn.harness.state.StateFetchingIterators.LazyBlockingStateFetchingIterator;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
@@ -32,16 +36,56 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link StateFetchingIterators}. */
+@RunWith(Enclosed.class)
public class StateFetchingIteratorsTest {
+
+ private static BeamFnStateClient fakeStateClient(
+ AtomicInteger callCount, ByteString... expected) {
+ return (requestBuilder, response) -> {
+ callCount.incrementAndGet();
+ if (expected.length == 0) {
+ response.complete(
+ StateResponse.newBuilder()
+ .setId(requestBuilder.getId())
+ .setGet(StateGetResponse.newBuilder())
+ .build());
+ return;
+ }
+
+ ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
+
+ int requestedPosition = 0; // Default position is 0
+ if (!ByteString.EMPTY.equals(continuationToken)) {
+ requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
+ }
+
+ // Compute the new continuation token
+ ByteString newContinuationToken = ByteString.EMPTY;
+ if (requestedPosition != expected.length - 1) {
+ newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
+ }
+ response.complete(
+ StateResponse.newBuilder()
+ .setId(requestBuilder.getId())
+ .setGet(
+ StateGetResponse.newBuilder()
+ .setData(expected[requestedPosition])
+ .setContinuationToken(newContinuationToken))
+ .build());
+ };
+ }
+
/** Tests for {@link StateFetchingIterators.LazyBlockingStateFetchingIterator}. */
@RunWith(JUnit4.class)
public static class LazyBlockingStateFetchingIteratorTest {
@@ -77,49 +121,55 @@
ByteString.EMPTY);
}
- private BeamFnStateClient fakeStateClient(AtomicInteger callCount, ByteString... expected) {
- return (requestBuilder, response) -> {
- callCount.incrementAndGet();
- if (expected.length == 0) {
- response.complete(
- StateResponse.newBuilder()
- .setId(requestBuilder.getId())
- .setGet(StateGetResponse.newBuilder())
- .build());
- return;
- }
-
- ByteString continuationToken = requestBuilder.getGet().getContinuationToken();
-
- int requestedPosition = 0; // Default position is 0
- if (!ByteString.EMPTY.equals(continuationToken)) {
- requestedPosition = Integer.parseInt(continuationToken.toStringUtf8());
- }
-
- // Compute the new continuation token
- ByteString newContinuationToken = ByteString.EMPTY;
- if (requestedPosition != expected.length - 1) {
- newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1));
- }
- response.complete(
- StateResponse.newBuilder()
- .setId(requestBuilder.getId())
- .setGet(
- StateGetResponse.newBuilder()
- .setData(expected[requestedPosition])
- .setContinuationToken(newContinuationToken))
- .build());
- };
+ @Test
+ public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception {
+ AtomicInteger callCount = new AtomicInteger();
+ BeamFnStateClient fakeStateClient =
+ new BeamFnStateClient() {
+ @Override
+ public void handle(
+ StateRequest.Builder requestBuilder, CompletableFuture<StateResponse> response) {
+ callCount.incrementAndGet();
+ }
+ };
+ PrefetchableIterator<ByteString> byteStrings =
+ new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
+ assertEquals(0, callCount.get());
+ byteStrings.prefetch();
+ assertEquals(1, callCount.get()); // first prefetch
+ byteStrings.prefetch();
+ assertEquals(1, callCount.get()); // subsequent is ignored
}
private void testFetch(ByteString... expected) {
AtomicInteger callCount = new AtomicInteger();
BeamFnStateClient fakeStateClient = fakeStateClient(callCount, expected);
- Iterator<ByteString> byteStrings =
+ PrefetchableIterator<ByteString> byteStrings =
new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance());
assertEquals(0, callCount.get()); // Ensure it's fully lazy.
- assertArrayEquals(expected, Iterators.toArray(byteStrings, Object.class));
+ assertFalse(byteStrings.isReady());
+
+ // Prefetch every second element in the iterator capturing the results
+ List<ByteString> results = new ArrayList<>();
+ for (int i = 0; i < expected.length; ++i) {
+ if (i % 2 == 0) {
+ // Ensure that prefetch performs the call
+ byteStrings.prefetch();
+ assertEquals(i + 1, callCount.get());
+ assertTrue(byteStrings.isReady());
+ }
+ assertTrue(byteStrings.hasNext());
+ results.add(byteStrings.next());
+ }
+ assertFalse(byteStrings.hasNext());
+ assertTrue(byteStrings.isReady());
+
+ assertEquals(Arrays.asList(expected), results);
}
+ }
+
+ @RunWith(JUnit4.class)
+ public static class FirstPageAndRemainderTest {
@Test
public void testEmptyValues() throws Exception {
@@ -133,7 +183,7 @@
@Test
public void testManyValues() throws Exception {
- testFetchValues(VarIntCoder.of(), 11, 37, 389, 5077);
+ testFetchValues(VarIntCoder.of(), 1, 22, 333, 4444, 55555, 666666);
}
private <T> void testFetchValues(Coder<T> coder, T... expected) {
@@ -153,35 +203,42 @@
AtomicInteger callCount = new AtomicInteger();
BeamFnStateClient fakeStateClient =
fakeStateClient(callCount, Iterables.toArray(byteStrings, ByteString.class));
- Iterable<T> values =
- StateFetchingIterators.readAllAndDecodeStartingFrom(
- fakeStateClient, StateRequest.getDefaultInstance(), coder);
+ PrefetchableIterable<T> values =
+ new FirstPageAndRemainder<>(fakeStateClient, StateRequest.getDefaultInstance(), coder);
// Ensure it's fully lazy.
assertEquals(0, callCount.get());
- Iterator<T> valuesIter = values.iterator();
+ PrefetchableIterator<T> valuesIter = values.iterator();
+ assertFalse(valuesIter.isReady());
assertEquals(0, callCount.get());
- // No more is read than necessary.
- if (valuesIter.hasNext()) {
- valuesIter.next();
- }
+ // Ensure that the first page result is cached across multiple iterators and subsequent
+ // iterators are ready and prefetch does nothing
+ valuesIter.prefetch();
+ assertTrue(valuesIter.isReady());
assertEquals(1, callCount.get());
- // The first page is cached.
- Iterator<T> valuesIter2 = values.iterator();
- assertEquals(1, callCount.get());
- if (valuesIter2.hasNext()) {
- valuesIter2.next();
- }
+ PrefetchableIterator<T> valuesIter2 = values.iterator();
+ assertTrue(valuesIter2.isReady());
+ valuesIter2.prefetch();
assertEquals(1, callCount.get());
- if (valuesIter.hasNext()) {
- valuesIter.next();
- // Subsequent pages are pre-fetched, so after accessing the second page,
- // the third should be requested.
- assertEquals(3, callCount.get());
+ // Prefetch every second element in the iterator capturing the results
+ List<T> results = new ArrayList<>();
+ for (int i = 0; i < expected.length; ++i) {
+ if (i % 2 == 1) {
+ // Ensure that prefetch performs the call
+ valuesIter2.prefetch();
+ assertTrue(valuesIter2.isReady());
+ // Note that this is i+2 because we expect to prefetch the page after the current one
+ // We also have to bound it to the max number of pages
+ assertEquals(Math.min(i + 2, expected.length), callCount.get());
+ }
+ assertTrue(valuesIter2.hasNext());
+ results.add(valuesIter2.next());
}
+ assertFalse(valuesIter2.hasNext());
+ assertTrue(valuesIter2.isReady());
// The contents agree.
assertArrayEquals(expected, Iterables.toArray(values, Object.class));
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java
index 5c2f1f5..69d5d19 100644
--- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java
+++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java
@@ -18,9 +18,11 @@
package org.apache.beam.sdk.io.aws.options;
import com.amazonaws.ClientConfiguration;
+import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
+import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper;
@@ -66,6 +68,7 @@
private static final String AWS_ACCESS_KEY_ID = "awsAccessKeyId";
private static final String AWS_SECRET_KEY = "awsSecretKey";
+ private static final String SESSION_TOKEN = "sessionToken";
private static final String CREDENTIALS_FILE_PATH = "credentialsFilePath";
public static final String CLIENT_EXECUTION_TIMEOUT = "clientExecutionTimeout";
public static final String CONNECTION_MAX_IDLE_TIME = "connectionMaxIdleTime";
@@ -119,8 +122,17 @@
}
if (typeName.equals(AWSStaticCredentialsProvider.class.getSimpleName())) {
- return new AWSStaticCredentialsProvider(
- new BasicAWSCredentials(asMap.get(AWS_ACCESS_KEY_ID), asMap.get(AWS_SECRET_KEY)));
+ boolean isSession = asMap.containsKey(SESSION_TOKEN);
+ if (isSession) {
+ return new AWSStaticCredentialsProvider(
+ new BasicSessionCredentials(
+ asMap.get(AWS_ACCESS_KEY_ID),
+ asMap.get(AWS_SECRET_KEY),
+ asMap.get(SESSION_TOKEN)));
+ } else {
+ return new AWSStaticCredentialsProvider(
+ new BasicAWSCredentials(asMap.get(AWS_ACCESS_KEY_ID), asMap.get(AWS_SECRET_KEY)));
+ }
} else if (typeName.equals(PropertiesFileCredentialsProvider.class.getSimpleName())) {
return new PropertiesFileCredentialsProvider(asMap.get(CREDENTIALS_FILE_PATH));
} else if (typeName.equals(
@@ -179,11 +191,16 @@
typeSerializer.writeTypePrefixForObject(credentialsProvider, jsonGenerator);
if (credentialsProvider.getClass().equals(AWSStaticCredentialsProvider.class)) {
- jsonGenerator.writeStringField(
- AWS_ACCESS_KEY_ID, credentialsProvider.getCredentials().getAWSAccessKeyId());
- jsonGenerator.writeStringField(
- AWS_SECRET_KEY, credentialsProvider.getCredentials().getAWSSecretKey());
-
+ AWSCredentials credentials = credentialsProvider.getCredentials();
+ if (credentials.getClass().equals(BasicSessionCredentials.class)) {
+ BasicSessionCredentials sessionCredentials = (BasicSessionCredentials) credentials;
+ jsonGenerator.writeStringField(AWS_ACCESS_KEY_ID, sessionCredentials.getAWSAccessKeyId());
+ jsonGenerator.writeStringField(AWS_SECRET_KEY, sessionCredentials.getAWSSecretKey());
+ jsonGenerator.writeStringField(SESSION_TOKEN, sessionCredentials.getSessionToken());
+ } else {
+ jsonGenerator.writeStringField(AWS_ACCESS_KEY_ID, credentials.getAWSAccessKeyId());
+ jsonGenerator.writeStringField(AWS_SECRET_KEY, credentials.getAWSSecretKey());
+ }
} else if (credentialsProvider.getClass().equals(PropertiesFileCredentialsProvider.class)) {
try {
PropertiesFileCredentialsProvider specificProvider =
diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java
index 9651803..0e318c9 100644
--- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java
+++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java
@@ -25,6 +25,7 @@
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
+import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper;
@@ -78,6 +79,20 @@
assertEquals(
credentialsProvider.getCredentials().getAWSSecretKey(),
deserializedCredentialsProvider.getCredentials().getAWSSecretKey());
+
+ String sessionToken = "session-token";
+ BasicSessionCredentials sessionCredentials =
+ new BasicSessionCredentials(awsKeyId, awsSecretKey, sessionToken);
+ credentialsProvider = new AWSStaticCredentialsProvider(sessionCredentials);
+ serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider);
+ deserializedCredentialsProvider =
+ objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class);
+ BasicSessionCredentials deserializedCredentials =
+ (BasicSessionCredentials) deserializedCredentialsProvider.getCredentials();
+ assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass());
+ assertEquals(deserializedCredentials.getAWSAccessKeyId(), awsKeyId);
+ assertEquals(deserializedCredentials.getAWSSecretKey(), awsSecretKey);
+ assertEquals(deserializedCredentials.getSessionToken(), sessionToken);
}
@Test
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/options/AwsModule.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/options/AwsModule.java
index 5051438..7453d76 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/options/AwsModule.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/options/AwsModule.java
@@ -41,7 +41,9 @@
import org.apache.beam.sdk.io.aws2.s3.SSECustomerKey;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
+import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
@@ -64,6 +66,7 @@
public class AwsModule extends SimpleModule {
private static final String ACCESS_KEY_ID = "accessKeyId";
private static final String SECRET_ACCESS_KEY = "secretAccessKey";
+ private static final String SESSION_TOKEN = "sessionToken";
public static final String CONNECTION_ACQUIRE_TIMEOUT = "connectionAcquisitionTimeout";
public static final String CONNECTION_MAX_IDLE_TIMEOUT = "connectionMaxIdleTime";
public static final String CONNECTION_TIMEOUT = "connectionTimeout";
@@ -107,10 +110,18 @@
throw new IOException(
String.format("AWS credentials provider type name key '%s' not found", typeNameKey));
}
-
if (typeName.equals(StaticCredentialsProvider.class.getSimpleName())) {
- return StaticCredentialsProvider.create(
- AwsBasicCredentials.create(asMap.get(ACCESS_KEY_ID), asMap.get(SECRET_ACCESS_KEY)));
+ boolean isSession = asMap.containsKey(SESSION_TOKEN);
+ if (isSession) {
+ return StaticCredentialsProvider.create(
+ AwsSessionCredentials.create(
+ asMap.get(ACCESS_KEY_ID),
+ asMap.get(SECRET_ACCESS_KEY),
+ asMap.get(SESSION_TOKEN)));
+ } else {
+ return StaticCredentialsProvider.create(
+ AwsBasicCredentials.create(asMap.get(ACCESS_KEY_ID), asMap.get(SECRET_ACCESS_KEY)));
+ }
} else if (typeName.equals(DefaultCredentialsProvider.class.getSimpleName())) {
return DefaultCredentialsProvider.create();
} else if (typeName.equals(EnvironmentVariableCredentialsProvider.class.getSimpleName())) {
@@ -158,10 +169,16 @@
// BEAM-11958 Use deprecated Jackson APIs to be compatible with older versions of jackson
typeSerializer.writeTypePrefixForObject(credentialsProvider, jsonGenerator);
if (credentialsProvider.getClass().equals(StaticCredentialsProvider.class)) {
- jsonGenerator.writeStringField(
- ACCESS_KEY_ID, credentialsProvider.resolveCredentials().accessKeyId());
- jsonGenerator.writeStringField(
- SECRET_ACCESS_KEY, credentialsProvider.resolveCredentials().secretAccessKey());
+ AwsCredentials credentials = credentialsProvider.resolveCredentials();
+ if (credentials.getClass().equals(AwsSessionCredentials.class)) {
+ AwsSessionCredentials sessionCredentials = (AwsSessionCredentials) credentials;
+ jsonGenerator.writeStringField(ACCESS_KEY_ID, sessionCredentials.accessKeyId());
+ jsonGenerator.writeStringField(SECRET_ACCESS_KEY, sessionCredentials.secretAccessKey());
+ jsonGenerator.writeStringField(SESSION_TOKEN, sessionCredentials.sessionToken());
+ } else {
+ jsonGenerator.writeStringField(ACCESS_KEY_ID, credentials.accessKeyId());
+ jsonGenerator.writeStringField(SECRET_ACCESS_KEY, credentials.secretAccessKey());
+ }
} else if (!SINGLETON_CREDENTIAL_PROVIDERS.contains(credentialsProvider.getClass())) {
throw new IllegalArgumentException(
"Unsupported AWS credentials provider type " + credentialsProvider.getClass());
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/options/AwsModuleTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/options/AwsModuleTest.java
index 2f0e038..9414e24 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/options/AwsModuleTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/options/AwsModuleTest.java
@@ -33,6 +33,7 @@
import org.junit.runners.JUnit4;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
@@ -68,6 +69,20 @@
assertEquals(
credentialsProvider.resolveCredentials().secretAccessKey(),
deserializedCredentialsProvider.resolveCredentials().secretAccessKey());
+
+ AwsSessionCredentials sessionCredentials =
+ AwsSessionCredentials.create("key-id", "secret-key", "session-token");
+ credentialsProvider = StaticCredentialsProvider.create(sessionCredentials);
+ serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider);
+ deserializedCredentialsProvider =
+ objectMapper.readValue(serializedCredentialsProvider, AwsCredentialsProvider.class);
+
+ assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass());
+ AwsSessionCredentials deserializedCredentials =
+ (AwsSessionCredentials) deserializedCredentialsProvider.resolveCredentials();
+ assertEquals(sessionCredentials.accessKeyId(), deserializedCredentials.accessKeyId());
+ assertEquals(sessionCredentials.secretAccessKey(), deserializedCredentials.secretAccessKey());
+ assertEquals(sessionCredentials.sessionToken(), deserializedCredentials.sessionToken());
}
@Test
diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle
index 215a66b..731cf33 100644
--- a/sdks/java/io/google-cloud-platform/build.gradle
+++ b/sdks/java/io/google-cloud-platform/build.gradle
@@ -91,6 +91,7 @@
compile library.java.grpc_netty
compile library.java.grpc_netty_shaded
permitUnusedDeclared library.java.grpc_netty_shaded // BEAM-11761
+ compile library.java.grpc_protobuf
compile library.java.grpc_stub
permitUnusedDeclared library.java.grpc_stub // BEAM-11761
compile library.java.grpc_google_cloud_pubsub_v1
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
index afa5ae7..a636b04 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
@@ -83,6 +83,11 @@
import com.google.cloud.hadoop.util.ChainingHttpRequestInitializer;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Int64Value;
+import com.google.rpc.RetryInfo;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.Status.Code;
+import io.grpc.protobuf.ProtoUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
@@ -162,6 +167,9 @@
MonitoringInfoConstants.Labels.SERVICE, "BigQuery",
MonitoringInfoConstants.Labels.METHOD, "BigQueryBatchWrite");
+ private static final Metadata.Key<RetryInfo> KEY_RETRY_INFO =
+ ProtoUtils.keyForProto(RetryInfo.getDefaultInstance());
+
@Override
public JobService getJobService(BigQueryOptions options) {
return new JobServiceImpl(options);
@@ -1304,6 +1312,32 @@
static class StorageClientImpl implements StorageClient {
+ // If client retries ReadRows requests due to RESOURCE_EXHAUSTED error, bump
+ // throttlingMsecs according to delay. Runtime can use this information for
+ // autoscaling decisions.
+ @VisibleForTesting
+ public static class RetryAttemptCounter implements BigQueryReadSettings.RetryAttemptListener {
+ public final Counter throttlingMsecs =
+ Metrics.counter(StorageClientImpl.class, "throttling-msecs");
+
+ @SuppressWarnings("ProtoDurationGetSecondsGetNano")
+ @Override
+ public void onRetryAttempt(Status status, Metadata metadata) {
+ if (status != null
+ && status.getCode() == Code.RESOURCE_EXHAUSTED
+ && metadata != null
+ && metadata.containsKey(KEY_RETRY_INFO)) {
+ RetryInfo retryInfo = metadata.get(KEY_RETRY_INFO);
+ if (retryInfo.hasRetryDelay()) {
+ long delay =
+ retryInfo.getRetryDelay().getSeconds() * 1000
+ + retryInfo.getRetryDelay().getNanos() / 1000000;
+ throttlingMsecs.inc(delay);
+ }
+ }
+ }
+ }
+
private static final HeaderProvider USER_AGENT_HEADER_PROVIDER =
FixedHeaderProvider.create(
"user-agent", "Apache_Beam_Java/" + ReleaseInfo.getReleaseInfo().getVersion());
@@ -1317,7 +1351,8 @@
.setTransportChannelProvider(
BigQueryReadSettings.defaultGrpcTransportProviderBuilder()
.setHeaderProvider(USER_AGENT_HEADER_PROVIDER)
- .build());
+ .build())
+ .setReadRowsRetryAttemptListener(new RetryAttemptCounter());
UnaryCallSettings.Builder<CreateReadSessionRequest, ReadSession> createReadSessionSettings =
settingsBuilder.getStubSettingsBuilder().createReadSessionSettings();
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
index dd5202e..74d6636 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
@@ -1218,6 +1218,10 @@
List<Cursor> cursors = new ArrayList<>(partitionQueryResponse.getPartitionsList());
cursors.sort(CURSOR_REFERENCE_VALUE_COMPARATOR);
final int size = cursors.size();
+ if (size == 0) {
+ emit(c, dbRoot, structuredQuery.toBuilder());
+ return;
+ }
final int lastIdx = size - 1;
for (int i = 0; i < size; i++) {
Cursor curr = cursors.get(i);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/healthcare/HL7v2Message.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/healthcare/HL7v2Message.java
index 1a2a0b0..b9a72f3 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/healthcare/HL7v2Message.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/healthcare/HL7v2Message.java
@@ -80,7 +80,9 @@
out.setCreateTime(this.getCreateTime());
out.setData(this.getData());
out.setSendFacility(this.getSendFacility());
- out.setSchematizedData(new SchematizedData().setData(this.schematizedData));
+ if (this.schematizedData != null) {
+ out.setSchematizedData(new SchematizedData().setData(this.schematizedData));
+ }
out.setLabels(this.labels);
return out;
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
index b0cc681..bf6a288 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
@@ -24,25 +24,36 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
/** Common util functions for converting between PubsubMessage proto and {@link PubsubMessage}. */
-public class PubsubMessages {
+public final class PubsubMessages {
+ private PubsubMessages() {}
+
+ public static com.google.pubsub.v1.PubsubMessage toProto(PubsubMessage input) {
+ Map<String, String> attributes = input.getAttributeMap();
+ com.google.pubsub.v1.PubsubMessage.Builder message =
+ com.google.pubsub.v1.PubsubMessage.newBuilder()
+ .setData(ByteString.copyFrom(input.getPayload()));
+ // TODO(BEAM-8085) this should not be null
+ if (attributes != null) {
+ message.putAllAttributes(attributes);
+ }
+ String messageId = input.getMessageId();
+ if (messageId != null) {
+ message.setMessageId(messageId);
+ }
+ return message.build();
+ }
+
+ public static PubsubMessage fromProto(com.google.pubsub.v1.PubsubMessage input) {
+ return new PubsubMessage(
+ input.getData().toByteArray(), input.getAttributesMap(), input.getMessageId());
+ }
+
// Convert the PubsubMessage to a PubsubMessage proto, then return its serialized representation.
public static class ParsePayloadAsPubsubMessageProto
implements SerializableFunction<PubsubMessage, byte[]> {
@Override
public byte[] apply(PubsubMessage input) {
- Map<String, String> attributes = input.getAttributeMap();
- com.google.pubsub.v1.PubsubMessage.Builder message =
- com.google.pubsub.v1.PubsubMessage.newBuilder()
- .setData(ByteString.copyFrom(input.getPayload()));
- // TODO(BEAM-8085) this should not be null
- if (attributes != null) {
- message.putAllAttributes(attributes);
- }
- String messageId = input.getMessageId();
- if (messageId != null) {
- message.setMessageId(messageId);
- }
- return message.build().toByteArray();
+ return toProto(input).toByteArray();
}
}
@@ -54,8 +65,7 @@
try {
com.google.pubsub.v1.PubsubMessage message =
com.google.pubsub.v1.PubsubMessage.parseFrom(input);
- return new PubsubMessage(
- message.getData().toByteArray(), message.getAttributesMap(), message.getMessageId());
+ return fromProto(message);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Could not decode Pubsub message", e);
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubChecks.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubChecks.java
deleted file mode 100644
index 6dc1516..0000000
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubChecks.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.gcp.pubsublite;
-
-import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.toCpsPublishTransformer;
-
-import com.google.cloud.pubsublite.Message;
-import com.google.cloud.pubsublite.proto.PubSubMessage;
-import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.TypeDescriptor;
-
-/**
- * A class providing a conversion validity check between Cloud Pub/Sub and Pub/Sub Lite message
- * types.
- */
-public final class CloudPubsubChecks {
- private CloudPubsubChecks() {}
-
- /**
- * Ensure that all messages that pass through can be converted to Cloud Pub/Sub messages using the
- * standard transformation methods in the client library.
- *
- * <p>Will fail the pipeline if a message has multiple attributes per key.
- */
- public static PTransform<PCollection<? extends PubSubMessage>, PCollection<PubSubMessage>>
- ensureUsableAsCloudPubsub() {
- return MapElements.into(TypeDescriptor.of(PubSubMessage.class))
- .via(
- message -> {
- Object unused = toCpsPublishTransformer().transform(Message.fromProto(message));
- return message;
- });
- }
-}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubTransforms.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubTransforms.java
new file mode 100644
index 0000000..1140c11
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/CloudPubsubTransforms.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.fromCpsPublishTransformer;
+import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.toCpsPublishTransformer;
+import static com.google.cloud.pubsublite.cloudpubsub.MessageTransforms.toCpsSubscribeTransformer;
+
+import com.google.cloud.pubsublite.Message;
+import com.google.cloud.pubsublite.cloudpubsub.KeyExtractor;
+import com.google.cloud.pubsublite.proto.PubSubMessage;
+import com.google.cloud.pubsublite.proto.SequencedMessage;
+import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage;
+import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessages;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+
+/** A class providing transforms between Cloud Pub/Sub and Pub/Sub Lite message types. */
+public final class CloudPubsubTransforms {
+ private CloudPubsubTransforms() {}
+ /**
+ * Ensure that all messages that pass through can be converted to Cloud Pub/Sub messages using the
+ * standard transformation methods in the client library.
+ *
+ * <p>Will fail the pipeline if a message has multiple attributes per key.
+ */
+ public static PTransform<PCollection<PubSubMessage>, PCollection<PubSubMessage>>
+ ensureUsableAsCloudPubsub() {
+ return new PTransform<PCollection<PubSubMessage>, PCollection<PubSubMessage>>() {
+ @Override
+ public PCollection<PubSubMessage> expand(PCollection<PubSubMessage> input) {
+ return input.apply(
+ MapElements.into(TypeDescriptor.of(PubSubMessage.class))
+ .via(
+ message -> {
+ Object unused =
+ toCpsPublishTransformer().transform(Message.fromProto(message));
+ return message;
+ }));
+ }
+ };
+ }
+
+ /**
+ * Transform messages read from Pub/Sub Lite to their equivalent Cloud Pub/Sub Message that would
+ * have been read from PubsubIO.
+ *
+ * <p>Will fail the pipeline if a message has multiple attributes per map key.
+ */
+ public static PTransform<PCollection<SequencedMessage>, PCollection<PubsubMessage>>
+ toCloudPubsubMessages() {
+ return new PTransform<PCollection<SequencedMessage>, PCollection<PubsubMessage>>() {
+ @Override
+ public PCollection<PubsubMessage> expand(PCollection<SequencedMessage> input) {
+ return input.apply(
+ MapElements.into(TypeDescriptor.of(PubsubMessage.class))
+ .via(
+ message ->
+ PubsubMessages.fromProto(
+ toCpsSubscribeTransformer()
+ .transform(
+ com.google.cloud.pubsublite.SequencedMessage.fromProto(
+ message)))));
+ }
+ };
+ }
+
+ /**
+ * Transform messages publishable using PubsubIO to their equivalent Pub/Sub Lite publishable
+ * message.
+ */
+ public static PTransform<PCollection<PubsubMessage>, PCollection<PubSubMessage>>
+ fromCloudPubsubMessages() {
+ return new PTransform<PCollection<PubsubMessage>, PCollection<PubSubMessage>>() {
+ @Override
+ public PCollection<PubSubMessage> expand(PCollection<PubsubMessage> input) {
+ return input.apply(
+ MapElements.into(TypeDescriptor.of(PubSubMessage.class))
+ .via(
+ message ->
+ fromCpsPublishTransformer(KeyExtractor.DEFAULT)
+ .transform(PubsubMessages.toProto(message))
+ .toProto()));
+ }
+ };
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactory.java
new file mode 100644
index 0000000..de0cf43
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactory.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import java.io.Serializable;
+
+/**
+ * A ManagedBacklogReaderFactory produces TopicBacklogReaders and tears down any produced readers
+ * when it is itself closed.
+ *
+ * <p>close() should never be called on produced readers.
+ */
+public interface ManagedBacklogReaderFactory extends AutoCloseable, Serializable {
+ TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition);
+
+ @Override
+ void close();
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactoryImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactoryImpl.java
new file mode 100644
index 0000000..9a337bf
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/ManagedBacklogReaderFactoryImpl.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import com.google.api.gax.rpc.ApiException;
+import com.google.cloud.pubsublite.Offset;
+import com.google.cloud.pubsublite.proto.ComputeMessageStatsResponse;
+import java.util.HashMap;
+import java.util.Map;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+public class ManagedBacklogReaderFactoryImpl implements ManagedBacklogReaderFactory {
+ private final SerializableFunction<SubscriptionPartition, TopicBacklogReader> newReader;
+
+ @GuardedBy("this")
+ private final Map<SubscriptionPartition, TopicBacklogReader> readers = new HashMap<>();
+
+ ManagedBacklogReaderFactoryImpl(
+ SerializableFunction<SubscriptionPartition, TopicBacklogReader> newReader) {
+ this.newReader = newReader;
+ }
+
+ private static final class NonCloseableTopicBacklogReader implements TopicBacklogReader {
+ private final TopicBacklogReader underlying;
+
+ NonCloseableTopicBacklogReader(TopicBacklogReader underlying) {
+ this.underlying = underlying;
+ }
+
+ @Override
+ public ComputeMessageStatsResponse computeMessageStats(Offset offset) throws ApiException {
+ return underlying.computeMessageStats(offset);
+ }
+
+ @Override
+ public void close() {
+ throw new IllegalArgumentException(
+ "Cannot call close() on a reader returned from ManagedBacklogReaderFactory.");
+ }
+ }
+
+ @Override
+ public synchronized TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition) {
+ return new NonCloseableTopicBacklogReader(
+ readers.computeIfAbsent(subscriptionPartition, newReader::apply));
+ }
+
+ @Override
+ public synchronized void close() {
+ readers.values().forEach(TopicBacklogReader::close);
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRange.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRange.java
new file mode 100644
index 0000000..b39d87e
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRange.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.coders.DefaultCoder;
+import org.apache.beam.sdk.io.range.OffsetRange;
+
+@AutoValue
+@DefaultCoder(OffsetByteRangeCoder.class)
+abstract class OffsetByteRange {
+ abstract OffsetRange getRange();
+
+ abstract long getByteCount();
+
+ static OffsetByteRange of(OffsetRange range, long byteCount) {
+ return new AutoValue_OffsetByteRange(range, byteCount);
+ }
+
+ static OffsetByteRange of(OffsetRange range) {
+ return of(range, 0);
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeCoder.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeCoder.java
new file mode 100644
index 0000000..076cda1
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeCoder.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderProvider;
+import org.apache.beam.sdk.coders.CoderProviders;
+import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TypeDescriptor;
+
+public class OffsetByteRangeCoder extends AtomicCoder<OffsetByteRange> {
+ private static final Coder<OffsetByteRange> CODER =
+ DelegateCoder.of(
+ KvCoder.of(OffsetRange.Coder.of(), VarLongCoder.of()),
+ OffsetByteRangeCoder::toKv,
+ OffsetByteRangeCoder::fromKv);
+
+ private static KV<OffsetRange, Long> toKv(OffsetByteRange value) {
+ return KV.of(value.getRange(), value.getByteCount());
+ }
+
+ private static OffsetByteRange fromKv(KV<OffsetRange, Long> kv) {
+ return OffsetByteRange.of(kv.getKey(), kv.getValue());
+ }
+
+ @Override
+ public void encode(OffsetByteRange value, OutputStream outStream) throws IOException {
+ CODER.encode(value, outStream);
+ }
+
+ @Override
+ public OffsetByteRange decode(InputStream inStream) throws IOException {
+ return CODER.decode(inStream);
+ }
+
+ public static CoderProvider getCoderProvider() {
+ return CoderProviders.forCoder(
+ TypeDescriptor.of(OffsetByteRange.class), new OffsetByteRangeCoder());
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java
index 608af8f..da9aaaa 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTracker.java
@@ -26,8 +26,6 @@
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.beam.sdk.io.range.OffsetRange;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Stopwatch;
import org.joda.time.Duration;
@@ -44,24 +42,27 @@
* received. IMPORTANT: minTrackingTime must be strictly smaller than the SDF read timeout when it
* would return ProcessContinuation.resume().
*/
-class OffsetByteRangeTracker extends RestrictionTracker<OffsetRange, OffsetByteProgress>
- implements HasProgress {
- private final TopicBacklogReader backlogReader;
+class OffsetByteRangeTracker extends TrackerWithProgress {
+ private final TopicBacklogReader unownedBacklogReader;
private final Duration minTrackingTime;
private final long minBytesReceived;
private final Stopwatch stopwatch;
- private OffsetRange range;
+ private OffsetByteRange range;
private @Nullable Long lastClaimed;
- private long byteCount = 0;
public OffsetByteRangeTracker(
- OffsetRange range,
- TopicBacklogReader backlogReader,
+ OffsetByteRange range,
+ TopicBacklogReader unownedBacklogReader,
Stopwatch stopwatch,
Duration minTrackingTime,
long minBytesReceived) {
- checkArgument(range.getTo() == Long.MAX_VALUE);
- this.backlogReader = backlogReader;
+ checkArgument(
+ range.getRange().getTo() == Long.MAX_VALUE,
+ "May only construct OffsetByteRangeTracker with an unbounded range with no progress.");
+ checkArgument(
+ range.getByteCount() == 0L,
+ "May only construct OffsetByteRangeTracker with an unbounded range with no progress.");
+ this.unownedBacklogReader = unownedBacklogReader;
this.minTrackingTime = minTrackingTime;
this.minBytesReceived = minBytesReceived;
this.stopwatch = stopwatch.reset().start();
@@ -69,11 +70,6 @@
}
@Override
- public void finalize() {
- this.backlogReader.close();
- }
-
- @Override
public IsBounded isBounded() {
return IsBounded.UNBOUNDED;
}
@@ -87,32 +83,32 @@
position.lastOffset().value(),
lastClaimed);
checkArgument(
- toClaim >= range.getFrom(),
+ toClaim >= range.getRange().getFrom(),
"Trying to claim offset %s before start of the range %s",
toClaim,
range);
// split() has already been called, truncating this range. No more offsets may be claimed.
- if (range.getTo() != Long.MAX_VALUE) {
- boolean isRangeEmpty = range.getTo() == range.getFrom();
- boolean isValidClosedRange = nextOffset() == range.getTo();
+ if (range.getRange().getTo() != Long.MAX_VALUE) {
+ boolean isRangeEmpty = range.getRange().getTo() == range.getRange().getFrom();
+ boolean isValidClosedRange = nextOffset() == range.getRange().getTo();
checkState(
isRangeEmpty || isValidClosedRange,
"Violated class precondition: offset range improperly split. Please report a beam bug.");
return false;
}
lastClaimed = toClaim;
- byteCount += position.batchBytes();
+ range = OffsetByteRange.of(range.getRange(), range.getByteCount() + position.batchBytes());
return true;
}
@Override
- public OffsetRange currentRestriction() {
+ public OffsetByteRange currentRestriction() {
return range;
}
private long nextOffset() {
checkState(lastClaimed == null || lastClaimed < Long.MAX_VALUE);
- return lastClaimed == null ? currentRestriction().getFrom() : lastClaimed + 1;
+ return lastClaimed == null ? currentRestriction().getRange().getFrom() : lastClaimed + 1;
}
/**
@@ -124,29 +120,33 @@
if (duration.isLongerThan(minTrackingTime)) {
return true;
}
- if (byteCount >= minBytesReceived) {
+ if (currentRestriction().getByteCount() >= minBytesReceived) {
return true;
}
return false;
}
@Override
- public @Nullable SplitResult<OffsetRange> trySplit(double fractionOfRemainder) {
+ public @Nullable SplitResult<OffsetByteRange> trySplit(double fractionOfRemainder) {
// Cannot split a bounded range. This should already be completely claimed.
- if (range.getTo() != Long.MAX_VALUE) {
+ if (range.getRange().getTo() != Long.MAX_VALUE) {
return null;
}
if (!receivedEnough()) {
return null;
}
- range = new OffsetRange(currentRestriction().getFrom(), nextOffset());
- return SplitResult.of(this.range, new OffsetRange(nextOffset(), Long.MAX_VALUE));
+ range =
+ OffsetByteRange.of(
+ new OffsetRange(currentRestriction().getRange().getFrom(), nextOffset()),
+ range.getByteCount());
+ return SplitResult.of(
+ this.range, OffsetByteRange.of(new OffsetRange(nextOffset(), Long.MAX_VALUE), 0));
}
@Override
@SuppressWarnings("unboxing.of.nullable")
public void checkDone() throws IllegalStateException {
- if (range.getFrom() == range.getTo()) {
+ if (range.getRange().getFrom() == range.getRange().getTo()) {
return;
}
checkState(
@@ -155,18 +155,18 @@
range);
long lastClaimedNotNull = checkNotNull(lastClaimed);
checkState(
- lastClaimedNotNull >= range.getTo() - 1,
+ lastClaimedNotNull >= range.getRange().getTo() - 1,
"Last attempted offset was %s in range %s, claiming work in [%s, %s) was not attempted",
lastClaimedNotNull,
range,
lastClaimedNotNull + 1,
- range.getTo());
+ range.getRange().getTo());
}
@Override
public Progress getProgress() {
ComputeMessageStatsResponse stats =
- this.backlogReader.computeMessageStats(Offset.of(nextOffset()));
- return Progress.from(byteCount, stats.getMessageBytes());
+ this.unownedBacklogReader.computeMessageStats(Offset.of(nextOffset()));
+ return Progress.from(range.getByteCount(), stats.getMessageBytes());
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java
index 623e20c..d7526d8 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerServerPublisherCache.java
@@ -27,4 +27,8 @@
private PerServerPublisherCache() {}
static final PublisherCache PUBLISHER_CACHE = new PublisherCache();
+
+ static {
+ Runtime.getRuntime().addShutdownHook(new Thread(PUBLISHER_CACHE::close));
+ }
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java
index a9f7a43..fdf7920 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdf.java
@@ -17,13 +17,12 @@
*/
package org.apache.beam.sdk.io.gcp.pubsublite;
-import static com.google.cloud.pubsublite.internal.ExtractStatus.toCanonical;
-import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static com.google.cloud.pubsublite.internal.wire.ApiServiceUtils.blockingShutdown;
import com.google.cloud.pubsublite.Offset;
+import com.google.cloud.pubsublite.internal.ExtractStatus;
import com.google.cloud.pubsublite.internal.wire.Committer;
import com.google.cloud.pubsublite.proto.SequencedMessage;
-import java.util.concurrent.ExecutionException;
import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableBiFunction;
@@ -35,31 +34,35 @@
class PerSubscriptionPartitionSdf extends DoFn<SubscriptionPartition, SequencedMessage> {
private final Duration maxSleepTime;
+ private final ManagedBacklogReaderFactory backlogReaderFactory;
private final SubscriptionPartitionProcessorFactory processorFactory;
private final SerializableFunction<SubscriptionPartition, InitialOffsetReader>
offsetReaderFactory;
- private final SerializableBiFunction<
- SubscriptionPartition, OffsetRange, RestrictionTracker<OffsetRange, OffsetByteProgress>>
+ private final SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress>
trackerFactory;
private final SerializableFunction<SubscriptionPartition, Committer> committerFactory;
PerSubscriptionPartitionSdf(
Duration maxSleepTime,
+ ManagedBacklogReaderFactory backlogReaderFactory,
SerializableFunction<SubscriptionPartition, InitialOffsetReader> offsetReaderFactory,
- SerializableBiFunction<
- SubscriptionPartition,
- OffsetRange,
- RestrictionTracker<OffsetRange, OffsetByteProgress>>
+ SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress>
trackerFactory,
SubscriptionPartitionProcessorFactory processorFactory,
SerializableFunction<SubscriptionPartition, Committer> committerFactory) {
this.maxSleepTime = maxSleepTime;
+ this.backlogReaderFactory = backlogReaderFactory;
this.processorFactory = processorFactory;
this.offsetReaderFactory = offsetReaderFactory;
this.trackerFactory = trackerFactory;
this.committerFactory = committerFactory;
}
+ @Teardown
+ public void teardown() {
+ backlogReaderFactory.close();
+ }
+
@GetInitialWatermarkEstimatorState
public Instant getInitialWatermarkState() {
return Instant.EPOCH;
@@ -72,7 +75,7 @@
@ProcessElement
public ProcessContinuation processElement(
- RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+ RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
@Element SubscriptionPartition subscriptionPartition,
OutputReceiver<SequencedMessage> receiver)
throws Exception {
@@ -83,38 +86,44 @@
processor
.lastClaimed()
.ifPresent(
- lastClaimedOffset ->
- /* TODO(boyuanzz): When default dataflow can use finalizers, undo this.
- finalizer.afterBundleCommit(
- Instant.ofEpochMilli(Long.MAX_VALUE),
- () -> */ {
+ lastClaimedOffset -> {
Committer committer = committerFactory.apply(subscriptionPartition);
committer.startAsync().awaitRunning();
// Commit the next-to-deliver offset.
try {
committer.commitOffset(Offset.of(lastClaimedOffset.value() + 1)).get();
- } catch (ExecutionException e) {
- throw toCanonical(checkArgumentNotNull(e.getCause())).underlying;
} catch (Exception e) {
- throw toCanonical(e).underlying;
+ throw ExtractStatus.toCanonical(e).underlying;
}
- committer.stopAsync().awaitTerminated();
+ blockingShutdown(committer);
});
return result;
}
}
@GetInitialRestriction
- public OffsetRange getInitialRestriction(@Element SubscriptionPartition subscriptionPartition) {
+ public OffsetByteRange getInitialRestriction(
+ @Element SubscriptionPartition subscriptionPartition) {
try (InitialOffsetReader reader = offsetReaderFactory.apply(subscriptionPartition)) {
Offset offset = reader.read();
- return new OffsetRange(offset.value(), Long.MAX_VALUE /* open interval */);
+ return OffsetByteRange.of(
+ new OffsetRange(offset.value(), Long.MAX_VALUE /* open interval */));
}
}
@NewTracker
- public RestrictionTracker<OffsetRange, OffsetByteProgress> newTracker(
- @Element SubscriptionPartition subscriptionPartition, @Restriction OffsetRange range) {
- return trackerFactory.apply(subscriptionPartition, range);
+ public TrackerWithProgress newTracker(
+ @Element SubscriptionPartition subscriptionPartition, @Restriction OffsetByteRange range) {
+ return trackerFactory.apply(backlogReaderFactory.newReader(subscriptionPartition), range);
+ }
+
+ @GetSize
+ public double getSize(
+ @Element SubscriptionPartition subscriptionPartition,
+ @Restriction OffsetByteRange restriction) {
+ if (restriction.getRange().getTo() != Long.MAX_VALUE) {
+ return restriction.getByteCount();
+ }
+ return newTracker(subscriptionPartition, restriction).getProgress().getWorkRemaining();
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java
index f8dc24b..3dbdec6 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PublisherCache.java
@@ -23,52 +23,50 @@
import com.google.api.core.ApiService.State;
import com.google.api.gax.rpc.ApiException;
import com.google.cloud.pubsublite.MessageMetadata;
-import com.google.cloud.pubsublite.internal.CloseableMonitor;
import com.google.cloud.pubsublite.internal.Publisher;
+import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import java.util.HashMap;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
/** A map of working publishers by PublisherOptions. */
-class PublisherCache {
- private final CloseableMonitor monitor = new CloseableMonitor();
-
- private final Executor listenerExecutor = Executors.newSingleThreadExecutor();
-
- @GuardedBy("monitor.monitor")
+class PublisherCache implements AutoCloseable {
+ @GuardedBy("this")
private final HashMap<PublisherOptions, Publisher<MessageMetadata>> livePublishers =
new HashMap<>();
- Publisher<MessageMetadata> get(PublisherOptions options) throws ApiException {
+ private synchronized void evict(PublisherOptions options) {
+ livePublishers.remove(options);
+ }
+
+ synchronized Publisher<MessageMetadata> get(PublisherOptions options) throws ApiException {
checkArgument(options.usesCache());
- try (CloseableMonitor.Hold h = monitor.enter()) {
- Publisher<MessageMetadata> publisher = livePublishers.get(options);
- if (publisher != null) {
- return publisher;
- }
- publisher = Publishers.newPublisher(options);
- livePublishers.put(options, publisher);
- publisher.addListener(
- new Listener() {
- @Override
- public void failed(State s, Throwable t) {
- try (CloseableMonitor.Hold h = monitor.enter()) {
- livePublishers.remove(options);
- }
- }
- },
- listenerExecutor);
- publisher.startAsync().awaitRunning();
+ Publisher<MessageMetadata> publisher = livePublishers.get(options);
+ if (publisher != null) {
return publisher;
}
+ publisher = Publishers.newPublisher(options);
+ livePublishers.put(options, publisher);
+ publisher.addListener(
+ new Listener() {
+ @Override
+ public void failed(State s, Throwable t) {
+ evict(options);
+ }
+ },
+ SystemExecutors.getFuturesExecutor());
+ publisher.startAsync().awaitRunning();
+ return publisher;
}
@VisibleForTesting
- void set(PublisherOptions options, Publisher<MessageMetadata> toCache) {
- try (CloseableMonitor.Hold h = monitor.enter()) {
- livePublishers.put(options, toCache);
- }
+ synchronized void set(PublisherOptions options, Publisher<MessageMetadata> toCache) {
+ livePublishers.put(options, toCache);
+ }
+
+ @Override
+ public synchronized void close() {
+ livePublishers.forEach(((options, publisher) -> publisher.stopAsync()));
+ livePublishers.clear();
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java
index 34012f7..67ea6cf 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/Publishers.java
@@ -17,17 +17,27 @@
*/
package org.apache.beam.sdk.io.gcp.pubsublite;
+import static com.google.cloud.pubsublite.internal.ExtractStatus.toCanonical;
import static com.google.cloud.pubsublite.internal.UncheckedApiPreconditions.checkArgument;
+import static com.google.cloud.pubsublite.internal.wire.ServiceClients.addDefaultMetadata;
+import static com.google.cloud.pubsublite.internal.wire.ServiceClients.addDefaultSettings;
import com.google.api.gax.rpc.ApiException;
import com.google.cloud.pubsublite.AdminClient;
import com.google.cloud.pubsublite.AdminClientSettings;
import com.google.cloud.pubsublite.MessageMetadata;
-import com.google.cloud.pubsublite.TopicPath;
+import com.google.cloud.pubsublite.Partition;
+import com.google.cloud.pubsublite.cloudpubsub.PublisherSettings;
import com.google.cloud.pubsublite.internal.Publisher;
import com.google.cloud.pubsublite.internal.wire.PartitionCountWatchingPublisherSettings;
+import com.google.cloud.pubsublite.internal.wire.PubsubContext;
import com.google.cloud.pubsublite.internal.wire.PubsubContext.Framework;
+import com.google.cloud.pubsublite.internal.wire.RoutingMetadata;
import com.google.cloud.pubsublite.internal.wire.SinglePartitionPublisherBuilder;
+import com.google.cloud.pubsublite.v1.AdminServiceClient;
+import com.google.cloud.pubsublite.v1.AdminServiceSettings;
+import com.google.cloud.pubsublite.v1.PublisherServiceClient;
+import com.google.cloud.pubsublite.v1.PublisherServiceSettings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.reflect.TypeToken;
class Publishers {
@@ -35,6 +45,38 @@
private Publishers() {}
+ private static AdminClient newAdminClient(PublisherOptions options) throws ApiException {
+ try {
+ return AdminClient.create(
+ AdminClientSettings.newBuilder()
+ .setServiceClient(
+ AdminServiceClient.create(
+ addDefaultSettings(
+ options.topicPath().location().extractRegion(),
+ AdminServiceSettings.newBuilder())))
+ .setRegion(options.topicPath().location().extractRegion())
+ .build());
+ } catch (Throwable t) {
+ throw toCanonical(t).underlying;
+ }
+ }
+
+ private static PublisherServiceClient newServiceClient(
+ PublisherOptions options, Partition partition) {
+ PublisherServiceSettings.Builder settingsBuilder = PublisherServiceSettings.newBuilder();
+ settingsBuilder =
+ addDefaultMetadata(
+ PubsubContext.of(FRAMEWORK),
+ RoutingMetadata.of(options.topicPath(), partition),
+ settingsBuilder);
+ try {
+ return PublisherServiceClient.create(
+ addDefaultSettings(options.topicPath().location().extractRegion(), settingsBuilder));
+ } catch (Throwable t) {
+ throw toCanonical(t).underlying;
+ }
+ }
+
@SuppressWarnings("unchecked")
static Publisher<MessageMetadata> newPublisher(PublisherOptions options) throws ApiException {
SerializableSupplier<Object> supplier = options.publisherSupplier();
@@ -44,20 +86,18 @@
checkArgument(token.isSupertypeOf(supplied.getClass()));
return (Publisher<MessageMetadata>) supplied;
}
-
- TopicPath topic = options.topicPath();
- PartitionCountWatchingPublisherSettings.Builder publisherSettings =
- PartitionCountWatchingPublisherSettings.newBuilder()
- .setTopic(topic)
- .setPublisherFactory(
- partition ->
- SinglePartitionPublisherBuilder.newBuilder()
- .setTopic(topic)
- .setPartition(partition)
- .build())
- .setAdminClient(
- AdminClient.create(
- AdminClientSettings.newBuilder().setRegion(topic.location().region()).build()));
- return publisherSettings.build().instantiate();
+ return PartitionCountWatchingPublisherSettings.newBuilder()
+ .setTopic(options.topicPath())
+ .setPublisherFactory(
+ partition ->
+ SinglePartitionPublisherBuilder.newBuilder()
+ .setTopic(options.topicPath())
+ .setPartition(partition)
+ .setServiceClient(newServiceClient(options, partition))
+ .setBatchingSettings(PublisherSettings.DEFAULT_BATCHING_SETTINGS)
+ .build())
+ .setAdminClient(newAdminClient(options))
+ .build()
+ .instantiate();
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java
index ca1f2be..b93ac61 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteIO.java
@@ -107,7 +107,7 @@
* }</pre>
*/
public static PTransform<PCollection<PubSubMessage>, PDone> write(PublisherOptions options) {
- return new PTransform<PCollection<PubSubMessage>, PDone>("PubsubLiteIO") {
+ return new PTransform<PCollection<PubSubMessage>, PDone>() {
@Override
public PDone expand(PCollection<PubSubMessage> input) {
PubsubLiteSink sink = new PubsubLiteSink(options);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java
index d3acdfa..d0e3afa 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteSink.java
@@ -28,16 +28,14 @@
import com.google.cloud.pubsublite.internal.CheckedApiException;
import com.google.cloud.pubsublite.internal.ExtractStatus;
import com.google.cloud.pubsublite.internal.Publisher;
+import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
import com.google.cloud.pubsublite.proto.PubSubMessage;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import java.util.ArrayDeque;
import java.util.Deque;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
import java.util.function.Consumer;
import org.apache.beam.sdk.io.gcp.pubsublite.PublisherOrError.Kind;
import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
/** A sink which publishes messages to Pub/Sub Lite. */
@SuppressWarnings({
@@ -56,8 +54,6 @@
@GuardedBy("this")
private transient Deque<CheckedApiException> errorsSinceLastFinish;
- private static final Executor executor = Executors.newCachedThreadPool();
-
PubsubLiteSink(PublisherOptions options) {
this.options = options;
}
@@ -89,7 +85,7 @@
onFailure.accept(t);
}
},
- MoreExecutors.directExecutor());
+ SystemExecutors.getFuturesExecutor());
if (!options.usesCache()) {
publisher.startAsync();
}
@@ -130,7 +126,7 @@
onFailure.accept(t);
}
},
- executor);
+ SystemExecutors.getFuturesExecutor());
}
// Intentionally don't flush on bundle finish to allow multi-sink client reuse.
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java
index 9875880..b6a9f5d 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscribeTransform.java
@@ -23,6 +23,7 @@
import com.google.api.gax.rpc.ApiException;
import com.google.cloud.pubsublite.AdminClient;
import com.google.cloud.pubsublite.AdminClientSettings;
+import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.Partition;
import com.google.cloud.pubsublite.TopicPath;
import com.google.cloud.pubsublite.internal.wire.Committer;
@@ -31,8 +32,6 @@
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
-import org.apache.beam.sdk.io.range.OffsetRange;
-import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
@@ -54,10 +53,11 @@
checkArgument(subscriptionPartition.subscription().equals(options.subscriptionPath()));
}
- private Subscriber newSubscriber(Partition partition, Consumer<List<SequencedMessage>> consumer) {
+ private Subscriber newSubscriber(
+ Partition partition, Offset initialOffset, Consumer<List<SequencedMessage>> consumer) {
try {
return options
- .getSubscriberFactory(partition)
+ .getSubscriberFactory(partition, initialOffset)
.newSubscriber(
messages ->
consumer.accept(
@@ -71,23 +71,31 @@
private SubscriptionPartitionProcessor newPartitionProcessor(
SubscriptionPartition subscriptionPartition,
- RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+ RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
OutputReceiver<SequencedMessage> receiver)
throws ApiException {
checkSubscription(subscriptionPartition);
return new SubscriptionPartitionProcessorImpl(
tracker,
receiver,
- consumer -> newSubscriber(subscriptionPartition.partition(), consumer),
+ consumer ->
+ newSubscriber(
+ subscriptionPartition.partition(),
+ Offset.of(tracker.currentRestriction().getRange().getFrom()),
+ consumer),
options.flowControlSettings());
}
- private RestrictionTracker<OffsetRange, OffsetByteProgress> newRestrictionTracker(
- SubscriptionPartition subscriptionPartition, OffsetRange initial) {
+ private TopicBacklogReader newBacklogReader(SubscriptionPartition subscriptionPartition) {
checkSubscription(subscriptionPartition);
+ return options.getBacklogReader(subscriptionPartition.partition());
+ }
+
+ private TrackerWithProgress newRestrictionTracker(
+ TopicBacklogReader backlogReader, OffsetByteRange initial) {
return new OffsetByteRangeTracker(
initial,
- options.getBacklogReader(subscriptionPartition.partition()),
+ backlogReader,
Stopwatch.createUnstarted(),
options.minBundleTimeout(),
LongMath.saturatedMultiply(options.flowControlSettings().bytesOutstanding(), 10));
@@ -107,7 +115,7 @@
try (AdminClient admin =
AdminClient.create(
AdminClientSettings.newBuilder()
- .setRegion(options.subscriptionPath().location().region())
+ .setRegion(options.subscriptionPath().location().extractRegion())
.build())) {
return TopicPath.parse(admin.getSubscription(options.subscriptionPath()).get().getTopic());
} catch (Throwable t) {
@@ -118,25 +126,15 @@
@Override
public PCollection<SequencedMessage> expand(PBegin input) {
PCollection<SubscriptionPartition> subscriptionPartitions;
- if (options.partitions().isEmpty()) {
- subscriptionPartitions =
- input.apply(new SubscriptionPartitionLoader(getTopicPath(), options.subscriptionPath()));
- } else {
- subscriptionPartitions =
- input.apply(
- Create.of(
- options.partitions().stream()
- .map(
- partition ->
- SubscriptionPartition.of(options.subscriptionPath(), partition))
- .collect(Collectors.toList())));
- }
+ subscriptionPartitions =
+ input.apply(new SubscriptionPartitionLoader(getTopicPath(), options.subscriptionPath()));
return subscriptionPartitions.apply(
ParDo.of(
new PerSubscriptionPartitionSdf(
// Ensure we read for at least 5 seconds more than the bundle timeout.
options.minBundleTimeout().plus(Duration.standardSeconds(5)),
+ new ManagedBacklogReaderFactoryImpl(this::newBacklogReader),
this::newInitialOffsetReader,
this::newRestrictionTracker,
this::newPartitionProcessor,
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java
index 0d3afe2..a9625be 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriberOptions.java
@@ -23,6 +23,7 @@
import com.google.api.gax.rpc.ApiException;
import com.google.auto.value.AutoValue;
+import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.Partition;
import com.google.cloud.pubsublite.SubscriptionPath;
import com.google.cloud.pubsublite.cloudpubsub.FlowControlSettings;
@@ -35,13 +36,13 @@
import com.google.cloud.pubsublite.internal.wire.RoutingMetadata;
import com.google.cloud.pubsublite.internal.wire.SubscriberBuilder;
import com.google.cloud.pubsublite.internal.wire.SubscriberFactory;
+import com.google.cloud.pubsublite.proto.Cursor;
+import com.google.cloud.pubsublite.proto.SeekRequest;
import com.google.cloud.pubsublite.v1.CursorServiceClient;
import com.google.cloud.pubsublite.v1.CursorServiceSettings;
import com.google.cloud.pubsublite.v1.SubscriberServiceClient;
import com.google.cloud.pubsublite.v1.SubscriberServiceSettings;
import java.io.Serializable;
-import java.util.Set;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
@@ -69,11 +70,6 @@
public abstract FlowControlSettings flowControlSettings();
/**
- * A set of partitions. If empty, continuously poll the set of partitions using an admin client.
- */
- public abstract Set<Partition> partitions();
-
- /**
* The minimum wall time to pass before allowing bundle closure.
*
* <p>Setting this to too small of a value will result in increased compute costs and lower
@@ -108,7 +104,6 @@
public static Builder newBuilder() {
Builder builder = new AutoValue_SubscriberOptions.Builder();
return builder
- .setPartitions(ImmutableSet.of())
.setFlowControlSettings(DEFAULT_FLOW_CONTROL)
.setMinBundleTimeout(MIN_BUNDLE_TIMEOUT);
}
@@ -119,20 +114,19 @@
throws ApiException {
try {
SubscriberServiceSettings.Builder settingsBuilder = SubscriberServiceSettings.newBuilder();
-
settingsBuilder =
addDefaultMetadata(
PubsubContext.of(FRAMEWORK),
RoutingMetadata.of(subscriptionPath(), partition),
settingsBuilder);
return SubscriberServiceClient.create(
- addDefaultSettings(subscriptionPath().location().region(), settingsBuilder));
+ addDefaultSettings(subscriptionPath().location().extractRegion(), settingsBuilder));
} catch (Throwable t) {
throw toCanonical(t).underlying;
}
}
- SubscriberFactory getSubscriberFactory(Partition partition) {
+ SubscriberFactory getSubscriberFactory(Partition partition, Offset initialOffset) {
SubscriberFactory factory = subscriberFactory();
if (factory != null) {
return factory;
@@ -143,6 +137,10 @@
.setSubscriptionPath(subscriptionPath())
.setPartition(partition)
.setServiceClient(newSubscriberServiceClient(partition))
+ .setInitialLocation(
+ SeekRequest.newBuilder()
+ .setCursor(Cursor.newBuilder().setOffset(initialOffset.value()))
+ .build())
.build();
}
@@ -150,7 +148,7 @@
try {
return CursorServiceClient.create(
addDefaultSettings(
- subscriptionPath().location().region(), CursorServiceSettings.newBuilder()));
+ subscriptionPath().location().extractRegion(), CursorServiceSettings.newBuilder()));
} catch (Throwable t) {
throw toCanonical(t).underlying;
}
@@ -189,7 +187,7 @@
return new InitialOffsetReaderImpl(
CursorClient.create(
CursorClientSettings.newBuilder()
- .setRegion(subscriptionPath().location().region())
+ .setRegion(subscriptionPath().location().extractRegion())
.build()),
subscriptionPath(),
partition);
@@ -201,8 +199,6 @@
public abstract Builder setSubscriptionPath(SubscriptionPath path);
// Optional parameters
- public abstract Builder setPartitions(Set<Partition> partitions);
-
public abstract Builder setFlowControlSettings(FlowControlSettings flowControlSettings);
public abstract Builder setMinBundleTimeout(Duration minBundleTimeout);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java
index 866e922..e411d80 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionLoader.java
@@ -92,7 +92,9 @@
})
.withPollInterval(pollDuration)
.withTerminationPerInput(
- terminate ? Watch.Growth.afterIterations(10) : Watch.Growth.never()));
+ terminate
+ ? Watch.Growth.afterTotalOf(pollDuration.multipliedBy(10))
+ : Watch.Growth.never()));
return partitions.apply(
MapElements.into(TypeDescriptor.of(SubscriptionPartition.class))
.via(kv -> SubscriptionPartition.of(subscription, kv.getValue())));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java
index 6bf3623..530c180 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorFactory.java
@@ -19,7 +19,6 @@
import com.google.cloud.pubsublite.proto.SequencedMessage;
import java.io.Serializable;
-import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
@@ -28,6 +27,6 @@
SubscriptionPartitionProcessor newProcessor(
SubscriptionPartition subscriptionPartition,
- RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+ RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
OutputReceiver<SequencedMessage> receiver);
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java
index 8d2a137..a086d18 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImpl.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.gcp.pubsublite;
+import static com.google.cloud.pubsublite.internal.wire.ApiServiceUtils.blockingShutdown;
+
import com.google.api.core.ApiService.Listener;
import com.google.api.core.ApiService.State;
import com.google.cloud.pubsublite.Offset;
@@ -24,9 +26,8 @@
import com.google.cloud.pubsublite.internal.CheckedApiException;
import com.google.cloud.pubsublite.internal.ExtractStatus;
import com.google.cloud.pubsublite.internal.wire.Subscriber;
-import com.google.cloud.pubsublite.proto.Cursor;
+import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
import com.google.cloud.pubsublite.proto.FlowControlRequest;
-import com.google.cloud.pubsublite.proto.SeekRequest;
import com.google.cloud.pubsublite.proto.SequencedMessage;
import com.google.protobuf.util.Timestamps;
import java.util.List;
@@ -36,19 +37,17 @@
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Function;
-import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.SettableFuture;
import org.joda.time.Duration;
import org.joda.time.Instant;
class SubscriptionPartitionProcessorImpl extends Listener
implements SubscriptionPartitionProcessor {
- private final RestrictionTracker<OffsetRange, OffsetByteProgress> tracker;
+ private final RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker;
private final OutputReceiver<SequencedMessage> receiver;
private final Subscriber subscriber;
private final SettableFuture<Void> completionFuture = SettableFuture.create();
@@ -57,7 +56,7 @@
@SuppressWarnings("methodref.receiver.bound.invalid")
SubscriptionPartitionProcessorImpl(
- RestrictionTracker<OffsetRange, OffsetByteProgress> tracker,
+ RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker,
OutputReceiver<SequencedMessage> receiver,
Function<Consumer<List<SequencedMessage>>, Subscriber> subscriberFactory,
FlowControlSettings flowControlSettings) {
@@ -70,23 +69,15 @@
@Override
@SuppressWarnings("argument.type.incompatible")
public void start() throws CheckedApiException {
- this.subscriber.addListener(this, MoreExecutors.directExecutor());
+ this.subscriber.addListener(this, SystemExecutors.getFuturesExecutor());
this.subscriber.startAsync();
this.subscriber.awaitRunning();
try {
- this.subscriber
- .seek(
- SeekRequest.newBuilder()
- .setCursor(Cursor.newBuilder().setOffset(tracker.currentRestriction().getFrom()))
- .build())
- .get();
this.subscriber.allowFlow(
FlowControlRequest.newBuilder()
.setAllowedBytes(flowControlSettings.bytesOutstanding())
.setAllowedMessages(flowControlSettings.messagesOutstanding())
.build());
- } catch (ExecutionException e) {
- throw ExtractStatus.toCanonical(e.getCause());
} catch (Throwable t) {
throw ExtractStatus.toCanonical(t);
}
@@ -125,7 +116,7 @@
@Override
public void close() {
- subscriber.stopAsync().awaitTerminated();
+ blockingShutdown(subscriber);
}
@Override
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java
index 8c1dd94..79db0f1 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TopicBacklogReaderSettings.java
@@ -62,7 +62,7 @@
try (AdminClient adminClient =
AdminClient.create(
AdminClientSettings.newBuilder()
- .setRegion(subscriptionPath.location().region())
+ .setRegion(subscriptionPath.location().extractRegion())
.build())) {
return setTopicPath(
TopicPath.parse(adminClient.getSubscription(subscriptionPath).get().getTopic()));
@@ -81,7 +81,9 @@
TopicBacklogReader instantiate() throws ApiException {
TopicStatsClientSettings settings =
- TopicStatsClientSettings.newBuilder().setRegion(topicPath().location().region()).build();
+ TopicStatsClientSettings.newBuilder()
+ .setRegion(topicPath().location().extractRegion())
+ .build();
TopicBacklogReader impl =
new TopicBacklogReaderImpl(TopicStatsClient.create(settings), topicPath(), partition());
return new LimitingTopicBacklogReader(impl, Ticker.systemTicker());
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TrackerWithProgress.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TrackerWithProgress.java
new file mode 100644
index 0000000..7f0d030
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/TrackerWithProgress.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
+
+public abstract class TrackerWithProgress
+ extends RestrictionTracker<OffsetByteRange, OffsetByteProgress> implements HasProgress {}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
index 53d8db0..a9a3609 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
@@ -73,6 +73,10 @@
import com.google.cloud.bigquery.storage.v1.SplitReadStreamResponse;
import com.google.cloud.hadoop.util.ApiErrorExtractor;
import com.google.cloud.hadoop.util.RetryBoundedBackOff;
+import com.google.protobuf.Parser;
+import com.google.rpc.RetryInfo;
+import io.grpc.Metadata;
+import io.grpc.Status;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
@@ -90,6 +94,7 @@
import org.apache.beam.sdk.extensions.gcp.util.Transport;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl.DatasetServiceImpl;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl.JobServiceImpl;
+import org.apache.beam.sdk.metrics.MetricName;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.ExpectedLogs;
@@ -152,6 +157,7 @@
// Setup the ProcessWideContainer for testing metrics are set.
MetricsContainerImpl container = new MetricsContainerImpl(null);
MetricsEnvironment.setProcessWideContainer(container);
+ MetricsEnvironment.setCurrentContainer(container);
}
@FunctionalInterface
@@ -1560,4 +1566,65 @@
client.splitReadStream(request, "myproject:mydataset.mytable");
verifyReadMetricWasSet("myproject", "mydataset", "mytable", "resource_exhausted", 1);
}
+
+ @Test
+ public void testRetryAttemptCounter() {
+ BigQueryServicesImpl.StorageClientImpl.RetryAttemptCounter counter =
+ new BigQueryServicesImpl.StorageClientImpl.RetryAttemptCounter();
+
+ RetryInfo retryInfo =
+ RetryInfo.newBuilder()
+ .setRetryDelay(
+ com.google.protobuf.Duration.newBuilder()
+ .setSeconds(123)
+ .setNanos(456000000)
+ .build())
+ .build();
+
+ Metadata metadata = new Metadata();
+ metadata.put(
+ Metadata.Key.of(
+ "google.rpc.retryinfo-bin",
+ new Metadata.BinaryMarshaller<RetryInfo>() {
+ @Override
+ public byte[] toBytes(RetryInfo value) {
+ return value.toByteArray();
+ }
+
+ @Override
+ public RetryInfo parseBytes(byte[] serialized) {
+ try {
+ Parser<RetryInfo> parser = (RetryInfo.newBuilder().build()).getParserForType();
+ return parser.parseFrom(serialized);
+ } catch (Exception e) {
+ return null;
+ }
+ }
+ }),
+ retryInfo);
+
+ MetricName metricName =
+ MetricName.named(
+ "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$StorageClientImpl",
+ "throttling-msecs");
+ MetricsContainerImpl container =
+ (MetricsContainerImpl) MetricsEnvironment.getCurrentContainer();
+
+ // Nulls don't bump the counter.
+ counter.onRetryAttempt(null, null);
+ assertEquals(0, (long) container.getCounter(metricName).getCumulative());
+
+ // Resource exhausted with empty metadata doesn't bump the counter.
+ counter.onRetryAttempt(
+ Status.RESOURCE_EXHAUSTED.withDescription("You have consumed some quota"), new Metadata());
+ assertEquals(0, (long) container.getCounter(metricName).getCumulative());
+
+ // Resource exhausted with retry info bumps the counter.
+ counter.onRetryAttempt(Status.RESOURCE_EXHAUSTED.withDescription("Stop for a while"), metadata);
+ assertEquals(123456, (long) container.getCounter(metricName).getCumulative());
+
+ // Other errors with retry info doesn't bump the counter.
+ counter.onRetryAttempt(Status.UNAVAILABLE.withDescription("Server is gone"), metadata);
+ assertEquals(123456, (long) container.getCounter(metricName).getCumulative());
+ }
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
index 0c9bbf1..1f29883 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnPartitionQueryTest.java
@@ -99,6 +99,41 @@
assertEquals(expected, allValues);
}
+ @Test
+ public void endToEnd_emptyCursors() throws Exception {
+ // First page of the response
+ PartitionQueryRequest request1 =
+ PartitionQueryRequest.newBuilder()
+ .setParent(String.format("projects/%s/databases/(default)/document", projectId))
+ .build();
+ PartitionQueryResponse response1 = PartitionQueryResponse.newBuilder().build();
+ when(callable.call(request1)).thenReturn(pagedResponse1);
+ when(page1.getResponse()).thenReturn(response1);
+ when(pagedResponse1.iteratePages()).thenReturn(ImmutableList.of(page1));
+
+ when(stub.partitionQueryPagedCallable()).thenReturn(callable);
+
+ when(ff.getFirestoreStub(any())).thenReturn(stub);
+ RpcQosOptions options = RpcQosOptions.defaultOptions();
+ when(ff.getRpcQos(any()))
+ .thenReturn(FirestoreStatefulComponentFactory.INSTANCE.getRpcQos(options));
+
+ ArgumentCaptor<PartitionQueryPair> responses =
+ ArgumentCaptor.forClass(PartitionQueryPair.class);
+
+ doNothing().when(processContext).output(responses.capture());
+
+ when(processContext.element()).thenReturn(request1);
+
+ PartitionQueryFn fn = new PartitionQueryFn(clock, ff, options);
+
+ runFunction(fn);
+
+ List<PartitionQueryPair> expected = newArrayList(new PartitionQueryPair(request1, response1));
+ List<PartitionQueryPair> allValues = responses.getAllValues();
+ assertEquals(expected, allValues);
+ }
+
@Override
public void resumeFromLastReadValue() throws Exception {
when(ff.getFirestoreStub(any())).thenReturn(stub);
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
index 25ed63c..ed789da 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/PartitionQueryResponseToRunQueryRequestTest.java
@@ -121,6 +121,39 @@
assertEquals(expectedQueries, actualQueries);
}
+ @Test
+ public void ensureCursorPairingWorks_emptyCursorsInResponse() {
+ StructuredQuery query =
+ StructuredQuery.newBuilder()
+ .addFrom(
+ CollectionSelector.newBuilder()
+ .setAllDescendants(true)
+ .setCollectionId("c1")
+ .build())
+ .build();
+
+ List<StructuredQuery> expectedQueries = newArrayList(query);
+
+ PartitionQueryPair partitionQueryPair =
+ new PartitionQueryPair(
+ PartitionQueryRequest.newBuilder().setStructuredQuery(query).build(),
+ PartitionQueryResponse.newBuilder().build());
+
+ ArgumentCaptor<RunQueryRequest> captor = ArgumentCaptor.forClass(RunQueryRequest.class);
+ when(processContext.element()).thenReturn(partitionQueryPair);
+ doNothing().when(processContext).output(captor.capture());
+
+ PartitionQueryResponseToRunQueryRequest fn = new PartitionQueryResponseToRunQueryRequest();
+ fn.processElement(processContext);
+
+ List<StructuredQuery> actualQueries =
+ captor.getAllValues().stream()
+ .map(RunQueryRequest::getStructuredQuery)
+ .collect(Collectors.toList());
+
+ assertEquals(expectedQueries, actualQueries);
+ }
+
private static Cursor referenceValueCursor(String referenceValue) {
return Cursor.newBuilder()
.addValues(Value.newBuilder().setReferenceValue(referenceValue).build())
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java
index f34ebb6..5a31f4f 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/OffsetByteRangeTrackerTest.java
@@ -49,7 +49,7 @@
private static final double IGNORED_FRACTION = -10000000.0;
private static final long MIN_BYTES = 1000;
private static final OffsetRange RANGE = new OffsetRange(123L, Long.MAX_VALUE);
- private final TopicBacklogReader reader = mock(TopicBacklogReader.class);
+ private final TopicBacklogReader unownedBacklogReader = mock(TopicBacklogReader.class);
@Spy Ticker ticker;
private OffsetByteRangeTracker tracker;
@@ -60,14 +60,18 @@
when(ticker.read()).thenReturn(0L);
tracker =
new OffsetByteRangeTracker(
- RANGE, reader, Stopwatch.createUnstarted(ticker), Duration.millis(500), MIN_BYTES);
+ OffsetByteRange.of(RANGE, 0),
+ unownedBacklogReader,
+ Stopwatch.createUnstarted(ticker),
+ Duration.millis(500),
+ MIN_BYTES);
}
@Test
public void progressTracked() {
assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(123), 10)));
assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(124), 11)));
- when(reader.computeMessageStats(Offset.of(125)))
+ when(unownedBacklogReader.computeMessageStats(Offset.of(125)))
.thenReturn(ComputeMessageStatsResponse.newBuilder().setMessageBytes(1000).build());
Progress progress = tracker.getProgress();
assertEquals(21, progress.getWorkCompleted(), .0001);
@@ -76,7 +80,7 @@
@Test
public void getProgressStatsFailure() {
- when(reader.computeMessageStats(Offset.of(123)))
+ when(unownedBacklogReader.computeMessageStats(Offset.of(123)))
.thenThrow(new CheckedApiException(Code.INTERNAL).underlying);
assertThrows(ApiException.class, tracker::getProgress);
}
@@ -86,11 +90,15 @@
public void claimSplitSuccess() {
assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(1_000), MIN_BYTES)));
assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(10_000), MIN_BYTES)));
- SplitResult<OffsetRange> splits = tracker.trySplit(IGNORED_FRACTION);
- assertEquals(RANGE.getFrom(), splits.getPrimary().getFrom());
- assertEquals(10_001, splits.getPrimary().getTo());
- assertEquals(10_001, splits.getResidual().getFrom());
- assertEquals(Long.MAX_VALUE, splits.getResidual().getTo());
+ SplitResult<OffsetByteRange> splits = tracker.trySplit(IGNORED_FRACTION);
+ OffsetByteRange primary = splits.getPrimary();
+ assertEquals(RANGE.getFrom(), primary.getRange().getFrom());
+ assertEquals(10_001, primary.getRange().getTo());
+ assertEquals(MIN_BYTES * 2, primary.getByteCount());
+ OffsetByteRange residual = splits.getResidual();
+ assertEquals(10_001, residual.getRange().getFrom());
+ assertEquals(Long.MAX_VALUE, residual.getRange().getTo());
+ assertEquals(0, residual.getByteCount());
assertEquals(splits.getPrimary(), tracker.currentRestriction());
tracker.checkDone();
assertNull(tracker.trySplit(IGNORED_FRACTION));
@@ -100,10 +108,10 @@
@SuppressWarnings({"dereference.of.nullable", "argument.type.incompatible"})
public void splitWithoutClaimEmpty() {
when(ticker.read()).thenReturn(100000000000000L);
- SplitResult<OffsetRange> splits = tracker.trySplit(IGNORED_FRACTION);
- assertEquals(RANGE.getFrom(), splits.getPrimary().getFrom());
- assertEquals(RANGE.getFrom(), splits.getPrimary().getTo());
- assertEquals(RANGE, splits.getResidual());
+ SplitResult<OffsetByteRange> splits = tracker.trySplit(IGNORED_FRACTION);
+ assertEquals(RANGE.getFrom(), splits.getPrimary().getRange().getFrom());
+ assertEquals(RANGE.getFrom(), splits.getPrimary().getRange().getTo());
+ assertEquals(RANGE, splits.getResidual().getRange());
assertEquals(splits.getPrimary(), tracker.currentRestriction());
tracker.checkDone();
assertNull(tracker.trySplit(IGNORED_FRACTION));
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java
index 598037e..0a4e3e7 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/PerSubscriptionPartitionSdfTest.java
@@ -28,6 +28,7 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;
@@ -51,6 +52,8 @@
import org.apache.beam.sdk.transforms.SerializableBiFunction;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.math.DoubleMath;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Test;
@@ -65,22 +68,24 @@
public class PerSubscriptionPartitionSdfTest {
private static final Duration MAX_SLEEP_TIME =
Duration.standardMinutes(10).plus(Duration.millis(10));
- private static final OffsetRange RESTRICTION = new OffsetRange(1, Long.MAX_VALUE);
+ private static final OffsetByteRange RESTRICTION =
+ OffsetByteRange.of(new OffsetRange(1, Long.MAX_VALUE), 0);
private static final SubscriptionPartition PARTITION =
SubscriptionPartition.of(example(SubscriptionPath.class), example(Partition.class));
@Mock SerializableFunction<SubscriptionPartition, InitialOffsetReader> offsetReaderFactory;
+ @Mock ManagedBacklogReaderFactory backlogReaderFactory;
+ @Mock TopicBacklogReader backlogReader;
+
@Mock
- SerializableBiFunction<
- SubscriptionPartition, OffsetRange, RestrictionTracker<OffsetRange, OffsetByteProgress>>
- trackerFactory;
+ SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress> trackerFactory;
@Mock SubscriptionPartitionProcessorFactory processorFactory;
@Mock SerializableFunction<SubscriptionPartition, Committer> committerFactory;
@Mock InitialOffsetReader initialOffsetReader;
- @Spy RestrictionTracker<OffsetRange, OffsetByteProgress> tracker;
+ @Spy TrackerWithProgress tracker;
@Mock OutputReceiver<SequencedMessage> output;
@Mock SubscriptionPartitionProcessor processor;
@@ -98,9 +103,11 @@
when(trackerFactory.apply(any(), any())).thenReturn(tracker);
when(committerFactory.apply(any())).thenReturn(committer);
when(tracker.currentRestriction()).thenReturn(RESTRICTION);
+ when(backlogReaderFactory.newReader(any())).thenReturn(backlogReader);
sdf =
new PerSubscriptionPartitionSdf(
MAX_SLEEP_TIME,
+ backlogReaderFactory,
offsetReaderFactory,
trackerFactory,
processorFactory,
@@ -110,9 +117,10 @@
@Test
public void getInitialRestrictionReadSuccess() {
when(initialOffsetReader.read()).thenReturn(example(Offset.class));
- OffsetRange range = sdf.getInitialRestriction(PARTITION);
- assertEquals(example(Offset.class).value(), range.getFrom());
- assertEquals(Long.MAX_VALUE, range.getTo());
+ OffsetByteRange range = sdf.getInitialRestriction(PARTITION);
+ assertEquals(example(Offset.class).value(), range.getRange().getFrom());
+ assertEquals(Long.MAX_VALUE, range.getRange().getTo());
+ assertEquals(0, range.getByteCount());
verify(offsetReaderFactory).apply(PARTITION);
}
@@ -125,7 +133,13 @@
@Test
public void newTrackerCallsFactory() {
assertSame(tracker, sdf.newTracker(PARTITION, RESTRICTION));
- verify(trackerFactory).apply(PARTITION, RESTRICTION);
+ verify(trackerFactory).apply(backlogReader, RESTRICTION);
+ }
+
+ @Test
+ public void tearDownClosesBacklogReaderFactory() {
+ sdf.teardown();
+ verify(backlogReaderFactory).close();
}
@Test
@@ -159,12 +173,48 @@
order2.verify(committer).awaitTerminated();
}
+ private static final class NoopManagedBacklogReaderFactory
+ implements ManagedBacklogReaderFactory {
+ @Override
+ public TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition) {
+ return null;
+ }
+
+ @Override
+ public void close() {}
+ }
+
@Test
@SuppressWarnings("return.type.incompatible")
public void dofnIsSerializable() throws Exception {
ObjectOutputStream output = new ObjectOutputStream(new ByteArrayOutputStream());
output.writeObject(
new PerSubscriptionPartitionSdf(
- MAX_SLEEP_TIME, x -> null, (x, y) -> null, (x, y, z) -> null, (x) -> null));
+ MAX_SLEEP_TIME,
+ new NoopManagedBacklogReaderFactory(),
+ x -> null,
+ (x, y) -> null,
+ (x, y, z) -> null,
+ (x) -> null));
+ }
+
+ @Test
+ public void getProgressUnboundedRangeDelegates() {
+ Progress progress = Progress.from(0, 0.2);
+ when(tracker.getProgress()).thenReturn(progress);
+ assertTrue(
+ DoubleMath.fuzzyEquals(
+ progress.getWorkRemaining(), sdf.getSize(PARTITION, RESTRICTION), .0001));
+ verify(tracker).getProgress();
+ }
+
+ @Test
+ public void getProgressBoundedReturnsBytes() {
+ assertTrue(
+ DoubleMath.fuzzyEquals(
+ 123.0,
+ sdf.getSize(PARTITION, OffsetByteRange.of(new OffsetRange(87, 8000), 123)),
+ .0001));
+ verifyNoInteractions(tracker);
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java
new file mode 100644
index 0000000..e242942
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/ReadWriteIT.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsublite;
+
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static org.junit.Assert.fail;
+
+import com.google.cloud.pubsublite.AdminClient;
+import com.google.cloud.pubsublite.AdminClientSettings;
+import com.google.cloud.pubsublite.BacklogLocation;
+import com.google.cloud.pubsublite.CloudZone;
+import com.google.cloud.pubsublite.Message;
+import com.google.cloud.pubsublite.ProjectId;
+import com.google.cloud.pubsublite.SubscriptionName;
+import com.google.cloud.pubsublite.SubscriptionPath;
+import com.google.cloud.pubsublite.TopicName;
+import com.google.cloud.pubsublite.TopicPath;
+import com.google.cloud.pubsublite.proto.PubSubMessage;
+import com.google.cloud.pubsublite.proto.SequencedMessage;
+import com.google.cloud.pubsublite.proto.Subscription;
+import com.google.cloud.pubsublite.proto.Subscription.DeliveryConfig.DeliveryRequirement;
+import com.google.cloud.pubsublite.proto.Topic;
+import com.google.cloud.pubsublite.proto.Topic.PartitionConfig.Capacity;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import com.google.protobuf.ByteString;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.StreamingOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.FlatMapElements;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Duration;
+import org.junit.After;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@RunWith(JUnit4.class)
+public class ReadWriteIT {
+ private static final Logger LOG = LoggerFactory.getLogger(ReadWriteIT.class);
+ private static final CloudZone ZONE = CloudZone.parse("us-central1-b");
+ private static final int MESSAGE_COUNT = 90;
+
+ @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+ private static ProjectId getProject(PipelineOptions options) {
+ return ProjectId.of(checkArgumentNotNull(options.as(GcpOptions.class).getProject()));
+ }
+
+ private static String randomName() {
+ return "beam_it_resource_" + ThreadLocalRandom.current().nextLong();
+ }
+
+ private static AdminClient newAdminClient() {
+ return AdminClient.create(AdminClientSettings.newBuilder().setRegion(ZONE.region()).build());
+ }
+
+ private final Deque<Runnable> cleanupActions = new ArrayDeque<>();
+
+ private TopicPath createTopic(ProjectId id) throws Exception {
+ TopicPath toReturn =
+ TopicPath.newBuilder()
+ .setProject(id)
+ .setLocation(ZONE)
+ .setName(TopicName.of(randomName()))
+ .build();
+ Topic.Builder topic = Topic.newBuilder().setName(toReturn.toString());
+ topic
+ .getPartitionConfigBuilder()
+ .setCount(2)
+ .setCapacity(Capacity.newBuilder().setPublishMibPerSec(4).setSubscribeMibPerSec(4));
+ topic.getRetentionConfigBuilder().setPerPartitionBytes(30 * (1L << 30));
+ cleanupActions.addLast(
+ () -> {
+ try (AdminClient client = newAdminClient()) {
+ client.deleteTopic(toReturn).get();
+ } catch (Throwable t) {
+ LOG.error("Failed to clean up topic.", t);
+ }
+ });
+ try (AdminClient client = newAdminClient()) {
+ client.createTopic(topic.build()).get();
+ }
+ return toReturn;
+ }
+
+ private SubscriptionPath createSubscription(TopicPath topic) throws Exception {
+ SubscriptionPath toReturn =
+ SubscriptionPath.newBuilder()
+ .setProject(topic.project())
+ .setLocation(ZONE)
+ .setName(SubscriptionName.of(randomName()))
+ .build();
+ Subscription.Builder subscription = Subscription.newBuilder().setName(toReturn.toString());
+ subscription
+ .getDeliveryConfigBuilder()
+ .setDeliveryRequirement(DeliveryRequirement.DELIVER_IMMEDIATELY);
+ subscription.setTopic(topic.toString());
+ cleanupActions.addLast(
+ () -> {
+ try (AdminClient client = newAdminClient()) {
+ client.deleteSubscription(toReturn).get();
+ } catch (Throwable t) {
+ LOG.error("Failed to clean up subscription.", t);
+ }
+ });
+ try (AdminClient client = newAdminClient()) {
+ client.createSubscription(subscription.build(), BacklogLocation.BEGINNING).get();
+ }
+ return toReturn;
+ }
+
+ @After
+ public void tearDown() {
+ while (!cleanupActions.isEmpty()) {
+ cleanupActions.removeLast().run();
+ }
+ }
+
+ // Workaround for BEAM-12867
+ // TODO(BEAM-12867): Remove this.
+ private static class CustomCreate extends PTransform<PCollection<Void>, PCollection<Integer>> {
+ @Override
+ public PCollection<Integer> expand(PCollection<Void> input) {
+ return input.apply(
+ "createIndexes",
+ FlatMapElements.via(
+ new SimpleFunction<Void, Iterable<Integer>>() {
+ @Override
+ public Iterable<Integer> apply(Void input) {
+ return IntStream.range(0, MESSAGE_COUNT).boxed().collect(Collectors.toList());
+ }
+ }));
+ }
+ }
+
+ public static void writeMessages(TopicPath topicPath, Pipeline pipeline) {
+ PCollection<Void> trigger = pipeline.apply(Create.of((Void) null));
+ PCollection<Integer> indexes = trigger.apply("createIndexes", new CustomCreate());
+ PCollection<PubSubMessage> messages =
+ indexes.apply(
+ "createMessages",
+ MapElements.via(
+ new SimpleFunction<Integer, PubSubMessage>(
+ index ->
+ Message.builder()
+ .setData(ByteString.copyFromUtf8(index.toString()))
+ .build()
+ .toProto()) {}));
+ // Add UUIDs to messages for later deduplication.
+ messages = messages.apply("addUuids", PubsubLiteIO.addUuids());
+ messages.apply(
+ "writeMessages",
+ PubsubLiteIO.write(PublisherOptions.newBuilder().setTopicPath(topicPath).build()));
+ }
+
+ public static PCollection<SequencedMessage> readMessages(
+ SubscriptionPath subscriptionPath, Pipeline pipeline) {
+ PCollection<SequencedMessage> messages =
+ pipeline.apply(
+ "readMessages",
+ PubsubLiteIO.read(
+ SubscriberOptions.newBuilder()
+ .setSubscriptionPath(subscriptionPath)
+ // setMinBundleTimeout INTENDED FOR TESTING ONLY
+ // This sacrifices efficiency to make tests run faster. Do not use this in a
+ // real pipeline!
+ .setMinBundleTimeout(Duration.standardSeconds(5))
+ .build()));
+ // Deduplicate messages based on the uuids added in PubsubLiteIO.addUuids() when writing.
+ return messages.apply(
+ "dedupeMessages", PubsubLiteIO.deduplicate(UuidDeduplicationOptions.newBuilder().build()));
+ }
+
+ // This static out of band communication is needed to retain serializability.
+ @GuardedBy("ReadWriteIT.class")
+ private static final List<SequencedMessage> received = new ArrayList<>();
+
+ private static synchronized void addMessageReceived(SequencedMessage message) {
+ received.add(message);
+ }
+
+ private static synchronized List<SequencedMessage> getTestQuickstartReceived() {
+ return ImmutableList.copyOf(received);
+ }
+
+ private static PTransform<PCollection<? extends SequencedMessage>, PCollection<Void>>
+ collectTestQuickstart() {
+ return MapElements.via(
+ new SimpleFunction<SequencedMessage, Void>() {
+ @Override
+ public Void apply(SequencedMessage input) {
+ addMessageReceived(input);
+ return null;
+ }
+ });
+ }
+
+ @Test
+ public void testReadWrite() throws Exception {
+ pipeline.getOptions().as(StreamingOptions.class).setStreaming(true);
+ pipeline.getOptions().as(TestPipelineOptions.class).setBlockOnRun(false);
+
+ TopicPath topic = createTopic(getProject(pipeline.getOptions()));
+ SubscriptionPath subscription = createSubscription(topic);
+
+ // Publish some messages
+ writeMessages(topic, pipeline);
+
+ // Read some messages. They should be deduplicated by the time we see them, so there should be
+ // exactly numMessages, one for every index in [0,MESSAGE_COUNT).
+ PCollection<SequencedMessage> messages = readMessages(subscription, pipeline);
+ messages.apply("messageReceiver", collectTestQuickstart());
+ pipeline.run();
+ LOG.info("Running!");
+ for (int round = 0; round < 120; ++round) {
+ Thread.sleep(1000);
+ Map<Integer, Integer> receivedCounts = new HashMap<>();
+ for (SequencedMessage message : getTestQuickstartReceived()) {
+ int id = Integer.parseInt(message.getMessage().getData().toStringUtf8());
+ receivedCounts.put(id, receivedCounts.getOrDefault(id, 0) + 1);
+ }
+ LOG.info("Performing comparison round {}.\n", round);
+ boolean done = true;
+ List<Integer> missing = new ArrayList<>();
+ for (int id = 0; id < MESSAGE_COUNT; id++) {
+ int idCount = receivedCounts.getOrDefault(id, 0);
+ if (idCount == 0) {
+ missing.add(id);
+ done = false;
+ }
+ if (idCount > 1) {
+ fail(String.format("Failed to deduplicate message with id %s.", id));
+ }
+ }
+ LOG.info("Still messing messages: {}.\n", missing);
+ if (done) {
+ return;
+ }
+ }
+ fail(
+ String.format(
+ "Failed to receive all messages after 2 minutes. Received %s messages.",
+ getTestQuickstartReceived().size()));
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java
index dbf3b93..3d743758 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/SubscriptionPartitionProcessorImplTest.java
@@ -31,7 +31,6 @@
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;
-import com.google.api.core.ApiFutures;
import com.google.api.gax.rpc.ApiException;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.cloud.pubsublite.Offset;
@@ -40,7 +39,6 @@
import com.google.cloud.pubsublite.internal.wire.Subscriber;
import com.google.cloud.pubsublite.proto.Cursor;
import com.google.cloud.pubsublite.proto.FlowControlRequest;
-import com.google.cloud.pubsublite.proto.SeekRequest;
import com.google.cloud.pubsublite.proto.SequencedMessage;
import com.google.protobuf.util.Timestamps;
import java.util.List;
@@ -64,7 +62,7 @@
@RunWith(JUnit4.class)
@SuppressWarnings("initialization.fields.uninitialized")
public class SubscriptionPartitionProcessorImplTest {
- @Spy RestrictionTracker<OffsetRange, OffsetByteProgress> tracker;
+ @Spy RestrictionTracker<OffsetByteRange, OffsetByteProgress> tracker;
@Mock OutputReceiver<SequencedMessage> receiver;
@Mock Function<Consumer<List<SequencedMessage>>, Subscriber> subscriberFactory;
@@ -83,6 +81,10 @@
.build();
}
+ private OffsetByteRange initialRange() {
+ return OffsetByteRange.of(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
+ }
+
@Before
public void setUp() {
initMocks(this);
@@ -100,18 +102,11 @@
@Test
public void lifecycle() throws Exception {
- when(tracker.currentRestriction())
- .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
- when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+ when(tracker.currentRestriction()).thenReturn(initialRange());
processor.start();
verify(subscriber).startAsync();
verify(subscriber).awaitRunning();
verify(subscriber)
- .seek(
- SeekRequest.newBuilder()
- .setCursor(Cursor.newBuilder().setOffset(example(Offset.class).value()))
- .build());
- verify(subscriber)
.allowFlow(
FlowControlRequest.newBuilder()
.setAllowedBytes(DEFAULT_FLOW_CONTROL.bytesOutstanding())
@@ -123,29 +118,15 @@
}
@Test
- public void lifecycleSeekThrows() throws Exception {
- when(tracker.currentRestriction())
- .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
- when(subscriber.seek(any()))
- .thenReturn(ApiFutures.immediateFailedFuture(new CheckedApiException(Code.OUT_OF_RANGE)));
+ public void lifecycleFlowControlThrows() throws Exception {
+ when(tracker.currentRestriction()).thenReturn(initialRange());
doThrow(new CheckedApiException(Code.OUT_OF_RANGE)).when(subscriber).allowFlow(any());
assertThrows(CheckedApiException.class, () -> processor.start());
}
@Test
- public void lifecycleFlowControlThrows() {
- when(tracker.currentRestriction())
- .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
- when(subscriber.seek(any()))
- .thenReturn(ApiFutures.immediateFailedFuture(new CheckedApiException(Code.OUT_OF_RANGE)));
- assertThrows(CheckedApiException.class, () -> processor.start());
- }
-
- @Test
public void lifecycleSubscriberAwaitThrows() throws Exception {
- when(tracker.currentRestriction())
- .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
- when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+ when(tracker.currentRestriction()).thenReturn(initialRange());
processor.start();
doThrow(new CheckedApiException(Code.INTERNAL).underlying).when(subscriber).awaitTerminated();
assertThrows(ApiException.class, () -> processor.close());
@@ -155,21 +136,19 @@
@Test
public void subscriberFailureFails() throws Exception {
- when(tracker.currentRestriction())
- .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
- when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+ when(tracker.currentRestriction()).thenReturn(initialRange());
processor.start();
subscriber.fail(new CheckedApiException(Code.OUT_OF_RANGE));
ApiException e =
- assertThrows(ApiException.class, () -> processor.waitForCompletion(Duration.ZERO));
+ assertThrows(
+ // Longer wait is needed due to listener asynchrony.
+ ApiException.class, () -> processor.waitForCompletion(Duration.standardSeconds(1)));
assertEquals(Code.OUT_OF_RANGE, e.getStatusCode().getCode());
}
@Test
public void allowFlowFailureFails() throws Exception {
- when(tracker.currentRestriction())
- .thenReturn(new OffsetRange(example(Offset.class).value(), Long.MAX_VALUE));
- when(subscriber.seek(any())).thenReturn(ApiFutures.immediateFuture(example(Offset.class)));
+ when(tracker.currentRestriction()).thenReturn(initialRange());
processor.start();
when(tracker.tryClaim(any())).thenReturn(true);
doThrow(new CheckedApiException(Code.OUT_OF_RANGE)).when(subscriber).allowFlow(any());
diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle
index 69cd88f..3f9d0c7 100644
--- a/sdks/java/io/jms/build.gradle
+++ b/sdks/java/io/jms/build.gradle
@@ -36,6 +36,7 @@
testCompile library.java.activemq_kahadb_store
testCompile library.java.activemq_client
testCompile library.java.junit
+ testCompile library.java.mockito_core
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java
new file mode 100644
index 0000000..0e023d1
--- /dev/null
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jms;
+
+import java.io.Serializable;
+import org.apache.beam.sdk.io.UnboundedSource;
+
+/**
+ * Enables users to specify their own `JMS` backlog reporters enabling {@link JmsIO} to report
+ * {@link UnboundedSource.UnboundedReader#getTotalBacklogBytes()}.
+ */
+public interface AutoScaler extends Serializable {
+
+ /** The {@link AutoScaler} is started when the {@link JmsIO.UnboundedJmsReader} is started. */
+ void start();
+
+ /**
+ * Returns the size of the backlog of unread data in the underlying data source represented by all
+ * splits of this source.
+ */
+ long getTotalBacklogBytes();
+
+ /** The {@link AutoScaler} is stopped when the {@link JmsIO.UnboundedJmsReader} is closed. */
+ void stop();
+}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java
new file mode 100644
index 0000000..2b05cf6
--- /dev/null
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jms;
+
+import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
+
+/**
+ * Default implementation of {@link AutoScaler}. Returns {@link
+ * org.apache.beam.sdk.io.UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} as the default value.
+ */
+public class DefaultAutoscaler implements AutoScaler {
+ @Override
+ public void start() {}
+
+ @Override
+ public long getTotalBacklogBytes() {
+ return BACKLOG_UNKNOWN;
+ }
+
+ @Override
+ public void stop() {}
+}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
index 4999e10..9fa4492 100644
--- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
@@ -196,6 +196,8 @@
abstract @Nullable Coder<T> getCoder();
+ abstract @Nullable AutoScaler getAutoScaler();
+
abstract Builder<T> builder();
@AutoValue.Builder
@@ -218,6 +220,8 @@
abstract Builder<T> setCoder(Coder<T> coder);
+ abstract Builder<T> setAutoScaler(AutoScaler autoScaler);
+
abstract Read<T> build();
}
@@ -344,6 +348,14 @@
return builder().setCoder(coder).build();
}
+ /**
+ * Sets the {@link AutoScaler} to use for reporting backlog during the execution of this source.
+ */
+ public Read<T> withAutoScaler(AutoScaler autoScaler) {
+ checkArgument(autoScaler != null, "autoScaler can not be null");
+ return builder().setAutoScaler(autoScaler).build();
+ }
+
@Override
public PCollection<T> expand(PBegin input) {
checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required");
@@ -447,6 +459,7 @@
private Connection connection;
private Session session;
private MessageConsumer consumer;
+ private AutoScaler autoScaler;
private T currentMessage;
private Instant currentTimestamp;
@@ -474,6 +487,12 @@
}
connection.start();
this.connection = connection;
+ if (spec.getAutoScaler() == null) {
+ this.autoScaler = new DefaultAutoscaler();
+ } else {
+ this.autoScaler = spec.getAutoScaler();
+ }
+ this.autoScaler.start();
} catch (Exception e) {
throw new IOException("Error connecting to JMS", e);
}
@@ -545,6 +564,11 @@
}
@Override
+ public long getTotalBacklogBytes() {
+ return this.autoScaler.getTotalBacklogBytes();
+ }
+
+ @Override
public UnboundedSource<T, ?> getCurrentSource() {
return source;
}
@@ -565,6 +589,10 @@
connection.close();
connection = null;
}
+ if (autoScaler != null) {
+ autoScaler.stop();
+ autoScaler = null;
+ }
} catch (Exception e) {
throw new IOException(e);
}
diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
index c335f8a..a9f3c3f 100644
--- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
+++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
@@ -17,12 +17,17 @@
*/
package org.apache.beam.sdk.io.jms;
+import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import java.io.IOException;
import java.lang.reflect.Proxy;
@@ -421,6 +426,50 @@
CoderProperties.coderDecodeEncodeEqual(coder, jmsCheckpointMark);
}
+ @Test
+ public void testDefaultAutoscaler() throws IOException {
+ JmsIO.Read spec =
+ JmsIO.read()
+ .withConnectionFactory(connectionFactory)
+ .withUsername(USERNAME)
+ .withPassword(PASSWORD)
+ .withQueue(QUEUE);
+ JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+ JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+
+ // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
+ reader.start();
+ assertEquals(BACKLOG_UNKNOWN, reader.getSplitBacklogBytes());
+ assertEquals(BACKLOG_UNKNOWN, reader.getTotalBacklogBytes());
+ reader.close();
+ }
+
+ @Test
+ public void testCustomAutoscaler() throws IOException {
+ long excpectedTotalBacklogBytes = 1111L;
+
+ AutoScaler autoScaler = mock(DefaultAutoscaler.class);
+ when(autoScaler.getTotalBacklogBytes()).thenReturn(excpectedTotalBacklogBytes);
+ JmsIO.Read spec =
+ JmsIO.read()
+ .withConnectionFactory(connectionFactory)
+ .withUsername(USERNAME)
+ .withPassword(PASSWORD)
+ .withQueue(QUEUE)
+ .withAutoScaler(autoScaler);
+
+ JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+ JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+
+ // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
+ reader.start();
+ verify(autoScaler, times(1)).start();
+ assertEquals(excpectedTotalBacklogBytes, reader.getTotalBacklogBytes());
+ verify(autoScaler, times(1)).getTotalBacklogBytes();
+ reader.close();
+ verify(autoScaler, times(1)).stop();
+ }
+
private int count(String queue) throws Exception {
Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD);
connection.start();
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index 370110d..05f7a9d 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -531,6 +531,13 @@
return coder_impl.MapCoderImpl(
self._key_coder.get_impl(), self._value_coder.get_impl())
+ @classmethod
+ def from_type_hint(cls, typehint, registry):
+ # type: (typehints.DictConstraint, CoderRegistry) -> MapCoder
+ return cls(
+ registry.get_coder(typehint.key_type),
+ registry.get_coder(typehint.value_type))
+
def to_type_hint(self):
return typehints.Dict[self._key_coder.to_type_hint(),
self._value_coder.to_type_hint()]
diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py
index 8b13b46..4d67f8b 100644
--- a/sdks/python/apache_beam/coders/row_coder.py
+++ b/sdks/python/apache_beam/coders/row_coder.py
@@ -87,10 +87,10 @@
def from_runner_api_parameter(schema, components, unused_context):
return RowCoder(schema)
- @staticmethod
- def from_type_hint(type_hint, registry):
+ @classmethod
+ def from_type_hint(cls, type_hint, registry):
schema = schema_from_element_type(type_hint)
- return RowCoder(schema)
+ return cls(schema)
@staticmethod
def from_payload(payload):
diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py
index 4fe5f3b..03b6cee 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -87,12 +87,14 @@
def register_standard_coders(self, fallback_coder):
"""Register coders for all basic and composite types."""
+ # Coders without subclasses.
self._register_coder_internal(int, coders.VarIntCoder)
self._register_coder_internal(float, coders.FloatCoder)
self._register_coder_internal(bytes, coders.BytesCoder)
self._register_coder_internal(bool, coders.BooleanCoder)
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
+ self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [coders.ProtoCoder, coders.FastPrimitivesCoder]
diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py
index 638c639..35de4aa 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -1593,10 +1593,12 @@
rows with transient errors (e.g. timeouts). Rows with permanent errors
will be output to dead letter queue under `'FailedRows'` tag.
- additional_bq_parameters (callable): A function that returns a dictionary
- with additional parameters to pass to BQ when creating / loading data
- into a table. These can be 'timePartitioning', 'clustering', etc. They
- are passed directly to the job load configuration. See
+ additional_bq_parameters (dict, callable): Additional parameters to pass
+ to BQ when creating / loading data into a table. If a callable, it
+ should be a function that receives a table reference indicating
+ the destination and returns a dictionary.
+ These can be 'timePartitioning', 'clustering', etc. They are passed
+ directly to the job load configuration. See
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#jobconfigurationload
table_side_inputs (tuple): A tuple with ``AsSideInput`` PCollections to be
passed to the table callable (if one is provided).
diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
index bb8c5b8..d50b3a8 100644
--- a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
+++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
@@ -302,7 +302,7 @@
return snapshot_options
-@with_input_types(ReadOperation, typing.Dict[typing.Any, typing.Any])
+@with_input_types(ReadOperation, _SPANNER_TRANSACTION)
@with_output_types(typing.List[typing.Any])
class _NaiveSpannerReadDoFn(DoFn):
def __init__(self, spanner_configuration):
@@ -422,7 +422,7 @@
@with_input_types(int)
-@with_output_types(typing.Dict[typing.Any, typing.Any])
+@with_output_types(_SPANNER_TRANSACTION)
class _CreateTransactionFn(DoFn):
"""
A DoFn to create the transaction of cloud spanner.
diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
index 4cdf294..e3d1965 100644
--- a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
+++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
@@ -19,6 +19,7 @@
import logging
import random
import string
+import typing
import unittest
import mock
@@ -336,7 +337,10 @@
self, mock_batch_snapshot_class, mock_client_class):
with self.assertRaises(ValueError):
p = TestPipeline()
- transaction = (p | beam.Create([{"invalid": "transaction"}]))
+ transaction = (
+ p | beam.Create([{
+ "invalid": "transaction"
+ }]).with_output_types(typing.Any))
_ = (
p | 'with query' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index d0aaa2f..287ec49 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -300,6 +300,7 @@
original_transform_node.full_label,
original_transform_node.main_inputs)
+ # TODO(BEAM-12854): Merge rather than override.
replacement_transform_node.resource_hints = (
original_transform_node.resource_hints)
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index bbaf52c..29a6016 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -508,6 +508,13 @@
# in the proto representation of the graph.
pipeline.replace_all(DataflowRunner._NON_PORTABLE_PTRANSFORM_OVERRIDES)
+ # Always upload graph out-of-band when explicitly using runner v2 with
+ # use_portable_job_submission to avoid irrelevant large graph limits.
+ if (apiclient._use_unified_worker(debug_options) and
+ debug_options.lookup_experiment('use_portable_job_submission') and
+ not debug_options.lookup_experiment('upload_graph')):
+ debug_options.add_experiment("upload_graph")
+
# Add setup_options for all the BeamPlugin imports
setup_options = options.view_as(SetupOptions)
plugins = BeamPlugin.get_all_plugin_paths()
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 1c02f71..98d2faa 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -185,9 +185,12 @@
@ptransform.PTransform.register_urn(TEST_COMPK_URN, None)
class CombinePerKeyTransform(ptransform.PTransform):
def expand(self, pcoll):
- return pcoll \
- | beam.CombinePerKey(sum).with_output_types(
- typing.Tuple[str, int])
+ output = pcoll \
+ | beam.CombinePerKey(sum)
+ # TODO: Use `with_output_types` instead of explicitly
+ # assigning to `.element_type` after fixing BEAM-12872
+ output.element_type = beam.typehints.Tuple[str, int]
+ return output
def to_runner_api_parameter(self, unused_context):
return TEST_COMPK_URN, None
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
index bebde98..9b0d6ff 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
@@ -63,12 +63,16 @@
if not os.path.exists(self._executable_jar):
parsed = urllib.parse.urlparse(self._executable_jar)
if not parsed.scheme:
+ try:
+ flink_version = self.flink_version()
+ except Exception:
+ flink_version = '$FLINK_VERSION'
raise ValueError(
'Unable to parse jar URL "%s". If using a full URL, make sure '
'the scheme is specified. If using a local file path, make sure '
'the file exists; you may have to first build the job server '
'using `./gradlew runners:flink:%s:job-server:shadowJar`.' %
- (self._executable_jar, self._flink_version))
+ (self._executable_jar, flink_version))
url = self._executable_jar
else:
url = job_server.JavaJarJobServer.path_to_beam_jar(
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py
index 281e2d9..1294f46 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py
@@ -192,6 +192,36 @@
self.assertEqual(
options_proto['beam:option:unknown_option_foo:v1'], 'some_value')
+ @requests_mock.mock()
+ def test_bad_url_flink_version(self, http_mock):
+ http_mock.get('http://flink/v1/config', json={'flink-version': '1.2.3.4'})
+ options = pipeline_options.FlinkRunnerOptions()
+ options.flink_job_server_jar = "bad url"
+ job_server = flink_uber_jar_job_server.FlinkUberJarJobServer(
+ 'http://flink', options)
+ with self.assertRaises(ValueError) as context:
+ job_server.executable_jar()
+ self.assertEqual(
+ 'Unable to parse jar URL "bad url". If using a full URL, make sure '
+ 'the scheme is specified. If using a local file path, make sure '
+ 'the file exists; you may have to first build the job server '
+ 'using `./gradlew runners:flink:1.2:job-server:shadowJar`.',
+ str(context.exception))
+
+ def test_bad_url_placeholder_version(self):
+ options = pipeline_options.FlinkRunnerOptions()
+ options.flink_job_server_jar = "bad url"
+ job_server = flink_uber_jar_job_server.FlinkUberJarJobServer(
+ 'http://example.com/bad', options)
+ with self.assertRaises(ValueError) as context:
+ job_server.executable_jar()
+ self.assertEqual(
+ 'Unable to parse jar URL "bad url". If using a full URL, make sure '
+ 'the scheme is specified. If using a local file path, make sure '
+ 'the file exists; you may have to first build the job server '
+ 'using `./gradlew runners:flink:$FLINK_VERSION:job-server:shadowJar`.',
+ str(context.exception))
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index bcedd86..41ad3df 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -21,6 +21,7 @@
import copy
import heapq
+import itertools
import operator
import random
from typing import Any
@@ -168,7 +169,8 @@
"""Combiners for obtaining extremal elements."""
# pylint: disable=no-self-argument
-
+ @with_input_types(T)
+ @with_output_types(List[T])
class Of(CombinerWithoutDefaults):
"""Obtain a list of the compare-most N elements in a PCollection.
@@ -202,10 +204,12 @@
# This is a more efficient global algorithm.
top_per_bundle = pcoll | core.ParDo(
_TopPerBundle(self._n, self._key, self._reverse))
- # If pcoll is empty, we can't guerentee that top_per_bundle
+ # If pcoll is empty, we can't guarantee that top_per_bundle
# won't be empty, so inject at least one empty accumulator
- # so that downstream is guerenteed to produce non-empty output.
- empty_bundle = pcoll.pipeline | core.Create([(None, [])])
+ # so that downstream is guaranteed to produce non-empty output.
+ empty_bundle = (
+ pcoll.pipeline | core.Create([(None, [])]).with_output_types(
+ top_per_bundle.element_type))
return ((top_per_bundle, empty_bundle) | core.Flatten()
| core.GroupByKey()
| core.ParDo(
@@ -219,6 +223,8 @@
TopCombineFn(self._n, self._key,
self._reverse)).without_defaults()
+ @with_input_types(Tuple[K, V])
+ @with_output_types(Tuple[K, List[V]])
class PerKey(ptransform.PTransform):
"""Identifies the compare-most N elements associated with each key.
@@ -524,6 +530,8 @@
# pylint: disable=no-self-argument
+ @with_input_types(T)
+ @with_output_types(List[T])
class FixedSizeGlobally(CombinerWithoutDefaults):
"""Sample n elements from the input PCollection without replacement."""
def __init__(self, n):
@@ -543,6 +551,8 @@
def default_label(self):
return 'FixedSizeGlobally(%d)' % self._n
+ @with_input_types(Tuple[K, V])
+ @with_output_types(Tuple[K, List[V]])
class FixedSizePerKey(ptransform.PTransform):
"""Sample n elements associated with each key without replacement."""
def __init__(self, n):
@@ -597,16 +607,25 @@
class _TupleCombineFnBase(core.CombineFn):
- def __init__(self, *combiners):
+ def __init__(self, *combiners, merge_accumulators_batch_size=None):
self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners]
self._named_combiners = combiners
+ # If the `merge_accumulators_batch_size` value is not specified, we chose a
+ # bounded default that is inversely proportional to the number of
+ # accumulators in merged tuples.
+ num_combiners = max(1, len(combiners))
+ self._merge_accumulators_batch_size = (
+ merge_accumulators_batch_size or max(10, 1000 // num_combiners))
def display_data(self):
combiners = [
c.__name__ if hasattr(c, '__name__') else c.__class__.__name__
for c in self._named_combiners
]
- return {'combiners': str(combiners)}
+ return {
+ 'combiners': str(combiners),
+ 'merge_accumulators_batch_size': self._merge_accumulators_batch_size
+ }
def setup(self, *args, **kwargs):
for c in self._combiners:
@@ -616,10 +635,22 @@
return [c.create_accumulator(*args, **kwargs) for c in self._combiners]
def merge_accumulators(self, accumulators, *args, **kwargs):
- return [
- c.merge_accumulators(a, *args, **kwargs) for c,
- a in zip(self._combiners, zip(*accumulators))
- ]
+ # Make sure that `accumulators` is an iterator (so that the position is
+ # remembered).
+ accumulators = iter(accumulators)
+ result = next(accumulators)
+ while True:
+ # Load accumulators into memory and merge in batches to decrease peak
+ # memory usage.
+ accumulators_batch = [result] + list(
+ itertools.islice(accumulators, self._merge_accumulators_batch_size))
+ if len(accumulators_batch) == 1:
+ break
+ result = [
+ c.merge_accumulators(a, *args, **kwargs) for c,
+ a in zip(self._combiners, zip(*accumulators_batch))
+ ]
+ return result
def compact(self, accumulator, *args, **kwargs):
return [
@@ -670,6 +701,8 @@
]
+@with_input_types(T)
+@with_output_types(List[T])
class ToList(CombinerWithoutDefaults):
"""A global CombineFn that condenses a PCollection into a single list."""
def expand(self, pcoll):
@@ -698,6 +731,8 @@
return accumulator
+@with_input_types(Tuple[K, V])
+@with_output_types(Dict[K, V])
class ToDict(CombinerWithoutDefaults):
"""A global CombineFn that condenses a PCollection into a single dict.
@@ -735,6 +770,8 @@
return accumulator
+@with_input_types(T)
+@with_output_types(Set[T])
class ToSet(CombinerWithoutDefaults):
"""A global CombineFn that condenses a PCollection into a set."""
def expand(self, pcoll):
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index d826287..7e0e835 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -249,7 +249,8 @@
dd = DisplayData.create_from(transform)
expected_items = [
DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
- DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']")
+ DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
+ DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
@@ -358,6 +359,49 @@
max).with_common_input()).without_defaults())
assert_that(result, equal_to([(1, 7.0 / 4, 3)]))
+ def test_empty_tuple_combine_fn(self):
+ with TestPipeline() as p:
+ result = (
+ p
+ | Create([(), (), ()])
+ | beam.CombineGlobally(combine.TupleCombineFn()))
+ assert_that(result, equal_to([()]))
+
+ def test_tuple_combine_fn_batched_merge(self):
+ num_combine_fns = 10
+ max_num_accumulators_in_memory = 30
+ # Maximum number of accumulator tuples in memory - 1 for the merge result.
+ merge_accumulators_batch_size = (
+ max_num_accumulators_in_memory // num_combine_fns - 1)
+ num_accumulator_tuples_to_merge = 20
+
+ class CountedAccumulator:
+ count = 0
+ oom = False
+
+ def __init__(self):
+ if CountedAccumulator.count > max_num_accumulators_in_memory:
+ CountedAccumulator.oom = True
+ else:
+ CountedAccumulator.count += 1
+
+ class CountedAccumulatorCombineFn(beam.CombineFn):
+ def create_accumulator(self):
+ return CountedAccumulator()
+
+ def merge_accumulators(self, accumulators):
+ CountedAccumulator.count += 1
+ for _ in accumulators:
+ CountedAccumulator.count -= 1
+
+ combine_fn = combine.TupleCombineFn(
+ *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
+ merge_accumulators_batch_size=merge_accumulators_batch_size)
+ combine_fn.merge_accumulators(
+ combine_fn.create_accumulator()
+ for _ in range(num_accumulator_tuples_to_merge))
+ assert not CountedAccumulator.oom
+
def test_to_list_and_to_dict1(self):
with TestPipeline() as pipeline:
the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index ec001e2..9ce33f5 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -2333,9 +2333,8 @@
self.assertStartswith(
e.exception.args[0],
- "Type hint violation for 'CombinePerKey(TopCombineFn)': "
- "requires Tuple[TypeVariable[K], TypeVariable[T]] "
- "but got {} for element".format(int))
+ "Input type hint violation at TopMod: expected Tuple[TypeVariable[K], "
+ "TypeVariable[V]], got {}".format(int))
def test_per_key_pipeline_checking_satisfied(self):
d = (
@@ -2492,10 +2491,8 @@
self.assertStartswith(
e.exception.args[0],
- "Type hint violation for 'CombinePerKey': "
- "requires "
- "Tuple[TypeVariable[K], Tuple[TypeVariable[K], TypeVariable[V]]] "
- "but got Tuple[None, int] for element")
+ "Input type hint violation at ToDict: expected Tuple[TypeVariable[K], "
+ "TypeVariable[V]], got {}".format(int))
def test_to_dict_pipeline_check_satisfied(self):
d = (
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 28024fa..63bbece 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -178,7 +178,8 @@
pcolls_dict = {str(ix): pcolls[ix] for ix in range(num_tags)}
restore_tags = lambda vs: tuple(vs[str(ix)] for ix in range(num_tags))
- result = pcolls_dict | _CoGBKImpl(pipeline=self.pipeline)
+ result = (
+ pcolls_dict | 'CoGroupByKeyImpl' >> _CoGBKImpl(pipeline=self.pipeline))
if restore_tags:
return result | 'RestoreTags' >> MapTuple(
lambda k, vs: (k, restore_tags(vs)))
diff --git a/website/www/site/content/en/documentation/dsls/dataframes/overview.md b/website/www/site/content/en/documentation/dsls/dataframes/overview.md
index 294f925..e08c61b 100644
--- a/website/www/site/content/en/documentation/dsls/dataframes/overview.md
+++ b/website/www/site/content/en/documentation/dsls/dataframes/overview.md
@@ -45,7 +45,7 @@
In this example, the only traditional Beam type is the `Pipeline` instance. Otherwise the example is written completely with the DataFrame API. This is possible because the Beam DataFrame API includes its own IO operations (for example, `read_csv` and `to_csv`) based on the pandas native implementations. `read_*` and `to_*` operations support file patterns and any Beam-compatible file system. The grouping is accomplished with a group-by-key, and arbitrary pandas operations (in this case, `sum`) can be applied before the final write that occurs with `to_csv`.
-The Beam DataFrame API aims to be compatible with the native pandas implementation, with a few caveats detailed below in [Differences from standard pandas]({{< ref "#differences_from_standard_pandas" >}}).
+The Beam DataFrame API aims to be compatible with the native pandas implementation, with a few caveats detailed below in [Differences from standard pandas](/documentation/dsls/dataframes/differences-from-pandas/).
## Embedding DataFrames in a pipeline
diff --git a/website/www/site/content/en/documentation/transforms/python/overview.md b/website/www/site/content/en/documentation/transforms/python/overview.md
index b3ba546..71d5f1e 100644
--- a/website/www/site/content/en/documentation/transforms/python/overview.md
+++ b/website/www/site/content/en/documentation/transforms/python/overview.md
@@ -54,6 +54,7 @@
<tr><td><a href="/documentation/transforms/python/aggregation/count">Count</a></td><td>Counts the number of elements within each aggregation.</td></tr>
<tr><td><a href="/documentation/transforms/python/aggregation/distinct">Distinct</a></td><td>Produces a collection containing distinct elements from the input collection.</td></tr>
<tr><td><a href="/documentation/transforms/python/aggregation/groupbykey">GroupByKey</a></td><td>Takes a keyed collection of elements and produces a collection where each element consists of a key and all values associated with that key.</td></tr>
+ <tr><td><a href="/documentation/transforms/python/aggregation/groupby">GroupBy</a></td><td>Takes a collection of elements and produces a collection grouped, by properties of those elements. Unlike GroupByKey, the key is dynamically created from the elements themselves.</td></tr>
<tr><td><a href="/documentation/transforms/python/aggregation/groupintobatches">GroupIntoBatches</a></td><td>Batches the input into desired batch size.</td></tr>
<tr><td><a href="/documentation/transforms/python/aggregation/latest">Latest</a></td><td>Gets the element with the latest timestamp.</td></tr>
<tr><td><a href="/documentation/transforms/python/aggregation/max">Max</a></td><td>Gets the element with the maximum value within each aggregation.</td></tr>