blob: 73382e3e8e839030bcf2d15c91102b395d032a5b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.apex.translation;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import com.datatorrent.api.DAG;
import com.datatorrent.api.Sink;
import com.datatorrent.lib.util.KryoCloneUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import org.apache.beam.runners.apex.ApexPipelineOptions;
import org.apache.beam.runners.apex.ApexRunner;
import org.apache.beam.runners.apex.ApexRunnerResult;
import org.apache.beam.runners.apex.TestApexRunner;
import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator;
import org.apache.beam.runners.apex.translation.operators.ApexReadUnboundedInputOperator;
import org.apache.beam.runners.apex.translation.utils.ApexStateInternals;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* integration test for {@link ParDoTranslator}.
*/
@RunWith(JUnit4.class)
public class ParDoTranslatorTest {
private static final Logger LOG = LoggerFactory.getLogger(ParDoTranslatorTest.class);
private static final long SLEEP_MILLIS = 500;
private static final long TIMEOUT_MILLIS = 30000;
@Test
public void test() throws Exception {
ApexPipelineOptions options = PipelineOptionsFactory.create()
.as(ApexPipelineOptions.class);
options.setApplicationName("ParDoBound");
options.setRunner(ApexRunner.class);
Pipeline p = Pipeline.create(options);
List<Integer> collection = Lists.newArrayList(1, 2, 3, 4, 5);
List<Integer> expected = Lists.newArrayList(6, 7, 8, 9, 10);
p.apply(Create.of(collection).withCoder(SerializableCoder.of(Integer.class)))
.apply(ParDo.of(new Add(5)))
.apply(ParDo.of(new EmbeddedCollector()));
ApexRunnerResult result = (ApexRunnerResult) p.run();
DAG dag = result.getApexDAG();
DAG.OperatorMeta om = dag.getOperatorMeta("Create.Values");
Assert.assertNotNull(om);
Assert.assertEquals(om.getOperator().getClass(), ApexReadUnboundedInputOperator.class);
om = dag.getOperatorMeta("ParDo(Add)/ParMultiDo(Add)");
Assert.assertNotNull(om);
Assert.assertEquals(om.getOperator().getClass(), ApexParDoOperator.class);
long timeout = System.currentTimeMillis() + TIMEOUT_MILLIS;
while (System.currentTimeMillis() < timeout) {
if (EmbeddedCollector.RESULTS.containsAll(expected)) {
break;
}
LOG.info("Waiting for expected results.");
Thread.sleep(SLEEP_MILLIS);
}
Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.RESULTS);
}
private static class Add extends DoFn<Integer, Integer> {
private static final long serialVersionUID = 1L;
private Integer number;
private PCollectionView<Integer> sideInputView;
private Add(Integer number) {
this.number = number;
}
private Add(PCollectionView<Integer> sideInputView) {
this.sideInputView = sideInputView;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
if (sideInputView != null) {
number = c.sideInput(sideInputView);
}
c.output(c.element() + number);
}
}
private static class EmbeddedCollector extends DoFn<Object, Void> {
private static final long serialVersionUID = 1L;
private static final Set<Object> RESULTS = Collections.synchronizedSet(new HashSet<>());
public EmbeddedCollector() {
RESULTS.clear();
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
RESULTS.add(c.element());
}
}
private static Throwable runExpectingAssertionFailure(Pipeline pipeline) {
// We cannot use thrown.expect(AssertionError.class) because the AssertionError
// is first caught by JUnit and causes a test failure.
try {
pipeline.run();
} catch (AssertionError exc) {
return exc;
}
fail("assertion should have failed");
throw new RuntimeException("unreachable");
}
@Test
public void testAssertionFailure() throws Exception {
ApexPipelineOptions options = PipelineOptionsFactory.create()
.as(ApexPipelineOptions.class);
options.setRunner(TestApexRunner.class);
Pipeline pipeline = Pipeline.create(options);
PCollection<Integer> pcollection = pipeline
.apply(Create.of(1, 2, 3, 4));
PAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3, 7);
Throwable exc = runExpectingAssertionFailure(pipeline);
Pattern expectedPattern = Pattern.compile(
"Expected: iterable over \\[((<4>|<7>|<3>|<2>|<1>)(, )?){5}\\] in any order");
// A loose pattern, but should get the job done.
assertTrue(
"Expected error message from PAssert with substring matching "
+ expectedPattern
+ " but the message was \""
+ exc.getMessage()
+ "\"",
expectedPattern.matcher(exc.getMessage()).find());
}
@Test
public void testContainsInAnyOrder() throws Exception {
ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class);
options.setRunner(TestApexRunner.class);
Pipeline pipeline = Pipeline.create(options);
PCollection<Integer> pcollection = pipeline.apply(Create.of(1, 2, 3, 4));
PAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3);
// TODO: terminate faster based on processed assertion vs. auto-shutdown
pipeline.run();
}
@Test
public void testSerialization() throws Exception {
ApexPipelineOptions options = PipelineOptionsFactory.create()
.as(ApexPipelineOptions.class);
options.setRunner(TestApexRunner.class);
Pipeline pipeline = Pipeline.create(options);
PCollectionView<Integer> singletonView = pipeline.apply(Create.of(1))
.apply(Sum.integersGlobally().asSingletonView());
ApexParDoOperator<Integer, Integer> operator =
new ApexParDoOperator<>(
options,
new Add(singletonView),
new TupleTag<Integer>(),
TupleTagList.empty().getAll(),
WindowingStrategy.globalDefault(),
Collections.<PCollectionView<?>>singletonList(singletonView),
VarIntCoder.of(),
new ApexStateInternals.ApexStateBackend());
operator.setup(null);
operator.beginWindow(0);
WindowedValue<Integer> wv1 = WindowedValue.valueInGlobalWindow(1);
WindowedValue<Iterable<?>> sideInput = WindowedValue.<Iterable<?>>valueInGlobalWindow(
Lists.<Integer>newArrayList(22));
operator.input.process(ApexStreamTuple.DataTuple.of(wv1)); // pushed back input
final List<Object> results = Lists.newArrayList();
Sink<Object> sink = new Sink<Object>() {
@Override
public void put(Object tuple) {
results.add(tuple);
}
@Override
public int getCount(boolean reset) {
return 0;
}
};
// verify pushed back input checkpointing
Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator));
operator.output.setSink(sink);
operator.setup(null);
operator.beginWindow(1);
WindowedValue<Integer> wv2 = WindowedValue.valueInGlobalWindow(2);
operator.sideInput1.process(ApexStreamTuple.DataTuple.of(sideInput));
Assert.assertEquals("number outputs", 1, results.size());
Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(23),
((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue());
// verify side input checkpointing
results.clear();
Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator));
operator.output.setSink(sink);
operator.setup(null);
operator.beginWindow(2);
operator.input.process(ApexStreamTuple.DataTuple.of(wv2));
Assert.assertEquals("number outputs", 1, results.size());
Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(24),
((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue());
}
@Test
public void testMultiOutputParDoWithSideInputs() throws Exception {
ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class);
options.setRunner(ApexRunner.class); // non-blocking run
Pipeline pipeline = Pipeline.create(options);
List<Integer> inputs = Arrays.asList(3, -42, 666);
final TupleTag<String> mainOutputTag = new TupleTag<>("main");
final TupleTag<Void> additionalOutputTag = new TupleTag<>("output");
PCollectionView<Integer> sideInput1 = pipeline
.apply("CreateSideInput1", Create.of(11))
.apply("ViewSideInput1", View.<Integer>asSingleton());
PCollectionView<Integer> sideInputUnread = pipeline
.apply("CreateSideInputUnread", Create.of(-3333))
.apply("ViewSideInputUnread", View.<Integer>asSingleton());
PCollectionView<Integer> sideInput2 = pipeline
.apply("CreateSideInput2", Create.of(222))
.apply("ViewSideInput2", View.<Integer>asSingleton());
PCollectionTuple outputs = pipeline
.apply(Create.of(inputs))
.apply(ParDo
.of(new TestMultiOutputWithSideInputsFn(
Arrays.asList(sideInput1, sideInput2),
Arrays.<TupleTag<String>>asList()))
.withSideInputs(sideInput1)
.withSideInputs(sideInputUnread)
.withSideInputs(sideInput2)
.withOutputTags(mainOutputTag, TupleTagList.of(additionalOutputTag)));
outputs.get(mainOutputTag).apply(ParDo.of(new EmbeddedCollector()));
outputs.get(additionalOutputTag).setCoder(VoidCoder.of());
ApexRunnerResult result = (ApexRunnerResult) pipeline.run();
HashSet<String> expected = Sets.newHashSet("processing: 3: [11, 222]",
"processing: -42: [11, 222]", "processing: 666: [11, 222]");
long timeout = System.currentTimeMillis() + TIMEOUT_MILLIS;
while (System.currentTimeMillis() < timeout) {
if (EmbeddedCollector.RESULTS.containsAll(expected)) {
break;
}
LOG.info("Waiting for expected results.");
Thread.sleep(SLEEP_MILLIS);
}
result.cancel();
Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.RESULTS);
}
private static class TestMultiOutputWithSideInputsFn extends DoFn<Integer, String> {
private static final long serialVersionUID = 1L;
final List<PCollectionView<Integer>> sideInputViews = new ArrayList<>();
final List<TupleTag<String>> additionalOutputTupleTags = new ArrayList<>();
public TestMultiOutputWithSideInputsFn(List<PCollectionView<Integer>> sideInputViews,
List<TupleTag<String>> additionalOutputTupleTags) {
this.sideInputViews.addAll(sideInputViews);
this.additionalOutputTupleTags.addAll(additionalOutputTupleTags);
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
outputToAllWithSideInputs(c, "processing: " + c.element());
}
private void outputToAllWithSideInputs(ProcessContext c, String value) {
if (!sideInputViews.isEmpty()) {
List<Integer> sideInputValues = new ArrayList<>();
for (PCollectionView<Integer> sideInputView : sideInputViews) {
sideInputValues.add(c.sideInput(sideInputView));
}
value += ": " + sideInputValues;
}
c.output(value);
for (TupleTag<String> additionalOutputTupleTag : additionalOutputTupleTags) {
c.output(additionalOutputTupleTag,
additionalOutputTupleTag.getId() + ": " + value);
}
}
}
}