blob: ed35aad8b4d309dc14d84aa70d2a4aa73f25c377 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.direct;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.runners.direct.CommittedResult.OutputType;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Optional;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors;
import org.hamcrest.Matchers;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Tests for {@link DirectTransformExecutor}. */
@RunWith(JUnit4.class)
public class DirectTransformExecutorTest {
@Rule public ExpectedException thrown = ExpectedException.none();
private PCollection<String> created;
private AppliedPTransform<?, ?, ?> createdProducer;
private AppliedPTransform<?, ?, ?> downstreamProducer;
private CountDownLatch evaluatorCompleted;
private RegisteringCompletionCallback completionCallback;
private TransformExecutorService transformEvaluationState;
private BundleFactory bundleFactory;
@Mock private DirectMetrics metrics;
@Mock private EvaluationContext evaluationContext;
@Mock private TransformEvaluatorRegistry registry;
@Rule public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
bundleFactory = ImmutableListBundleFactory.create();
transformEvaluationState =
TransformExecutorServices.parallel(MoreExecutors.newDirectExecutorService());
evaluatorCompleted = new CountDownLatch(1);
completionCallback = new RegisteringCompletionCallback(evaluatorCompleted);
created = p.apply(Create.of("foo", "spam", "third"));
PCollection<KV<Integer, String>> downstream = created.apply(WithKeys.of(3));
DirectGraphs.performDirectOverrides(p);
DirectGraph graph = DirectGraphs.getGraph(p);
createdProducer = graph.getProducer(created);
downstreamProducer = graph.getProducer(downstream);
when(evaluationContext.getMetrics()).thenReturn(metrics);
}
@Test
public void callWithNullInputBundleFinishesBundleAndCompletes() throws Exception {
final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();
final AtomicBoolean finishCalled = new AtomicBoolean(false);
TransformEvaluator<Object> evaluator =
new TransformEvaluator<Object>() {
@Override
public void processElement(WindowedValue<Object> element) throws Exception {
throw new IllegalArgumentException("Shouldn't be called");
}
@Override
public TransformResult<Object> finishBundle() throws Exception {
finishCalled.set(true);
return result;
}
};
when(registry.forApplication(createdProducer, null)).thenReturn(evaluator);
DirectTransformExecutor<Object> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.emptyList(),
null,
createdProducer,
completionCallback,
transformEvaluationState);
executor.run();
assertThat(finishCalled.get(), is(true));
assertThat(completionCallback.handledResult, equalTo(result));
assertThat(completionCallback.handledException, is(nullValue()));
}
@Test
public void nullTransformEvaluatorTerminates() throws Exception {
when(registry.forApplication(createdProducer, null)).thenReturn(null);
DirectTransformExecutor<Object> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.emptyList(),
null,
createdProducer,
completionCallback,
transformEvaluationState);
executor.run();
assertThat(completionCallback.handledResult, is(nullValue()));
assertThat(completionCallback.handledEmpty, equalTo(true));
assertThat(completionCallback.handledException, is(nullValue()));
}
@Test
public void inputBundleProcessesEachElementFinishesAndCompletes() throws Exception {
final TransformResult<String> result =
StepTransformResult.<String>withoutHold(downstreamProducer).build();
final Collection<WindowedValue<String>> elementsProcessed = new ArrayList<>();
TransformEvaluator<String> evaluator =
new TransformEvaluator<String>() {
@Override
public void processElement(WindowedValue<String> element) throws Exception {
elementsProcessed.add(element);
}
@Override
public TransformResult<String> finishBundle() throws Exception {
return result;
}
};
WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
WindowedValue<String> spam = WindowedValue.valueInGlobalWindow("spam");
WindowedValue<String> third = WindowedValue.valueInGlobalWindow("third");
CommittedBundle<String> inputBundle =
bundleFactory.createBundle(created).add(foo).add(spam).add(third).commit(Instant.now());
when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
DirectTransformExecutor<String> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.emptyList(),
inputBundle,
downstreamProducer,
completionCallback,
transformEvaluationState);
Future<?> future = Executors.newSingleThreadExecutor().submit(executor);
evaluatorCompleted.await();
future.get();
assertThat(elementsProcessed, containsInAnyOrder(spam, third, foo));
assertThat(completionCallback.handledResult, equalTo(result));
assertThat(completionCallback.handledException, is(nullValue()));
}
@Test
@SuppressWarnings("FutureReturnValueIgnored") // expected exception checked via completionCallback
public void processElementThrowsExceptionCallsback() throws Exception {
final TransformResult<String> result =
StepTransformResult.<String>withoutHold(downstreamProducer).build();
final Exception exception = new Exception();
TransformEvaluator<String> evaluator =
new TransformEvaluator<String>() {
@Override
public void processElement(WindowedValue<String> element) throws Exception {
throw exception;
}
@Override
public TransformResult<String> finishBundle() throws Exception {
return result;
}
};
WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
CommittedBundle<String> inputBundle =
bundleFactory.createBundle(created).add(foo).commit(Instant.now());
when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
DirectTransformExecutor<String> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.emptyList(),
inputBundle,
downstreamProducer,
completionCallback,
transformEvaluationState);
Executors.newSingleThreadExecutor().submit(executor);
evaluatorCompleted.await();
assertThat(completionCallback.handledResult, is(nullValue()));
assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
}
@Test
@SuppressWarnings("FutureReturnValueIgnored") // expected exception checked via completionCallback
public void finishBundleThrowsExceptionCallsback() throws Exception {
final Exception exception = new Exception();
TransformEvaluator<String> evaluator =
new TransformEvaluator<String>() {
@Override
public void processElement(WindowedValue<String> element) throws Exception {}
@Override
public TransformResult<String> finishBundle() throws Exception {
throw exception;
}
};
CommittedBundle<String> inputBundle = bundleFactory.createBundle(created).commit(Instant.now());
when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
DirectTransformExecutor<String> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.emptyList(),
inputBundle,
downstreamProducer,
completionCallback,
transformEvaluationState);
Executors.newSingleThreadExecutor().submit(executor);
evaluatorCompleted.await();
assertThat(completionCallback.handledResult, is(nullValue()));
assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
}
@Test
public void callWithEnforcementAppliesEnforcement() throws Exception {
final TransformResult<Object> result =
StepTransformResult.withoutHold(downstreamProducer).build();
TransformEvaluator<Object> evaluator =
new TransformEvaluator<Object>() {
@Override
public void processElement(WindowedValue<Object> element) throws Exception {}
@Override
public TransformResult<Object> finishBundle() throws Exception {
return result;
}
};
WindowedValue<String> fooElem = WindowedValue.valueInGlobalWindow("foo");
WindowedValue<String> barElem = WindowedValue.valueInGlobalWindow("bar");
CommittedBundle<String> inputBundle =
bundleFactory.createBundle(created).add(fooElem).add(barElem).commit(Instant.now());
when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
TestEnforcementFactory enforcement = new TestEnforcementFactory();
DirectTransformExecutor<String> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.<ModelEnforcementFactory>singleton(enforcement),
inputBundle,
downstreamProducer,
completionCallback,
transformEvaluationState);
executor.run();
TestEnforcement<?> testEnforcement = enforcement.instance;
assertThat(testEnforcement.beforeElements, containsInAnyOrder(barElem, fooElem));
assertThat(testEnforcement.afterElements, containsInAnyOrder(barElem, fooElem));
assertThat(testEnforcement.finishedBundles, Matchers.contains(result));
}
@Test
public void callWithEnforcementThrowsOnFinishPropagates() throws Exception {
final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();
TransformEvaluator<Object> evaluator =
new TransformEvaluator<Object>() {
@Override
public void processElement(WindowedValue<Object> element) throws Exception {}
@Override
public TransformResult<Object> finishBundle() throws Exception {
return result;
}
};
WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
CommittedBundle<String> inputBundle =
bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
DirectTransformExecutor<String> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.<ModelEnforcementFactory>singleton(
new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_BUNDLE)),
inputBundle,
downstreamProducer,
completionCallback,
transformEvaluationState);
Future<?> task = Executors.newSingleThreadExecutor().submit(executor);
thrown.expectCause(isA(RuntimeException.class));
thrown.expectMessage("afterFinish");
task.get();
}
@Test
public void callWithEnforcementThrowsOnElementPropagates() throws Exception {
final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();
TransformEvaluator<Object> evaluator =
new TransformEvaluator<Object>() {
@Override
public void processElement(WindowedValue<Object> element) throws Exception {}
@Override
public TransformResult<Object> finishBundle() throws Exception {
return result;
}
};
WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
CommittedBundle<String> inputBundle =
bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
DirectTransformExecutor<String> executor =
new DirectTransformExecutor<>(
evaluationContext,
registry,
Collections.<ModelEnforcementFactory>singleton(
new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_ELEMENT)),
inputBundle,
downstreamProducer,
completionCallback,
transformEvaluationState);
Future<?> task = Executors.newSingleThreadExecutor().submit(executor);
thrown.expectCause(isA(RuntimeException.class));
thrown.expectMessage("afterElement");
task.get();
}
private static class RegisteringCompletionCallback implements CompletionCallback {
private TransformResult<?> handledResult = null;
private boolean handledEmpty = false;
private Exception handledException = null;
private final CountDownLatch onMethod;
private RegisteringCompletionCallback(CountDownLatch onMethod) {
this.onMethod = onMethod;
}
@Override
public CommittedResult handleResult(CommittedBundle<?> inputBundle, TransformResult<?> result) {
handledResult = result;
onMethod.countDown();
@SuppressWarnings("rawtypes")
Iterable unprocessedElements =
result.getUnprocessedElements() == null
? Collections.emptyList()
: result.getUnprocessedElements();
Optional<? extends CommittedBundle<?>> unprocessedBundle;
if (inputBundle == null || Iterables.isEmpty(unprocessedElements)) {
unprocessedBundle = Optional.absent();
} else {
unprocessedBundle =
Optional.<CommittedBundle<?>>of(inputBundle.withElements(unprocessedElements));
}
return CommittedResult.create(
result, unprocessedBundle, Collections.emptyList(), EnumSet.noneOf(OutputType.class));
}
@Override
public void handleEmpty(AppliedPTransform<?, ?, ?> transform) {
handledEmpty = true;
onMethod.countDown();
}
@Override
public void handleException(CommittedBundle<?> inputBundle, Exception e) {
handledException = e;
onMethod.countDown();
}
@Override
public void handleError(Error err) {
throw err;
}
}
private static class TestEnforcementFactory implements ModelEnforcementFactory {
private TestEnforcement<?> instance;
@Override
public <T> TestEnforcement<T> forBundle(
CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
TestEnforcement<T> newEnforcement = new TestEnforcement<>();
instance = newEnforcement;
return newEnforcement;
}
}
private static class TestEnforcement<T> implements ModelEnforcement<T> {
private final List<WindowedValue<T>> beforeElements = new ArrayList<>();
private final List<WindowedValue<T>> afterElements = new ArrayList<>();
private final List<TransformResult<?>> finishedBundles = new ArrayList<>();
@Override
public void beforeElement(WindowedValue<T> element) {
beforeElements.add(element);
}
@Override
public void afterElement(WindowedValue<T> element) {
afterElements.add(element);
}
@Override
public void afterFinish(
CommittedBundle<T> input,
TransformResult<T> result,
Iterable<? extends CommittedBundle<?>> outputs) {
finishedBundles.add(result);
}
}
private static class ThrowingEnforcementFactory implements ModelEnforcementFactory {
private final When when;
private ThrowingEnforcementFactory(When when) {
this.when = when;
}
enum When {
BEFORE_BUNDLE,
BEFORE_ELEMENT,
AFTER_ELEMENT,
AFTER_BUNDLE
}
@Override
public <T> ModelEnforcement<T> forBundle(
CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
if (when == When.BEFORE_BUNDLE) {
throw new RuntimeException("forBundle");
}
return new ThrowingEnforcement<>();
}
private class ThrowingEnforcement<T> implements ModelEnforcement<T> {
@Override
public void beforeElement(WindowedValue<T> element) {
if (when == When.BEFORE_ELEMENT) {
throw new RuntimeException("beforeElement");
}
}
@Override
public void afterElement(WindowedValue<T> element) {
if (when == When.AFTER_ELEMENT) {
throw new RuntimeException("afterElement");
}
}
@Override
public void afterFinish(
CommittedBundle<T> input,
TransformResult<T> result,
Iterable<? extends CommittedBundle<?>> outputs) {
if (when == When.AFTER_BUNDLE) {
throw new RuntimeException("afterFinish");
}
}
}
}
}