blob: abecc4bdccc3999c6a5b62eb70a45aa405c4d6c8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkState;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItems;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TimestampedValue;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link DoFnTester}. */
@RunWith(JUnit4.class)
public class DoFnTesterTest {
@Rule public final TestPipeline p = TestPipeline.create();
@Rule public ExpectedException thrown = ExpectedException.none();
@Test
public void processElement() throws Exception {
for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
tester.setCloningBehavior(cloning);
tester.processElement(1L);
List<String> take = tester.takeOutputElements();
assertThat(take, hasItems("1"));
// Following takeOutputElements(), neither takeOutputElements()
// nor peekOutputElements() return anything.
assertTrue(tester.takeOutputElements().isEmpty());
assertTrue(tester.peekOutputElements().isEmpty());
}
}
}
@Test
public void processElementsWithPeeks() throws Exception {
for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
tester.setCloningBehavior(cloning);
// Explicitly call startBundle().
tester.startBundle();
// process a couple of elements.
tester.processElement(1L);
tester.processElement(2L);
// peek the first 2 outputs.
List<String> peek = tester.peekOutputElements();
assertThat(peek, hasItems("1", "2"));
// process a couple more.
tester.processElement(3L);
tester.processElement(4L);
// peek all the outputs so far.
peek = tester.peekOutputElements();
assertThat(peek, hasItems("1", "2", "3", "4"));
// take the outputs.
List<String> take = tester.takeOutputElements();
assertThat(take, hasItems("1", "2", "3", "4"));
// Following takeOutputElements(), neither takeOutputElements()
// nor peekOutputElements() return anything.
assertTrue(tester.peekOutputElements().isEmpty());
assertTrue(tester.takeOutputElements().isEmpty());
// process a couple more.
tester.processElement(5L);
tester.processElement(6L);
// peek and take now have only the 2 last outputs.
peek = tester.peekOutputElements();
assertThat(peek, hasItems("5", "6"));
take = tester.takeOutputElements();
assertThat(take, hasItems("5", "6"));
tester.finishBundle();
}
}
}
@Test
public void processBundle() throws Exception {
for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
tester.setCloningBehavior(cloning);
// processBundle() returns all the output like takeOutputElements().
assertThat(tester.processBundle(1L, 2L, 3L, 4L), hasItems("1", "2", "3", "4"));
// peek now returns nothing.
assertTrue(tester.peekOutputElements().isEmpty());
}
}
}
@Test
public void processMultipleBundles() throws Exception {
for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
tester.setCloningBehavior(cloning);
// processBundle() returns all the output like takeOutputElements().
assertThat(tester.processBundle(1L, 2L, 3L, 4L), hasItems("1", "2", "3", "4"));
assertThat(tester.processBundle(5L, 6L, 7L), hasItems("5", "6", "7"));
assertThat(tester.processBundle(8L, 9L), hasItems("8", "9"));
// peek now returns nothing.
assertTrue(tester.peekOutputElements().isEmpty());
}
}
}
@Test
public void doNotClone() throws Exception {
final AtomicInteger numSetupCalls = new AtomicInteger();
final AtomicInteger numTeardownCalls = new AtomicInteger();
DoFn<Long, String> fn =
new DoFn<Long, String>() {
@ProcessElement
public void process(ProcessContext context) {}
@Setup
public void setup() {
numSetupCalls.addAndGet(1);
}
@Teardown
public void teardown() {
numTeardownCalls.addAndGet(1);
}
};
try (DoFnTester<Long, String> tester = DoFnTester.of(fn)) {
tester.setCloningBehavior(DoFnTester.CloningBehavior.DO_NOT_CLONE);
tester.processBundle(1L, 2L, 3L);
tester.processBundle(4L, 5L);
tester.processBundle(6L);
}
assertEquals(1, numSetupCalls.get());
assertEquals(1, numTeardownCalls.get());
}
private static class CountBundleCallsFn extends DoFn<Long, String> {
private int numStartBundleCalls = 0;
private int numFinishBundleCalls = 0;
@ProcessElement
public void process(ProcessContext context) {
context.output(numStartBundleCalls + "/" + numFinishBundleCalls);
}
@StartBundle
public void startBundle() {
++numStartBundleCalls;
}
@FinishBundle
public void finishBundle() {
++numFinishBundleCalls;
}
}
@Test
public void cloneOnce() throws Exception {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CountBundleCallsFn())) {
tester.setCloningBehavior(DoFnTester.CloningBehavior.CLONE_ONCE);
assertThat(tester.processBundle(1L, 2L, 3L), contains("1/0", "1/0", "1/0"));
assertThat(tester.processBundle(4L, 5L), contains("2/1", "2/1"));
assertThat(tester.processBundle(6L), contains("3/2"));
}
}
@Test
public void clonePerBundle() throws Exception {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CountBundleCallsFn())) {
tester.setCloningBehavior(DoFnTester.CloningBehavior.CLONE_PER_BUNDLE);
assertThat(tester.processBundle(1L, 2L, 3L), contains("1/0", "1/0", "1/0"));
assertThat(tester.processBundle(4L, 5L), contains("1/0", "1/0"));
assertThat(tester.processBundle(6L), contains("1/0"));
}
}
@Test
public void processTimestampedElement() throws Exception {
try (DoFnTester<Long, TimestampedValue<Long>> tester = DoFnTester.of(new ReifyTimestamps())) {
TimestampedValue<Long> input = TimestampedValue.of(1L, new Instant(100));
tester.processTimestampedElement(input);
assertThat(tester.takeOutputElements(), contains(input));
}
}
static class ReifyTimestamps extends DoFn<Long, TimestampedValue<Long>> {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(TimestampedValue.of(c.element(), c.timestamp()));
}
}
@Test
public void processElementWithOutputTimestamp() throws Exception {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
tester.processElement(1L);
tester.processElement(2L);
List<TimestampedValue<String>> peek = tester.peekOutputElementsWithTimestamp();
TimestampedValue<String> one = TimestampedValue.of("1", new Instant(1000L));
TimestampedValue<String> two = TimestampedValue.of("2", new Instant(2000L));
assertThat(peek, hasItems(one, two));
tester.processElement(3L);
tester.processElement(4L);
TimestampedValue<String> three = TimestampedValue.of("3", new Instant(3000L));
TimestampedValue<String> four = TimestampedValue.of("4", new Instant(4000L));
peek = tester.peekOutputElementsWithTimestamp();
assertThat(peek, hasItems(one, two, three, four));
List<TimestampedValue<String>> take = tester.takeOutputElementsWithTimestamp();
assertThat(take, hasItems(one, two, three, four));
// Following takeOutputElementsWithTimestamp(), neither takeOutputElementsWithTimestamp()
// nor peekOutputElementsWithTimestamp() return anything.
assertTrue(tester.takeOutputElementsWithTimestamp().isEmpty());
assertTrue(tester.peekOutputElementsWithTimestamp().isEmpty());
// peekOutputElements() and takeOutputElements() also return nothing.
assertTrue(tester.peekOutputElements().isEmpty());
assertTrue(tester.takeOutputElements().isEmpty());
}
}
@Test
public void peekValuesInWindow() throws Exception {
try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
tester.startBundle();
tester.processElement(1L);
tester.processElement(2L);
tester.finishBundle();
assertThat(
tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE),
containsInAnyOrder(
TimestampedValue.of("1", new Instant(1000L)),
TimestampedValue.of("2", new Instant(2000L))));
assertThat(
tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), new Instant(10L))),
Matchers.emptyIterable());
}
}
@Test
public void fnWithSideInputDefault() throws Exception {
PCollection<Integer> pCollection = p.apply(Create.empty(VarIntCoder.of()));
final PCollectionView<Integer> value =
pCollection.apply(View.<Integer>asSingleton().withDefaultValue(0));
try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new SideInputDoFn(value))) {
tester.processElement(1);
tester.processElement(2);
tester.processElement(4);
tester.processElement(8);
assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0));
}
}
@Test
public void fnWithSideInputExplicit() throws Exception {
PCollection<Integer> pCollection = p.apply(Create.of(-2));
final PCollectionView<Integer> value =
pCollection.apply(View.<Integer>asSingleton().withDefaultValue(0));
try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new SideInputDoFn(value))) {
tester.setSideInput(value, GlobalWindow.INSTANCE, -2);
tester.processElement(16);
tester.processElement(32);
tester.processElement(64);
tester.processElement(128);
tester.finishBundle();
assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, -2));
}
}
@Test
public void testSupportsWindowParameter() throws Exception {
Instant now = Instant.now();
try (DoFnTester<Integer, KV<Integer, BoundedWindow>> tester =
DoFnTester.of(new DoFnWithWindowParameter())) {
BoundedWindow firstWindow = new IntervalWindow(now, now.plus(Duration.standardMinutes(1)));
tester.processWindowedElement(1, now, firstWindow);
tester.processWindowedElement(2, now, firstWindow);
BoundedWindow secondWindow = new IntervalWindow(now, now.plus(Duration.standardMinutes(4)));
tester.processWindowedElement(3, now, secondWindow);
tester.finishBundle();
assertThat(
tester.peekOutputElementsInWindow(firstWindow),
containsInAnyOrder(
TimestampedValue.of(KV.of(1, firstWindow), now),
TimestampedValue.of(KV.of(2, firstWindow), now)));
assertThat(
tester.peekOutputElementsInWindow(secondWindow),
containsInAnyOrder(TimestampedValue.of(KV.of(3, secondWindow), now)));
}
}
private static class DoFnWithWindowParameter extends DoFn<Integer, KV<Integer, BoundedWindow>> {
@ProcessElement
public void processElement(ProcessContext c, BoundedWindow window) {
c.output(KV.of(c.element(), window));
}
}
@Test
public void testSupportsFinishBundleOutput() throws Exception {
for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new BundleCounterDoFn())) {
tester.setCloningBehavior(cloning);
assertThat(tester.processBundle(1, 2, 3, 4), contains(4));
assertThat(tester.processBundle(5, 6, 7), contains(3));
assertThat(tester.processBundle(8, 9), contains(2));
}
}
}
private static class BundleCounterDoFn extends DoFn<Integer, Integer> {
private int elements;
@StartBundle
public void startBundle() {
elements = 0;
}
@ProcessElement
public void processElement(ProcessContext c) {
elements++;
}
@FinishBundle
public void finishBundle(FinishBundleContext c) {
c.output(elements, Instant.now(), GlobalWindow.INSTANCE);
}
}
private static class SideInputDoFn extends DoFn<Integer, Integer> {
private final PCollectionView<Integer> value;
private SideInputDoFn(PCollectionView<Integer> value) {
this.value = value;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
c.output(c.sideInput(value));
}
}
/**
* A {@link DoFn} that adds values to a user metric and converts input to String in {@link
* DoFn.ProcessElement @ProcessElement}.
*/
private static class CounterDoFn extends DoFn<Long, String> {
Counter agg = Metrics.counter(CounterDoFn.class, "ctr");
Counter startBundleCalls = Metrics.counter(CounterDoFn.class, "startBundleCalls");
Counter finishBundleCalls = Metrics.counter(CounterDoFn.class, "finishBundleCalls");
private enum LifecycleState {
UNINITIALIZED,
SET_UP,
INSIDE_BUNDLE,
TORN_DOWN
}
private LifecycleState state = LifecycleState.UNINITIALIZED;
@Setup
public void setup() {
checkState(state == LifecycleState.UNINITIALIZED, "Wrong state: %s", state);
state = LifecycleState.SET_UP;
}
@StartBundle
public void startBundle() {
checkState(state == LifecycleState.SET_UP, "Wrong state: %s", state);
state = LifecycleState.INSIDE_BUNDLE;
startBundleCalls.inc();
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
checkState(state == LifecycleState.INSIDE_BUNDLE, "Wrong state: %s", state);
agg.inc(c.element());
Instant instant = new Instant(1000L * c.element());
c.outputWithTimestamp(c.element().toString(), instant);
}
@FinishBundle
public void finishBundle() {
checkState(state == LifecycleState.INSIDE_BUNDLE, "Wrong state: %s", state);
state = LifecycleState.SET_UP;
finishBundleCalls.inc();
}
@Teardown
public void teardown() {
checkState(state == LifecycleState.SET_UP, "Wrong state: %s", state);
state = LifecycleState.TORN_DOWN;
}
}
}