Adds support for SDF in ULR and the Java SDK.
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index b769552..a4017d8 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -216,13 +216,22 @@
google.protobuf.DoubleValue fraction_of_work = 5;
}
+ // An an Application should be scheduled after a delay.
+ message DelayedApplication {
+ // (Required) The delay in seconds.
+ double delay_sec = 1;
+
+ // (Required) The application that should be scheduled.
+ Application application = 2;
+ }
+
// Root applications that should replace the current bundle.
repeated Application primary_roots = 1;
// Root applications that have been removed from the current bundle and
// have to be executed in a separate bundle (e.g. in parallel on a different
// worker, or after the current bundle completes, etc.)
- repeated Application residual_roots = 2;
+ repeated DelayedApplication residual_roots = 2;
}
// A request to process a given bundle.
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index 44e3f42..a58769a 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -263,9 +263,27 @@
}
// Payload for all of these: ParDoPayload containing the user's SDF
enum SplittableParDoComponents {
+ // Pairs the input element with its initial restriction.
+ // Input: element; output: KV(element, restriction).
PAIR_WITH_RESTRICTION = 0 [(beam_urn) = "beam:transform:sdf_pair_with_restriction:v1"];
+
+ // Splits the restriction inside an element/restriction pair.
+ // Input: KV(element, restriction); output: KV(element, restriction).
SPLIT_RESTRICTION = 1 [(beam_urn) = "beam:transform:sdf_split_restriction:v1"];
+
+ // Applies the DoFn to every element/restriction pair in a uniquely keyed
+ // collection, in a splittable fashion.
+ // Input: KV(bytes, KV(element, restriction)); output: DoFn's output.
+ // The first "bytes" is an opaque unique key using the standard bytes coder.
+ // Typically a runner would rewrite this into a runner-specific grouping
+ // operation supporting state and timers, followed by PROCESS_ELEMENTS,
+ // with some runner-specific glue code in between.
PROCESS_KEYED_ELEMENTS = 2 [(beam_urn) = "beam:transform:sdf_process_keyed_elements:v1"];
+
+ // Like PROCESS_KEYED_ELEMENTS, but without the unique key - just elements
+ // and restrictions.
+ // Input: KV(element, restriction); output: DoFn's output.
+ PROCESS_ELEMENTS = 3 [(beam_urn) = "beam:transform:sdf_process_elements:v1"];
}
}
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index 8db73df..8db3156 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -470,14 +470,14 @@
if (doFn instanceof ProcessFn) {
@SuppressWarnings("unchecked")
- StateInternalsFactory<String> stateInternalsFactory =
- (StateInternalsFactory<String>) this.currentKeyStateInternals.getFactory();
+ StateInternalsFactory<byte[]> stateInternalsFactory =
+ (StateInternalsFactory<byte[]>) this.currentKeyStateInternals.getFactory();
@SuppressWarnings({ "rawtypes", "unchecked" })
ProcessFn<InputT, OutputT, Object, RestrictionTracker<Object, Object>>
splittableDoFn = (ProcessFn) doFn;
splittableDoFn.setStateInternalsFactory(stateInternalsFactory);
- TimerInternalsFactory<String> timerInternalsFactory = key -> currentKeyTimerInternals;
+ TimerInternalsFactory<byte[]> timerInternalsFactory = key -> currentKeyTimerInternals;
splittableDoFn.setTimerInternalsFactory(timerInternalsFactory);
splittableDoFn.setProcessElementInvoker(
new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java
index f57eb99..00554f7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/Environments.java
@@ -40,6 +40,7 @@
ImmutableMap.<String, EnvironmentIdExtractor>builder()
.put(PTransformTranslation.COMBINE_TRANSFORM_URN, Environments::combineExtractor)
.put(PTransformTranslation.PAR_DO_TRANSFORM_URN, Environments::parDoExtractor)
+ .put(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN, Environments::parDoExtractor)
.put(PTransformTranslation.READ_TRANSFORM_URN, Environments::readExtractor)
.put(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, Environments::windowExtractor)
.build();
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
index b1d6015..d3a6281 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
@@ -85,6 +85,8 @@
getUrn(StandardPTransforms.Composites.WRITE_FILES);
public static final String SPLITTABLE_PROCESS_KEYED_URN =
getUrn(SplittableParDoComponents.PROCESS_KEYED_ELEMENTS);
+ public static final String SPLITTABLE_PROCESS_ELEMENTS_URN =
+ getUrn(SplittableParDoComponents.PROCESS_ELEMENTS);
private static final Map<Class<? extends PTransform>, TransformPayloadTranslator>
KNOWN_PAYLOAD_TRANSLATORS = loadTransformPayloadTranslators();
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
index a8c6bb6..c00b63d 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
@@ -172,4 +172,8 @@
public Environment getEnvironment(String environmentId) {
return components.getEnvironmentsOrThrow(environmentId);
}
+
+ public Components getComponents() {
+ return components;
+ }
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
index 8bd3960..f8a43e9 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
@@ -23,6 +23,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -148,7 +149,7 @@
.invokeGetRestrictionCoder(input.getPipeline().getCoderRegistry());
Coder<KV<InputT, RestrictionT>> splitCoder = KvCoder.of(input.getCoder(), restrictionCoder);
- PCollection<KV<String, KV<InputT, RestrictionT>>> keyedRestrictions =
+ PCollection<KV<byte[], KV<InputT, RestrictionT>>> keyedRestrictions =
input
.apply(
"Pair with initial restriction",
@@ -198,7 +199,7 @@
* {@link KV KVs} keyed with arbitrary but globally unique keys.
*/
public static class ProcessKeyedElements<InputT, OutputT, RestrictionT>
- extends PTransform<PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple> {
+ extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
private final DoFn<InputT, OutputT> fn;
private final Coder<InputT> elementCoder;
private final Coder<RestrictionT> restrictionCoder;
@@ -269,7 +270,7 @@
}
@Override
- public PCollectionTuple expand(PCollection<KV<String, KV<InputT, RestrictionT>>> input) {
+ public PCollectionTuple expand(PCollection<KV<byte[], KV<InputT, RestrictionT>>> input) {
return createPrimitiveOutputFor(
input, fn, mainOutputTag, additionalOutputTags, outputTagsToCoders, windowingStrategy);
}
@@ -395,10 +396,10 @@
* collection is effectively the same elements as input, but the per-key state and timers are now
* effectively per-element.
*/
- private static class RandomUniqueKeyFn<T> implements SerializableFunction<T, String> {
+ private static class RandomUniqueKeyFn<T> implements SerializableFunction<T, byte[]> {
@Override
- public String apply(T input) {
- return UUID.randomUUID().toString();
+ public byte[] apply(T input) {
+ return UUID.randomUUID().toString().getBytes(StandardCharsets.UTF_8);
}
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java
index 4e913f9..25129dc 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/PipelineValidator.java
@@ -248,7 +248,7 @@
}
private static void validateExecutableStage(
- String id, PTransform transform, Components outerComponentsIgnored) throws Exception {
+ String id, PTransform transform, Components outerComponents) throws Exception {
ExecutableStagePayload payload =
ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
index 8c360ef..f560b51 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
@@ -39,13 +39,13 @@
*/
public class ProcessFnRunner<InputT, OutputT, RestrictionT>
implements PushbackSideInputDoFnRunner<
- KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> {
- private final DoFnRunner<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> underlying;
+ KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
+ private final DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> underlying;
private final Collection<PCollectionView<?>> views;
private final ReadyCheckingSideInputReader sideInputReader;
public ProcessFnRunner(
- DoFnRunner<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> underlying,
+ DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> underlying,
Collection<PCollectionView<?>> views,
ReadyCheckingSideInputReader sideInputReader) {
this.underlying = underlying;
@@ -54,7 +54,7 @@
}
@Override
- public DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> getFn() {
+ public DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> getFn() {
return underlying.getFn();
}
@@ -64,9 +64,9 @@
}
@Override
- public Iterable<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>>
+ public Iterable<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>>
processElementInReadyWindows(
- WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>> windowedKWI) {
+ WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> windowedKWI) {
checkTrivialOuterWindows(windowedKWI);
BoundedWindow window = getUnderlyingWindow(windowedKWI.getValue());
if (!isReady(window)) {
@@ -88,7 +88,7 @@
}
private static <T> void checkTrivialOuterWindows(
- WindowedValue<KeyedWorkItem<String, T>> windowedKWI) {
+ WindowedValue<KeyedWorkItem<byte[], T>> windowedKWI) {
// In practice it will be in 0 or 1 windows (ValueInEmptyWindows or ValueInGlobalWindow)
Collection<? extends BoundedWindow> outerWindows = windowedKWI.getWindows();
if (!outerWindows.isEmpty()) {
@@ -104,7 +104,7 @@
}
}
- private static <T> BoundedWindow getUnderlyingWindow(KeyedWorkItem<String, T> kwi) {
+ private static <T> BoundedWindow getUnderlyingWindow(KeyedWorkItem<byte[], T> kwi) {
if (Iterables.isEmpty(kwi.elementsIterable())) {
// ProcessFn sets only a single timer.
TimerData timer = Iterables.getOnlyElement(kwi.timersIterable());
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
index 45f8b4b..fe40afd 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
@@ -28,9 +28,9 @@
import org.apache.beam.runners.core.construction.ReplacementOutputs;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.core.construction.SplittableParDo.ProcessKeyedElements;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
@@ -100,14 +100,14 @@
/** Overrides a {@link ProcessKeyedElements} into {@link SplittableProcessViaKeyedWorkItems}. */
public static class OverrideFactory<InputT, OutputT, RestrictionT>
implements PTransformOverrideFactory<
- PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple,
+ PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple,
ProcessKeyedElements<InputT, OutputT, RestrictionT>> {
@Override
public PTransformReplacement<
- PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple>
+ PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple>
getReplacementTransform(
AppliedPTransform<
- PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple,
+ PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple,
ProcessKeyedElements<InputT, OutputT, RestrictionT>>
transform) {
return PTransformReplacement.of(
@@ -127,7 +127,7 @@
* method for a splittable {@link DoFn}.
*/
public static class SplittableProcessViaKeyedWorkItems<InputT, OutputT, RestrictionT>
- extends PTransform<PCollection<KV<String, KV<InputT, RestrictionT>>>, PCollectionTuple> {
+ extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
public SplittableProcessViaKeyedWorkItems(
@@ -136,13 +136,13 @@
}
@Override
- public PCollectionTuple expand(PCollection<KV<String, KV<InputT, RestrictionT>>> input) {
+ public PCollectionTuple expand(PCollection<KV<byte[], KV<InputT, RestrictionT>>> input) {
return input
.apply(new GBKIntoKeyedWorkItems<>())
.setCoder(
KeyedWorkItemCoder.of(
- StringUtf8Coder.of(),
- ((KvCoder<String, KV<InputT, RestrictionT>>) input.getCoder()).getValueCoder(),
+ ByteArrayCoder.of(),
+ ((KvCoder<byte[], KV<InputT, RestrictionT>>) input.getCoder()).getValueCoder(),
input.getWindowingStrategy().getWindowFn().windowCoder()))
.apply(new ProcessElements<>(original));
}
@@ -152,7 +152,7 @@
public static class ProcessElements<
InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
extends PTransform<
- PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, PCollectionTuple> {
+ PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
public ProcessElements(ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {
@@ -186,7 +186,7 @@
@Override
public PCollectionTuple expand(
- PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>> input) {
+ PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> input) {
return ProcessKeyedElements.createPrimitiveOutputFor(
input,
original.getFn(),
@@ -212,7 +212,7 @@
@VisibleForTesting
public static class ProcessFn<
InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
- extends DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> {
+ extends DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
/**
* The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is
* acquired during the first {@link DoFn.ProcessElement} call for each element and restriction,
@@ -245,8 +245,8 @@
private final Coder<RestrictionT> restrictionCoder;
private final WindowingStrategy<InputT, ?> inputWindowingStrategy;
- private transient @Nullable StateInternalsFactory<String> stateInternalsFactory;
- private transient @Nullable TimerInternalsFactory<String> timerInternalsFactory;
+ private transient @Nullable StateInternalsFactory<byte[]> stateInternalsFactory;
+ private transient @Nullable TimerInternalsFactory<byte[]> timerInternalsFactory;
private transient @Nullable SplittableProcessElementInvoker<
InputT, OutputT, RestrictionT, TrackerT>
processElementInvoker;
@@ -270,11 +270,11 @@
this.restrictionTag = StateTags.value("restriction", restrictionCoder);
}
- public void setStateInternalsFactory(StateInternalsFactory<String> stateInternalsFactory) {
+ public void setStateInternalsFactory(StateInternalsFactory<byte[]> stateInternalsFactory) {
this.stateInternalsFactory = stateInternalsFactory;
}
- public void setTimerInternalsFactory(TimerInternalsFactory<String> timerInternalsFactory) {
+ public void setTimerInternalsFactory(TimerInternalsFactory<byte[]> timerInternalsFactory) {
this.timerInternalsFactory = timerInternalsFactory;
}
@@ -322,7 +322,7 @@
@ProcessElement
public void processElement(final ProcessContext c) {
- String key = c.element().key();
+ byte[] key = c.element().key();
StateInternals stateInternals = stateInternalsFactory.stateInternalsForKey(key);
TimerInternals timerInternals = timerInternalsFactory.timerInternalsForKey(key);
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
index 617e557..c0b70d4 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
@@ -31,6 +31,7 @@
import static org.junit.Assert.assertTrue;
import java.io.Serializable;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -124,7 +125,7 @@
PositionT,
TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
implements AutoCloseable {
- private final DoFnTester<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> tester;
+ private final DoFnTester<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> tester;
private Instant currentProcessingTime;
private InMemoryTimerInternals timerInternals;
@@ -200,7 +201,8 @@
void startElement(WindowedValue<KV<InputT, RestrictionT>> windowedValue) throws Exception {
tester.processElement(
- KeyedWorkItems.elementsWorkItem("key", Collections.singletonList(windowedValue)));
+ KeyedWorkItems.elementsWorkItem(
+ "key".getBytes(StandardCharsets.UTF_8), Collections.singletonList(windowedValue)));
}
/**
@@ -219,7 +221,8 @@
if (timers.isEmpty()) {
return false;
}
- tester.processElement(KeyedWorkItems.timersWorkItem("key", timers));
+ tester.processElement(
+ KeyedWorkItems.timersWorkItem("key".getBytes(StandardCharsets.UTF_8), timers));
return true;
}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 6fc14a2..00ab794 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -61,7 +61,7 @@
* in the direct runner. Currently overrides applications of <a
* href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a>.
*/
-class ParDoMultiOverrideFactory<InputT, OutputT>
+public class ParDoMultiOverrideFactory<InputT, OutputT>
implements PTransformOverrideFactory<
PCollection<? extends InputT>, PCollectionTuple,
PTransform<PCollection<? extends InputT>, PCollectionTuple>> {
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
index 58adb94..0b7cf33 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
@@ -53,7 +53,7 @@
PositionT,
TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
implements TransformEvaluatorFactory {
- private final ParDoEvaluatorFactory<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT>
+ private final ParDoEvaluatorFactory<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT>
delegateFactory;
private final ScheduledExecutorService ses;
private final EvaluationContext evaluationContext;
@@ -107,9 +107,9 @@
}
@SuppressWarnings({"unchecked", "rawtypes"})
- private TransformEvaluator<KeyedWorkItem<String, KV<InputT, RestrictionT>>> createEvaluator(
+ private TransformEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> createEvaluator(
AppliedPTransform<
- PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, PCollectionTuple,
+ PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple,
ProcessElements<InputT, OutputT, RestrictionT, TrackerT>>
application,
CommittedBundle<InputT> inputBundle)
@@ -118,15 +118,15 @@
application.getTransform();
final DoFnLifecycleManagerRemovingTransformEvaluator
- <KeyedWorkItem<String, KV<InputT, RestrictionT>>> evaluator =
+ <KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> evaluator =
delegateFactory.createEvaluator(
(AppliedPTransform) application,
- (PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>) inputBundle.getPCollection(),
+ (PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>) inputBundle.getPCollection(),
inputBundle.getKey(),
application.getTransform().getSideInputs(),
application.getTransform().getMainOutputTag(),
application.getTransform().getAdditionalOutputTags().getAll());
- final ParDoEvaluator<KeyedWorkItem<String, KV<InputT, RestrictionT>>> pde =
+ final ParDoEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> pde =
evaluator.getParDoEvaluator();
final ProcessFn<InputT, OutputT, RestrictionT, TrackerT> processFn =
(ProcessFn<InputT, OutputT, RestrictionT, TrackerT>)
@@ -176,7 +176,7 @@
}
private static <InputT, OutputT, RestrictionT>
- ParDoEvaluator.DoFnRunnerFactory<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT>
+ ParDoEvaluator.DoFnRunnerFactory<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT>
processFnRunnerFactory() {
return (options,
fn,
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java
index ab612a6..d115cd5 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/CopyOnAccessInMemoryStateInternals.java
@@ -95,7 +95,7 @@
*
* @return this table
*/
- public CopyOnAccessInMemoryStateInternals commit() {
+ public CopyOnAccessInMemoryStateInternals<K> commit() {
table.commit();
return this;
}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java
index 8f9127b..a01f1cd 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ImmutableListBundleFactory.java
@@ -21,6 +21,7 @@
import static com.google.common.base.Preconditions.checkState;
import com.google.auto.value.AutoValue;
+import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import java.util.Iterator;
import javax.annotation.Nonnull;
@@ -117,6 +118,13 @@
return CommittedImmutableListBundle.create(
pcollection, key, committedElements, minSoFar, synchronizedCompletionTime);
}
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("elements", elements.build())
+ .toString();
+ }
}
@AutoValue
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java
index 1e575f2..e6748ee 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/ReferenceRunner.java
@@ -19,17 +19,23 @@
package org.apache.beam.runners.direct.portable;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static org.apache.beam.runners.core.construction.SyntheticComponents.uniqueId;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.protobuf.Struct;
import java.io.File;
+import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.model.fnexecution.v1.ProvisionApi.ProvisionInfo;
import org.apache.beam.model.fnexecution.v1.ProvisionApi.Resources;
@@ -41,17 +47,20 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform.Builder;
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.runners.core.construction.ModelCoders;
import org.apache.beam.runners.core.construction.ModelCoders.KvCoderComponents;
import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
import org.apache.beam.runners.core.construction.graph.PipelineValidator;
import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
import org.apache.beam.runners.core.construction.graph.ProtoOverrides.TransformReplacement;
+import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.runners.direct.ExecutableGraph;
import org.apache.beam.runners.direct.portable.artifact.LocalFileSystemArtifactRetrievalService;
import org.apache.beam.runners.direct.portable.artifact.UnsupportedArtifactRetrievalService;
@@ -73,6 +82,7 @@
import org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter;
import org.apache.beam.runners.fnexecution.provisioning.StaticGrpcProvisionService;
import org.apache.beam.runners.fnexecution.state.GrpcStateService;
+import org.apache.beam.runners.fnexecution.wire.LengthPrefixUnknownCoders;
import org.apache.beam.sdk.fn.IdGenerators;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -106,15 +116,18 @@
private RunnerApi.Pipeline executable(RunnerApi.Pipeline original) {
RunnerApi.Pipeline p = original;
-
+ PipelineValidator.validate(p);
p =
ProtoOverrides.updateTransform(
- PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN,
+ PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
p,
- new PortableGroupByKeyReplacer());
- PipelineValidator.validate(p);
-
+ new SplittableProcessKeyedReplacer());
+ p =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, p, new PortableGroupByKeyReplacer());
p = GreedyPipelineFuser.fuse(p).toPipeline();
+
+ p = foldFeedSDFIntoExecutableStage(p);
PipelineValidator.validate(p);
return p;
@@ -325,6 +338,219 @@
}
}
+ /**
+ * Replaces the {@link PTransformTranslation#SPLITTABLE_PROCESS_KEYED_URN} with a {@link
+ * DirectGroupByKey#DIRECT_GBKO_URN} (construct keyed work items) followed by a {@link
+ * SplittableRemoteStageEvaluatorFactory#FEED_SDF_URN} (convert the keyed work items to
+ * element/restriction pairs that later go into {@link
+ * PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN}).
+ */
+ @VisibleForTesting
+ static class SplittableProcessKeyedReplacer implements TransformReplacement {
+ @Override
+ public MessageWithComponents getReplacement(String spkId, ComponentsOrBuilder components) {
+ PTransform spk = components.getTransformsOrThrow(spkId);
+ checkArgument(
+ PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN.equals(spk.getSpec().getUrn()),
+ "URN must be %s, got %s",
+ PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
+ spk.getSpec().getUrn());
+
+ Components.Builder newComponents = Components.newBuilder();
+ newComponents.putAllCoders(components.getCodersMap());
+
+ Builder newPTransform = spk.toBuilder();
+
+ String inputId = getOnlyElement(spk.getInputsMap().values());
+ PCollection input = components.getPcollectionsOrThrow(inputId);
+
+ // This is a Coder<KV<String, KV<ElementT, RestrictionT>>>
+ Coder inputCoder = components.getCodersOrThrow(input.getCoderId());
+ KvCoderComponents kvComponents = ModelCoders.getKvCoderComponents(inputCoder);
+ String windowCoderId =
+ components
+ .getWindowingStrategiesOrThrow(input.getWindowingStrategyId())
+ .getWindowCoderId();
+
+ // === Construct a raw GBK returning KeyedWorkItem's ===
+ String kwiCollectionId =
+ uniqueId(String.format("%s.kwi", spkId), components::containsPcollections);
+ {
+ // This coder isn't actually required for the pipeline to function properly - the KWIs can
+ // be passed around as pure java objects with no coding of the values, but it approximates a
+ // full pipeline.
+ Coder kwiCoder =
+ Coder.newBuilder()
+ .setSpec(
+ SdkFunctionSpec.newBuilder()
+ .setSpec(FunctionSpec.newBuilder().setUrn("beam:direct:keyedworkitem:v1")))
+ .addAllComponentCoderIds(
+ ImmutableList.of(
+ kvComponents.keyCoderId(), kvComponents.valueCoderId(), windowCoderId))
+ .build();
+ String kwiCoderId =
+ uniqueId(
+ String.format(
+ "keyed_work_item(%s:%s)",
+ kvComponents.keyCoderId(), kvComponents.valueCoderId()),
+ components::containsCoders);
+
+ PCollection kwiCollection =
+ input.toBuilder().setUniqueName(kwiCollectionId).setCoderId(kwiCoderId).build();
+ String rawGbkId =
+ uniqueId(String.format("%s/RawGBK", spkId), components::containsTransforms);
+ PTransform rawGbk =
+ PTransform.newBuilder()
+ .setUniqueName(String.format("%s/RawGBK", spk.getUniqueName()))
+ .putAllInputs(spk.getInputsMap())
+ .setSpec(FunctionSpec.newBuilder().setUrn(DirectGroupByKey.DIRECT_GBKO_URN))
+ .putOutputs("output", kwiCollectionId)
+ .build();
+
+ newComponents
+ .putCoders(kwiCoderId, kwiCoder)
+ .putPcollections(kwiCollectionId, kwiCollection)
+ .putTransforms(rawGbkId, rawGbk);
+ newPTransform.addSubtransforms(rawGbkId);
+ }
+
+ // === Construct a "Feed SDF" operation that converts KWI to KV<ElementT, RestrictionT> ===
+ String feedSDFCollectionId =
+ uniqueId(String.format("%s.feed", spkId), components::containsPcollections);
+ {
+ String feedSDFCoderId =
+ uniqueId(String.format("%s/FeedSDF-wire", spkId), components::containsCoders);
+ String elementRestrictionCoderId = kvComponents.valueCoderId();
+ MessageWithComponents feedSDFCoder =
+ LengthPrefixUnknownCoders.forCoder(
+ elementRestrictionCoderId, newComponents.build(), false);
+
+ PCollection feedSDFCollection =
+ input.toBuilder().setUniqueName(feedSDFCollectionId).setCoderId(feedSDFCoderId).build();
+ String feedSDFId =
+ uniqueId(String.format("%s/FeedSDF", spkId), components::containsTransforms);
+ PTransform feedSDF =
+ PTransform.newBuilder()
+ .setUniqueName(String.format("%s/FeedSDF", spk.getUniqueName()))
+ .putInputs("input", kwiCollectionId)
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(SplittableRemoteStageEvaluatorFactory.FEED_SDF_URN))
+ .putOutputs("output", feedSDFCollectionId)
+ .build();
+
+ newComponents
+ .putCoders(feedSDFCoderId, feedSDFCoder.getCoder())
+ .putAllCoders(feedSDFCoder.getComponents().getCodersMap())
+ .putPcollections(feedSDFCollectionId, feedSDFCollection)
+ .putTransforms(feedSDFId, feedSDF);
+ newPTransform.addSubtransforms(feedSDFId);
+ }
+
+ // === Construct the SPLITTABLE_PROCESS_ELEMENTS operation
+ {
+ String runSDFId =
+ uniqueId(String.format("%s/RunSDF", spkId), components::containsTransforms);
+ PTransform runSDF =
+ PTransform.newBuilder()
+ .setUniqueName(String.format("%s/RunSDF", spk.getUniqueName()))
+ .putInputs("input", feedSDFCollectionId)
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN)
+ .setPayload(spk.getSpec().getPayload()))
+ .putAllOutputs(spk.getOutputsMap())
+ .build();
+ newComponents.putTransforms(runSDFId, runSDF);
+ newPTransform.addSubtransforms(runSDFId);
+ }
+
+ return MessageWithComponents.newBuilder()
+ .setPtransform(newPTransform.build())
+ .setComponents(newComponents)
+ .build();
+ }
+ }
+
+ /**
+ * Finds FEED_SDF nodes followed by an ExecutableStage and replaces them by a single {@link
+ * SplittableRemoteStageEvaluatorFactory#URN} stage that feeds the ExecutableStage knowing that
+ * the first instruction in the stage is an SDF.
+ */
+ private static Pipeline foldFeedSDFIntoExecutableStage(Pipeline p) {
+ Pipeline.Builder newPipeline = p.toBuilder();
+ Components.Builder newPipelineComponents = newPipeline.getComponentsBuilder();
+
+ QueryablePipeline q = QueryablePipeline.forPipeline(p);
+ String feedSdfUrn = SplittableRemoteStageEvaluatorFactory.FEED_SDF_URN;
+ List<PTransformNode> feedSDFNodes =
+ q.getTransforms()
+ .stream()
+ .filter(node -> node.getTransform().getSpec().getUrn().equals(feedSdfUrn))
+ .collect(Collectors.toList());
+ Map<String, PTransformNode> stageToFeeder = Maps.newHashMap();
+ for (PTransformNode node : feedSDFNodes) {
+ PCollectionNode output = Iterables.getOnlyElement(q.getOutputPCollections(node));
+ PTransformNode consumer = Iterables.getOnlyElement(q.getPerElementConsumers(output));
+ String consumerUrn = consumer.getTransform().getSpec().getUrn();
+ checkState(
+ consumerUrn.equals(ExecutableStage.URN),
+ "Expected all FeedSDF nodes to be consumed by an ExecutableStage, "
+ + "but %s is consumed by %s which is %s",
+ node.getId(),
+ consumer.getId(),
+ consumerUrn);
+ stageToFeeder.put(consumer.getId(), node);
+ }
+
+ // Copy over root transforms except for the excluded FEED_SDF transforms.
+ Set<String> feedSDFIds =
+ feedSDFNodes.stream().map(PTransformNode::getId).collect(Collectors.toSet());
+ newPipeline.clearRootTransformIds();
+ for (String rootId : p.getRootTransformIdsList()) {
+ if (!feedSDFIds.contains(rootId)) {
+ newPipeline.addRootTransformIds(rootId);
+ }
+ }
+ // Copy over all transforms, except FEED_SDF transforms are skipped, and ExecutableStage's
+ // feeding from them are replaced.
+ for (PTransformNode node : q.getTransforms()) {
+ if (feedSDFNodes.contains(node)) {
+ // These transforms are skipped and handled separately.
+ continue;
+ }
+ if (!stageToFeeder.containsKey(node.getId())) {
+ // This transform is unchanged
+ newPipelineComponents.putTransforms(node.getId(), node.getTransform());
+ continue;
+ }
+ // "node" is an ExecutableStage transform feeding from an SDF.
+ PTransformNode feedSDFNode = stageToFeeder.get(node.getId());
+ PCollectionNode rawGBKOutput =
+ Iterables.getOnlyElement(q.getPerElementInputPCollections(feedSDFNode));
+
+ // Replace the ExecutableStage transform.
+ newPipelineComponents.putTransforms(
+ node.getId(),
+ node.getTransform()
+ .toBuilder()
+ .mergeSpec(
+ // Change URN from ExecutableStage.URN to URN of the ULR's splittable executable
+ // stage evaluator.
+ FunctionSpec.newBuilder()
+ .setUrn(SplittableRemoteStageEvaluatorFactory.URN)
+ .build())
+ .putInputs(
+ // The splittable executable stage now reads from the raw GBK, instead of
+ // from the now non-existent FEED_SDF.
+ Iterables.getOnlyElement(node.getTransform().getInputsMap().keySet()),
+ rawGBKOutput.getId())
+ .build());
+ }
+
+ return newPipeline.build();
+ }
+
private enum EnvironmentType {
DOCKER,
IN_PROCESS
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/SplittableRemoteStageEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/SplittableRemoteStageEvaluatorFactory.java
new file mode 100644
index 0000000..fb424c5
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/SplittableRemoteStageEvaluatorFactory.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct.portable;
+
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Collection;
+import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
+import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
+import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
+import org.apache.beam.runners.fnexecution.control.RemoteBundle;
+import org.apache.beam.runners.fnexecution.splittabledofn.SDFFeederViaStateAndTimers;
+import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
+import org.apache.beam.runners.fnexecution.wire.WireCoders;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+
+/**
+ * The {@link TransformEvaluatorFactory} for {@link #URN}, which reads from a {@link
+ * DirectGroupByKey#DIRECT_GBKO_URN} and feeds the data, using state and timers, to a {@link
+ * ExecutableStage} whose first instruction is an SDF.
+ */
+class SplittableRemoteStageEvaluatorFactory implements TransformEvaluatorFactory {
+ public static final String URN = "urn:beam:directrunner:transforms:splittable_remote_stage:v1";
+
+ // A fictional transform that transforms from KWI<unique key, KV<element, restriction>>
+ // to simply KV<element, restriction> taken by the SDF inside the ExecutableStage.
+ public static final String FEED_SDF_URN = "urn:beam:directrunner:transforms:feed_sdf:v1";
+
+ private final BundleFactory bundleFactory;
+ private final JobBundleFactory jobBundleFactory;
+ private final StepStateAndTimers.Provider stp;
+
+ SplittableRemoteStageEvaluatorFactory(
+ BundleFactory bundleFactory,
+ JobBundleFactory jobBundleFactory,
+ StepStateAndTimers.Provider stepStateAndTimers) {
+ this.bundleFactory = bundleFactory;
+ this.jobBundleFactory = jobBundleFactory;
+ this.stp = stepStateAndTimers;
+ }
+
+ @Nullable
+ @Override
+ public <InputT> TransformEvaluator<InputT> forApplication(
+ PTransformNode application, CommittedBundle<?> inputBundle) throws Exception {
+ return new SplittableRemoteStageEvaluator(
+ bundleFactory,
+ jobBundleFactory,
+ stp.forStepAndKey(application, inputBundle.getKey()),
+ application);
+ }
+
+ @Override
+ public void cleanup() throws Exception {
+ jobBundleFactory.close();
+ }
+
+ private static class SplittableRemoteStageEvaluator<InputT, RestrictionT>
+ implements TransformEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> {
+ private final PTransformNode transform;
+ private final ExecutableStage stage;
+
+ private final CopyOnAccessInMemoryStateInternals<byte[]> stateInternals;
+ private final DirectTimerInternals timerInternals;
+ private final RemoteBundle<KV<InputT, RestrictionT>> bundle;
+ private final Collection<UncommittedBundle<?>> outputs;
+
+ private final SDFFeederViaStateAndTimers<InputT, RestrictionT> feeder;
+
+ private SplittableRemoteStageEvaluator(
+ BundleFactory bundleFactory,
+ JobBundleFactory jobBundleFactory,
+ StepStateAndTimers<byte[]> stp,
+ PTransformNode transform)
+ throws Exception {
+ this.stateInternals = stp.stateInternals();
+ this.timerInternals = stp.timerInternals();
+ this.transform = transform;
+ this.stage =
+ ExecutableStage.fromPayload(
+ ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload()));
+ this.outputs = new ArrayList<>();
+ this.bundle =
+ jobBundleFactory
+ .<KV<InputT, RestrictionT>>forStage(stage)
+ .getBundle(
+ BundleFactoryOutputReceiverFactory.create(
+ bundleFactory, stage.getComponents(), outputs::add),
+ StateRequestHandler.unsupported(),
+ new BundleProgressHandler() {
+ @Override
+ public void onProgress(ProcessBundleProgressResponse progress) {
+ if (progress.hasSplit()) {
+ feeder.split(progress.getSplit());
+ }
+ }
+
+ @Override
+ public void onCompleted(ProcessBundleResponse response) {
+ if (response.hasSplit()) {
+ feeder.split(response.getSplit());
+ }
+ }
+ });
+
+ FullWindowedValueCoder<KV<InputT, RestrictionT>> windowedValueCoder =
+ (FullWindowedValueCoder<KV<InputT, RestrictionT>>)
+ WireCoders.<KV<InputT, RestrictionT>>instantiateRunnerWireCoder(
+ stage.getInputPCollection(), stage.getComponents());
+ KvCoder<InputT, RestrictionT> kvCoder =
+ ((KvCoder<InputT, RestrictionT>) windowedValueCoder.getValueCoder());
+
+ this.feeder =
+ new SDFFeederViaStateAndTimers<>(
+ stateInternals,
+ timerInternals,
+ kvCoder.getKeyCoder(),
+ kvCoder.getValueCoder(),
+ (Coder<BoundedWindow>) windowedValueCoder.getWindowCoder());
+ }
+
+ @Override
+ public void processElement(
+ WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> windowedWorkItem)
+ throws Exception {
+ KeyedWorkItem<byte[], KV<InputT, RestrictionT>> kwi = windowedWorkItem.getValue();
+ WindowedValue<KV<InputT, RestrictionT>> elementRestriction =
+ Iterables.getOnlyElement(kwi.elementsIterable(), null);
+ if (elementRestriction != null) {
+ feeder.seed(elementRestriction);
+ } else {
+ elementRestriction = feeder.resume(Iterables.getOnlyElement(kwi.timersIterable()));
+ }
+ bundle.getInputReceiver().accept(elementRestriction);
+ }
+
+ @Override
+ public TransformResult<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> finishBundle()
+ throws Exception {
+ bundle.close();
+ feeder.commit();
+ CopyOnAccessInMemoryStateInternals<byte[]> state = stateInternals.commit();
+ StepTransformResult.Builder<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> result =
+ StepTransformResult.withHold(transform, state.getEarliestWatermarkHold());
+ return result
+ .addOutput(outputs)
+ .withState(state)
+ .withTimerUpdate(timerInternals.getTimerUpdate())
+ .build();
+ }
+ }
+}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java
index a27bc07..4a502c7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/portable/TransformEvaluatorRegistry.java
@@ -67,6 +67,10 @@
.put(
ExecutableStage.URN,
new RemoteStageEvaluatorFactory(bundleFactory, jobBundleFactory))
+ .put(
+ SplittableRemoteStageEvaluatorFactory.URN,
+ new SplittableRemoteStageEvaluatorFactory(
+ bundleFactory, jobBundleFactory, stepStateAndTimers))
.build());
}
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java
index 2599a89..6855fca 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/portable/ReferenceRunnerTest.java
@@ -18,6 +18,8 @@
package org.apache.beam.runners.direct.portable;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
@@ -25,19 +27,28 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.io.Serializable;
+import java.util.Arrays;
import java.util.Collections;
import java.util.Set;
import org.apache.beam.runners.core.construction.JavaReadViaImpulse;
+import org.apache.beam.runners.core.construction.PTransformMatchers;
import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.runners.PTransformOverride;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
@@ -65,12 +76,12 @@
ParDo.of(
new DoFn<Integer, KV<String, Integer>>() {
@ProcessElement
- public void process(@Element Integer e,
- MultiOutputReceiver r) {
+ public void process(@Element Integer e, MultiOutputReceiver r) {
for (int i = 0; i < e; i++) {
- r.get(food).outputWithTimestamp(
- KV.of("foo", e),
- new Instant(0).plus(Duration.standardHours(i)));
+ r.get(food)
+ .outputWithTimestamp(
+ KV.of("foo", e),
+ new Instant(0).plus(Duration.standardHours(i)));
}
r.get(originals).output(e);
}
@@ -86,10 +97,10 @@
ParDo.of(
new DoFn<KV<String, Iterable<Integer>>, KV<String, Set<Integer>>>() {
@ProcessElement
- public void process(@Element KV<String, Iterable<Integer>> e,
- OutputReceiver<KV<String, Set<Integer>>> r) {
- r.output(
- KV.of(e.getKey(), ImmutableSet.copyOf(e.getValue())));
+ public void process(
+ @Element KV<String, Iterable<Integer>> e,
+ OutputReceiver<KV<String, Set<Integer>>> r) {
+ r.output(KV.of(e.getKey(), ImmutableSet.copyOf(e.getValue())));
}
}));
@@ -136,4 +147,66 @@
PipelineOptionsTranslation.toProto(PipelineOptionsFactory.create()));
runner.execute();
}
+
+ static class PairStringWithIndexToLength extends DoFn<String, KV<String, Integer>> {
+ @ProcessElement
+ public ProcessContinuation process(ProcessContext c, OffsetRangeTracker tracker) {
+ for (long i = tracker.currentRestriction().getFrom(), numIterations = 0;
+ tracker.tryClaim(i);
+ ++i, ++numIterations) {
+ c.output(KV.of(c.element(), (int) i));
+ if (numIterations % 3 == 0) {
+ return resume();
+ }
+ }
+ return stop();
+ }
+
+ @GetInitialRestriction
+ public OffsetRange getInitialRange(String element) {
+ return new OffsetRange(0, element.length());
+ }
+
+ @SplitRestriction
+ public void splitRange(
+ String element, OffsetRange range, OutputReceiver<OffsetRange> receiver) {
+ long middle = (range.getFrom() + range.getTo()) / 2;
+ receiver.output(new OffsetRange(range.getFrom(), middle));
+ receiver.output(new OffsetRange(middle, range.getTo()));
+ }
+ }
+
+ @Test
+ public void testSDF() throws Exception {
+ Pipeline p = Pipeline.create();
+
+ PCollection<KV<String, Integer>> res =
+ p.apply(Create.of("a", "bb", "ccccc"))
+ .apply(ParDo.of(new PairStringWithIndexToLength()))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()));
+
+ PAssert.that(res)
+ .containsInAnyOrder(
+ Arrays.asList(
+ KV.of("a", 0),
+ KV.of("bb", 0),
+ KV.of("bb", 1),
+ KV.of("ccccc", 0),
+ KV.of("ccccc", 1),
+ KV.of("ccccc", 2),
+ KV.of("ccccc", 3),
+ KV.of("ccccc", 4)));
+
+ p.replaceAll(
+ Arrays.asList(
+ JavaReadViaImpulse.boundedOverride(),
+ PTransformOverride.of(
+ PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())));
+
+ ReferenceRunner runner =
+ ReferenceRunner.forInProcessPipeline(
+ PipelineTranslation.toProto(p),
+ PipelineOptionsTranslation.toProto(PipelineOptionsFactory.create()));
+ runner.execute();
+ }
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
index df008a8..be9abc3 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
@@ -60,14 +60,14 @@
*/
public class SplittableDoFnOperator<
InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
- extends DoFnOperator<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> {
+ extends DoFnOperator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
private transient ScheduledExecutorService executorService;
public SplittableDoFnOperator(
- DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> doFn,
+ DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> doFn,
String stepName,
- Coder<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>> inputCoder,
+ Coder<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>> inputCoder,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
OutputManagerFactory<OutputT> outputManagerFactory,
@@ -76,7 +76,7 @@
Collection<PCollectionView<?>> sideInputs,
PipelineOptions options,
Coder<?> keyCoder,
- KeySelector<WindowedValue<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, ?> keySelector) {
+ KeySelector<WindowedValue<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, ?> keySelector) {
super(
doFn,
stepName,
@@ -94,8 +94,8 @@
@Override
protected DoFnRunner<
- KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> createWrappingDoFnRunner(
- DoFnRunner<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> wrappedRunner) {
+ KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> createWrappingDoFnRunner(
+ DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> wrappedRunner) {
// don't wrap in anything because we don't need state cleanup because ProcessFn does
// all that
return wrappedRunner;
@@ -162,7 +162,7 @@
doFnRunner.processElement(
WindowedValue.valueInGlobalWindow(
KeyedWorkItems.timersWorkItem(
- (String) keyedStateInternals.getKey(),
+ (byte[]) keyedStateInternals.getKey(),
Collections.singletonList(timer.getNamespace()))));
}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java
new file mode 100644
index 0000000..9af2605
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.splittabledofn;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.List;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaces;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternals.TimerData;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * Helper class for feeding element/restricton pairs into a {@link
+ * PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} transform, implementing checkpointing
+ * only, by using state and timers for storing the last element/restriction pair, similarly to
+ * {@link org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessFn} but in a portable
+ * fashion.
+ */
+public class SDFFeederViaStateAndTimers<InputT, RestrictionT> {
+ private final Coder<BoundedWindow> windowCoder;
+ private final Coder<WindowedValue<KV<InputT, RestrictionT>>> elementRestrictionWireCoder;
+
+ private final StateInternals stateInternals;
+ private final TimerInternals timerInternals;
+
+ private StateNamespace stateNamespace;
+
+ private final StateTag<ValueState<WindowedValue<KV<InputT, RestrictionT>>>> seedTag;
+ private ValueState<WindowedValue<KV<InputT, RestrictionT>>> seedState;
+
+ private final StateTag<ValueState<RestrictionT>> restrictionTag;
+ private ValueState<RestrictionT> restrictionState;
+
+ private StateTag<WatermarkHoldState> watermarkHoldTag =
+ StateTags.makeSystemTagInternal(
+ StateTags.<GlobalWindow>watermarkStateInternal("hold", TimestampCombiner.LATEST));
+ private WatermarkHoldState holdState;
+
+ private Instant inputTimestamp;
+ private BundleSplit split;
+
+ /** Initializes the feeder. */
+ public SDFFeederViaStateAndTimers(
+ StateInternals stateInternals,
+ TimerInternals timerInternals,
+ Coder<InputT> elementWireCoder,
+ Coder<RestrictionT> restrictionWireCoder,
+ Coder<BoundedWindow> windowCoder) {
+ this.stateInternals = stateInternals;
+ this.timerInternals = timerInternals;
+ this.windowCoder = windowCoder;
+ this.elementRestrictionWireCoder =
+ FullWindowedValueCoder.of(KvCoder.of(elementWireCoder, restrictionWireCoder), windowCoder);
+ this.seedTag = StateTags.value("seed", elementRestrictionWireCoder);
+ this.restrictionTag = StateTags.value("restriction", restrictionWireCoder);
+ }
+
+ /** Passes the initial element/restriction pair. */
+ public void seed(WindowedValue<KV<InputT, RestrictionT>> elementRestriction) {
+ initState(
+ StateNamespaces.window(
+ windowCoder, Iterables.getOnlyElement(elementRestriction.getWindows())));
+ seedState.write(elementRestriction);
+ inputTimestamp = elementRestriction.getTimestamp();
+ }
+
+ /**
+ * Resumes from a timer and returns the current element/restriction pair (with an up-to-date value
+ * of the restriction).
+ */
+ public WindowedValue<KV<InputT, RestrictionT>> resume(TimerData timer) {
+ initState(timer.getNamespace());
+ WindowedValue<KV<InputT, RestrictionT>> seed = seedState.read();
+ inputTimestamp = seed.getTimestamp();
+ return seed.withValue(KV.of(seed.getValue().getKey(), restrictionState.read()));
+ }
+
+ /**
+ * Commits the state and timers: clears both if no checkpoint happened, or adjusts the restriction
+ * and sets a wake-up timer if a checkpoint happened.
+ */
+ public void commit() throws IOException {
+ if (split == null) {
+ // No split - the call terminated.
+ seedState.clear();
+ restrictionState.clear();
+ holdState.clear();
+ return;
+ }
+
+ // For now can only happen on the first instruction which is SPLITTABLE_PROCESS_ELEMENTS.
+ List<DelayedApplication> residuals = split.getResidualRootsList();
+ checkArgument(residuals.size() == 1, "More than 1 residual is unsupported for now");
+ DelayedApplication residual = residuals.get(0);
+
+ ByteString encodedResidual = residual.getApplication().getElement();
+ WindowedValue<KV<InputT, RestrictionT>> decodedResidual =
+ CoderUtils.decodeFromByteArray(elementRestrictionWireCoder, encodedResidual.toByteArray());
+
+ restrictionState.write(decodedResidual.getValue().getValue());
+
+ Instant watermarkHold =
+ residual.getApplication().getOutputWatermarksMap().isEmpty()
+ ? inputTimestamp
+ : new Instant(
+ Iterables.getOnlyElement(
+ residual.getApplication().getOutputWatermarksMap().values()));
+ checkArgument(
+ !watermarkHold.isBefore(inputTimestamp),
+ "Watermark hold %s can not be before input timestamp %s",
+ watermarkHold,
+ inputTimestamp);
+ holdState.add(watermarkHold);
+
+ Duration resumeDelay = new Duration((long) (1000L * residual.getDelaySec()));
+ Instant wakeupTime = timerInternals.currentProcessingTime().plus(resumeDelay);
+
+ // Set a timer to continue processing this element.
+ timerInternals.setTimer(
+ stateNamespace, "sdfContinuation", wakeupTime, TimeDomain.PROCESSING_TIME);
+ }
+
+ /** Signals that a split happened. */
+ public void split(BundleSplit split) {
+ checkState(
+ this.split == null,
+ "At most 1 split supported, however got new split %s in addition to existing %s",
+ split,
+ this.split);
+ this.split = split;
+ }
+
+ private void initState(StateNamespace ns) {
+ stateNamespace = ns;
+ seedState = stateInternals.state(ns, seedTag);
+ restrictionState = stateInternals.state(ns, restrictionTag);
+ holdState = stateInternals.state(ns, watermarkHoldTag);
+ }
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/package-info.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/package-info.java
new file mode 100644
index 0000000..e50f14d
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Utilities for a Beam runner to interact with a remotely running splittable DoFn. */
+package org.apache.beam.runners.fnexecution.splittabledofn;
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index db25581..b2f8195 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -28,6 +28,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -93,7 +94,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
.setPrimitiveTransformReference(pTransformId)
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
index 34d704d..a4f6e54 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
@@ -27,6 +27,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -84,7 +85,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
.setPrimitiveTransformReference(pTransformId)
.setName(getOnlyElement(pTransform.getInputsMap().keySet()))
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
index 5df178f..185fc11 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
@@ -28,6 +28,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.control.ProcessBundleHandler;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -81,7 +82,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) {
ImmutableList.Builder<FnDataReceiver<WindowedValue<?>>> consumers = ImmutableList.builder();
for (String pCollectionId : pTransform.getOutputsMap().values()) {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
new file mode 100644
index 0000000..9cbe34c
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Sets;
+import java.io.IOException;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.harness.state.SideInputSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/** A {@link PTransformRunnerFactory} for transforms invoking a {@link DoFn}. */
+abstract class DoFnPTransformRunnerFactory<
+ TransformInputT,
+ FnInputT,
+ OutputT,
+ RunnerT extends DoFnPTransformRunnerFactory.DoFnPTransformRunner<TransformInputT>>
+ implements PTransformRunnerFactory<RunnerT> {
+ interface DoFnPTransformRunner<T> {
+ void startBundle() throws Exception;
+
+ void processElement(WindowedValue<T> input) throws Exception;
+
+ void finishBundle() throws Exception;
+ }
+
+ @Override
+ public final RunnerT createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ BeamFnStateClient beamFnStateClient,
+ String ptransformId,
+ PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) {
+ Context<FnInputT, OutputT> context =
+ new Context<>(
+ pipelineOptions,
+ beamFnStateClient,
+ ptransformId,
+ pTransform,
+ processBundleInstructionId,
+ pCollections,
+ coders,
+ windowingStrategies,
+ pCollectionIdsToConsumers,
+ splitListener);
+
+ RunnerT runner = createRunner(context);
+
+ // Register the appropriate handlers.
+ addStartFunction.accept(runner::startBundle);
+ Iterable<String> mainInput =
+ Sets.difference(
+ pTransform.getInputsMap().keySet(), context.parDoPayload.getSideInputsMap().keySet());
+ for (String localInputName : mainInput) {
+ pCollectionIdsToConsumers.put(
+ pTransform.getInputsOrThrow(localInputName),
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<TransformInputT>>) runner::processElement);
+ }
+ addFinishFunction.accept(runner::finishBundle);
+ return runner;
+ }
+
+ abstract RunnerT createRunner(Context<FnInputT, OutputT> context);
+
+ static class Context<InputT, OutputT> {
+ final PipelineOptions pipelineOptions;
+ final BeamFnStateClient beamFnStateClient;
+ final String ptransformId;
+ final PTransform pTransform;
+ final Supplier<String> processBundleInstructionId;
+ final RehydratedComponents rehydratedComponents;
+ final DoFn<InputT, OutputT> doFn;
+ final TupleTag<OutputT> mainOutputTag;
+ final Coder<?> inputCoder;
+ final Coder<?> keyCoder;
+ final Coder<? extends BoundedWindow> windowCoder;
+ final WindowingStrategy<InputT, ?> windowingStrategy;
+ final Map<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap;
+ final ParDoPayload parDoPayload;
+ final ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer;
+ final BundleSplitListener splitListener;
+
+ Context(
+ PipelineOptions pipelineOptions,
+ BeamFnStateClient beamFnStateClient,
+ String ptransformId,
+ PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+ BundleSplitListener splitListener) {
+ this.pipelineOptions = pipelineOptions;
+ this.beamFnStateClient = beamFnStateClient;
+ this.ptransformId = ptransformId;
+ this.pTransform = pTransform;
+ this.processBundleInstructionId = processBundleInstructionId;
+ ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMapBuilder =
+ ImmutableMap.builder();
+ try {
+ rehydratedComponents =
+ RehydratedComponents.forComponents(
+ RunnerApi.Components.newBuilder()
+ .putAllCoders(coders)
+ .putAllWindowingStrategies(windowingStrategies)
+ .build());
+ parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
+ doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
+ mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
+ String mainInputTag =
+ Iterables.getOnlyElement(
+ Sets.difference(
+ pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
+ PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
+ inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
+ if (inputCoder instanceof KvCoder
+ // TODO: Stop passing windowed value coders within PCollections.
+ || (inputCoder instanceof WindowedValue.WindowedValueCoder
+ && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) {
+ this.keyCoder =
+ inputCoder instanceof WindowedValueCoder
+ ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
+ : ((KvCoder) inputCoder).getKeyCoder();
+ } else {
+ this.keyCoder = null;
+ }
+
+ windowingStrategy =
+ (WindowingStrategy)
+ rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());
+ windowCoder = windowingStrategy.getWindowFn().windowCoder();
+
+ // Build the map from tag id to side input specification
+ for (Map.Entry<String, RunnerApi.SideInput> entry :
+ parDoPayload.getSideInputsMap().entrySet()) {
+ String sideInputTag = entry.getKey();
+ RunnerApi.SideInput sideInput = entry.getValue();
+ checkArgument(
+ Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+ sideInput.getAccessPattern().getUrn()),
+ "This SDK is only capable of dealing with %s materializations "
+ + "but was asked to handle %s for PCollectionView with tag %s.",
+ Materializations.MULTIMAP_MATERIALIZATION_URN,
+ sideInput.getAccessPattern().getUrn(),
+ sideInputTag);
+
+ PCollection sideInputPCollection =
+ pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
+ WindowingStrategy sideInputWindowingStrategy =
+ rehydratedComponents.getWindowingStrategy(
+ sideInputPCollection.getWindowingStrategyId());
+ tagToSideInputSpecMapBuilder.put(
+ new TupleTag<>(entry.getKey()),
+ SideInputSpec.create(
+ rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
+ sideInputWindowingStrategy.getWindowFn().windowCoder(),
+ PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
+ PCollectionViewTranslation.windowMappingFnFromProto(
+ entry.getValue().getWindowMappingFn())));
+ }
+ } catch (IOException exn) {
+ throw new IllegalArgumentException("Malformed ParDoPayload", exn);
+ }
+
+ ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
+ tagToConsumerBuilder = ImmutableListMultimap.builder();
+ for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
+ tagToConsumerBuilder.putAll(
+ new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
+ }
+ tagToConsumer = tagToConsumerBuilder.build();
+ tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build();
+ this.splitListener = splitListener;
+ }
+ }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
index 52d9435..6ac860e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
@@ -27,6 +27,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -69,7 +70,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException {
// Give each input a MultiplexingFnDataReceiver to all outputs of the flatten.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 9778aac..4cfb5d9 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -17,75 +17,28 @@
*/
package org.apache.beam.fn.harness;
-import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Preconditions.checkState;
import com.google.auto.service.AutoService;
-import com.google.auto.value.AutoValue;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.ListMultimap;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.Sets;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.InvalidProtocolBufferException;
-import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collection;
-import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import java.util.function.Supplier;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.state.BagUserState;
-import org.apache.beam.fn.harness.state.BeamFnStateClient;
-import org.apache.beam.fn.harness.state.MultimapSideInput;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
-import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.fn.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.state.BagState;
-import org.apache.beam.sdk.state.CombiningState;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
-import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.StateBinder;
-import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
-import org.apache.beam.sdk.state.ValueState;
-import org.apache.beam.sdk.state.WatermarkHoldState;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
-import org.apache.beam.sdk.transforms.DoFn.StartBundleContext;
import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
-import org.apache.beam.sdk.transforms.Materializations;
-import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -94,399 +47,145 @@
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
-import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
-import org.apache.beam.sdk.util.CombineFnUtil;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Instant;
/**
- * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers
- * of abstraction caused by StateInternals/TimerInternals since they model state and timer
- * concepts differently.
+ * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers of
+ * abstraction caused by StateInternals/TimerInternals since they model state and timer concepts
+ * differently.
*/
-public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
-
- private final ProcessBundleContext processContext;
- private final FinishBundleContext finishBundleContext;
- private StartBundleContext startBundleContext;
-
- /**
- * A registrar which provides a factory to handle Java {@link DoFn}s.
- */
+public class FnApiDoFnRunner<InputT, OutputT>
+ implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<InputT> {
+ /** A registrar which provides a factory to handle Java {@link DoFn}s. */
@AutoService(PTransformRunnerFactory.Registrar.class)
- public static class Registrar implements
- PTransformRunnerFactory.Registrar {
-
+ public static class Registrar implements PTransformRunnerFactory.Registrar {
@Override
public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
- return ImmutableMap.of(
- PTransformTranslation.PAR_DO_TRANSFORM_URN, new NewFactory(),
- ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory());
+ return ImmutableMap.of(PTransformTranslation.PAR_DO_TRANSFORM_URN, new Factory());
}
}
- /** A factory for {@link FnApiDoFnRunner}. */
static class Factory<InputT, OutputT>
- implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
+ extends DoFnPTransformRunnerFactory<
+ InputT, InputT, OutputT, FnApiDoFnRunner<InputT, OutputT>> {
@Override
- public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
- PipelineOptions pipelineOptions,
- BeamFnDataClient beamFnDataClient,
- BeamFnStateClient beamFnStateClient,
- String ptransformId,
- PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, PCollection> pCollections,
- Map<String, RunnerApi.Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
- Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
-
- // For every output PCollection, create a map from output name to Consumer
- ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
- tagToOutputMapBuilder = ImmutableListMultimap.builder();
- for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
- tagToOutputMapBuilder.putAll(
- new TupleTag<>(entry.getKey()),
- pCollectionIdsToConsumers.get(entry.getValue()));
- }
- ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap =
- tagToOutputMapBuilder.build();
-
- // Get the DoFnInfo from the serialized blob.
- ByteString serializedFn = pTransform.getSpec().getPayload();
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray(
- serializedFn.toByteArray(), "DoFnInfo");
-
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnRunner<InputT, OutputT> runner =
- new FnApiDoFnRunner<>(
- pipelineOptions,
- beamFnStateClient,
- ptransformId,
- processBundleInstructionId,
- doFnInfo.getDoFn(),
- doFnInfo.getInputCoder(),
- (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
- (Collection) tagToOutputMap.get(doFnInfo.getMainOutput()),
- tagToOutputMap,
- ImmutableMap.of(),
- doFnInfo.getWindowingStrategy());
-
- registerHandlers(
- runner,
- pTransform,
- ImmutableSet.of(),
- addStartFunction,
- addFinishFunction,
- pCollectionIdsToConsumers);
- return runner;
+ public FnApiDoFnRunner<InputT, OutputT> createRunner(Context<InputT, OutputT> context) {
+ return new FnApiDoFnRunner<>(context);
}
}
- static class NewFactory<InputT, OutputT>
- implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
- @Override
- public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
- PipelineOptions pipelineOptions,
- BeamFnDataClient beamFnDataClient,
- BeamFnStateClient beamFnStateClient,
- String ptransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Map<String, RunnerApi.Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
- Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
-
- DoFn<InputT, OutputT> doFn;
- TupleTag<OutputT> mainOutputTag;
- Coder<InputT> inputCoder;
- WindowingStrategy<InputT, ?> windowingStrategy;
-
- ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap =
- ImmutableMap.builder();
- ParDoPayload parDoPayload;
- try {
- RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(
- RunnerApi.Components.newBuilder()
- .putAllCoders(coders).putAllWindowingStrategies(windowingStrategies).build());
- parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
- doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
- mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
- String mainInputTag = Iterables.getOnlyElement(Sets.difference(
- pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
- RunnerApi.PCollection mainInput =
- pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
- inputCoder = (Coder<InputT>) rehydratedComponents.getCoder(
- mainInput.getCoderId());
- windowingStrategy = (WindowingStrategy) rehydratedComponents.getWindowingStrategy(
- mainInput.getWindowingStrategyId());
-
- // Build the map from tag id to side input specification
- for (Map.Entry<String, RunnerApi.SideInput> entry
- : parDoPayload.getSideInputsMap().entrySet()) {
- String sideInputTag = entry.getKey();
- RunnerApi.SideInput sideInput = entry.getValue();
- checkArgument(
- Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
- sideInput.getAccessPattern().getUrn()),
- "This SDK is only capable of dealing with %s materializations "
- + "but was asked to handle %s for PCollectionView with tag %s.",
- Materializations.MULTIMAP_MATERIALIZATION_URN,
- sideInput.getAccessPattern().getUrn(),
- sideInputTag);
-
- RunnerApi.PCollection sideInputPCollection =
- pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
- WindowingStrategy sideInputWindowingStrategy =
- rehydratedComponents.getWindowingStrategy(
- sideInputPCollection.getWindowingStrategyId());
- tagToSideInputSpecMap.put(
- new TupleTag<>(entry.getKey()),
- SideInputSpec.create(
- rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
- sideInputWindowingStrategy.getWindowFn().windowCoder(),
- PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
- PCollectionViewTranslation.windowMappingFnFromProto(
- entry.getValue().getWindowMappingFn())));
- }
- } catch (InvalidProtocolBufferException exn) {
- throw new IllegalArgumentException("Malformed ParDoPayload", exn);
- } catch (IOException exn) {
- throw new IllegalArgumentException("Malformed ParDoPayload", exn);
- }
-
- ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
- tagToConsumerBuilder = ImmutableListMultimap.builder();
- for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
- tagToConsumerBuilder.putAll(
- new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
- }
- ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer =
- tagToConsumerBuilder.build();
-
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>(
- pipelineOptions,
- beamFnStateClient,
- ptransformId,
- processBundleInstructionId,
- doFn,
- inputCoder,
- (Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
- tagToConsumer.get(mainOutputTag),
- tagToConsumer,
- tagToSideInputSpecMap.build(),
- windowingStrategy);
- registerHandlers(
- runner,
- pTransform,
- parDoPayload.getSideInputsMap().keySet(),
- addStartFunction,
- addFinishFunction,
- pCollectionIdsToConsumers);
- return runner;
- }
- }
-
- private static <InputT, OutputT> void registerHandlers(
- DoFnRunner<InputT, OutputT> runner,
- RunnerApi.PTransform pTransform,
- Set<String> sideInputLocalNames,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction,
- Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers) {
- // Register the appropriate handlers.
- addStartFunction.accept(runner::startBundle);
- for (String localInputName
- : Sets.difference(pTransform.getInputsMap().keySet(), sideInputLocalNames)) {
- pCollectionIdsToConsumers.put(
- pTransform.getInputsOrThrow(localInputName),
- (FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>) runner::processElement);
- }
- addFinishFunction.accept(runner::finishBundle);
- }
-
//////////////////////////////////////////////////////////////////////////////////////////////////
- private final PipelineOptions pipelineOptions;
- private final BeamFnStateClient beamFnStateClient;
- private final String ptransformId;
- private final Supplier<String> processBundleInstructionId;
- private final DoFn<InputT, OutputT> doFn;
- private final Coder<InputT> inputCoder;
+ private final Context<InputT, OutputT> context;
private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
- private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap;
- private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
- private final Map<StateKey, Object> stateKeyObjectCache;
- private final WindowingStrategy windowingStrategy;
+ private FnApiStateAccessor stateAccessor;
private final DoFnSignature doFnSignature;
private final DoFnInvoker<InputT, OutputT> doFnInvoker;
- private final StateBinder stateBinder;
- private final Collection<ThrowingRunnable> stateFinalizers;
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and is null otherwise.
- */
+ private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+ private final ProcessBundleContext processContext;
+ private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+ /** Only valid during {@link #processElement}, null otherwise. */
private WindowedValue<InputT> currentElement;
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and is null otherwise.
- */
private BoundedWindow currentWindow;
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and only when processing a {@link KV} and is null otherwise.
- */
- private ByteString encodedCurrentKey;
-
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and is null otherwise.
- */
- private ByteString encodedCurrentWindow;
-
- FnApiDoFnRunner(
- PipelineOptions pipelineOptions,
- BeamFnStateClient beamFnStateClient,
- String ptransformId,
- Supplier<String> processBundleInstructionId,
- DoFn<InputT, OutputT> doFn,
- Coder<InputT> inputCoder,
- Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
- Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
- Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
- WindowingStrategy windowingStrategy) {
- this.pipelineOptions = pipelineOptions;
- this.beamFnStateClient = beamFnStateClient;
- this.ptransformId = ptransformId;
- this.processBundleInstructionId = processBundleInstructionId;
- this.doFn = doFn;
- this.inputCoder = inputCoder;
- this.mainOutputConsumers = mainOutputConsumers;
- this.outputMap = outputMap;
- this.sideInputSpecMap = sideInputSpecMap;
- this.stateKeyObjectCache = new HashMap<>();
- this.windowingStrategy = windowingStrategy;
- this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
- this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
+ FnApiDoFnRunner(Context<InputT, OutputT> context) {
+ this.context = context;
+ this.mainOutputConsumers =
+ (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+ (Collection) context.tagToConsumer.get(context.mainOutputTag);
+ this.doFnSignature = DoFnSignatures.signatureForDoFn(context.doFn);
+ this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
this.doFnInvoker.invokeSetup();
- this.stateBinder = new BeamFnStateBinder();
- this.stateFinalizers = new ArrayList<>();
- this.startBundleContext = doFn.new StartBundleContext() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
- };
+ this.startBundleContext =
+ this.context.doFn.new StartBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
+ };
this.processContext = new ProcessBundleContext();
- finishBundleContext = doFn.new FinishBundleContext() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
+ finishBundleContext =
+ this.context.doFn.new FinishBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
- @Override
- public void output(OutputT output, Instant timestamp, BoundedWindow window) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
- }
+ @Override
+ public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+ outputTo(
+ mainOutputConsumers,
+ WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+ }
- @Override
- public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
- Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
- if (consumers == null) {
- throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
- }
- outputTo(consumers,
- WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
- }
- };
+ @Override
+ public <T> void output(
+ TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
+ if (consumers == null) {
+ throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
+ }
+ outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+ }
+ };
}
@Override
public void startBundle() {
+ this.stateAccessor =
+ new FnApiStateAccessor(
+ context.pipelineOptions,
+ context.ptransformId,
+ context.processBundleInstructionId,
+ context.tagToSideInputSpecMap,
+ context.beamFnStateClient,
+ context.keyCoder,
+ (Coder<BoundedWindow>) context.windowCoder);
+
doFnInvoker.invokeStartBundle(startBundleContext);
}
@Override
public void processElement(WindowedValue<InputT> elem) {
currentElement = elem;
+ stateAccessor.setCurrentElement(elem);
try {
Iterator<BoundedWindow> windowIterator =
(Iterator<BoundedWindow>) elem.getWindows().iterator();
while (windowIterator.hasNext()) {
currentWindow = windowIterator.next();
+ stateAccessor.setCurrentWindow(currentWindow);
doFnInvoker.invokeProcessElement(processContext);
}
} finally {
currentElement = null;
currentWindow = null;
- encodedCurrentKey = null;
- encodedCurrentWindow = null;
+ stateAccessor.setCurrentElement(null);
+ stateAccessor.setCurrentWindow(null);
}
}
@Override
- public void onTimer(
- String timerId,
- BoundedWindow window,
- Instant timestamp,
- TimeDomain timeDomain) {
- throw new UnsupportedOperationException("TODO: Add support for timers");
- }
-
- @Override
public void finishBundle() {
doFnInvoker.invokeFinishBundle(finishBundleContext);
- // Persist all dirty state cells
- try {
- for (ThrowingRunnable runnable : stateFinalizers) {
- runnable.run();
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new IllegalStateException(e);
- } catch (Exception e) {
- throw new IllegalStateException(e);
- }
-
// TODO: Support caching state data across bundle boundaries.
- stateKeyObjectCache.clear();
+ this.stateAccessor.finalizeState();
+ this.stateAccessor = null;
}
- @Override
- public DoFn<InputT, OutputT> getFn() {
- return doFnInvoker.getFn();
- }
-
- /**
- * Outputs the given element to the specified set of consumers wrapping any exceptions.
- */
+ /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
private <T> void outputTo(
- Collection<FnDataReceiver<WindowedValue<T>>> consumers,
- WindowedValue<T> output) {
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
try {
for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
consumer.accept(output);
@@ -499,12 +198,11 @@
/**
* Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}.
*/
- private class ProcessBundleContext
- extends DoFn<InputT, OutputT>.ProcessContext
+ private class ProcessBundleContext extends DoFn<InputT, OutputT>.ProcessContext
implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
private ProcessBundleContext() {
- doFn.super();
+ context.doFn.super();
}
@Override
@@ -577,11 +275,11 @@
checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
StateSpec<?> spec;
try {
- spec = (StateSpec<?>) stateDeclaration.field().get(doFn);
+ spec = (StateSpec<?>) stateDeclaration.field().get(context.doFn);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
- return spec.bind(stateId, stateBinder);
+ return spec.bind(stateId, stateAccessor);
}
@Override
@@ -591,60 +289,51 @@
@Override
public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
+ return context.pipelineOptions;
}
@Override
public PipelineOptions pipelineOptions() {
- return pipelineOptions;
+ return context.pipelineOptions;
}
@Override
public void output(OutputT output) {
- outputTo(mainOutputConsumers,
+ outputTo(
+ mainOutputConsumers,
WindowedValue.of(
- output,
- currentElement.getTimestamp(),
- currentWindow,
- currentElement.getPane()));
+ output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
}
@Override
public void outputWithTimestamp(OutputT output, Instant timestamp) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(
- output,
- timestamp,
- currentWindow,
- currentElement.getPane()));
+ outputTo(
+ mainOutputConsumers,
+ WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
}
@Override
public <T> void output(TupleTag<T> tag, T output) {
- Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
- outputTo(consumers,
+ outputTo(
+ consumers,
WindowedValue.of(
- output,
- currentElement.getTimestamp(),
- currentWindow,
- currentElement.getPane()));
+ output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
}
@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
- Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
- outputTo(consumers,
- WindowedValue.of(
- output,
- timestamp,
- currentWindow,
- currentElement.getPane()));
+ outputTo(
+ consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
}
@Override
@@ -654,7 +343,7 @@
@Override
public <T> T sideInput(PCollectionView<T> view) {
- return (T) bindSideInputView(view.getTagInternal());
+ return stateAccessor.get(view, currentWindow);
}
@Override
@@ -672,359 +361,4 @@
throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
}
}
-
- /**
- * A {@link StateBinder} that uses the Beam Fn State API to read and write user state.
- *
- * <p>TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that
- * {@link #bindWatermark} should never be implemented.
- */
- private class BeamFnStateBinder implements StateBinder {
- @Override
- public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
- return (ValueState<T>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- new Function<StateKey, Object>() {
- @Override
- public Object apply(StateKey key) {
- return new ValueState<T>() {
- private final BagUserState<T> impl = createBagUserState(id, coder);
-
- @Override
- public void clear() {
- impl.clear();
- }
-
- @Override
- public void write(T input) {
- impl.clear();
- impl.append(input);
- }
-
- @Override
- public T read() {
- Iterator<T> value = impl.get().iterator();
- if (value.hasNext()) {
- return value.next();
- } else {
- return null;
- }
- }
-
- @Override
- public ValueState<T> readLater() {
- // TODO: Support prefetching.
- return this;
- }
- };
- }
- });
- }
-
- @Override
- public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
- return (BagState<T>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- new Function<StateKey, Object>() {
- @Override
- public Object apply(StateKey key) {
- return new BagState<T>() {
- private final BagUserState<T> impl = createBagUserState(id, elemCoder);
-
- @Override
- public void add(T value) {
- impl.append(value);
- }
-
- @Override
- public ReadableState<Boolean> isEmpty() {
- return ReadableStates.immediate(!impl.get().iterator().hasNext());
- }
-
- @Override
- public Iterable<T> read() {
- return impl.get();
- }
-
- @Override
- public BagState<T> readLater() {
- // TODO: Support prefetching.
- return this;
- }
-
- @Override
- public void clear() {
- impl.clear();
- }
- };
- }
- });
- }
-
- @Override
- public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
- throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
- }
-
- @Override
- public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(String id,
- StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder,
- Coder<ValueT> mapValueCoder) {
- throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
- }
-
- @Override
- public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
- String id,
- StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec, Coder<AccumT> accumCoder,
- CombineFn<ElementT, AccumT, ResultT> combineFn) {
- return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- new Function<StateKey, Object>() {
- @Override
- public Object apply(StateKey key) {
- // TODO: Support squashing accumulators depending on whether we know of all
- // remote accumulators and local accumulators or just local accumulators.
- return new CombiningState<ElementT, AccumT, ResultT>() {
- private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
-
- @Override
- public AccumT getAccum() {
- Iterator<AccumT> iterator = impl.get().iterator();
- if (iterator.hasNext()) {
- return iterator.next();
- }
- return combineFn.createAccumulator();
- }
-
- @Override
- public void addAccum(AccumT accum) {
- Iterator<AccumT> iterator = impl.get().iterator();
-
- // Only merge if there was a prior value
- if (iterator.hasNext()) {
- accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
- // Since there was a prior value, we need to clear.
- impl.clear();
- }
-
- impl.append(accum);
- }
-
- @Override
- public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
- return combineFn.mergeAccumulators(accumulators);
- }
-
- @Override
- public CombiningState<ElementT, AccumT, ResultT> readLater() {
- return this;
- }
-
- @Override
- public ResultT read() {
- Iterator<AccumT> iterator = impl.get().iterator();
- if (iterator.hasNext()) {
- return combineFn.extractOutput(iterator.next());
- }
- return combineFn.defaultValue();
- }
-
- @Override
- public void add(ElementT value) {
- AccumT newAccumulator = combineFn.addInput(getAccum(), value);
- impl.clear();
- impl.append(newAccumulator);
- }
-
- @Override
- public ReadableState<Boolean> isEmpty() {
- return ReadableStates.immediate(!impl.get().iterator().hasNext());
- }
-
- @Override
- public void clear() {
- impl.clear();
- }
- };
- }
- });
- }
-
- @Override
- public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT>
- bindCombiningWithContext(
- String id,
- StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
- Coder<AccumT> accumCoder,
- CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
- return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- key -> bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn,
- new StateContext<BoundedWindow>() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public <T> T sideInput(PCollectionView<T> view) {
- return (T) bindSideInputView(view.getTagInternal());
- }
-
- @Override
- public BoundedWindow window() {
- return currentWindow;
- }
- })));
- }
-
- /**
- * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing
- * and is waiting on resolution of BEAM-2535.
- */
- @Override
- @Deprecated
- public WatermarkHoldState bindWatermark(String id, StateSpec<WatermarkHoldState> spec,
- TimestampCombiner timestampCombiner) {
- throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
- }
-
- private <T> BagUserState<T> createBagUserState(
- String stateId, Coder<T> valueCoder) {
- BagUserState<T> rval = new BagUserState<>(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- stateId,
- encodedCurrentWindow,
- encodedCurrentKey,
- valueCoder);
- stateFinalizers.add(rval::asyncClose);
- return rval;
- }
- }
-
- private StateKey createBagUserStateKey(String stateId) {
- cacheEncodedKeyAndWindowForKeyedContext();
- StateKey.Builder builder = StateKey.newBuilder();
- builder.getBagUserStateBuilder()
- .setWindow(encodedCurrentWindow)
- .setKey(encodedCurrentKey)
- .setPtransformId(ptransformId)
- .setUserStateId(stateId);
- return builder.build();
- }
-
- /**
- * Memoizes an encoded key and window for the current element being processed saving on the
- * encoding cost of the key and window across multiple state cells for the lifetime of
- * {@link #processElement}.
- *
- * <p>This should only be called during {@link #processElement}.
- */
- private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
- if (encodedCurrentKey == null) {
- checkState(currentElement.getValue() instanceof KV,
- "Accessing state in unkeyed context. Current element is not a KV: %s.",
- currentElement);
- checkState(
- // TODO: Stop passing windowed value coders within PCollections.
- inputCoder instanceof KvCoder
- || (inputCoder instanceof WindowedValueCoder
- && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder)),
- "Accessing state in unkeyed context. Keyed coder expected but found %s.",
- inputCoder);
-
- ByteString.Output encodedKeyOut = ByteString.newOutput();
-
- Coder<K> keyCoder = inputCoder instanceof WindowedValueCoder
- ? ((KvCoder<K, ?>) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
- : ((KvCoder<K, ?>) inputCoder).getKeyCoder();
- try {
- keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- encodedCurrentKey = encodedKeyOut.toByteString();
- }
-
- if (encodedCurrentWindow == null) {
- ByteString.Output encodedWindowOut = ByteString.newOutput();
- try {
- windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut);
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- encodedCurrentWindow = encodedWindowOut.toByteString();
- }
- }
-
- /**
- * A specification for side inputs containing a value {@link Coder},
- * the window {@link Coder}, {@link ViewFn}, and the {@link WindowMappingFn}.
- * @param <W>
- */
- @AutoValue
- abstract static class SideInputSpec<W extends BoundedWindow> {
- static <W extends BoundedWindow> SideInputSpec create(
- Coder<?> coder,
- Coder<W> windowCoder,
- ViewFn<?, ?> viewFn,
- WindowMappingFn<W> windowMappingFn) {
- return new AutoValue_FnApiDoFnRunner_SideInputSpec<>(
- coder, windowCoder, viewFn, windowMappingFn);
- }
-
- abstract Coder<?> getCoder();
-
- abstract Coder<W> getWindowCoder();
-
- abstract ViewFn<?, ?> getViewFn();
-
- abstract WindowMappingFn<W> getWindowMappingFn();
- }
-
- private <K, V> Object bindSideInputView(TupleTag<?> view) {
- SideInputSpec sideInputSpec = sideInputSpecMap.get(view);
- checkArgument(sideInputSpec != null,
- "Attempting to access unknown side input %s.",
- view);
- KvCoder<K, V> kvCoder = (KvCoder) sideInputSpec.getCoder();
-
- ByteString.Output encodedWindowOut = ByteString.newOutput();
- try {
- sideInputSpec.getWindowCoder().encode(
- sideInputSpec.getWindowMappingFn().getSideInputWindow(currentWindow), encodedWindowOut);
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- ByteString encodedWindow = encodedWindowOut.toByteString();
-
- StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
- cacheKeyBuilder.getMultimapSideInputBuilder()
- .setPtransformId(ptransformId)
- .setSideInputId(view.getId())
- .setWindow(encodedWindow);
- return stateKeyObjectCache.computeIfAbsent(
- cacheKeyBuilder.build(),
- key -> sideInputSpec.getViewFn().apply(createMultimapSideInput(
- view.getId(), encodedWindow, kvCoder.getKeyCoder(), kvCoder.getValueCoder())));
- }
-
- private <K, V> MultimapSideInput<K, V> createMultimapSideInput(
- String sideInputId,
- ByteString encodedWindow,
- Coder<K> keyCoder,
- Coder<V> valueCoder) {
-
- return new MultimapSideInput<>(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- sideInputId,
- encodedWindow,
- keyCoder,
- valueCoder);
- }
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
index 7819726..e2c61de 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
@@ -27,6 +27,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -109,7 +110,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException {
Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers =
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
index c130d4d..efd4a38 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
@@ -22,6 +22,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -37,7 +38,6 @@
* A factory able to instantiate an appropriate handler for a given PTransform.
*/
public interface PTransformRunnerFactory<T> {
-
/**
* Creates and returns a handler for a given PTransform. Note that the handler must support
* processing multiple bundles. The handler will be discarded if an error is thrown during element
@@ -60,6 +60,7 @@
* registered within this multimap.
* @param addStartFunction A consumer to register a start bundle handler with.
* @param addFinishFunction A consumer to register a finish bundle handler with.
+ * @param splitListener A listener to be invoked when the PTransform splits itself.
*/
T createRunnerForPTransform(
PipelineOptions pipelineOptions,
@@ -73,7 +74,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException;
/**
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
new file mode 100644
index 0000000..c8cac40
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+import org.apache.beam.runners.core.OutputAndTimeBoundedSplittableProcessElementInvoker;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.SplittableProcessElementInvoker;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/** Runs the {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} transform. */
+public class SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>
+ implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<KV<InputT, RestrictionT>> {
+ /** A registrar which provides a factory to handle Java {@link DoFn}s. */
+ @AutoService(PTransformRunnerFactory.Registrar.class)
+ public static class Registrar implements PTransformRunnerFactory.Registrar {
+ @Override
+ public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+ return ImmutableMap.of(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN, new Factory());
+ }
+ }
+
+ static class Factory<InputT, RestrictionT, OutputT>
+ extends DoFnPTransformRunnerFactory<
+ KV<InputT, RestrictionT>,
+ InputT,
+ OutputT,
+ SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>> {
+
+ @Override
+ SplittableProcessElementsRunner<InputT, RestrictionT, OutputT> createRunner(
+ Context<InputT, OutputT> context) {
+ Coder<WindowedValue<KV<InputT, RestrictionT>>> windowedCoder =
+ FullWindowedValueCoder.of(
+ (Coder<KV<InputT, RestrictionT>>) context.inputCoder, context.windowCoder);
+
+ return new SplittableProcessElementsRunner<>(
+ context,
+ windowedCoder,
+ (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+ (Collection) context.tagToConsumer.get(context.mainOutputTag),
+ Iterables.getOnlyElement(context.pTransform.getInputsMap().keySet()));
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ private final Context<InputT, OutputT> context;
+ private final String mainInputId;
+ private final Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder;
+ private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
+ private final DoFnInvoker<InputT, OutputT> doFnInvoker;
+ private final ScheduledExecutorService executor;
+
+ private FnApiStateAccessor stateAccessor;
+
+ private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+ private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+ SplittableProcessElementsRunner(
+ Context<InputT, OutputT> context,
+ Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder,
+ Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
+ String mainInputId) {
+ this.context = context;
+ this.mainInputId = mainInputId;
+ this.inputCoder = inputCoder;
+ this.mainOutputConsumers = mainOutputConsumers;
+ this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
+ this.doFnInvoker.invokeSetup();
+ this.executor = Executors.newSingleThreadScheduledExecutor();
+
+ this.startBundleContext =
+ context.doFn.new StartBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
+ };
+ this.finishBundleContext =
+ context.doFn.new FinishBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
+
+ @Override
+ public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> void output(
+ TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ @Override
+ public void startBundle() {
+ this.stateAccessor =
+ new FnApiStateAccessor(
+ context.pipelineOptions,
+ context.ptransformId,
+ context.processBundleInstructionId,
+ context.tagToSideInputSpecMap,
+ context.beamFnStateClient,
+ context.keyCoder,
+ (Coder<BoundedWindow>) context.windowCoder);
+ doFnInvoker.invokeStartBundle(startBundleContext);
+ }
+
+ @Override
+ public void processElement(WindowedValue<KV<InputT, RestrictionT>> elem) {
+ processElementTyped(elem);
+ }
+
+ private <PositionT, TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
+ void processElementTyped(WindowedValue<KV<InputT, RestrictionT>> elem) {
+ checkArgument(
+ elem.getWindows().size() == 1,
+ "SPLITTABLE_PROCESS_ELEMENTS expects its input to be in 1 window, but got %s windows",
+ elem.getWindows().size());
+ this.stateAccessor.setCurrentWindow(elem.getWindows().iterator().next());
+ WindowedValue<InputT> element = elem.withValue(elem.getValue().getKey());
+ TrackerT tracker = doFnInvoker.invokeNewTracker(elem.getValue().getValue());
+ OutputAndTimeBoundedSplittableProcessElementInvoker<
+ InputT, OutputT, RestrictionT, PositionT, TrackerT>
+ processElementInvoker =
+ new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
+ context.doFn,
+ context.pipelineOptions,
+ new OutputWindowedValue<OutputT>() {
+ @Override
+ public void outputWindowedValue(
+ OutputT output,
+ Instant timestamp,
+ Collection<? extends BoundedWindow> windows,
+ PaneInfo pane) {
+ outputTo(
+ mainOutputConsumers, WindowedValue.of(output, timestamp, windows, pane));
+ }
+
+ @Override
+ public <AdditionalOutputT> void outputWindowedValue(
+ TupleTag<AdditionalOutputT> tag,
+ AdditionalOutputT output,
+ Instant timestamp,
+ Collection<? extends BoundedWindow> windows,
+ PaneInfo pane) {
+ Collection<FnDataReceiver<WindowedValue<AdditionalOutputT>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
+ if (consumers == null) {
+ throw new IllegalArgumentException(
+ String.format("Unknown output tag %s", tag));
+ }
+ outputTo(consumers, WindowedValue.of(output, timestamp, windows, pane));
+ }
+ },
+ stateAccessor,
+ executor,
+ 10000,
+ Duration.standardSeconds(10));
+
+ SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, TrackerT>.Result result =
+ processElementInvoker.invokeProcessElement(doFnInvoker, element, tracker);
+ if (result.getContinuation().shouldResume()) {
+ WindowedValue<KV<InputT, RestrictionT>> primary =
+ element.withValue(KV.of(element.getValue(), tracker.currentRestriction()));
+ WindowedValue<KV<InputT, RestrictionT>> residual =
+ element.withValue(KV.of(element.getValue(), result.getResidualRestriction()));
+ ByteString.Output primaryBytes = ByteString.newOutput();
+ ByteString.Output residualBytes = ByteString.newOutput();
+ try {
+ inputCoder.encode(primary, primaryBytes);
+ inputCoder.encode(residual, residualBytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ Application primaryApplication =
+ Application.newBuilder()
+ .setPtransformId(context.ptransformId)
+ .setInputId(mainInputId)
+ .setElement(primaryBytes.toByteString())
+ .build();
+ Application residualApplication =
+ Application.newBuilder()
+ .setPtransformId(context.ptransformId)
+ .setInputId(mainInputId)
+ .setElement(residualBytes.toByteString())
+ .build();
+ context.splitListener.split(
+ ImmutableList.of(primaryApplication),
+ ImmutableList.of(
+ DelayedApplication.newBuilder()
+ .setApplication(residualApplication)
+ .setDelaySec(0.001 * result.getContinuation().resumeDelay().getMillis())
+ .build()));
+ }
+ }
+
+ @Override
+ public void finishBundle() {
+ doFnInvoker.invokeFinishBundle(finishBundleContext);
+ }
+
+ /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
+ private <T> void outputTo(
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
+ try {
+ for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
+ consumer.accept(output);
+ }
+ } catch (Throwable t) {
+ throw UserCodeException.wrap(t);
+ }
+ }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
new file mode 100644
index 0000000..db73447
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.control;
+
+import java.util.List;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+
+/**
+ * Listens to splits happening to a single bundle. See <a
+ * href="https://s.apache.org/beam-breaking-fusion">Breaking the Fusion Barrier</a> for a
+ * discussion of the design.
+ */
+public interface BundleSplitListener {
+ /**
+ * Signals that the current bundle should be split into the given set of primary and residual
+ * roots.
+ *
+ * <p>Primary roots are the new decomposition of the bundle's work into transform applications
+ * that have happened or will happen as part of this bundle (modulo future splits). Residual roots
+ * are a decomposition of work that has been given away by the bundle, so the runner must delegate
+ * it for someone else to execute.
+ */
+ void split(List<Application> primaryRoots, List<DelayedApplication> residualRoots);
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 64d713e..ee907a8 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -46,8 +46,12 @@
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
@@ -144,7 +148,8 @@
SetMultimap<String, String> pCollectionIdsToConsumingPTransforms,
ListMultimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException {
// Recursively ensure that all consumers of the output PCollection have been created.
@@ -166,7 +171,8 @@
pCollectionIdsToConsumingPTransforms,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction);
+ addFinishFunction,
+ splitListener);
}
}
@@ -196,15 +202,12 @@
processBundleDescriptor.getWindowingStrategiesMap(),
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction);
+ addFinishFunction,
+ splitListener);
}
public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request)
throws Exception {
- BeamFnApi.InstructionResponse.Builder response =
- BeamFnApi.InstructionResponse.newBuilder()
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
-
String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference();
BeamFnApi.ProcessBundleDescriptor bundleDescriptor =
(BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId);
@@ -223,6 +226,8 @@
}
}
+ ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder();
+
// Instantiate a State API call handler depending on whether a State Api service descriptor
// was specified.
try (HandleStateCallsForBundle beamFnStateClient =
@@ -230,6 +235,23 @@
? new BlockTillStateCallsFinish(beamFnStateGrpcClientCache.forApiServiceDescriptor(
bundleDescriptor.getStateApiServiceDescriptor()))
: new FailAllStateCallsForBundle(request.getProcessBundle())) {
+ Multimap<String, Application> allPrimaries = ArrayListMultimap.create();
+ Multimap<String, DelayedApplication> allResiduals = ArrayListMultimap.create();
+ BundleSplitListener splitListener =
+ (List<Application> primaries, List<DelayedApplication> residuals) -> {
+ // Reset primaries and accumulate residuals.
+ Multimap<String, Application> newPrimaries = ArrayListMultimap.create();
+ for (Application primary : primaries) {
+ newPrimaries.put(primary.getPtransformId(), primary);
+ }
+ allPrimaries.clear();
+ allPrimaries.putAll(newPrimaries);
+
+ for (DelayedApplication residual : residuals) {
+ allResiduals.put(residual.getApplication().getPtransformId(), residual);
+ }
+ };
+
// Create a BeamFnStateClient
for (Map.Entry<String, RunnerApi.PTransform> entry
: bundleDescriptor.getTransformsMap().entrySet()) {
@@ -251,7 +273,8 @@
pCollectionIdsToConsumingPTransforms,
pCollectionIdsToConsumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ splitListener);
}
// Already in reverse topological order so we don't need to do anything.
@@ -265,9 +288,16 @@
LOG.debug("Finishing function {}", finishFunction);
finishFunction.run();
}
+ if (!allPrimaries.isEmpty()) {
+ response.setSplit(
+ BundleSplit.newBuilder()
+ .addAllPrimaryRoots(allPrimaries.values())
+ .addAllResidualRoots(allResiduals.values())
+ .build());
+ }
}
- return response;
+ return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response);
}
/**
@@ -355,7 +385,8 @@
Map<String, WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) {
String message =
String.format(
"No factory registered for %s, known factories %s",
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
new file mode 100644
index 0000000..e7a043b
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.MapState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import org.apache.beam.sdk.state.SetState;
+import org.apache.beam.sdk.state.StateBinder;
+import org.apache.beam.sdk.state.StateContext;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CombineFnUtil;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** Provides access to side inputs and state via a {@link BeamFnStateClient}. */
+public class FnApiStateAccessor implements SideInputReader, StateBinder {
+ private final PipelineOptions pipelineOptions;
+ private final Map<StateKey, Object> stateKeyObjectCache;
+ private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
+ private final BeamFnStateClient beamFnStateClient;
+ private final String ptransformId;
+ private final Supplier<String> processBundleInstructionId;
+ private final Collection<ThrowingRunnable> stateFinalizers;
+
+ private final Coder<?> keyCoder;
+ private final Coder<BoundedWindow> windowCoder;
+
+ private WindowedValue<?> currentElement;
+ private BoundedWindow currentWindow;
+ private ByteString encodedCurrentKey;
+ private ByteString encodedCurrentWindow;
+
+ public FnApiStateAccessor(
+ PipelineOptions pipelineOptions,
+ String ptransformId,
+ Supplier<String> processBundleInstructionId,
+ Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
+ BeamFnStateClient beamFnStateClient,
+ Coder<?> keyCoder,
+ Coder<BoundedWindow> windowCoder) {
+ this.pipelineOptions = pipelineOptions;
+ this.stateKeyObjectCache = Maps.newHashMap();
+ this.sideInputSpecMap = sideInputSpecMap;
+ this.beamFnStateClient = beamFnStateClient;
+ this.ptransformId = ptransformId;
+ this.processBundleInstructionId = processBundleInstructionId;
+ this.stateFinalizers = new ArrayList<>();
+
+ this.keyCoder = keyCoder;
+ this.windowCoder = windowCoder;
+ }
+
+ public void setCurrentElement(WindowedValue<?> currentElement) {
+ this.currentElement = currentElement;
+ this.encodedCurrentKey = null;
+ }
+
+ public void setCurrentWindow(BoundedWindow currentWindow) {
+ this.currentWindow = currentWindow;
+ this.encodedCurrentWindow = null;
+ }
+
+ @Override
+ @Nullable
+ public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+ TupleTag<?> tag = view.getTagInternal();
+
+ SideInputSpec sideInputSpec = sideInputSpecMap.get(tag);
+ checkArgument(sideInputSpec != null, "Attempting to access unknown side input %s.", view);
+ KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
+
+ ByteString.Output encodedWindowOut = ByteString.newOutput();
+ try {
+ sideInputSpec
+ .getWindowCoder()
+ .encode(sideInputSpec.getWindowMappingFn().getSideInputWindow(window), encodedWindowOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ ByteString encodedWindow = encodedWindowOut.toByteString();
+
+ StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
+ cacheKeyBuilder
+ .getMultimapSideInputBuilder()
+ .setPtransformId(ptransformId)
+ .setSideInputId(tag.getId())
+ .setWindow(encodedWindow);
+ return (T)
+ stateKeyObjectCache.computeIfAbsent(
+ cacheKeyBuilder.build(),
+ key ->
+ sideInputSpec
+ .getViewFn()
+ .apply(
+ new MultimapSideInput<>(
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ ptransformId,
+ tag.getId(),
+ encodedWindow,
+ kvCoder.getKeyCoder(),
+ kvCoder.getValueCoder())));
+ }
+
+ @Override
+ public <T> boolean contains(PCollectionView<T> view) {
+ return sideInputSpecMap.containsKey(view.getTagInternal());
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return sideInputSpecMap.isEmpty();
+ }
+
+ @Override
+ public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
+ return (ValueState<T>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey key) {
+ return new ValueState<T>() {
+ private final BagUserState<T> impl = createBagUserState(id, coder);
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+
+ @Override
+ public void write(T input) {
+ impl.clear();
+ impl.append(input);
+ }
+
+ @Override
+ public T read() {
+ Iterator<T> value = impl.get().iterator();
+ if (value.hasNext()) {
+ return value.next();
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public ValueState<T> readLater() {
+ // TODO: Support prefetching.
+ return this;
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
+ return (BagState<T>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey key) {
+ return new BagState<T>() {
+ private final BagUserState<T> impl = createBagUserState(id, elemCoder);
+
+ @Override
+ public void add(T value) {
+ impl.append(value);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return ReadableStates.immediate(!impl.get().iterator().hasNext());
+ }
+
+ @Override
+ public Iterable<T> read() {
+ return impl.get();
+ }
+
+ @Override
+ public BagState<T> readLater() {
+ // TODO: Support prefetching.
+ return this;
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
+ throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+ }
+
+ @Override
+ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+ String id,
+ StateSpec<MapState<KeyT, ValueT>> spec,
+ Coder<KeyT> mapKeyCoder,
+ Coder<ValueT> mapValueCoder) {
+ throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+ }
+
+ @Override
+ public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
+ String id,
+ StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+ Coder<AccumT> accumCoder,
+ CombineFn<ElementT, AccumT, ResultT> combineFn) {
+ return (CombiningState<ElementT, AccumT, ResultT>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey key) {
+ // TODO: Support squashing accumulators depending on whether we know of all
+ // remote accumulators and local accumulators or just local accumulators.
+ return new CombiningState<ElementT, AccumT, ResultT>() {
+ private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
+
+ @Override
+ public AccumT getAccum() {
+ Iterator<AccumT> iterator = impl.get().iterator();
+ if (iterator.hasNext()) {
+ return iterator.next();
+ }
+ return combineFn.createAccumulator();
+ }
+
+ @Override
+ public void addAccum(AccumT accum) {
+ Iterator<AccumT> iterator = impl.get().iterator();
+
+ // Only merge if there was a prior value
+ if (iterator.hasNext()) {
+ accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+ // Since there was a prior value, we need to clear.
+ impl.clear();
+ }
+
+ impl.append(accum);
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ return combineFn.mergeAccumulators(accumulators);
+ }
+
+ @Override
+ public CombiningState<ElementT, AccumT, ResultT> readLater() {
+ return this;
+ }
+
+ @Override
+ public ResultT read() {
+ Iterator<AccumT> iterator = impl.get().iterator();
+ if (iterator.hasNext()) {
+ return combineFn.extractOutput(iterator.next());
+ }
+ return combineFn.defaultValue();
+ }
+
+ @Override
+ public void add(ElementT value) {
+ AccumT newAccumulator = combineFn.addInput(getAccum(), value);
+ impl.clear();
+ impl.append(newAccumulator);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return ReadableStates.immediate(!impl.get().iterator().hasNext());
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ public <ElementT, AccumT, ResultT>
+ CombiningState<ElementT, AccumT, ResultT> bindCombiningWithContext(
+ String id,
+ StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+ Coder<AccumT> accumCoder,
+ CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
+ return (CombiningState<ElementT, AccumT, ResultT>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ key ->
+ bindCombining(
+ id,
+ spec,
+ accumCoder,
+ CombineFnUtil.bindContext(
+ combineFn,
+ new StateContext<BoundedWindow>() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return pipelineOptions;
+ }
+
+ @Override
+ public <T> T sideInput(PCollectionView<T> view) {
+ return get(view, currentWindow);
+ }
+
+ @Override
+ public BoundedWindow window() {
+ return currentWindow;
+ }
+ })));
+ }
+
+ /**
+ * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing and is
+ * waiting on resolution of BEAM-2535.
+ */
+ @Override
+ @Deprecated
+ public WatermarkHoldState bindWatermark(
+ String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
+ throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
+ }
+
+ private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> valueCoder) {
+ BagUserState<T> rval =
+ new BagUserState<>(
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ ptransformId,
+ stateId,
+ encodedCurrentWindow,
+ encodedCurrentKey,
+ valueCoder);
+ stateFinalizers.add(rval::asyncClose);
+ return rval;
+ }
+
+ private StateKey createBagUserStateKey(String stateId) {
+ cacheEncodedKeyAndWindowForKeyedContext();
+ StateKey.Builder builder = StateKey.newBuilder();
+ builder
+ .getBagUserStateBuilder()
+ .setWindow(encodedCurrentWindow)
+ .setKey(encodedCurrentKey)
+ .setPtransformId(ptransformId)
+ .setUserStateId(stateId);
+ return builder.build();
+ }
+
+ /**
+ * Memoizes an encoded key and window for the current element being processed saving on the
+ * encoding cost of the key and window across multiple state cells for the lifetime of {@link
+ * DoFn.ProcessElement}
+ *
+ * <p>This should only be called during {@link DoFn.ProcessElement}.
+ */
+ private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
+ if (encodedCurrentKey == null) {
+ checkState(
+ currentElement.getValue() instanceof KV,
+ "Accessing state in unkeyed context. Current element is not a KV: %s.",
+ currentElement);
+ checkState(keyCoder != null, "Accessing state in unkeyed context, no key coder available");
+
+ ByteString.Output encodedKeyOut = ByteString.newOutput();
+ try {
+ ((Coder) keyCoder).encode(((KV<?, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ encodedCurrentKey = encodedKeyOut.toByteString();
+ }
+
+ if (encodedCurrentWindow == null) {
+ ByteString.Output encodedWindowOut = ByteString.newOutput();
+ try {
+ windowCoder.encode(currentWindow, encodedWindowOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ encodedCurrentWindow = encodedWindowOut.toByteString();
+ }
+ }
+
+ public void finalizeState() {
+ // Persist all dirty state cells
+ try {
+ for (ThrowingRunnable runnable : stateFinalizers) {
+ runnable.run();
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException(e);
+ } catch (Exception e) {
+ throw new IllegalStateException(e);
+ }
+ }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
new file mode 100644
index 0000000..8aec3a1
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
+
+/**
+ * A specification for side inputs containing a value {@link Coder}, the window {@link Coder},
+ * {@link ViewFn}, and the {@link WindowMappingFn}.
+ *
+ * @param <W>
+ */
+@AutoValue
+public abstract class SideInputSpec<W extends BoundedWindow> {
+ public static <W extends BoundedWindow> SideInputSpec create(
+ Coder<?> coder,
+ Coder<W> windowCoder,
+ ViewFn<?, ?> viewFn,
+ WindowMappingFn<W> windowMappingFn) {
+ return new AutoValue_SideInputSpec<>(
+ coder, windowCoder, viewFn, windowMappingFn);
+ }
+
+ abstract Coder<?> getCoder();
+
+ abstract Coder<W> getWindowCoder();
+
+ abstract ViewFn<?, ?> getViewFn();
+
+ abstract WindowMappingFn<W> getWindowMappingFn();
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
index d72753b..9871c37 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
@@ -201,7 +201,8 @@
null /* windowingStrategies */,
receivers,
null /* addStartFunction */,
- null /* addFinishFunction */);
+ null, /* addFinishFunction */
+ null /* splitListener */);
WindowedValue<Integer> value =
WindowedValue.of(
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index f138677..4b4c64f 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -152,7 +152,8 @@
COMPONENTS.getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
verifyZeroInteractions(mockBeamFnDataClient);
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
index 12540c9..a70aedd 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
@@ -141,7 +141,8 @@
COMPONENTS.getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
verifyZeroInteractions(mockBeamFnDataClient);
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
index 044de10..e438bd5 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
@@ -151,7 +151,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
// This is testing a deprecated way of running sources and should be removed
// once all source definitions are instead propagated along the input edge.
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
index 9b1bd75..3b58517 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
@@ -87,7 +87,8 @@
Collections.emptyMap(),
consumers,
null /* addStartFunction */,
- null /* addFinishFunction */);
+ null, /* addFinishFunction */
+ null /* splitListener */);
mainOutputValues.clear();
assertThat(consumers.keySet(), containsInAnyOrder(
@@ -149,7 +150,8 @@
Collections.emptyMap(),
consumers,
null /* addStartFunction */,
- null /* addFinishFunction */);
+ null, /* addFinishFunction */
+ null /* splitListener */);
mainOutputValues.clear();
assertThat(consumers.keySet(), containsInAnyOrder("inputATarget", "mainOutputTarget"));
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 85aa564..ee04cf8 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -18,8 +18,6 @@
package org.apache.beam.fn.harness;
-import static com.google.common.base.Preconditions.checkState;
-import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -27,29 +25,23 @@
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.protobuf.ByteString;
import java.io.IOException;
+import java.io.Serializable;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
-import java.util.ServiceLoader;
-import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -60,8 +52,6 @@
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
-import org.apache.beam.sdk.transforms.CombineWithContext.Context;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
@@ -73,15 +63,13 @@
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.CoderUtils;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.hamcrest.collection.IsMapContaining;
+import org.apache.beam.sdk.values.TupleTagList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
@@ -90,137 +78,10 @@
/** Tests for {@link FnApiDoFnRunner}. */
@RunWith(JUnit4.class)
-public class FnApiDoFnRunnerTest {
+public class FnApiDoFnRunnerTest implements Serializable {
public static final String TEST_PTRANSFORM_ID = "pTransformId";
- private static class TestDoFn extends DoFn<String, String> {
- private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
- private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
-
- private enum State {
- NOT_SET_UP,
- OUTSIDE_BUNDLE,
- INSIDE_BUNDLE,
- }
-
- private State state = State.NOT_SET_UP;
-
- private BoundedWindow window;
-
- @Setup
- public void setUp() {
- checkState(State.NOT_SET_UP.equals(state), "Unexpected state: %s", state);
- state = State.OUTSIDE_BUNDLE;
- }
-
- // No testing for TearDown - it's currently not supported by FnHarness.
-
- @StartBundle
- public void startBundle() {
- checkState(State.OUTSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
- state = State.INSIDE_BUNDLE;
- }
-
- @ProcessElement
- public void processElement(ProcessContext context, BoundedWindow window) {
- checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
- context.output("MainOutput" + context.element());
- context.output(additionalOutput, "AdditionalOutput" + context.element());
- this.window = window;
- }
-
- @FinishBundle
- public void finishBundle(FinishBundleContext context) {
- checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
- state = State.OUTSIDE_BUNDLE;
- if (window != null) {
- context.output("FinishBundle", window.maxTimestamp(), window);
- window = null;
- }
- }
- }
-
- /**
- * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
- * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
- * are directed to the correct consumers.
- */
- @Test
- public void testCreatingAndProcessingDoFn() throws Exception {
- String pTransformId = "pTransformId";
-
- DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
- new TestDoFn(),
- WindowingStrategy.globalDefault(),
- ImmutableList.of(),
- StringUtf8Coder.of(),
- TestDoFn.mainOutput);
- RunnerApi.FunctionSpec functionSpec =
- RunnerApi.FunctionSpec.newBuilder()
- .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
- .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
- .build();
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs("inputA", "inputATarget")
- .putInputs("inputB", "inputBTarget")
- .putOutputs(TestDoFn.mainOutput.getId(), "mainOutputTarget")
- .putOutputs(TestDoFn.additionalOutput.getId(), "additionalOutputTarget")
- .build();
-
- List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
- List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
- Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("mainOutputTarget",
- (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
- consumers.put("additionalOutputTarget",
- (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
- PipelineOptionsFactory.create(),
- null /* beamFnDataClient */,
- null /* beamFnStateClient */,
- pTransformId,
- pTransform,
- Suppliers.ofInstance("57L")::get,
- Collections.emptyMap(),
- Collections.emptyMap(),
- Collections.emptyMap(),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- Iterables.getOnlyElement(startFunctions).run();
- mainOutputValues.clear();
-
- assertThat(consumers.keySet(), containsInAnyOrder(
- "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
-
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
- Iterables.getOnlyElement(consumers.get("inputBTarget")).accept(valueInGlobalWindow("B"));
- assertThat(mainOutputValues, contains(
- valueInGlobalWindow("MainOutputA1"),
- valueInGlobalWindow("MainOutputA2"),
- valueInGlobalWindow("MainOutputB")));
- assertThat(additionalOutputValues, contains(
- valueInGlobalWindow("AdditionalOutputA1"),
- valueInGlobalWindow("AdditionalOutputA2"),
- valueInGlobalWindow("AdditionalOutputB")));
- mainOutputValues.clear();
- additionalOutputValues.clear();
-
- Iterables.getOnlyElement(finishFunctions).run();
- assertThat(
- mainOutputValues,
- contains(
- timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
- mainOutputValues.clear();
- }
-
private static class ConcatCombineFn extends CombineFn<String, String, String> {
@Override
public String createAccumulator() {
@@ -247,33 +108,6 @@
}
}
- private static class ConcatCombineFnWithContext
- extends CombineFnWithContext<String, String, String> {
- @Override
- public String createAccumulator(Context c) {
- return "";
- }
-
- @Override
- public String addInput(String accumulator, String input, Context c) {
- return accumulator.concat(input);
- }
-
- @Override
- public String mergeAccumulators(Iterable<String> accumulators, Context c) {
- StringBuilder builder = new StringBuilder();
- for (String value : accumulators) {
- builder.append(value);
- }
- return builder.toString();
- }
-
- @Override
- public String extractOutput(String accumulator, Context c) {
- return accumulator;
- }
- }
-
private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
@@ -287,17 +121,12 @@
@StateId("combine")
private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
- @StateId("combineWithContext")
- private final StateSpec<CombiningState<String, String, String>> combiningWithContextStateSpec =
- StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext());
@ProcessElement
public void processElement(ProcessContext context,
@StateId("value") ValueState<String> valueState,
@StateId("bag") BagState<String> bagState,
- @StateId("combine") CombiningState<String, String, String> combiningState,
- @StateId("combineWithContext")
- CombiningState<String, String, String> combiningWithContextState) {
+ @StateId("combine") CombiningState<String, String, String> combiningState) {
context.output("value:" + valueState.read());
valueState.write(context.element().getValue());
@@ -306,41 +135,34 @@
context.output("combine:" + combiningState.read());
combiningState.add(context.element().getValue());
-
- context.output("combineWithContext:" + combiningWithContextState.read());
- combiningWithContextState.add(context.element().getValue());
}
}
@Test
public void testUsingUserState() throws Exception {
- DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
- new TestStatefulDoFn(),
- WindowingStrategy.globalDefault(),
- ImmutableList.of(),
- KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
- new TupleTag<>("mainOutput"));
- RunnerApi.FunctionSpec functionSpec =
- RunnerApi.FunctionSpec.newBuilder()
- .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
- .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
- .build();
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs("input", "inputTarget")
- .putOutputs("mainOutput", "mainOutputTarget")
- .build();
+ Pipeline p = Pipeline.create();
+ PCollection<KV<String, String>> valuePCollection =
+ p.apply(Create.of(KV.of("unused", "unused")));
+ PCollection<String> outputPCollection =
+ valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(new TestStatefulDoFn()));
+
+ SdkComponents sdkComponents = SdkComponents.create();
+ RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+ String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+ String outputPCollectionId =
+ sdkComponents.registerPCollection(outputPCollection);
+ RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
+ pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of(
bagUserStateKey("value", "X"), encode("X0"),
bagUserStateKey("bag", "X"), encode("X0"),
- bagUserStateKey("combine", "X"), encode("X0"),
- bagUserStateKey("combineWithContext", "X"), encode("X0")
+ bagUserStateKey("combine", "X"), encode("X0")
));
List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("mainOutputTarget",
+ consumers.put(outputPCollectionId,
(FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
List<ThrowingRunnable> startFunctions = new ArrayList<>();
List<ThrowingRunnable> finishFunctions = new ArrayList<>();
@@ -352,22 +174,23 @@
TEST_PTRANSFORM_ID,
pTransform,
Suppliers.ofInstance("57L")::get,
- Collections.emptyMap(),
- Collections.emptyMap(),
- Collections.emptyMap(),
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
Iterables.getOnlyElement(startFunctions).run();
mainOutputValues.clear();
- assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget"));
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
// Ensure that bag user state that is initially empty or populated works.
// Ensure that the key order does not matter when we traverse over KV pairs.
FnDataReceiver<WindowedValue<?>> mainInput =
- Iterables.getOnlyElement(consumers.get("inputTarget"));
+ Iterables.getOnlyElement(consumers.get(inputPCollectionId));
mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
@@ -376,19 +199,15 @@
valueInGlobalWindow("value:X0"),
valueInGlobalWindow("bag:[X0]"),
valueInGlobalWindow("combine:X0"),
- valueInGlobalWindow("combineWithContext:X0"),
valueInGlobalWindow("value:null"),
valueInGlobalWindow("bag:[]"),
valueInGlobalWindow("combine:"),
- valueInGlobalWindow("combineWithContext:"),
valueInGlobalWindow("value:X1"),
valueInGlobalWindow("bag:[X0, X1]"),
valueInGlobalWindow("combine:X0X1"),
- valueInGlobalWindow("combineWithContext:X0X1"),
valueInGlobalWindow("value:Y1"),
valueInGlobalWindow("bag:[Y1]"),
- valueInGlobalWindow("combine:Y1"),
- valueInGlobalWindow("combineWithContext:Y1")));
+ valueInGlobalWindow("combine:Y1")));
mainOutputValues.clear();
Iterables.getOnlyElement(finishFunctions).run();
@@ -399,11 +218,9 @@
.put(bagUserStateKey("value", "X"), encode("X2"))
.put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
.put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
- .put(bagUserStateKey("combineWithContext", "X"), encode("X0X1X2"))
.put(bagUserStateKey("value", "Y"), encode("Y2"))
.put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
.put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
- .put(bagUserStateKey("combineWithContext", "Y"), encode("Y1Y2"))
.build(),
fakeClient.getData());
mainOutputValues.clear();
@@ -425,13 +242,17 @@
private final PCollectionView<String> defaultSingletonSideInput;
private final PCollectionView<String> singletonSideInput;
private final PCollectionView<Iterable<String>> iterableSideInput;
+ private final TupleTag<String> additionalOutput;
+
private TestSideInputDoFn(
PCollectionView<String> defaultSingletonSideInput,
PCollectionView<String> singletonSideInput,
- PCollectionView<Iterable<String>> iterableSideInput) {
+ PCollectionView<Iterable<String>> iterableSideInput,
+ TupleTag<String> additionalOutput) {
this.defaultSingletonSideInput = defaultSingletonSideInput;
this.singletonSideInput = singletonSideInput;
this.iterableSideInput = iterableSideInput;
+ this.additionalOutput = additionalOutput;
}
@ProcessElement
@@ -441,6 +262,7 @@
for (String sideInputValue : context.sideInput(iterableSideInput)) {
context.output(context.element() + ":" + sideInputValue);
}
+ context.output(additionalOutput, context.element() + ":additional");
}
}
@@ -453,21 +275,31 @@
PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
PCollectionView<Iterable<String>> iterableSideInputView =
valuePCollection.apply(View.asIterable());
- PCollection<String> outputPCollection = valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
- new TestSideInputDoFn(
- defaultSingletonSideInputView,
- singletonSideInputView,
- iterableSideInputView))
- .withSideInputs(
- defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView));
+ TupleTag<String> mainOutput = new TupleTag<String>("main") {};
+ TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
+ PCollectionTuple outputPCollection =
+ valuePCollection.apply(
+ TEST_PTRANSFORM_ID,
+ ParDo.of(
+ new TestSideInputDoFn(
+ defaultSingletonSideInputView,
+ singletonSideInputView,
+ iterableSideInputView,
+ additionalOutput))
+ .withSideInputs(
+ defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView)
+ .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
SdkComponents sdkComponents = SdkComponents.create();
RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
- String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+ String outputPCollectionId =
+ sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
+ String additionalPCollectionId =
+ sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
- RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
- pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID);
ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
@@ -478,13 +310,16 @@
FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+ consumers.put(outputPCollectionId,
(FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+ consumers.put(additionalPCollectionId,
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
List<ThrowingRunnable> startFunctions = new ArrayList<>();
List<ThrowingRunnable> finishFunctions = new ArrayList<>();
- new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
PipelineOptionsFactory.create(),
null /* beamFnDataClient */,
fakeClient,
@@ -496,12 +331,15 @@
pProto.getComponents().getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
Iterables.getOnlyElement(startFunctions).run();
mainOutputValues.clear();
- assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+ assertThat(
+ consumers.keySet(),
+ containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
// Ensure that bag user state that is initially empty or populated works.
// Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
@@ -520,6 +358,9 @@
valueInGlobalWindow("Y:iterableValue1"),
valueInGlobalWindow("Y:iterableValue2"),
valueInGlobalWindow("Y:iterableValue3")));
+ assertThat(
+ additionalOutputValues,
+ contains(valueInGlobalWindow("X:additional"), valueInGlobalWindow("Y:additional")));
mainOutputValues.clear();
Iterables.getOnlyElement(finishFunctions).run();
@@ -589,7 +430,7 @@
List<ThrowingRunnable> startFunctions = new ArrayList<>();
List<ThrowingRunnable> finishFunctions = new ArrayList<>();
- new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
PipelineOptionsFactory.create(),
null /* beamFnDataClient */,
fakeClient,
@@ -601,7 +442,8 @@
pProto.getComponents().getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
Iterables.getOnlyElement(startFunctions).run();
mainOutputValues.clear();
@@ -658,17 +500,4 @@
}
return out.toByteString();
}
-
- @Test
- public void testRegistration() {
- for (Registrar registrar :
- ServiceLoader.load(Registrar.class)) {
- if (registrar instanceof FnApiDoFnRunner.Registrar) {
- assertThat(registrar.getPTransformRunnerFactories(),
- IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN));
- return;
- }
- }
- fail("Expected registrar not found.");
- }
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
index 906bdeb..93ca808 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
@@ -82,7 +82,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
assertThat(startFunctions, empty());
assertThat(finishFunctions, empty());
@@ -116,7 +117,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
assertThat(startFunctions, empty());
assertThat(finishFunctions, empty());
@@ -150,7 +152,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
assertThat(startFunctions, empty());
assertThat(finishFunctions, empty());
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 48d56fd..eafef99 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -50,6 +50,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.function.ThrowingConsumer;
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -115,7 +116,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
assertThat(processBundleInstructionId.get(), equalTo("999L"));
transformsProcessed.add(pTransform);
@@ -176,7 +178,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
throw new IllegalStateException("TestException");
@@ -217,7 +220,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
addStartFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -261,7 +265,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
addFinishFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -335,10 +340,11 @@
Supplier<String> processBundleInstructionId,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Map<String, WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
return null;
}
@@ -386,10 +392,11 @@
Supplier<String> processBundleInstructionId,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Map<String, WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
return null;
}