blob: 10a2b022dcbb1b58ddff0c394f2b6c2e4724d87b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.HashMap;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine.BinaryCombineLongFn;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValues;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
/** Tests for {@link ParDoTranslation}. */
@RunWith(Enclosed.class)
@SuppressWarnings({
"rawtypes" // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
})
public class ParDoTranslationTest {
/** Tests for translating various {@link ParDo} transforms to/from {@link ParDoPayload} protos. */
@RunWith(Parameterized.class)
public static class TestParDoPayloadTranslation {
public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
private static PCollectionView<Long> singletonSideInput =
p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)).apply(View.asSingleton());
private static PCollectionView<Map<Long, Iterable<String>>> multimapSideInput =
p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam")))
.setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))
.apply(View.asMultimap());
private static PCollection<KV<Long, String>> mainInput =
p.apply(
"CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())));
@Parameters(name = "{index}: {0}")
public static Iterable<ParDo.MultiOutput<?, ?>> data() {
return ImmutableList.of(
ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag<>(), TupleTagList.empty()),
ParDo.of(new DropElementsFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty())
.withSideInputs(singletonSideInput, multimapSideInput),
ParDo.of(new DropElementsFn())
.withOutputTags(
new TupleTag<>(),
TupleTagList.of(new TupleTag<byte[]>() {}).and(new TupleTag<Integer>() {}))
.withSideInputs(singletonSideInput, multimapSideInput),
ParDo.of(new DropElementsFn())
.withOutputTags(
new TupleTag<>(),
TupleTagList.of(new TupleTag<byte[]>() {}).and(new TupleTag<Integer>() {})),
ParDo.of(new SplittableDropElementsFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()),
ParDo.of(new StateTimerDropElementsFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()));
}
@Parameter(0)
public ParDo.MultiOutput<KV<Long, String>, Void> parDo;
@Test
public void testToProto() throws Exception {
SdkComponents components = SdkComponents.create();
components.registerEnvironment(Environments.createDockerEnvironment("java"));
ParDoPayload payload =
ParDoTranslation.translateParDo(
parDo,
PCollection.createPrimitiveOutputInternal(
p,
WindowingStrategy.globalDefault(),
IsBounded.BOUNDED,
KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())),
DoFnSchemaInformation.create(),
p,
components);
assertThat(ParDoTranslation.getDoFn(payload), equalTo(parDo.getFn()));
assertThat(ParDoTranslation.getMainOutputTag(payload), equalTo(parDo.getMainOutputTag()));
for (PCollectionView<?> view : parDo.getSideInputs().values()) {
payload.getSideInputsOrThrow(view.getTagInternal().getId());
}
assertFalse(payload.getRequestsFinalization());
assertEquals(
parDo.getFn() instanceof StateTimerDropElementsFn,
components.requirements().contains(ParDoTranslation.REQUIRES_STATEFUL_PROCESSING_URN));
}
@Test
public void toTransformProto() throws Exception {
Map<TupleTag<?>, PCollection<?>> inputs = new HashMap<>();
inputs.put(new TupleTag<KV<Long, String>>("mainInputName") {}, mainInput);
inputs.putAll(PValues.fullyExpand(parDo.getAdditionalInputs()));
PCollectionTuple output = mainInput.apply(parDo);
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
// Encode
RunnerApi.PTransform protoTransform =
PTransformTranslation.toProto(
AppliedPTransform.<PCollection<KV<Long, String>>, PCollection<Void>, MultiOutput>of(
"foo", inputs, PValues.expandOutput(output), parDo, p),
sdkComponents);
RunnerApi.Components components = sdkComponents.toComponents();
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components);
// Decode
ParDoPayload parDoPayload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload());
for (PCollectionView<?> view : parDo.getSideInputs().values()) {
SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId());
PCollectionView<?> restoredView =
PCollectionViewTranslation.viewFromProto(
sideInput,
view.getTagInternal().getId(),
view.getPCollection(),
protoTransform,
rehydratedComponents);
assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal()));
assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass()));
assertThat(
restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass()));
assertThat(
restoredView.getWindowingStrategyInternal(),
equalTo(view.getWindowingStrategyInternal().fixDefaults()));
assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal()));
}
String mainInputId = sdkComponents.registerPCollection(mainInput);
assertThat(
ParDoTranslation.getMainInput(protoTransform, components),
equalTo(components.getPcollectionsOrThrow(mainInputId)));
assertThat(ParDoTranslation.getMainInputName(protoTransform), equalTo("mainInputName"));
// Ensure the correct timer coder components are used from the main input PCollection's key
// and window coders.
for (RunnerApi.TimerFamilySpec timerFamilySpec :
parDoPayload.getTimerFamilySpecsMap().values()) {
Coder<?> timerCoder =
CoderTranslation.fromProto(
components.getCodersOrThrow(timerFamilySpec.getTimerFamilyCoderId()),
rehydratedComponents,
TranslationContext.DEFAULT);
assertEquals(
org.apache.beam.runners.core.construction.Timer.Coder.of(
VarLongCoder.of(), GlobalWindow.Coder.INSTANCE),
timerCoder);
}
}
}
/** Tests for translating state and timer bits to/from protos. */
@RunWith(Parameterized.class)
public static class TestStateAndTimerTranslation {
@Parameters(name = "{index}: {0}")
public static Iterable<StateSpec<?>> stateSpecs() {
return ImmutableList.of(
StateSpecs.value(VarIntCoder.of()),
StateSpecs.bag(VarIntCoder.of()),
StateSpecs.set(VarIntCoder.of()),
StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()));
}
@Parameter public StateSpec<?> stateSpec;
@Test
public void testStateSpecToFromProto() throws Exception {
// Encode
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
RunnerApi.StateSpec stateSpecProto =
ParDoTranslation.translateStateSpec(stateSpec, sdkComponents);
// Decode
RehydratedComponents rehydratedComponents =
RehydratedComponents.forComponents(sdkComponents.toComponents());
StateSpec<?> deserializedStateSpec =
ParDoTranslation.fromProto(stateSpecProto, rehydratedComponents);
assertThat(stateSpec, equalTo(deserializedStateSpec));
}
}
private static class DropElementsFn extends DoFn<KV<Long, String>, Void> {
@ProcessElement
public void proc(ProcessContext context, BoundedWindow window) {
context.output(null);
}
@Override
public boolean equals(@Nullable Object other) {
return other instanceof DropElementsFn;
}
@Override
public int hashCode() {
return DropElementsFn.class.hashCode();
}
}
private static class SplittableDropElementsFn extends DoFn<KV<Long, String>, Void> {
@ProcessElement
public void proc(ProcessContext context, RestrictionTracker<Integer, ?> restriction) {
context.output(null);
}
@GetInitialRestriction
public Integer restriction(@Element KV<Long, String> elem) {
return 42;
}
@NewTracker
public RestrictionTracker<Integer, ?> newTracker(@Restriction Integer restriction) {
throw new UnsupportedOperationException("Should never be called; only to test translation");
}
@Override
public boolean equals(@Nullable Object other) {
return other instanceof SplittableDropElementsFn;
}
@Override
public int hashCode() {
return SplittableDropElementsFn.class.hashCode();
}
}
@SuppressWarnings("unused")
private static class StateTimerDropElementsFn extends DoFn<KV<Long, String>, Void> {
private static final String BAG_STATE_ID = "bagState";
private static final String COMBINING_STATE_ID = "combiningState";
private static final String EVENT_TIMER_ID = "eventTimer";
private static final String PROCESSING_TIMER_ID = "processingTimer";
@StateId(BAG_STATE_ID)
private final StateSpec<BagState<String>> bagState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(COMBINING_STATE_ID)
private final StateSpec<CombiningState<Long, long[], Long>> combiningState =
StateSpecs.combining(
new BinaryCombineLongFn() {
@Override
public long apply(long left, long right) {
return Math.max(left, right);
}
@Override
public long identity() {
return Long.MIN_VALUE;
}
});
@TimerId(EVENT_TIMER_ID)
private final TimerSpec eventTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME);
@TimerId(PROCESSING_TIMER_ID)
private final TimerSpec processingTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
@ProcessElement
public void dropInput(
ProcessContext context,
BoundedWindow window,
@StateId(BAG_STATE_ID) BagState<String> bagStateState,
@StateId(COMBINING_STATE_ID) CombiningState<Long, long[], Long> combiningStateState,
@TimerId(EVENT_TIMER_ID) Timer eventTimerTimer,
@TimerId(PROCESSING_TIMER_ID) Timer processingTimerTimer) {
context.output(null);
}
@OnTimer(EVENT_TIMER_ID)
public void onEventTime(OnTimerContext context) {}
@OnTimer(PROCESSING_TIMER_ID)
public void onProcessingTime(OnTimerContext context) {}
@Override
public boolean equals(@Nullable Object other) {
return other instanceof StateTimerDropElementsFn;
}
@Override
public int hashCode() {
return StateTimerDropElementsFn.class.hashCode();
}
}
@RunWith(JUnit4.class)
public static class BundleFinalizerTranslation {
private static class StartBundleDoFn extends DoFn<String, String> {
@StartBundle
public void startBundle(BundleFinalizer bundleFinalizer) {}
@ProcessElement
public void processElement() {}
}
private static class ProcessContextDoFn extends DoFn<String, String> {
@ProcessElement
public void processElement(BundleFinalizer finalizer) {}
}
private static class FinishBundleDoFn extends DoFn<String, String> {
@FinishBundle
public void finishBundle(BundleFinalizer bundleFinalizer) {}
@ProcessElement
public void processElement(BundleFinalizer finalizer) {}
}
@Test
public void testStartBundle() throws Exception {
Pipeline p = Pipeline.create();
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
ParDoPayload payload =
ParDoTranslation.translateParDo(
ParDo.of(new StartBundleDoFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()),
PCollection.createPrimitiveOutputInternal(
p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of()),
DoFnSchemaInformation.create(),
TestPipeline.create(),
sdkComponents);
assertTrue(payload.getRequestsFinalization());
}
@Test
public void testProcessContext() throws Exception {
Pipeline p = Pipeline.create();
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
ParDoPayload payload =
ParDoTranslation.translateParDo(
ParDo.of(new ProcessContextDoFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()),
PCollection.createPrimitiveOutputInternal(
p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of()),
DoFnSchemaInformation.create(),
TestPipeline.create(),
sdkComponents);
assertTrue(payload.getRequestsFinalization());
}
@Test
public void testFinishBundle() throws Exception {
Pipeline p = Pipeline.create();
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
ParDoPayload payload =
ParDoTranslation.translateParDo(
ParDo.of(new FinishBundleDoFn())
.withOutputTags(new TupleTag<>(), TupleTagList.empty()),
PCollection.createPrimitiveOutputInternal(
p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of()),
DoFnSchemaInformation.create(),
TestPipeline.create(),
sdkComponents);
assertTrue(payload.getRequestsFinalization());
}
}
}