| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.core.construction; |
| |
| import static org.hamcrest.Matchers.equalTo; |
| import static org.hamcrest.Matchers.instanceOf; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertThat; |
| |
| import java.util.HashMap; |
| import java.util.Map; |
| import org.apache.beam.model.pipeline.v1.RunnerApi; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.coders.StringUtf8Coder; |
| import org.apache.beam.sdk.coders.VarIntCoder; |
| import org.apache.beam.sdk.coders.VarLongCoder; |
| import org.apache.beam.sdk.coders.VoidCoder; |
| import org.apache.beam.sdk.io.GenerateSequence; |
| import org.apache.beam.sdk.runners.AppliedPTransform; |
| import org.apache.beam.sdk.state.BagState; |
| import org.apache.beam.sdk.state.CombiningState; |
| import org.apache.beam.sdk.state.StateSpec; |
| import org.apache.beam.sdk.state.StateSpecs; |
| import org.apache.beam.sdk.state.TimeDomain; |
| import org.apache.beam.sdk.state.Timer; |
| import org.apache.beam.sdk.state.TimerSpec; |
| import org.apache.beam.sdk.state.TimerSpecs; |
| import org.apache.beam.sdk.testing.TestPipeline; |
| import org.apache.beam.sdk.transforms.Combine.BinaryCombineLongFn; |
| import org.apache.beam.sdk.transforms.Create; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.DoFnSchemaInformation; |
| import org.apache.beam.sdk.transforms.ParDo; |
| import org.apache.beam.sdk.transforms.ParDo.MultiOutput; |
| import org.apache.beam.sdk.transforms.View; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; |
| import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; |
| import org.apache.beam.sdk.transforms.windowing.BoundedWindow; |
| import org.apache.beam.sdk.values.KV; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionTuple; |
| import org.apache.beam.sdk.values.PCollectionView; |
| import org.apache.beam.sdk.values.PValue; |
| import org.apache.beam.sdk.values.TupleTag; |
| import org.apache.beam.sdk.values.TupleTagList; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.Parameterized; |
| import org.junit.runners.Parameterized.Parameter; |
| import org.junit.runners.Parameterized.Parameters; |
| |
| /** Tests for {@link ParDoTranslation}. */ |
| public class ParDoTranslationTest { |
| |
| /** Tests for translating various {@link ParDo} transforms to/from {@link ParDoPayload} protos. */ |
| @RunWith(Parameterized.class) |
| public static class TestParDoPayloadTranslation { |
| public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); |
| |
| private static PCollectionView<Long> singletonSideInput = |
| p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)).apply(View.asSingleton()); |
| private static PCollectionView<Map<Long, Iterable<String>>> multimapSideInput = |
| p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam"))) |
| .setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())) |
| .apply(View.asMultimap()); |
| |
| private static PCollection<KV<Long, String>> mainInput = |
| p.apply( |
| "CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))); |
| |
| @Parameters(name = "{index}: {0}") |
| public static Iterable<ParDo.MultiOutput<?, ?>> data() { |
| return ImmutableList.of( |
| ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag<>(), TupleTagList.empty()), |
| ParDo.of(new DropElementsFn()) |
| .withOutputTags(new TupleTag<>(), TupleTagList.empty()) |
| .withSideInputs(singletonSideInput, multimapSideInput), |
| ParDo.of(new DropElementsFn()) |
| .withOutputTags( |
| new TupleTag<>(), |
| TupleTagList.of(new TupleTag<byte[]>() {}).and(new TupleTag<Integer>() {})) |
| .withSideInputs(singletonSideInput, multimapSideInput), |
| ParDo.of(new DropElementsFn()) |
| .withOutputTags( |
| new TupleTag<>(), |
| TupleTagList.of(new TupleTag<byte[]>() {}).and(new TupleTag<Integer>() {})), |
| ParDo.of(new SplittableDropElementsFn()) |
| .withOutputTags(new TupleTag<>(), TupleTagList.empty()), |
| ParDo.of(new StateTimerDropElementsFn()) |
| .withOutputTags(new TupleTag<>(), TupleTagList.empty())); |
| } |
| |
| @Parameter(0) |
| public ParDo.MultiOutput<KV<Long, String>, Void> parDo; |
| |
| @Test |
| public void testToProto() throws Exception { |
| SdkComponents components = SdkComponents.create(); |
| components.registerEnvironment(Environments.createDockerEnvironment("java")); |
| ParDoPayload payload = |
| ParDoTranslation.translateParDo(parDo, DoFnSchemaInformation.create(), p, components); |
| |
| assertThat(ParDoTranslation.getDoFn(payload), equalTo(parDo.getFn())); |
| assertThat(ParDoTranslation.getMainOutputTag(payload), equalTo(parDo.getMainOutputTag())); |
| for (PCollectionView<?> view : parDo.getSideInputs().values()) { |
| payload.getSideInputsOrThrow(view.getTagInternal().getId()); |
| } |
| } |
| |
| @Test |
| public void toTransformProto() throws Exception { |
| Map<TupleTag<?>, PValue> inputs = new HashMap<>(); |
| inputs.put(new TupleTag<KV<Long, String>>("mainInputName") {}, mainInput); |
| inputs.putAll(parDo.getAdditionalInputs()); |
| PCollectionTuple output = mainInput.apply(parDo); |
| |
| SdkComponents sdkComponents = SdkComponents.create(); |
| sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); |
| |
| // Encode |
| RunnerApi.PTransform protoTransform = |
| PTransformTranslation.toProto( |
| AppliedPTransform.<PCollection<KV<Long, String>>, PCollection<Void>, MultiOutput>of( |
| "foo", inputs, output.expand(), parDo, p), |
| sdkComponents); |
| RunnerApi.Components components = sdkComponents.toComponents(); |
| RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components); |
| |
| // Decode |
| ParDoPayload parDoPayload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload()); |
| for (PCollectionView<?> view : parDo.getSideInputs().values()) { |
| SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); |
| PCollectionView<?> restoredView = |
| PCollectionViewTranslation.viewFromProto( |
| sideInput, |
| view.getTagInternal().getId(), |
| view.getPCollection(), |
| protoTransform, |
| rehydratedComponents); |
| assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); |
| assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass())); |
| assertThat( |
| restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass())); |
| assertThat( |
| restoredView.getWindowingStrategyInternal(), |
| equalTo(view.getWindowingStrategyInternal().fixDefaults())); |
| assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal())); |
| } |
| String mainInputId = sdkComponents.registerPCollection(mainInput); |
| assertThat( |
| ParDoTranslation.getMainInput(protoTransform, components), |
| equalTo(components.getPcollectionsOrThrow(mainInputId))); |
| assertThat(ParDoTranslation.getMainInputName(protoTransform), equalTo("mainInputName")); |
| |
| // Validate that the timer PCollections are added correctly. |
| DoFnSignature signature = DoFnSignatures.signatureForDoFn(parDo.getFn()); |
| |
| for (String localTimerName : signature.timerDeclarations().keySet()) { |
| RunnerApi.PCollection timerPCollection = |
| components.getPcollectionsOrThrow(String.format("foo.%s", localTimerName)); |
| assertEquals( |
| components.getPcollectionsOrThrow(mainInputId).getIsBounded(), |
| timerPCollection.getIsBounded()); |
| assertEquals( |
| components.getPcollectionsOrThrow(mainInputId).getWindowingStrategyId(), |
| timerPCollection.getWindowingStrategyId()); |
| ModelCoders.KvCoderComponents timerKvCoderComponents = |
| ModelCoders.getKvCoderComponents( |
| components.getCodersOrThrow(timerPCollection.getCoderId())); |
| Coder<?> timerKeyCoder = |
| CoderTranslation.fromProto( |
| components.getCodersOrThrow(timerKvCoderComponents.keyCoderId()), |
| rehydratedComponents); |
| assertEquals(VarLongCoder.of(), timerKeyCoder); |
| Coder<?> timerValueCoder = |
| CoderTranslation.fromProto( |
| components.getCodersOrThrow(timerKvCoderComponents.valueCoderId()), |
| rehydratedComponents); |
| assertEquals( |
| org.apache.beam.runners.core.construction.Timer.Coder.of(VoidCoder.of()), |
| timerValueCoder); |
| } |
| } |
| } |
| |
| /** Tests for translating state and timer bits to/from protos. */ |
| @RunWith(Parameterized.class) |
| public static class TestStateAndTimerTranslation { |
| |
| @Parameters(name = "{index}: {0}") |
| public static Iterable<StateSpec<?>> stateSpecs() { |
| return ImmutableList.of( |
| StateSpecs.value(VarIntCoder.of()), |
| StateSpecs.bag(VarIntCoder.of()), |
| StateSpecs.set(VarIntCoder.of()), |
| StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of())); |
| } |
| |
| @Parameter public StateSpec<?> stateSpec; |
| |
| @Test |
| public void testStateSpecToFromProto() throws Exception { |
| // Encode |
| SdkComponents sdkComponents = SdkComponents.create(); |
| sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); |
| RunnerApi.StateSpec stateSpecProto = |
| ParDoTranslation.translateStateSpec(stateSpec, sdkComponents); |
| |
| // Decode |
| RehydratedComponents rehydratedComponents = |
| RehydratedComponents.forComponents(sdkComponents.toComponents()); |
| StateSpec<?> deserializedStateSpec = |
| ParDoTranslation.fromProto(stateSpecProto, rehydratedComponents); |
| |
| assertThat(stateSpec, equalTo(deserializedStateSpec)); |
| } |
| } |
| |
| private static class DropElementsFn extends DoFn<KV<Long, String>, Void> { |
| @ProcessElement |
| public void proc(ProcessContext context, BoundedWindow window) { |
| context.output(null); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return other instanceof DropElementsFn; |
| } |
| |
| @Override |
| public int hashCode() { |
| return DropElementsFn.class.hashCode(); |
| } |
| } |
| |
| private static class SplittableDropElementsFn extends DoFn<KV<Long, String>, Void> { |
| @ProcessElement |
| public void proc(ProcessContext context, RestrictionTracker<Integer, ?> restriction) { |
| context.output(null); |
| } |
| |
| @GetInitialRestriction |
| public Integer restriction(KV<Long, String> elem) { |
| return 42; |
| } |
| |
| @NewTracker |
| public RestrictionTracker<Integer, ?> newTracker(Integer restriction) { |
| throw new UnsupportedOperationException("Should never be called; only to test translation"); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return other instanceof SplittableDropElementsFn; |
| } |
| |
| @Override |
| public int hashCode() { |
| return SplittableDropElementsFn.class.hashCode(); |
| } |
| } |
| |
| @SuppressWarnings("unused") |
| private static class StateTimerDropElementsFn extends DoFn<KV<Long, String>, Void> { |
| private static final String BAG_STATE_ID = "bagState"; |
| private static final String COMBINING_STATE_ID = "combiningState"; |
| private static final String EVENT_TIMER_ID = "eventTimer"; |
| private static final String PROCESSING_TIMER_ID = "processingTimer"; |
| |
| @StateId(BAG_STATE_ID) |
| private final StateSpec<BagState<String>> bagState = StateSpecs.bag(StringUtf8Coder.of()); |
| |
| @StateId(COMBINING_STATE_ID) |
| private final StateSpec<CombiningState<Long, long[], Long>> combiningState = |
| StateSpecs.combining( |
| new BinaryCombineLongFn() { |
| @Override |
| public long apply(long left, long right) { |
| return Math.max(left, right); |
| } |
| |
| @Override |
| public long identity() { |
| return Long.MIN_VALUE; |
| } |
| }); |
| |
| @TimerId(EVENT_TIMER_ID) |
| private final TimerSpec eventTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME); |
| |
| @TimerId(PROCESSING_TIMER_ID) |
| private final TimerSpec processingTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); |
| |
| @ProcessElement |
| public void dropInput( |
| ProcessContext context, |
| BoundedWindow window, |
| @StateId(BAG_STATE_ID) BagState<String> bagStateState, |
| @StateId(COMBINING_STATE_ID) CombiningState<Long, long[], Long> combiningStateState, |
| @TimerId(EVENT_TIMER_ID) Timer eventTimerTimer, |
| @TimerId(PROCESSING_TIMER_ID) Timer processingTimerTimer) { |
| context.output(null); |
| } |
| |
| @OnTimer(EVENT_TIMER_ID) |
| public void onEventTime(OnTimerContext context) {} |
| |
| @OnTimer(PROCESSING_TIMER_ID) |
| public void onProcessingTime(OnTimerContext context) {} |
| |
| @Override |
| public boolean equals(Object other) { |
| return other instanceof StateTimerDropElementsFn; |
| } |
| |
| @Override |
| public int hashCode() { |
| return StateTimerDropElementsFn.class.hashCode(); |
| } |
| } |
| } |