blob: f4322ee8f94757c6c7d30397f13fd1afff50fc76 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.direct;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.beam.sdk.testing.PCollectionViewTesting.materializeValuesFor;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
import java.util.Collection;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
import org.apache.beam.runners.direct.WatermarkManager.FiredTimers;
import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate;
import org.apache.beam.runners.local.StructuralKey;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link EvaluationContext}. */
@RunWith(JUnit4.class)
public class EvaluationContextTest {
private EvaluationContext context;
private PCollection<Integer> created;
private PCollection<KV<String, Integer>> downstream;
private PCollectionView<Iterable<Integer>> view;
private PCollection<Long> unbounded;
private DirectGraph graph;
private AppliedPTransform<?, ?, ?> createdProducer;
private AppliedPTransform<?, ?, ?> downstreamProducer;
private AppliedPTransform<?, ?, ?> viewProducer;
private AppliedPTransform<?, ?, ?> unboundedProducer;
@Rule public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
@Before
public void setup() {
DirectRunner runner = DirectRunner.fromOptions(PipelineOptionsFactory.create());
created = p.apply(Create.of(1, 2, 3));
downstream = created.apply(WithKeys.of("foo"));
view = created.apply(View.asIterable());
unbounded = p.apply(GenerateSequence.from(0));
p.replaceAll(runner.defaultTransformOverrides());
KeyedPValueTrackingVisitor keyedPValueTrackingVisitor = KeyedPValueTrackingVisitor.create();
p.traverseTopologically(keyedPValueTrackingVisitor);
BundleFactory bundleFactory = ImmutableListBundleFactory.create();
DirectGraphs.performDirectOverrides(p);
graph = DirectGraphs.getGraph(p);
context =
EvaluationContext.create(
NanosOffsetClock.create(),
bundleFactory,
graph,
keyedPValueTrackingVisitor.getKeyedPValues(),
Executors.newSingleThreadExecutor());
createdProducer = graph.getProducer(created);
downstreamProducer = graph.getProducer(downstream);
viewProducer = graph.getProducer(view);
unboundedProducer = graph.getProducer(unbounded);
}
@Test
public void writeToViewWriterThenReadReads() {
PCollectionViewWriter<?, Iterable<Integer>> viewWriter =
context.createPCollectionViewWriter(
PCollection.createPrimitiveOutputInternal(
p,
WindowingStrategy.globalDefault(),
IsBounded.BOUNDED,
IterableCoder.of(KvCoder.of(VoidCoder.of(), VarIntCoder.of()))),
view);
BoundedWindow window = new TestBoundedWindow(new Instant(1024L));
BoundedWindow second = new TestBoundedWindow(new Instant(899999L));
ImmutableList.Builder<WindowedValue<?>> valuesBuilder = ImmutableList.builder();
for (Object materializedValue : materializeValuesFor(View.asIterable(), 1)) {
valuesBuilder.add(
WindowedValue.of(
materializedValue, new Instant(1222), window, PaneInfo.ON_TIME_AND_ONLY_FIRING));
}
for (Object materializedValue : materializeValuesFor(View.asIterable(), 2)) {
valuesBuilder.add(
WindowedValue.of(
materializedValue,
new Instant(8766L),
second,
PaneInfo.createPane(true, false, Timing.ON_TIME, 0, 0)));
}
viewWriter.add((Iterable) valuesBuilder.build());
SideInputReader reader = context.createSideInputReader(ImmutableList.of(view));
assertThat(reader.get(view, window), containsInAnyOrder(1));
assertThat(reader.get(view, second), containsInAnyOrder(2));
ImmutableList.Builder<WindowedValue<?>> overwrittenValuesBuilder = ImmutableList.builder();
for (Object materializedValue : materializeValuesFor(View.asIterable(), 4444)) {
overwrittenValuesBuilder.add(
WindowedValue.of(
materializedValue,
new Instant(8677L),
second,
PaneInfo.createPane(false, true, Timing.LATE, 1, 1)));
}
viewWriter.add((Iterable) overwrittenValuesBuilder.build());
assertThat(reader.get(view, second), containsInAnyOrder(2));
// The cached value is served in the earlier reader
reader = context.createSideInputReader(ImmutableList.of(view));
assertThat(reader.get(view, second), containsInAnyOrder(4444));
}
@Test
public void getExecutionContextSameStepSameKeyState() {
DirectExecutionContext fooContext =
context.getExecutionContext(createdProducer, StructuralKey.of("foo", StringUtf8Coder.of()));
StateTag<BagState<Integer>> intBag = StateTags.bag("myBag", VarIntCoder.of());
DirectStepContext stepContext = fooContext.getStepContext("s1");
stepContext.stateInternals().state(StateNamespaces.global(), intBag).add(1);
context.handleResult(
ImmutableListBundleFactory.create()
.createKeyedBundle(StructuralKey.of("foo", StringUtf8Coder.of()), created)
.commit(Instant.now()),
ImmutableList.of(),
StepTransformResult.withoutHold(createdProducer)
.withState(stepContext.commitState())
.build());
DirectExecutionContext secondFooContext =
context.getExecutionContext(createdProducer, StructuralKey.of("foo", StringUtf8Coder.of()));
assertThat(
secondFooContext
.getStepContext("s1")
.stateInternals()
.state(StateNamespaces.global(), intBag)
.read(),
contains(1));
}
@Test
public void getExecutionContextDifferentKeysIndependentState() {
DirectExecutionContext fooContext =
context.getExecutionContext(createdProducer, StructuralKey.of("foo", StringUtf8Coder.of()));
StateTag<BagState<Integer>> intBag = StateTags.bag("myBag", VarIntCoder.of());
fooContext.getStepContext("s1").stateInternals().state(StateNamespaces.global(), intBag).add(1);
DirectExecutionContext barContext =
context.getExecutionContext(createdProducer, StructuralKey.of("bar", StringUtf8Coder.of()));
assertThat(barContext, not(equalTo(fooContext)));
assertThat(
barContext
.getStepContext("s1")
.stateInternals()
.state(StateNamespaces.global(), intBag)
.read(),
emptyIterable());
}
@Test
public void getExecutionContextDifferentStepsIndependentState() {
StructuralKey<?> myKey = StructuralKey.of("foo", StringUtf8Coder.of());
DirectExecutionContext fooContext = context.getExecutionContext(createdProducer, myKey);
StateTag<BagState<Integer>> intBag = StateTags.bag("myBag", VarIntCoder.of());
fooContext.getStepContext("s1").stateInternals().state(StateNamespaces.global(), intBag).add(1);
DirectExecutionContext barContext = context.getExecutionContext(downstreamProducer, myKey);
assertThat(
barContext
.getStepContext("s1")
.stateInternals()
.state(StateNamespaces.global(), intBag)
.read(),
emptyIterable());
}
@Test
public void handleResultStoresState() {
StructuralKey<?> myKey = StructuralKey.of("foo".getBytes(UTF_8), ByteArrayCoder.of());
DirectExecutionContext fooContext = context.getExecutionContext(downstreamProducer, myKey);
StateTag<BagState<Integer>> intBag = StateTags.bag("myBag", VarIntCoder.of());
CopyOnAccessInMemoryStateInternals<?> state = fooContext.getStepContext("s1").stateInternals();
BagState<Integer> bag = state.state(StateNamespaces.global(), intBag);
bag.add(1);
bag.add(2);
bag.add(4);
TransformResult<?> stateResult =
StepTransformResult.withoutHold(downstreamProducer).withState(state).build();
context.handleResult(
context.createKeyedBundle(myKey, created).commit(Instant.now()),
ImmutableList.of(),
stateResult);
DirectExecutionContext afterResultContext =
context.getExecutionContext(downstreamProducer, myKey);
CopyOnAccessInMemoryStateInternals<?> afterResultState =
afterResultContext.getStepContext("s1").stateInternals();
assertThat(afterResultState.state(StateNamespaces.global(), intBag).read(), contains(1, 2, 4));
}
@Test
public void callAfterOutputMustHaveBeenProducedAfterEndOfWatermarkCallsback() throws Exception {
final CountDownLatch callLatch = new CountDownLatch(1);
Runnable callback = callLatch::countDown;
// Should call back after the end of the global window
context.scheduleAfterOutputWouldBeProduced(
downstream, GlobalWindow.INSTANCE, WindowingStrategy.globalDefault(), callback);
TransformResult<?> result =
StepTransformResult.withHold(createdProducer, new Instant(0)).build();
context.handleResult(null, ImmutableList.of(), result);
// Difficult to demonstrate that we took no action in a multithreaded world; poll for a bit
// will likely be flaky if this logic is broken
assertThat(callLatch.await(500L, TimeUnit.MILLISECONDS), is(false));
TransformResult<?> finishedResult = StepTransformResult.withoutHold(createdProducer).build();
context.handleResult(null, ImmutableList.of(), finishedResult);
context.forceRefresh();
// Obtain the value via blocking call
assertThat(callLatch.await(1, TimeUnit.SECONDS), is(true));
}
@Test
public void callAfterOutputMustHaveBeenProducedAlreadyAfterCallsImmediately() throws Exception {
TransformResult<?> finishedResult = StepTransformResult.withoutHold(createdProducer).build();
context.handleResult(null, ImmutableList.of(), finishedResult);
final CountDownLatch callLatch = new CountDownLatch(1);
context.extractFiredTimers();
Runnable callback = callLatch::countDown;
context.scheduleAfterOutputWouldBeProduced(
downstream, GlobalWindow.INSTANCE, WindowingStrategy.globalDefault(), callback);
assertThat(callLatch.await(1, TimeUnit.SECONDS), is(true));
}
@Test
public void extractFiredTimersExtractsTimers() {
TransformResult<?> holdResult =
StepTransformResult.withHold(createdProducer, new Instant(0)).build();
context.handleResult(null, ImmutableList.of(), holdResult);
StructuralKey<?> key = StructuralKey.of("foo".length(), VarIntCoder.of());
TimerData toFire =
TimerData.of(StateNamespaces.global(), new Instant(100L), TimeDomain.EVENT_TIME);
TransformResult<?> timerResult =
StepTransformResult.withoutHold(downstreamProducer)
.withState(CopyOnAccessInMemoryStateInternals.withUnderlying(key, null))
.withTimerUpdate(TimerUpdate.builder(key).setTimer(toFire).build())
.build();
// haven't added any timers, must be empty
assertThat(context.extractFiredTimers(), emptyIterable());
context.handleResult(
context.createKeyedBundle(key, created).commit(Instant.now()),
ImmutableList.of(),
timerResult);
// timer hasn't fired
assertThat(context.extractFiredTimers(), emptyIterable());
TransformResult<?> advanceResult = StepTransformResult.withoutHold(createdProducer).build();
// Should cause the downstream timer to fire
context.handleResult(null, ImmutableList.of(), advanceResult);
Collection<FiredTimers<AppliedPTransform<?, ?, ?>>> fired = context.extractFiredTimers();
assertThat(Iterables.getOnlyElement(fired).getKey(), equalTo(key));
FiredTimers<AppliedPTransform<?, ?, ?>> firedForKey = Iterables.getOnlyElement(fired);
// Contains exclusively the fired timer
assertThat(firedForKey.getTimers(), contains(toFire));
// Don't reextract timers
assertThat(context.extractFiredTimers(), emptyIterable());
}
@Test
public void createKeyedBundleKeyed() {
StructuralKey<String> key = StructuralKey.of("foo", StringUtf8Coder.of());
CommittedBundle<KV<String, Integer>> keyedBundle =
context.createKeyedBundle(key, downstream).commit(Instant.now());
assertThat(keyedBundle.getKey(), equalTo(key));
}
@Test
public void isDoneWithUnboundedPCollection() {
assertThat(context.isDone(unboundedProducer), is(false));
context.handleResult(
null, ImmutableList.of(), StepTransformResult.withoutHold(unboundedProducer).build());
context.extractFiredTimers();
assertThat(context.isDone(unboundedProducer), is(true));
}
@Test
public void isDoneWithPartiallyDone() {
assertThat(context.isDone(), is(false));
UncommittedBundle<Integer> rootBundle = context.createBundle(created);
rootBundle.add(WindowedValue.valueInGlobalWindow(1));
CommittedResult handleResult =
context.handleResult(
null,
ImmutableList.of(),
StepTransformResult.<Integer>withoutHold(createdProducer)
.addOutput(rootBundle)
.build());
@SuppressWarnings("unchecked")
CommittedBundle<Integer> committedBundle =
(CommittedBundle<Integer>) Iterables.getOnlyElement(handleResult.getOutputs());
context.handleResult(
null, ImmutableList.of(), StepTransformResult.withoutHold(unboundedProducer).build());
assertThat(context.isDone(), is(false));
for (AppliedPTransform<?, ?, ?> consumers : graph.getPerElementConsumers(created)) {
context.handleResult(
committedBundle, ImmutableList.of(), StepTransformResult.withoutHold(consumers).build());
}
context.extractFiredTimers();
assertThat(context.isDone(), is(true));
}
private static class TestBoundedWindow extends BoundedWindow {
private final Instant ts;
public TestBoundedWindow(Instant ts) {
this.ts = ts;
}
@Override
public Instant maxTimestamp() {
return ts;
}
}
}