blob: ccabb2e75b179a38f9ef33a5141e7e8a4af8d303 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.fnexecution.control;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import java.io.Serializable;
import java.util.Map;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
import org.apache.beam.runners.core.construction.ModelCoders;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.FusedPipeline;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.junit.Test;
/** Tests for {@link ProcessBundleDescriptors}. */
public class ProcessBundleDescriptorsTest implements Serializable {
/**
* Tests that a stateful stage will wrap the key coder of a stateful transform in a
* LengthPrefixCoder.
*/
@Test
public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
// Add another stateful stage with a non-standard key coder
Pipeline p = Pipeline.create();
Coder<Void> keycoder = VoidCoder.of();
assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
p.apply("impulse", Impulse.create())
.apply(
"create",
ParDo.of(
new DoFn<byte[], KV<Void, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {}
}))
.setCoder(KvCoder.of(keycoder, StringUtf8Coder.of()))
.apply(
"userState",
ParDo.of(
new DoFn<KV<Void, String>, KV<Void, String>>() {
@StateId("stateId")
private final StateSpec<BagState<String>> bufferState =
StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(
@Element KV<Void, String> element,
@StateId("stateId") BagState<String> state,
OutputReceiver<KV<Void, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
state.add(element.getValue());
}
}))
// Force the output to be materialized
.apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage =
Iterables.tryFind(
fused.getFusedStages(),
(ExecutableStage stage) ->
stage.getUserStates().stream()
.anyMatch(spec -> spec.localName().equals("stateId")));
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
// Ensure original key coder is not a LengthPrefixCoder
Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
RunnerApi.Coder originalCoder =
stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalCoder).keyCoderId();
assertThat(
stageCoderMap.get(originalKeyCoderId).getSpec().getUrn(),
is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
// Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
BeamFnApi.ProcessBundleDescriptor pbDescriptor =
ProcessBundleDescriptors.fromExecutableStage(
"test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance())
.getProcessBundleDescriptor();
String inputPCollectionId = inputPCollection.getId();
String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputPCollectionId).getCoderId();
Map<String, RunnerApi.Coder> pbCoderMap = pbDescriptor.getCodersMap();
RunnerApi.Coder coder = pbCoderMap.get(inputCoderId);
String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
RunnerApi.Coder keyCoder = pbCoderMap.get(keyCoderId);
// Ensure length prefix
assertThat(keyCoder.getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
String lengthPrefixWrappedCoderId = keyCoder.getComponentCoderIds(0);
// Check that the wrapped coder is unchanged
assertThat(lengthPrefixWrappedCoderId, is(originalKeyCoderId));
assertThat(
pbCoderMap.get(lengthPrefixWrappedCoderId), is(stageCoderMap.get(originalKeyCoderId)));
}
}