blob: d4bbcc5e3d06a7b7f70548de68d16b6503eef751 [file] [log] [blame]
/*
* Copyright (C) 2015 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.google.cloud.dataflow.sdk.transforms;
import com.google.cloud.dataflow.sdk.annotations.Experimental;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext;
import com.google.cloud.dataflow.sdk.util.DirectSideInputReader;
import com.google.cloud.dataflow.sdk.util.DoFnRunner;
import com.google.cloud.dataflow.sdk.util.DoFnRunnerBase;
import com.google.cloud.dataflow.sdk.util.DoFnRunners;
import com.google.cloud.dataflow.sdk.util.PTuple;
import com.google.cloud.dataflow.sdk.util.SerializableUtils;
import com.google.cloud.dataflow.sdk.util.WindowedValue;
import com.google.cloud.dataflow.sdk.util.WindowingStrategy;
import com.google.cloud.dataflow.sdk.util.common.Counter;
import com.google.cloud.dataflow.sdk.util.common.CounterSet;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
import com.google.cloud.dataflow.sdk.values.TupleTag;
import com.google.cloud.dataflow.sdk.values.TupleTagList;
import com.google.common.base.Function;
import com.google.common.base.Objects;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.joda.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A harness for unit-testing a {@link DoFn}.
*
* <p>For example:
*
* <pre> {@code
* DoFn<InputT, OutputT> fn = ...;
*
* DoFnTester<InputT, OutputT> fnTester = DoFnTester.of(fn);
*
* // Set arguments shared across all batches:
* fnTester.setSideInputs(...); // If fn takes side inputs.
* fnTester.setSideOutputTags(...); // If fn writes to side outputs.
*
* // Process a batch containing a single input element:
* Input testInput = ...;
* List<OutputT> testOutputs = fnTester.processBatch(testInput);
* Assert.assertThat(testOutputs, Matchers.hasItems(...));
*
* // Process a bigger batch:
* Assert.assertThat(fnTester.processBatch(i1, i2, ...), Matchers.hasItems(...));
* } </pre>
*
* @param <InputT> the type of the {@code DoFn}'s (main) input elements
* @param <OutputT> the type of the {@code DoFn}'s (main) output elements
*/
public class DoFnTester<InputT, OutputT> {
/**
* Returns a {@code DoFnTester} supporting unit-testing of the given
* {@link DoFn}.
*/
@SuppressWarnings("unchecked")
public static <InputT, OutputT> DoFnTester<InputT, OutputT> of(DoFn<InputT, OutputT> fn) {
return new DoFnTester<InputT, OutputT>(fn);
}
/**
* Returns a {@code DoFnTester} supporting unit-testing of the given
* {@link DoFn}.
*/
@SuppressWarnings("unchecked")
public static <InputT, OutputT> DoFnTester<InputT, OutputT>
of(DoFnWithContext<InputT, OutputT> fn) {
return new DoFnTester<InputT, OutputT>(DoFnReflector.of(fn.getClass()).toDoFn(fn));
}
/**
* Registers the tuple of values of the side input {@link PCollectionView}s to
* pass to the {@link DoFn} under test.
*
* <p>If needed, first creates a fresh instance of the {@link DoFn}
* under test.
*
* <p>If this isn't called, {@code DoFnTester} assumes the
* {@link DoFn} takes no side inputs.
*/
public void setSideInputs(Map<PCollectionView<?>, Iterable<WindowedValue<?>>> sideInputs) {
this.sideInputs = sideInputs;
resetState();
}
/**
* Registers the values of a side input {@link PCollectionView} to
* pass to the {@link DoFn} under test.
*
* <p>If needed, first creates a fresh instance of the {@code DoFn}
* under test.
*
* <p>If this isn't called, {@code DoFnTester} assumes the
* {@code DoFn} takes no side inputs.
*/
public void setSideInput(PCollectionView<?> sideInput, Iterable<WindowedValue<?>> value) {
sideInputs.put(sideInput, value);
}
/**
* Registers the values for a side input {@link PCollectionView} to
* pass to the {@link DoFn} under test. All values are placed
* in the global window.
*/
public void setSideInputInGlobalWindow(
PCollectionView<?> sideInput,
Iterable<?> value) {
sideInputs.put(
sideInput,
Iterables.transform(value, new Function<Object, WindowedValue<?>>() {
@Override
public WindowedValue<?> apply(Object input) {
return WindowedValue.valueInGlobalWindow(input);
}
}));
}
/**
* Registers the list of {@code TupleTag}s that can be used by the
* {@code DoFn} under test to output to side output
* {@code PCollection}s.
*
* <p>If needed, first creates a fresh instance of the DoFn under test.
*
* <p>If this isn't called, {@code DoFnTester} assumes the
* {@code DoFn} doesn't emit to any side outputs.
*/
public void setSideOutputTags(TupleTagList sideOutputTags) {
this.sideOutputTags = sideOutputTags.getAll();
resetState();
}
/**
* Whether or not a {@link DoFnTester} should clone the {@link DoFn} under test.
*/
public enum CloningBehavior {
CLONE,
DO_NOT_CLONE;
}
/**
* Instruct this {@link DoFnTester} whether or not to clone the {@link DoFn} under test.
*/
public void setCloningBehavior(CloningBehavior newValue) {
this.cloningBehavior = newValue;
}
/**
* Indicates whether this {@link DoFnTester} will clone the {@link DoFn} under test.
*/
public CloningBehavior getCloningBehavior() {
return cloningBehavior;
}
/**
* A convenience operation that first calls {@link #startBundle},
* then calls {@link #processElement} on each of the input elements, then
* calls {@link #finishBundle}, then returns the result of
* {@link #takeOutputElements}.
*/
public List<OutputT> processBatch(Iterable <? extends InputT> inputElements) {
startBundle();
for (InputT inputElement : inputElements) {
processElement(inputElement);
}
finishBundle();
return takeOutputElements();
}
/**
* A convenience method for testing {@link DoFn DoFns} with bundles of elements.
* Logic proceeds as follows:
*
* <ol>
* <li>Calls {@link #startBundle}.</li>
* <li>Calls {@link #processElement} on each of the arguments.<li>
* <li>Calls {@link #finishBundle}.</li>
* <li>Returns the result of {@link #takeOutputElements}.</li>
* </ol>
*/
@SafeVarargs
public final List<OutputT> processBatch(InputT... inputElements) {
return processBatch(Arrays.asList(inputElements));
}
/**
* Calls {@link DoFn#startBundle} on the {@code DoFn} under test.
*
* <p>If needed, first creates a fresh instance of the DoFn under test.
*/
public void startBundle() {
resetState();
initializeState();
fnRunner.startBundle();
state = State.STARTED;
}
/**
* Calls {@link DoFn#processElement} on the {@code DoFn} under test, in a
* context where {@link DoFn.ProcessContext#element} returns the
* given element.
*
* <p>Will call {@link #startBundle} automatically, if it hasn't
* already been called.
*
* @throws IllegalStateException if the {@code DoFn} under test has already
* been finished
*/
public void processElement(InputT element) {
if (state == State.FINISHED) {
throw new IllegalStateException("finishBundle() has already been called");
}
if (state == State.UNSTARTED) {
startBundle();
}
fnRunner.processElement(WindowedValue.valueInGlobalWindow(element));
}
/**
* Calls {@link DoFn#finishBundle} of the {@code DoFn} under test.
*
* <p>Will call {@link #startBundle} automatically, if it hasn't
* already been called.
*
* @throws IllegalStateException if the {@code DoFn} under test has already
* been finished
*/
public void finishBundle() {
if (state == State.FINISHED) {
throw new IllegalStateException("finishBundle() has already been called");
}
if (state == State.UNSTARTED) {
startBundle();
}
fnRunner.finishBundle();
state = State.FINISHED;
}
/**
* Returns the elements output so far to the main output. Does not
* clear them, so subsequent calls will continue to include these
* elements.
*
* @see #takeOutputElements
* @see #clearOutputElements
*
*/
public List<OutputT> peekOutputElements() {
// TODO: Should we return an unmodifiable list?
return Lists.transform(
peekOutputElementsWithTimestamp(),
new Function<OutputElementWithTimestamp<OutputT>, OutputT>() {
@Override
@SuppressWarnings("unchecked")
public OutputT apply(OutputElementWithTimestamp<OutputT> input) {
return input.getValue();
}
});
}
/**
* Returns the elements output so far to the main output with associated timestamps. Does not
* clear them, so subsequent calls will continue to include these.
* elements.
*
* @see #takeOutputElementsWithTimestamp
* @see #clearOutputElements
*/
@Experimental
public List<OutputElementWithTimestamp<OutputT>> peekOutputElementsWithTimestamp() {
// TODO: Should we return an unmodifiable list?
return Lists.transform(
outputManager.getOutput(mainOutputTag),
new Function<Object, OutputElementWithTimestamp<OutputT>>() {
@Override
@SuppressWarnings("unchecked")
public OutputElementWithTimestamp<OutputT> apply(Object input) {
return new OutputElementWithTimestamp<OutputT>(
((WindowedValue<OutputT>) input).getValue(),
((WindowedValue<OutputT>) input).getTimestamp());
}
});
}
/**
* Clears the record of the elements output so far to the main output.
*
* @see #peekOutputElements
*/
public void clearOutputElements() {
peekOutputElements().clear();
}
/**
* Returns the elements output so far to the main output.
* Clears the list so these elements don't appear in future calls.
*
* @see #peekOutputElements
*/
public List<OutputT> takeOutputElements() {
List<OutputT> resultElems = new ArrayList<>(peekOutputElements());
clearOutputElements();
return resultElems;
}
/**
* Returns the elements output so far to the main output with associated timestamps.
* Clears the list so these elements don't appear in future calls.
*
* @see #peekOutputElementsWithTimestamp
* @see #takeOutputElements
* @see #clearOutputElements
*/
@Experimental
public List<OutputElementWithTimestamp<OutputT>> takeOutputElementsWithTimestamp() {
List<OutputElementWithTimestamp<OutputT>> resultElems =
new ArrayList<>(peekOutputElementsWithTimestamp());
clearOutputElements();
return resultElems;
}
/**
* Returns the elements output so far to the side output with the
* given tag. Does not clear them, so subsequent calls will
* continue to include these elements.
*
* @see #takeSideOutputElements
* @see #clearSideOutputElements
*/
public <T> List<T> peekSideOutputElements(TupleTag<T> tag) {
// TODO: Should we return an unmodifiable list?
return Lists.transform(
outputManager.getOutput(tag),
new Function<WindowedValue<T>, T>() {
@SuppressWarnings("unchecked")
@Override
public T apply(WindowedValue<T> input) {
return input.getValue();
}});
}
/**
* Clears the record of the elements output so far to the side
* output with the given tag.
*
* @see #peekSideOutputElements
*/
public <T> void clearSideOutputElements(TupleTag<T> tag) {
peekSideOutputElements(tag).clear();
}
/**
* Returns the elements output so far to the side output with the given tag.
* Clears the list so these elements don't appear in future calls.
*
* @see #peekSideOutputElements
*/
public <T> List<T> takeSideOutputElements(TupleTag<T> tag) {
List<T> resultElems = new ArrayList<>(peekSideOutputElements(tag));
clearSideOutputElements(tag);
return resultElems;
}
/**
* Returns the value of the provided {@link Aggregator}.
*/
public <AggregateT> AggregateT getAggregatorValue(Aggregator<?, AggregateT> agg) {
@SuppressWarnings("unchecked")
Counter<AggregateT> counter =
(Counter<AggregateT>)
counterSet.getExistingCounter("user-" + STEP_NAME + "-" + agg.getName());
return counter.getAggregate();
}
/**
* Holder for an OutputElement along with its associated timestamp.
*/
@Experimental
public static class OutputElementWithTimestamp<OutputT> {
private final OutputT value;
private final Instant timestamp;
OutputElementWithTimestamp(OutputT value, Instant timestamp) {
this.value = value;
this.timestamp = timestamp;
}
OutputT getValue() {
return value;
}
Instant getTimestamp() {
return timestamp;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof OutputElementWithTimestamp)) {
return false;
}
OutputElementWithTimestamp<?> other = (OutputElementWithTimestamp<?>) obj;
return Objects.equal(other.value, value) && Objects.equal(other.timestamp, timestamp);
}
@Override
public int hashCode() {
return Objects.hashCode(value, timestamp);
}
}
/////////////////////////////////////////////////////////////////////////////
/** The possible states of processing a DoFn. */
enum State {
UNSTARTED,
STARTED,
FINISHED
}
/** The name of the step of a DoFnTester. */
static final String STEP_NAME = "stepName";
/** The name of the enclosing DoFn PTransform for a DoFnTester. */
static final String TRANSFORM_NAME = "transformName";
final PipelineOptions options = PipelineOptionsFactory.create();
/** The original DoFn under test. */
final DoFn<InputT, OutputT> origFn;
/** The side input values to provide to the DoFn under test. */
private Map<PCollectionView<?>, Iterable<WindowedValue<?>>> sideInputs =
new HashMap<>();
/**
* Whether to clone the original {@link DoFn} or just use it as-is.
*
* <p>Worker-side {@link DoFn DoFns} may not be serializable, and are not required to be.
*/
private CloningBehavior cloningBehavior = CloningBehavior.CLONE;
/** The output tags used by the DoFn under test. */
TupleTag<OutputT> mainOutputTag = new TupleTag<>();
List<TupleTag<?>> sideOutputTags = new ArrayList<>();
/** The original DoFn under test, if started. */
DoFn<InputT, OutputT> fn;
/** The ListOutputManager to examine the outputs. */
DoFnRunnerBase.ListOutputManager outputManager;
/** The DoFnRunner if processing is in progress. */
DoFnRunner<InputT, OutputT> fnRunner;
/** Counters for user-defined Aggregators if processing is in progress. */
CounterSet counterSet;
/** The state of processing of the DoFn under test. */
State state;
DoFnTester(DoFn<InputT, OutputT> origFn) {
this.origFn = origFn;
resetState();
}
void resetState() {
fn = null;
outputManager = null;
fnRunner = null;
counterSet = null;
state = State.UNSTARTED;
}
@SuppressWarnings("unchecked")
void initializeState() {
if (cloningBehavior.equals(CloningBehavior.DO_NOT_CLONE)) {
fn = origFn;
} else {
fn = (DoFn<InputT, OutputT>)
SerializableUtils.deserializeFromByteArray(
SerializableUtils.serializeToByteArray(origFn),
origFn.toString());
}
counterSet = new CounterSet();
PTuple runnerSideInputs = PTuple.empty();
for (Map.Entry<PCollectionView<?>, Iterable<WindowedValue<?>>> entry
: sideInputs.entrySet()) {
runnerSideInputs = runnerSideInputs.and(entry.getKey().getTagInternal(), entry.getValue());
}
outputManager = new DoFnRunnerBase.ListOutputManager();
fnRunner = DoFnRunners.createDefault(
options,
fn,
DirectSideInputReader.of(runnerSideInputs),
outputManager,
mainOutputTag,
sideOutputTags,
DirectModeExecutionContext.create().getOrCreateStepContext(STEP_NAME, TRANSFORM_NAME, null),
counterSet.getAddCounterMutator(),
WindowingStrategy.globalDefault());
}
}