blob: f343f4d8c1abec0512fd88e72b5602a290fd9e2a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.frontend.beam.transform;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.vertex.transform.Transform;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.compiler.frontend.beam.NemoPipelineOptions;
import org.apache.nemo.compiler.frontend.beam.SideInputElement;
import org.apache.reef.io.Tuple;
import org.junit.Before;
import org.junit.Test;
import java.util.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
public final class DoFnTransformTest {
// views and windows for testing side inputs
private PCollectionView<Iterable<String>> view1;
private PCollectionView<Iterable<String>> view2;
private final static Coder NULL_INPUT_CODER = null;
private final static Map<TupleTag<?>, Coder<?>> NULL_OUTPUT_CODERS = null;
@Before
public void setUp() {
Pipeline.create().apply(Create.of("1"));
view1 = Pipeline.create().apply(Create.of("1")).apply(View.asIterable());
view2 = Pipeline.create().apply(Create.of("2")).apply(View.asIterable());
}
@Test
@SuppressWarnings("unchecked")
public void testSingleOutput() {
final TupleTag<String> outputTag = new TupleTag<>("main-output");
final DoFnTransform<String, String> doFnTransform =
new DoFnTransform<>(
new IdentityDoFn<>(),
NULL_INPUT_CODER,
NULL_OUTPUT_CODERS,
outputTag,
Collections.emptyList(),
WindowingStrategy.globalDefault(),
PipelineOptionsFactory.as(NemoPipelineOptions.class),
DisplayData.none());
final Transform.Context context = mock(Transform.Context.class);
final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
doFnTransform.prepare(context, oc);
doFnTransform.onData(WindowedValue.valueInGlobalWindow("Hello"));
assertEquals(((TestOutputCollector<String>) oc).outputs.get(0), WindowedValue.valueInGlobalWindow("Hello"));
doFnTransform.close();
}
@Test
@SuppressWarnings("unchecked")
public void testCountBundle() {
final TupleTag<String> outputTag = new TupleTag<>("main-output");
final NemoPipelineOptions pipelineOptions = PipelineOptionsFactory.as(NemoPipelineOptions.class);
pipelineOptions.setMaxBundleSize(3L);
pipelineOptions.setMaxBundleTimeMills(10000000L);
final List<Integer> bundleOutput = new ArrayList<>();
final DoFnTransform<String, String> doFnTransform =
new DoFnTransform<>(
new BundleTestDoFn(bundleOutput),
NULL_INPUT_CODER,
NULL_OUTPUT_CODERS,
outputTag,
Collections.emptyList(),
WindowingStrategy.globalDefault(),
pipelineOptions,
DisplayData.none());
final Transform.Context context = mock(Transform.Context.class);
final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
doFnTransform.prepare(context, oc);
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
assertEquals(3, (int) bundleOutput.get(0));
bundleOutput.clear();
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
assertEquals(3, (int) bundleOutput.get(0));
doFnTransform.close();
}
@Test
@SuppressWarnings("unchecked")
public void testTimeBundle() {
final long maxBundleTimeMills = 1000L;
final TupleTag<String> outputTag = new TupleTag<>("main-output");
final NemoPipelineOptions pipelineOptions = PipelineOptionsFactory.as(NemoPipelineOptions.class);
pipelineOptions.setMaxBundleSize(10000000L);
pipelineOptions.setMaxBundleTimeMills(maxBundleTimeMills);
final List<Integer> bundleOutput = new ArrayList<>();
final DoFnTransform<String, String> doFnTransform =
new DoFnTransform<>(
new BundleTestDoFn(bundleOutput),
NULL_INPUT_CODER,
NULL_OUTPUT_CODERS,
outputTag,
Collections.emptyList(),
WindowingStrategy.globalDefault(),
pipelineOptions,
DisplayData.none());
final Transform.Context context = mock(Transform.Context.class);
final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
long startTime = System.currentTimeMillis();
doFnTransform.prepare(context, oc);
int count = 0;
while (bundleOutput.isEmpty()) {
doFnTransform.onData(WindowedValue.valueInGlobalWindow("a"));
count += 1;
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
long endTime = System.currentTimeMillis();
assertEquals(count, (int) bundleOutput.get(0));
assertTrue(endTime - startTime >= maxBundleTimeMills);
doFnTransform.close();
}
@Test
@SuppressWarnings("unchecked")
public void testMultiOutputOutput() {
TupleTag<String> mainOutput = new TupleTag<>("main-output");
TupleTag<String> additionalOutput1 = new TupleTag<>("output-1");
TupleTag<String> additionalOutput2 = new TupleTag<>("output-2");
ImmutableList<TupleTag<?>> tags = ImmutableList.of(additionalOutput1, additionalOutput2);
ImmutableMap<String, String> tagsMap =
ImmutableMap.<String, String>builder()
.put(additionalOutput1.getId(), additionalOutput1.getId())
.put(additionalOutput2.getId(), additionalOutput2.getId())
.build();
final DoFnTransform<String, String> doFnTransform =
new DoFnTransform<>(
new MultiOutputDoFn(additionalOutput1, additionalOutput2),
NULL_INPUT_CODER,
NULL_OUTPUT_CODERS,
mainOutput,
tags,
WindowingStrategy.globalDefault(),
PipelineOptionsFactory.as(NemoPipelineOptions.class),
DisplayData.none());
// mock context
final Transform.Context context = mock(Transform.Context.class);
final OutputCollector<WindowedValue<String>> oc = new TestOutputCollector<>();
doFnTransform.prepare(context, oc);
doFnTransform.onData(WindowedValue.valueInGlobalWindow("one"));
doFnTransform.onData(WindowedValue.valueInGlobalWindow("two"));
doFnTransform.onData(WindowedValue.valueInGlobalWindow("hello"));
// main output
assertEquals(WindowedValue.valueInGlobalWindow("got: hello"),
((TestOutputCollector<String>) oc).outputs.get(0));
// additional output 1
assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
new Tuple<>(additionalOutput1.getId(), WindowedValue.valueInGlobalWindow("extra: one"))
));
assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
new Tuple<>(additionalOutput1.getId(), WindowedValue.valueInGlobalWindow("got: hello"))
));
// additional output 2
assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
new Tuple<>(additionalOutput2.getId(), WindowedValue.valueInGlobalWindow("extra: two"))
));
assertTrue(((TestOutputCollector<String>) oc).getTaggedOutputs().contains(
new Tuple<>(additionalOutput2.getId(), WindowedValue.valueInGlobalWindow("got: hello"))
));
doFnTransform.close();
}
@Test
public void testSideInputs() {
// mock context
final Transform.Context context = mock(Transform.Context.class);
TupleTag<Tuple<String, Iterable<String>>> outputTag = new TupleTag<>("main-output");
WindowedValue<String> firstElement = WindowedValue.valueInGlobalWindow("first");
WindowedValue<String> secondElement = WindowedValue.valueInGlobalWindow("second");
SideInputElement firstSideinput = new SideInputElement<>(0, ImmutableList.of("1"));
SideInputElement secondSideinput = new SideInputElement(1, ImmutableList.of("2"));
final Map<Integer, PCollectionView<?>> sideInputMap = new HashMap<>();
sideInputMap.put(firstSideinput.getSideInputIndex(), view1);
sideInputMap.put(secondSideinput.getSideInputIndex(), view2);
final PushBackDoFnTransform<String, String> doFnTransform =
new PushBackDoFnTransform(
new SimpleSideInputDoFn<String>(view1, view2),
NULL_INPUT_CODER,
NULL_OUTPUT_CODERS,
outputTag,
Collections.emptyList(),
WindowingStrategy.globalDefault(),
sideInputMap, /* side inputs */
PipelineOptionsFactory.as(NemoPipelineOptions.class),
DisplayData.none());
final TestOutputCollector<String> oc = new TestOutputCollector<>();
doFnTransform.prepare(context, oc);
// Main input first, Side inputs later
doFnTransform.onData(firstElement);
doFnTransform.onData(WindowedValue.valueInGlobalWindow(firstSideinput));
doFnTransform.onData(WindowedValue.valueInGlobalWindow(secondSideinput));
assertEquals(
WindowedValue.valueInGlobalWindow(
concat(firstElement.getValue(), firstSideinput.getSideInputValue(), secondSideinput.getSideInputValue())),
oc.getOutput().get(0));
// Side inputs first, Main input later
doFnTransform.onData(secondElement);
assertEquals(
WindowedValue.valueInGlobalWindow(
concat(secondElement.getValue(), firstSideinput.getSideInputValue(), secondSideinput.getSideInputValue())),
oc.getOutput().get(1));
// There should be only 2 final outputs
assertEquals(2, oc.getOutput().size());
// The side inputs should be "READY"
assertTrue(doFnTransform.getSideInputReader().isReady(view1, GlobalWindow.INSTANCE));
assertTrue(doFnTransform.getSideInputReader().isReady(view2, GlobalWindow.INSTANCE));
// This watermark should remove the side inputs. (Now should be "NOT READY")
doFnTransform.onWatermark(new Watermark(GlobalWindow.TIMESTAMP_MAX_VALUE.getMillis()));
Iterable materializedSideInput1 = doFnTransform.getSideInputReader().get(view1, GlobalWindow.INSTANCE);
Iterable materializedSideInput2 = doFnTransform.getSideInputReader().get(view2, GlobalWindow.INSTANCE);
assertFalse(materializedSideInput1.iterator().hasNext());
assertFalse(materializedSideInput2.iterator().hasNext());
// There should be only 2 final outputs
doFnTransform.close();
assertEquals(2, oc.getOutput().size());
}
/**
* Bundle test do fn.
*/
private static class BundleTestDoFn extends DoFn<String, String> {
int count;
private final List<Integer> bundleOutput;
BundleTestDoFn(final List<Integer> bundleOutput) {
this.bundleOutput = bundleOutput;
}
@ProcessElement
public void processElement(final ProcessContext c) throws Exception {
count += 1;
c.output(c.element());
}
@StartBundle
public void startBundle(final StartBundleContext c) {
count = 0;
}
@FinishBundle
public void finishBundle(final FinishBundleContext c) {
bundleOutput.add(count);
}
}
/**
* Identitiy do fn.
*
* @param <T> type
*/
private static class IdentityDoFn<T> extends DoFn<T, T> {
@ProcessElement
public void processElement(final ProcessContext c) throws Exception {
c.output(c.element());
}
}
/**
* Side input do fn.
*
* @param <T> type
*/
private static class SimpleSideInputDoFn<T> extends DoFn<T, String> {
private final PCollectionView<?> view1;
private final PCollectionView<?> view2;
public SimpleSideInputDoFn(final PCollectionView<?> view1,
final PCollectionView<?> view2) {
this.view1 = view1;
this.view2 = view2;
}
@ProcessElement
public void processElement(final ProcessContext c) throws Exception {
final T element = c.element();
final Object view1Value = c.sideInput(view1);
final Object view2Value = c.sideInput(view2);
c.output(concat(element, view1Value, view2Value));
}
}
private static String concat(final Object obj1, final Object obj2, final Object obj3) {
return obj1.toString() + " / " + obj2 + " / " + obj3;
}
/**
* Multi output do fn.
*/
private static class MultiOutputDoFn extends DoFn<String, String> {
private TupleTag<String> additionalOutput1;
private TupleTag<String> additionalOutput2;
public MultiOutputDoFn(TupleTag<String> additionalOutput1, TupleTag<String> additionalOutput2) {
this.additionalOutput1 = additionalOutput1;
this.additionalOutput2 = additionalOutput2;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
if ("one".equals(c.element())) {
c.output(additionalOutput1, "extra: one");
} else if ("two".equals(c.element())) {
c.output(additionalOutput2, "extra: two");
} else {
c.output("got: " + c.element());
c.output(additionalOutput1, "got: " + c.element());
c.output(additionalOutput2, "got: " + c.element());
}
}
}
}