blob: caff55aad827c0131850a0ec88b44dcd028121f7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import java.util.HashMap;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine.BinaryCombineLongFn;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
/** Tests for {@link ParDoTranslation}. */
public class ParDoTranslationTest {
/** Tests for translating various {@link ParDo} transforms to/from {@link ParDoPayload} protos. */
@RunWith(Parameterized.class)
public static class TestParDoPayloadTranslation {
public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
private static PCollectionView<Long> singletonSideInput =
p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)).apply(View.asSingleton());
private static PCollectionView<Map<Long, Iterable<String>>> multimapSideInput =
p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam")))
.setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))
.apply(View.asMultimap());
private static PCollection<KV<Long, String>> mainInput =
p.apply(
"CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())));
@Parameters(name = "{index}: {0}")
public static Iterable<ParDo.MultiOutput<?, ?>> data() {
return ImmutableList.of(
ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag<>(), TupleTagList.empty()),
ParDo.of(new DropElementsFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty())
.withSideInputs(singletonSideInput, multimapSideInput),
ParDo.of(new DropElementsFn())
.withOutputTags(
new TupleTag<>(),
TupleTagList.of(new TupleTag<byte[]>() {}).and(new TupleTag<Integer>() {}))
.withSideInputs(singletonSideInput, multimapSideInput),
ParDo.of(new DropElementsFn())
.withOutputTags(
new TupleTag<>(),
TupleTagList.of(new TupleTag<byte[]>() {}).and(new TupleTag<Integer>() {})),
ParDo.of(new SplittableDropElementsFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()),
ParDo.of(new StateTimerDropElementsFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()));
}
@Parameter(0)
public ParDo.MultiOutput<KV<Long, String>, Void> parDo;
@Test
public void testToProto() throws Exception {
SdkComponents components = SdkComponents.create();
components.registerEnvironment(Environments.createDockerEnvironment("java"));
ParDoPayload payload =
ParDoTranslation.translateParDo(parDo, DoFnSchemaInformation.create(), p, components);
assertThat(ParDoTranslation.getDoFn(payload), equalTo(parDo.getFn()));
assertThat(ParDoTranslation.getMainOutputTag(payload), equalTo(parDo.getMainOutputTag()));
for (PCollectionView<?> view : parDo.getSideInputs().values()) {
payload.getSideInputsOrThrow(view.getTagInternal().getId());
}
}
@Test
public void toTransformProto() throws Exception {
Map<TupleTag<?>, PValue> inputs = new HashMap<>();
inputs.put(new TupleTag<KV<Long, String>>("mainInputName") {}, mainInput);
inputs.putAll(parDo.getAdditionalInputs());
PCollectionTuple output = mainInput.apply(parDo);
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
// Encode
RunnerApi.PTransform protoTransform =
PTransformTranslation.toProto(
AppliedPTransform.<PCollection<KV<Long, String>>, PCollection<Void>, MultiOutput>of(
"foo", inputs, output.expand(), parDo, p),
sdkComponents);
RunnerApi.Components components = sdkComponents.toComponents();
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components);
// Decode
ParDoPayload parDoPayload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload());
for (PCollectionView<?> view : parDo.getSideInputs().values()) {
SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId());
PCollectionView<?> restoredView =
PCollectionViewTranslation.viewFromProto(
sideInput,
view.getTagInternal().getId(),
view.getPCollection(),
protoTransform,
rehydratedComponents);
assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal()));
assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass()));
assertThat(
restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass()));
assertThat(
restoredView.getWindowingStrategyInternal(),
equalTo(view.getWindowingStrategyInternal().fixDefaults()));
assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal()));
}
String mainInputId = sdkComponents.registerPCollection(mainInput);
assertThat(
ParDoTranslation.getMainInput(protoTransform, components),
equalTo(components.getPcollectionsOrThrow(mainInputId)));
assertThat(ParDoTranslation.getMainInputName(protoTransform), equalTo("mainInputName"));
// Validate that the timer PCollections are added correctly.
DoFnSignature signature = DoFnSignatures.signatureForDoFn(parDo.getFn());
for (String localTimerName : signature.timerDeclarations().keySet()) {
RunnerApi.PCollection timerPCollection =
components.getPcollectionsOrThrow(String.format("foo.%s", localTimerName));
assertEquals(
components.getPcollectionsOrThrow(mainInputId).getIsBounded(),
timerPCollection.getIsBounded());
assertEquals(
components.getPcollectionsOrThrow(mainInputId).getWindowingStrategyId(),
timerPCollection.getWindowingStrategyId());
ModelCoders.KvCoderComponents timerKvCoderComponents =
ModelCoders.getKvCoderComponents(
components.getCodersOrThrow(timerPCollection.getCoderId()));
Coder<?> timerKeyCoder =
CoderTranslation.fromProto(
components.getCodersOrThrow(timerKvCoderComponents.keyCoderId()),
rehydratedComponents);
assertEquals(VarLongCoder.of(), timerKeyCoder);
Coder<?> timerValueCoder =
CoderTranslation.fromProto(
components.getCodersOrThrow(timerKvCoderComponents.valueCoderId()),
rehydratedComponents);
assertEquals(
org.apache.beam.runners.core.construction.Timer.Coder.of(VoidCoder.of()),
timerValueCoder);
}
}
}
/** Tests for translating state and timer bits to/from protos. */
@RunWith(Parameterized.class)
public static class TestStateAndTimerTranslation {
@Parameters(name = "{index}: {0}")
public static Iterable<StateSpec<?>> stateSpecs() {
return ImmutableList.of(
StateSpecs.value(VarIntCoder.of()),
StateSpecs.bag(VarIntCoder.of()),
StateSpecs.set(VarIntCoder.of()),
StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()));
}
@Parameter public StateSpec<?> stateSpec;
@Test
public void testStateSpecToFromProto() throws Exception {
// Encode
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
RunnerApi.StateSpec stateSpecProto =
ParDoTranslation.translateStateSpec(stateSpec, sdkComponents);
// Decode
RehydratedComponents rehydratedComponents =
RehydratedComponents.forComponents(sdkComponents.toComponents());
StateSpec<?> deserializedStateSpec =
ParDoTranslation.fromProto(stateSpecProto, rehydratedComponents);
assertThat(stateSpec, equalTo(deserializedStateSpec));
}
}
private static class DropElementsFn extends DoFn<KV<Long, String>, Void> {
@ProcessElement
public void proc(ProcessContext context, BoundedWindow window) {
context.output(null);
}
@Override
public boolean equals(Object other) {
return other instanceof DropElementsFn;
}
@Override
public int hashCode() {
return DropElementsFn.class.hashCode();
}
}
private static class SplittableDropElementsFn extends DoFn<KV<Long, String>, Void> {
@ProcessElement
public void proc(ProcessContext context, RestrictionTracker<Integer, ?> restriction) {
context.output(null);
}
@GetInitialRestriction
public Integer restriction(KV<Long, String> elem) {
return 42;
}
@NewTracker
public RestrictionTracker<Integer, ?> newTracker(Integer restriction) {
throw new UnsupportedOperationException("Should never be called; only to test translation");
}
@Override
public boolean equals(Object other) {
return other instanceof SplittableDropElementsFn;
}
@Override
public int hashCode() {
return SplittableDropElementsFn.class.hashCode();
}
}
@SuppressWarnings("unused")
private static class StateTimerDropElementsFn extends DoFn<KV<Long, String>, Void> {
private static final String BAG_STATE_ID = "bagState";
private static final String COMBINING_STATE_ID = "combiningState";
private static final String EVENT_TIMER_ID = "eventTimer";
private static final String PROCESSING_TIMER_ID = "processingTimer";
@StateId(BAG_STATE_ID)
private final StateSpec<BagState<String>> bagState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(COMBINING_STATE_ID)
private final StateSpec<CombiningState<Long, long[], Long>> combiningState =
StateSpecs.combining(
new BinaryCombineLongFn() {
@Override
public long apply(long left, long right) {
return Math.max(left, right);
}
@Override
public long identity() {
return Long.MIN_VALUE;
}
});
@TimerId(EVENT_TIMER_ID)
private final TimerSpec eventTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME);
@TimerId(PROCESSING_TIMER_ID)
private final TimerSpec processingTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
@ProcessElement
public void dropInput(
ProcessContext context,
BoundedWindow window,
@StateId(BAG_STATE_ID) BagState<String> bagStateState,
@StateId(COMBINING_STATE_ID) CombiningState<Long, long[], Long> combiningStateState,
@TimerId(EVENT_TIMER_ID) Timer eventTimerTimer,
@TimerId(PROCESSING_TIMER_ID) Timer processingTimerTimer) {
context.output(null);
}
@OnTimer(EVENT_TIMER_ID)
public void onEventTime(OnTimerContext context) {}
@OnTimer(PROCESSING_TIMER_ID)
public void onProcessingTime(OnTimerContext context) {}
@Override
public boolean equals(Object other) {
return other instanceof StateTimerDropElementsFn;
}
@Override
public int hashCode() {
return StateTimerDropElementsFn.class.hashCode();
}
}
}