[BEAM-8157] Follow-up to adjust comment and improve ProcessBundleDescriptorsTest
This just extends the logic of the test to also check if the LengthPrefixCoder
wrapped key coder remains unchanged.
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index a2129c4..f2d374a 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -164,10 +164,10 @@
}
/**
- * Patches the input coder of a stateful Executable transform to ensure that the byte
- * representation of a key used to partition the input element at the Runner, matches the key byte
- * representation received for state requests and timers from the SDK Harness. Stateful transforms
- * always have a KvCoder as input.
+ * Patches the input coder of a stateful transform to ensure that the byte representation of a key
+ * used to partition the input element at the Runner, matches the key byte representation received
+ * for state requests and timers from the SDK Harness. Stateful transforms always have a KvCoder
+ * as input.
*/
private static void lengthPrefixKeyCoder(
String inputColId, Components.Builder componentsBuilder) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
index b558b00..ccabb2e 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -33,6 +33,7 @@
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.FusedPipeline;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
@@ -54,7 +55,7 @@
public class ProcessBundleDescriptorsTest implements Serializable {
/**
- * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
+ * Tests that a stateful stage will wrap the key coder of a stateful transform in a
* LengthPrefixCoder.
*/
@Test
@@ -104,27 +105,40 @@
stage.getUserStates().stream()
.anyMatch(spec -> spec.localName().equals("stateId")));
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+
ExecutableStage stage = optionalStage.get();
+ PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
- ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
+ // Ensure original key coder is not a LengthPrefixCoder
+ Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
+ RunnerApi.Coder originalCoder =
+ stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
+ String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalCoder).keyCoderId();
+ assertThat(
+ stageCoderMap.get(originalKeyCoderId).getSpec().getUrn(),
+ is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+
+ // Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
+ BeamFnApi.ProcessBundleDescriptor pbDescriptor =
ProcessBundleDescriptors.fromExecutableStage(
- "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());
+ "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance())
+ .getProcessBundleDescriptor();
- BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
- String inputColId = stage.getInputPCollection().getId();
- String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();
+ String inputPCollectionId = inputPCollection.getId();
+ String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputPCollectionId).getCoderId();
- Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
- RunnerApi.Coder coder = codersMap.get(inputCoderId);
+ Map<String, RunnerApi.Coder> pbCoderMap = pbDescriptor.getCodersMap();
+ RunnerApi.Coder coder = pbCoderMap.get(inputCoderId);
String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
- assertThat(
- codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+ RunnerApi.Coder keyCoder = pbCoderMap.get(keyCoderId);
+ // Ensure length prefix
+ assertThat(keyCoder.getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+ String lengthPrefixWrappedCoderId = keyCoder.getComponentCoderIds(0);
- RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
- String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
+ // Check that the wrapped coder is unchanged
+ assertThat(lengthPrefixWrappedCoderId, is(originalKeyCoderId));
assertThat(
- stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
- is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+ pbCoderMap.get(lengthPrefixWrappedCoderId), is(stageCoderMap.get(originalKeyCoderId)));
}
}