Merge pull request #10013: [BEAM-8554] Use WorkItemCommitRequest protobuf fields to signal that …

diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index a603d16..4def350 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -1148,6 +1148,15 @@
     }
   }
 
+  private Windmill.WorkItemCommitRequest.Builder initializeOutputBuilder(
+      final ByteString key, final Windmill.WorkItem workItem) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(key)
+        .setShardingKey(workItem.getShardingKey())
+        .setWorkToken(workItem.getWorkToken())
+        .setCacheToken(workItem.getCacheToken());
+  }
+
   private void process(
       final SdkWorkerHarness worker,
       final ComputationState computationState,
@@ -1164,12 +1173,7 @@
     DataflowWorkerLoggingMDC.setStageName(computationId);
     LOG.debug("Starting processing for {}:\n{}", computationId, work);
 
-    Windmill.WorkItemCommitRequest.Builder outputBuilder =
-        Windmill.WorkItemCommitRequest.newBuilder()
-            .setKey(key)
-            .setShardingKey(workItem.getShardingKey())
-            .setWorkToken(workItem.getWorkToken())
-            .setCacheToken(workItem.getCacheToken());
+    Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
 
     // Before any processing starts, call any pending OnCommit callbacks.  Nothing that requires
     // cleanup should be done before this, since we might exit early here.
@@ -1334,20 +1338,22 @@
       WorkItemCommitRequest commitRequest = outputBuilder.build();
       int byteLimit = maxWorkItemCommitBytes;
       int commitSize = commitRequest.getSerializedSize();
+      int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize;
+
       // Detect overflow of integer serialized size or if the byte limit was exceeded.
-      windmillMaxObservedWorkItemCommitBytes.addValue(
-          commitSize < 0 ? Integer.MAX_VALUE : commitSize);
-      if (commitSize < 0) {
-        throw KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
-      } else if (commitSize > byteLimit) {
-        // Once supported, we should communicate the desired truncation for the commit to the
-        // streaming engine. For now we report the error but attempt the commit so that it will be
-        // truncated by the streaming engine backend.
+      windmillMaxObservedWorkItemCommitBytes.addValue(estimatedCommitSize);
+      if (estimatedCommitSize > byteLimit) {
         KeyCommitTooLargeException e =
             KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
         reportFailure(computationId, workItem, e);
         LOG.error(e.toString());
+
+        // Drop the current request in favor of a new, minimal one requesting truncation.
+        // Messages, timers, counters, and other commit content will not be used by the service
+        // so we're purposefully dropping them here
+        commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize);
       }
+
       commitQueue.put(new Commit(commitRequest, computationState, work));
 
       // Compute shuffle and state byte statistics these will be flushed asynchronously.
@@ -1442,6 +1448,14 @@
     }
   }
 
+  private WorkItemCommitRequest buildWorkItemTruncationRequest(
+      final ByteString key, final Windmill.WorkItem workItem, final int estimatedCommitSize) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
+    outputBuilder.setExceedsMaxWorkItemCommitBytes(true);
+    outputBuilder.setEstimatedWorkItemCommitBytes(estimatedCommitSize);
+    return outputBuilder.build();
+  }
+
   private void commitLoop() {
     Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder> computationRequestMap =
         new HashMap<>();
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 356dddb..2cb5ada 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -562,37 +562,70 @@
 
   private WorkItemCommitRequest.Builder makeExpectedOutput(
       int index, long timestamp, String key, String outKey) throws Exception {
+    StringBuilder expectedCommitRequestBuilder = initializeExpectedCommitRequest(key, index);
+    appendCommitOutputMessages(expectedCommitRequestBuilder, index, timestamp, outKey);
+
     return setMessagesMetadata(
         PaneInfo.NO_FIRING,
         intervalWindowBytes(DEFAULT_WINDOW),
-        parseCommitRequest(
-            "key: \""
-                + key
-                + "\" "
-                + "sharding_key: 17 "
-                + "work_token: "
-                + index
-                + " "
-                + "cache_token: 3 "
-                + "output_messages {"
-                + "  destination_stream_id: \""
-                + DEFAULT_DESTINATION_STREAM_ID
-                + "\""
-                + "  bundles {"
-                + "    key: \""
-                + outKey
-                + "\""
-                + "    messages {"
-                + "      timestamp: "
-                + timestamp
-                + "      data: \""
-                + dataStringForIndex(index)
-                + "\""
-                + "      metadata: \"\""
-                + "    }"
-                + "    messages_ids: \"\""
-                + "  }"
-                + "}"));
+        parseCommitRequest(expectedCommitRequestBuilder.toString()));
+  }
+
+  private WorkItemCommitRequest.Builder makeExpectedTruncationRequestOutput(
+      int index, String key, long estimatedSize) throws Exception {
+    StringBuilder expectedCommitRequestBuilder = initializeExpectedCommitRequest(key, index);
+    appendCommitTruncationFields(expectedCommitRequestBuilder, estimatedSize);
+
+    return parseCommitRequest(expectedCommitRequestBuilder.toString());
+  }
+
+  private StringBuilder initializeExpectedCommitRequest(String key, int index) {
+    StringBuilder requestBuilder = new StringBuilder();
+
+    requestBuilder.append("key: \"");
+    requestBuilder.append(key);
+    requestBuilder.append("\" ");
+    requestBuilder.append("sharding_key: 17 ");
+    requestBuilder.append("work_token: ");
+    requestBuilder.append(index);
+    requestBuilder.append(" ");
+    requestBuilder.append("cache_token: 3 ");
+
+    return requestBuilder;
+  }
+
+  private StringBuilder appendCommitOutputMessages(
+      StringBuilder requestBuilder, int index, long timestamp, String outKey) {
+    requestBuilder.append("output_messages {");
+    requestBuilder.append("  destination_stream_id: \"");
+    requestBuilder.append(DEFAULT_DESTINATION_STREAM_ID);
+    requestBuilder.append("\"");
+    requestBuilder.append("  bundles {");
+    requestBuilder.append("    key: \"");
+    requestBuilder.append(outKey);
+    requestBuilder.append("\"");
+    requestBuilder.append("    messages {");
+    requestBuilder.append("      timestamp: ");
+    requestBuilder.append(timestamp);
+    requestBuilder.append("      data: \"");
+    requestBuilder.append(dataStringForIndex(index));
+    requestBuilder.append("\"");
+    requestBuilder.append("      metadata: \"\"");
+    requestBuilder.append("    }");
+    requestBuilder.append("    messages_ids: \"\"");
+    requestBuilder.append("  }");
+    requestBuilder.append("}");
+
+    return requestBuilder;
+  }
+
+  private StringBuilder appendCommitTruncationFields(
+      StringBuilder requestBuilder, long estimatedSize) {
+    requestBuilder.append("exceeds_max_work_item_commit_bytes: true ");
+    requestBuilder.append("estimated_work_item_commit_bytes: ");
+    requestBuilder.append(estimatedSize);
+
+    return requestBuilder;
   }
 
   private StreamingComputationConfig makeDefaultStreamingComputationConfig(
@@ -948,64 +981,19 @@
 
     assertEquals(2, result.size());
     assertEquals(makeExpectedOutput(2, 0, "key", "key").build(), result.get(2L));
+
     assertTrue(result.containsKey(1L));
-    assertEquals("large_key", result.get(1L).getKey().toStringUtf8());
-    assertTrue(result.get(1L).getSerializedSize() > 1000);
+    WorkItemCommitRequest largeCommit = result.get(1L);
+    assertEquals("large_key", largeCommit.getKey().toStringUtf8());
+    assertEquals(
+        makeExpectedTruncationRequestOutput(
+                1, "large_key", largeCommit.getEstimatedWorkItemCommitBytes())
+            .build(),
+        largeCommit);
 
-    // Spam worker updates a few times.
-    int maxTries = 10;
-    while (--maxTries > 0) {
-      worker.reportPeriodicWorkerUpdates();
-      Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
-    }
-
-    // We should see an exception reported for the large commit but not the small one.
-    ArgumentCaptor<WorkItemStatus> workItemStatusCaptor =
-        ArgumentCaptor.forClass(WorkItemStatus.class);
-    verify(mockWorkUnitClient, atLeast(2)).reportWorkItemStatus(workItemStatusCaptor.capture());
-    List<WorkItemStatus> capturedStatuses = workItemStatusCaptor.getAllValues();
-    boolean foundErrors = false;
-    for (WorkItemStatus status : capturedStatuses) {
-      if (!status.getErrors().isEmpty()) {
-        assertFalse(foundErrors);
-        foundErrors = true;
-        String errorMessage = status.getErrors().get(0).getMessage();
-        assertThat(errorMessage, Matchers.containsString("KeyCommitTooLargeException"));
-      }
-    }
-    assertTrue(foundErrors);
-  }
-
-  @Test
-  public void testKeyCommitTooLargeException_StreamingEngine() throws Exception {
-    KvCoder<String, String> kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of());
-
-    List<ParallelInstruction> instructions =
-        Arrays.asList(
-            makeSourceInstruction(kvCoder),
-            makeDoFnInstruction(new LargeCommitFn(), 0, kvCoder),
-            makeSinkInstruction(kvCoder, 1));
-
-    FakeWindmillServer server = new FakeWindmillServer(errorCollector);
-    server.setExpectedExceptionCount(1);
-
-    StreamingDataflowWorkerOptions options =
-        createTestingPipelineOptions(server, "--experiments=enable_streaming_engine");
-    StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */);
-    worker.setMaxWorkItemCommitBytes(1000);
-    worker.start();
-
-    server.addWorkToOffer(makeInput(1, 0, "large_key"));
-    server.addWorkToOffer(makeInput(2, 0, "key"));
-    server.waitForEmptyWorkQueue();
-
-    Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(1);
-
-    assertEquals(2, result.size());
-    assertEquals(makeExpectedOutput(2, 0, "key", "key").build(), result.get(2L));
-    assertTrue(result.containsKey(1L));
-    assertEquals("large_key", result.get(1L).getKey().toStringUtf8());
-    assertTrue(result.get(1L).getSerializedSize() > 1000);
+    // Check this explicitly since the estimated commit bytes weren't actually
+    // checked against an expected value in the previous step
+    assertTrue(largeCommit.getEstimatedWorkItemCommitBytes() > 1000);
 
     // Spam worker updates a few times.
     int maxTries = 10;
diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 5310902..598164e 100644
--- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -272,13 +272,16 @@
   optional string state_family = 3;
 }
 
-// next id: 19
+// next id: 24
 message WorkItemCommitRequest {
   required bytes key = 1;
   required fixed64 work_token = 2;
   optional fixed64 sharding_key = 15;
   optional fixed64 cache_token = 16;
 
+  optional bool exceeds_max_work_item_commit_bytes = 20;
+  optional int64 estimated_work_item_commit_bytes = 21;
+
   repeated OutputMessageBundle output_messages = 3;
   repeated PubSubMessageBundle pubsub_messages = 7;
   repeated Timer output_timers = 4;
@@ -290,12 +293,14 @@
   optional SourceState source_state_updates = 12;
   optional int64 source_watermark = 13 [default=-0x8000000000000000];
   optional int64 source_backlog_bytes = 17 [default=-1];
+  optional int64 source_bytes_processed = 22;
+
   repeated WatermarkHold watermark_holds = 14;
 
   // DEPRECATED
   repeated GlobalDataId global_data_id_requests = 9;
 
-  reserved 6;
+  reserved 6, 19, 23;
 }
 
 message ComputationCommitWorkRequest {