Merge pull request #10013: [BEAM-8554] Use WorkItemCommitRequest protobuf fields to signal that …
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index a603d16..4def350 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -1148,6 +1148,15 @@
}
}
+ private Windmill.WorkItemCommitRequest.Builder initializeOutputBuilder(
+ final ByteString key, final Windmill.WorkItem workItem) {
+ return Windmill.WorkItemCommitRequest.newBuilder()
+ .setKey(key)
+ .setShardingKey(workItem.getShardingKey())
+ .setWorkToken(workItem.getWorkToken())
+ .setCacheToken(workItem.getCacheToken());
+ }
+
private void process(
final SdkWorkerHarness worker,
final ComputationState computationState,
@@ -1164,12 +1173,7 @@
DataflowWorkerLoggingMDC.setStageName(computationId);
LOG.debug("Starting processing for {}:\n{}", computationId, work);
- Windmill.WorkItemCommitRequest.Builder outputBuilder =
- Windmill.WorkItemCommitRequest.newBuilder()
- .setKey(key)
- .setShardingKey(workItem.getShardingKey())
- .setWorkToken(workItem.getWorkToken())
- .setCacheToken(workItem.getCacheToken());
+ Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
// Before any processing starts, call any pending OnCommit callbacks. Nothing that requires
// cleanup should be done before this, since we might exit early here.
@@ -1334,20 +1338,22 @@
WorkItemCommitRequest commitRequest = outputBuilder.build();
int byteLimit = maxWorkItemCommitBytes;
int commitSize = commitRequest.getSerializedSize();
+ int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize;
+
// Detect overflow of integer serialized size or if the byte limit was exceeded.
- windmillMaxObservedWorkItemCommitBytes.addValue(
- commitSize < 0 ? Integer.MAX_VALUE : commitSize);
- if (commitSize < 0) {
- throw KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
- } else if (commitSize > byteLimit) {
- // Once supported, we should communicate the desired truncation for the commit to the
- // streaming engine. For now we report the error but attempt the commit so that it will be
- // truncated by the streaming engine backend.
+ windmillMaxObservedWorkItemCommitBytes.addValue(estimatedCommitSize);
+ if (estimatedCommitSize > byteLimit) {
KeyCommitTooLargeException e =
KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
reportFailure(computationId, workItem, e);
LOG.error(e.toString());
+
+ // Drop the current request in favor of a new, minimal one requesting truncation.
+ // Messages, timers, counters, and other commit content will not be used by the service
+ // so we're purposefully dropping them here
+ commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize);
}
+
commitQueue.put(new Commit(commitRequest, computationState, work));
// Compute shuffle and state byte statistics these will be flushed asynchronously.
@@ -1442,6 +1448,14 @@
}
}
+ private WorkItemCommitRequest buildWorkItemTruncationRequest(
+ final ByteString key, final Windmill.WorkItem workItem, final int estimatedCommitSize) {
+ Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
+ outputBuilder.setExceedsMaxWorkItemCommitBytes(true);
+ outputBuilder.setEstimatedWorkItemCommitBytes(estimatedCommitSize);
+ return outputBuilder.build();
+ }
+
private void commitLoop() {
Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder> computationRequestMap =
new HashMap<>();
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 356dddb..2cb5ada 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -562,37 +562,70 @@
private WorkItemCommitRequest.Builder makeExpectedOutput(
int index, long timestamp, String key, String outKey) throws Exception {
+ StringBuilder expectedCommitRequestBuilder = initializeExpectedCommitRequest(key, index);
+ appendCommitOutputMessages(expectedCommitRequestBuilder, index, timestamp, outKey);
+
return setMessagesMetadata(
PaneInfo.NO_FIRING,
intervalWindowBytes(DEFAULT_WINDOW),
- parseCommitRequest(
- "key: \""
- + key
- + "\" "
- + "sharding_key: 17 "
- + "work_token: "
- + index
- + " "
- + "cache_token: 3 "
- + "output_messages {"
- + " destination_stream_id: \""
- + DEFAULT_DESTINATION_STREAM_ID
- + "\""
- + " bundles {"
- + " key: \""
- + outKey
- + "\""
- + " messages {"
- + " timestamp: "
- + timestamp
- + " data: \""
- + dataStringForIndex(index)
- + "\""
- + " metadata: \"\""
- + " }"
- + " messages_ids: \"\""
- + " }"
- + "}"));
+ parseCommitRequest(expectedCommitRequestBuilder.toString()));
+ }
+
+ private WorkItemCommitRequest.Builder makeExpectedTruncationRequestOutput(
+ int index, String key, long estimatedSize) throws Exception {
+ StringBuilder expectedCommitRequestBuilder = initializeExpectedCommitRequest(key, index);
+ appendCommitTruncationFields(expectedCommitRequestBuilder, estimatedSize);
+
+ return parseCommitRequest(expectedCommitRequestBuilder.toString());
+ }
+
+ private StringBuilder initializeExpectedCommitRequest(String key, int index) {
+ StringBuilder requestBuilder = new StringBuilder();
+
+ requestBuilder.append("key: \"");
+ requestBuilder.append(key);
+ requestBuilder.append("\" ");
+ requestBuilder.append("sharding_key: 17 ");
+ requestBuilder.append("work_token: ");
+ requestBuilder.append(index);
+ requestBuilder.append(" ");
+ requestBuilder.append("cache_token: 3 ");
+
+ return requestBuilder;
+ }
+
+ private StringBuilder appendCommitOutputMessages(
+ StringBuilder requestBuilder, int index, long timestamp, String outKey) {
+ requestBuilder.append("output_messages {");
+ requestBuilder.append(" destination_stream_id: \"");
+ requestBuilder.append(DEFAULT_DESTINATION_STREAM_ID);
+ requestBuilder.append("\"");
+ requestBuilder.append(" bundles {");
+ requestBuilder.append(" key: \"");
+ requestBuilder.append(outKey);
+ requestBuilder.append("\"");
+ requestBuilder.append(" messages {");
+ requestBuilder.append(" timestamp: ");
+ requestBuilder.append(timestamp);
+ requestBuilder.append(" data: \"");
+ requestBuilder.append(dataStringForIndex(index));
+ requestBuilder.append("\"");
+ requestBuilder.append(" metadata: \"\"");
+ requestBuilder.append(" }");
+ requestBuilder.append(" messages_ids: \"\"");
+ requestBuilder.append(" }");
+ requestBuilder.append("}");
+
+ return requestBuilder;
+ }
+
+ private StringBuilder appendCommitTruncationFields(
+ StringBuilder requestBuilder, long estimatedSize) {
+ requestBuilder.append("exceeds_max_work_item_commit_bytes: true ");
+ requestBuilder.append("estimated_work_item_commit_bytes: ");
+ requestBuilder.append(estimatedSize);
+
+ return requestBuilder;
}
private StreamingComputationConfig makeDefaultStreamingComputationConfig(
@@ -948,64 +981,19 @@
assertEquals(2, result.size());
assertEquals(makeExpectedOutput(2, 0, "key", "key").build(), result.get(2L));
+
assertTrue(result.containsKey(1L));
- assertEquals("large_key", result.get(1L).getKey().toStringUtf8());
- assertTrue(result.get(1L).getSerializedSize() > 1000);
+ WorkItemCommitRequest largeCommit = result.get(1L);
+ assertEquals("large_key", largeCommit.getKey().toStringUtf8());
+ assertEquals(
+ makeExpectedTruncationRequestOutput(
+ 1, "large_key", largeCommit.getEstimatedWorkItemCommitBytes())
+ .build(),
+ largeCommit);
- // Spam worker updates a few times.
- int maxTries = 10;
- while (--maxTries > 0) {
- worker.reportPeriodicWorkerUpdates();
- Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
- }
-
- // We should see an exception reported for the large commit but not the small one.
- ArgumentCaptor<WorkItemStatus> workItemStatusCaptor =
- ArgumentCaptor.forClass(WorkItemStatus.class);
- verify(mockWorkUnitClient, atLeast(2)).reportWorkItemStatus(workItemStatusCaptor.capture());
- List<WorkItemStatus> capturedStatuses = workItemStatusCaptor.getAllValues();
- boolean foundErrors = false;
- for (WorkItemStatus status : capturedStatuses) {
- if (!status.getErrors().isEmpty()) {
- assertFalse(foundErrors);
- foundErrors = true;
- String errorMessage = status.getErrors().get(0).getMessage();
- assertThat(errorMessage, Matchers.containsString("KeyCommitTooLargeException"));
- }
- }
- assertTrue(foundErrors);
- }
-
- @Test
- public void testKeyCommitTooLargeException_StreamingEngine() throws Exception {
- KvCoder<String, String> kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of());
-
- List<ParallelInstruction> instructions =
- Arrays.asList(
- makeSourceInstruction(kvCoder),
- makeDoFnInstruction(new LargeCommitFn(), 0, kvCoder),
- makeSinkInstruction(kvCoder, 1));
-
- FakeWindmillServer server = new FakeWindmillServer(errorCollector);
- server.setExpectedExceptionCount(1);
-
- StreamingDataflowWorkerOptions options =
- createTestingPipelineOptions(server, "--experiments=enable_streaming_engine");
- StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */);
- worker.setMaxWorkItemCommitBytes(1000);
- worker.start();
-
- server.addWorkToOffer(makeInput(1, 0, "large_key"));
- server.addWorkToOffer(makeInput(2, 0, "key"));
- server.waitForEmptyWorkQueue();
-
- Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(1);
-
- assertEquals(2, result.size());
- assertEquals(makeExpectedOutput(2, 0, "key", "key").build(), result.get(2L));
- assertTrue(result.containsKey(1L));
- assertEquals("large_key", result.get(1L).getKey().toStringUtf8());
- assertTrue(result.get(1L).getSerializedSize() > 1000);
+ // Check this explicitly since the estimated commit bytes weren't actually
+ // checked against an expected value in the previous step
+ assertTrue(largeCommit.getEstimatedWorkItemCommitBytes() > 1000);
// Spam worker updates a few times.
int maxTries = 10;
diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 5310902..598164e 100644
--- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -272,13 +272,16 @@
optional string state_family = 3;
}
-// next id: 19
+// next id: 24
message WorkItemCommitRequest {
required bytes key = 1;
required fixed64 work_token = 2;
optional fixed64 sharding_key = 15;
optional fixed64 cache_token = 16;
+ optional bool exceeds_max_work_item_commit_bytes = 20;
+ optional int64 estimated_work_item_commit_bytes = 21;
+
repeated OutputMessageBundle output_messages = 3;
repeated PubSubMessageBundle pubsub_messages = 7;
repeated Timer output_timers = 4;
@@ -290,12 +293,14 @@
optional SourceState source_state_updates = 12;
optional int64 source_watermark = 13 [default=-0x8000000000000000];
optional int64 source_backlog_bytes = 17 [default=-1];
+ optional int64 source_bytes_processed = 22;
+
repeated WatermarkHold watermark_holds = 14;
// DEPRECATED
repeated GlobalDataId global_data_id_requests = 9;
- reserved 6;
+ reserved 6, 19, 23;
}
message ComputationCommitWorkRequest {