[BEAM-8157] Follow-up to adjust comment and improve ProcessBundleDescriptorsTest

This just extends the logic of the test to also check if the LengthPrefixCoder
wrapped key coder remains unchanged.
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index a2129c4..f2d374a 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -164,10 +164,10 @@
   }
 
   /**
-   * Patches the input coder of a stateful Executable transform to ensure that the byte
-   * representation of a key used to partition the input element at the Runner, matches the key byte
-   * representation received for state requests and timers from the SDK Harness. Stateful transforms
-   * always have a KvCoder as input.
+   * Patches the input coder of a stateful transform to ensure that the byte representation of a key
+   * used to partition the input element at the Runner, matches the key byte representation received
+   * for state requests and timers from the SDK Harness. Stateful transforms always have a KvCoder
+   * as input.
    */
   private static void lengthPrefixKeyCoder(
       String inputColId, Components.Builder componentsBuilder) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
index b558b00..ccabb2e 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -33,6 +33,7 @@
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.FusedPipeline;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.runners.core.construction.graph.PipelineNode;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -54,7 +55,7 @@
 public class ProcessBundleDescriptorsTest implements Serializable {
 
   /**
-   * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
+   * Tests that a stateful stage will wrap the key coder of a stateful transform in a
    * LengthPrefixCoder.
    */
   @Test
@@ -104,27 +105,40 @@
                 stage.getUserStates().stream()
                     .anyMatch(spec -> spec.localName().equals("stateId")));
     checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+
     ExecutableStage stage = optionalStage.get();
+    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
 
-    ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
+    // Ensure original key coder is not a LengthPrefixCoder
+    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
+    RunnerApi.Coder originalCoder =
+        stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
+    String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalCoder).keyCoderId();
+    assertThat(
+        stageCoderMap.get(originalKeyCoderId).getSpec().getUrn(),
+        is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+
+    // Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
+    BeamFnApi.ProcessBundleDescriptor pbDescriptor =
         ProcessBundleDescriptors.fromExecutableStage(
-            "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());
+                "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance())
+            .getProcessBundleDescriptor();
 
-    BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
-    String inputColId = stage.getInputPCollection().getId();
-    String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();
+    String inputPCollectionId = inputPCollection.getId();
+    String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputPCollectionId).getCoderId();
 
-    Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
-    RunnerApi.Coder coder = codersMap.get(inputCoderId);
+    Map<String, RunnerApi.Coder> pbCoderMap = pbDescriptor.getCodersMap();
+    RunnerApi.Coder coder = pbCoderMap.get(inputCoderId);
     String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
 
-    assertThat(
-        codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+    RunnerApi.Coder keyCoder = pbCoderMap.get(keyCoderId);
+    // Ensure length prefix
+    assertThat(keyCoder.getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+    String lengthPrefixWrappedCoderId = keyCoder.getComponentCoderIds(0);
 
-    RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
-    String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
+    // Check that the wrapped coder is unchanged
+    assertThat(lengthPrefixWrappedCoderId, is(originalKeyCoderId));
     assertThat(
-        stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
-        is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+        pbCoderMap.get(lengthPrefixWrappedCoderId), is(stageCoderMap.get(originalKeyCoderId)));
   }
 }