Adds support for SDF in ULR and the Java SDK.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index db25581..b2f8195 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -28,6 +28,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -93,7 +94,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
.setPrimitiveTransformReference(pTransformId)
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
index 34d704d..a4f6e54 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
@@ -27,6 +27,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -84,7 +85,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
.setPrimitiveTransformReference(pTransformId)
.setName(getOnlyElement(pTransform.getInputsMap().keySet()))
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
index 5df178f..185fc11 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java
@@ -28,6 +28,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.control.ProcessBundleHandler;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -81,7 +82,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) {
ImmutableList.Builder<FnDataReceiver<WindowedValue<?>>> consumers = ImmutableList.builder();
for (String pCollectionId : pTransform.getOutputsMap().values()) {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
new file mode 100644
index 0000000..9cbe34c
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Sets;
+import java.io.IOException;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.harness.state.SideInputSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/** A {@link PTransformRunnerFactory} for transforms invoking a {@link DoFn}. */
+abstract class DoFnPTransformRunnerFactory<
+ TransformInputT,
+ FnInputT,
+ OutputT,
+ RunnerT extends DoFnPTransformRunnerFactory.DoFnPTransformRunner<TransformInputT>>
+ implements PTransformRunnerFactory<RunnerT> {
+ interface DoFnPTransformRunner<T> {
+ void startBundle() throws Exception;
+
+ void processElement(WindowedValue<T> input) throws Exception;
+
+ void finishBundle() throws Exception;
+ }
+
+ @Override
+ public final RunnerT createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ BeamFnStateClient beamFnStateClient,
+ String ptransformId,
+ PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) {
+ Context<FnInputT, OutputT> context =
+ new Context<>(
+ pipelineOptions,
+ beamFnStateClient,
+ ptransformId,
+ pTransform,
+ processBundleInstructionId,
+ pCollections,
+ coders,
+ windowingStrategies,
+ pCollectionIdsToConsumers,
+ splitListener);
+
+ RunnerT runner = createRunner(context);
+
+ // Register the appropriate handlers.
+ addStartFunction.accept(runner::startBundle);
+ Iterable<String> mainInput =
+ Sets.difference(
+ pTransform.getInputsMap().keySet(), context.parDoPayload.getSideInputsMap().keySet());
+ for (String localInputName : mainInput) {
+ pCollectionIdsToConsumers.put(
+ pTransform.getInputsOrThrow(localInputName),
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<TransformInputT>>) runner::processElement);
+ }
+ addFinishFunction.accept(runner::finishBundle);
+ return runner;
+ }
+
+ abstract RunnerT createRunner(Context<FnInputT, OutputT> context);
+
+ static class Context<InputT, OutputT> {
+ final PipelineOptions pipelineOptions;
+ final BeamFnStateClient beamFnStateClient;
+ final String ptransformId;
+ final PTransform pTransform;
+ final Supplier<String> processBundleInstructionId;
+ final RehydratedComponents rehydratedComponents;
+ final DoFn<InputT, OutputT> doFn;
+ final TupleTag<OutputT> mainOutputTag;
+ final Coder<?> inputCoder;
+ final Coder<?> keyCoder;
+ final Coder<? extends BoundedWindow> windowCoder;
+ final WindowingStrategy<InputT, ?> windowingStrategy;
+ final Map<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap;
+ final ParDoPayload parDoPayload;
+ final ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer;
+ final BundleSplitListener splitListener;
+
+ Context(
+ PipelineOptions pipelineOptions,
+ BeamFnStateClient beamFnStateClient,
+ String ptransformId,
+ PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
+ BundleSplitListener splitListener) {
+ this.pipelineOptions = pipelineOptions;
+ this.beamFnStateClient = beamFnStateClient;
+ this.ptransformId = ptransformId;
+ this.pTransform = pTransform;
+ this.processBundleInstructionId = processBundleInstructionId;
+ ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMapBuilder =
+ ImmutableMap.builder();
+ try {
+ rehydratedComponents =
+ RehydratedComponents.forComponents(
+ RunnerApi.Components.newBuilder()
+ .putAllCoders(coders)
+ .putAllWindowingStrategies(windowingStrategies)
+ .build());
+ parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
+ doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
+ mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
+ String mainInputTag =
+ Iterables.getOnlyElement(
+ Sets.difference(
+ pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
+ PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
+ inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
+ if (inputCoder instanceof KvCoder
+ // TODO: Stop passing windowed value coders within PCollections.
+ || (inputCoder instanceof WindowedValue.WindowedValueCoder
+ && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) {
+ this.keyCoder =
+ inputCoder instanceof WindowedValueCoder
+ ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
+ : ((KvCoder) inputCoder).getKeyCoder();
+ } else {
+ this.keyCoder = null;
+ }
+
+ windowingStrategy =
+ (WindowingStrategy)
+ rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());
+ windowCoder = windowingStrategy.getWindowFn().windowCoder();
+
+ // Build the map from tag id to side input specification
+ for (Map.Entry<String, RunnerApi.SideInput> entry :
+ parDoPayload.getSideInputsMap().entrySet()) {
+ String sideInputTag = entry.getKey();
+ RunnerApi.SideInput sideInput = entry.getValue();
+ checkArgument(
+ Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+ sideInput.getAccessPattern().getUrn()),
+ "This SDK is only capable of dealing with %s materializations "
+ + "but was asked to handle %s for PCollectionView with tag %s.",
+ Materializations.MULTIMAP_MATERIALIZATION_URN,
+ sideInput.getAccessPattern().getUrn(),
+ sideInputTag);
+
+ PCollection sideInputPCollection =
+ pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
+ WindowingStrategy sideInputWindowingStrategy =
+ rehydratedComponents.getWindowingStrategy(
+ sideInputPCollection.getWindowingStrategyId());
+ tagToSideInputSpecMapBuilder.put(
+ new TupleTag<>(entry.getKey()),
+ SideInputSpec.create(
+ rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
+ sideInputWindowingStrategy.getWindowFn().windowCoder(),
+ PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
+ PCollectionViewTranslation.windowMappingFnFromProto(
+ entry.getValue().getWindowMappingFn())));
+ }
+ } catch (IOException exn) {
+ throw new IllegalArgumentException("Malformed ParDoPayload", exn);
+ }
+
+ ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
+ tagToConsumerBuilder = ImmutableListMultimap.builder();
+ for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
+ tagToConsumerBuilder.putAll(
+ new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
+ }
+ tagToConsumer = tagToConsumerBuilder.build();
+ tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build();
+ this.splitListener = splitListener;
+ }
+ }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
index 52d9435..6ac860e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java
@@ -27,6 +27,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -69,7 +70,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException {
// Give each input a MultiplexingFnDataReceiver to all outputs of the flatten.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 9778aac..4cfb5d9 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -17,75 +17,28 @@
*/
package org.apache.beam.fn.harness;
-import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Preconditions.checkState;
import com.google.auto.service.AutoService;
-import com.google.auto.value.AutoValue;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.ListMultimap;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.Sets;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.InvalidProtocolBufferException;
-import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collection;
-import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import java.util.function.Supplier;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.state.BagUserState;
-import org.apache.beam.fn.harness.state.BeamFnStateClient;
-import org.apache.beam.fn.harness.state.MultimapSideInput;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
-import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
-import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.fn.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.state.BagState;
-import org.apache.beam.sdk.state.CombiningState;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.ReadableState;
-import org.apache.beam.sdk.state.ReadableStates;
-import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.StateBinder;
-import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
-import org.apache.beam.sdk.state.ValueState;
-import org.apache.beam.sdk.state.WatermarkHoldState;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
-import org.apache.beam.sdk.transforms.DoFn.StartBundleContext;
import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
-import org.apache.beam.sdk.transforms.Materializations;
-import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -94,399 +47,145 @@
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
-import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
-import org.apache.beam.sdk.util.CombineFnUtil;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Instant;
/**
- * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers
- * of abstraction caused by StateInternals/TimerInternals since they model state and timer
- * concepts differently.
+ * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers of
+ * abstraction caused by StateInternals/TimerInternals since they model state and timer concepts
+ * differently.
*/
-public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
-
- private final ProcessBundleContext processContext;
- private final FinishBundleContext finishBundleContext;
- private StartBundleContext startBundleContext;
-
- /**
- * A registrar which provides a factory to handle Java {@link DoFn}s.
- */
+public class FnApiDoFnRunner<InputT, OutputT>
+ implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<InputT> {
+ /** A registrar which provides a factory to handle Java {@link DoFn}s. */
@AutoService(PTransformRunnerFactory.Registrar.class)
- public static class Registrar implements
- PTransformRunnerFactory.Registrar {
-
+ public static class Registrar implements PTransformRunnerFactory.Registrar {
@Override
public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
- return ImmutableMap.of(
- PTransformTranslation.PAR_DO_TRANSFORM_URN, new NewFactory(),
- ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory());
+ return ImmutableMap.of(PTransformTranslation.PAR_DO_TRANSFORM_URN, new Factory());
}
}
- /** A factory for {@link FnApiDoFnRunner}. */
static class Factory<InputT, OutputT>
- implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
+ extends DoFnPTransformRunnerFactory<
+ InputT, InputT, OutputT, FnApiDoFnRunner<InputT, OutputT>> {
@Override
- public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
- PipelineOptions pipelineOptions,
- BeamFnDataClient beamFnDataClient,
- BeamFnStateClient beamFnStateClient,
- String ptransformId,
- PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, PCollection> pCollections,
- Map<String, RunnerApi.Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
- Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
-
- // For every output PCollection, create a map from output name to Consumer
- ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
- tagToOutputMapBuilder = ImmutableListMultimap.builder();
- for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
- tagToOutputMapBuilder.putAll(
- new TupleTag<>(entry.getKey()),
- pCollectionIdsToConsumers.get(entry.getValue()));
- }
- ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToOutputMap =
- tagToOutputMapBuilder.build();
-
- // Get the DoFnInfo from the serialized blob.
- ByteString serializedFn = pTransform.getSpec().getPayload();
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray(
- serializedFn.toByteArray(), "DoFnInfo");
-
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnRunner<InputT, OutputT> runner =
- new FnApiDoFnRunner<>(
- pipelineOptions,
- beamFnStateClient,
- ptransformId,
- processBundleInstructionId,
- doFnInfo.getDoFn(),
- doFnInfo.getInputCoder(),
- (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
- (Collection) tagToOutputMap.get(doFnInfo.getMainOutput()),
- tagToOutputMap,
- ImmutableMap.of(),
- doFnInfo.getWindowingStrategy());
-
- registerHandlers(
- runner,
- pTransform,
- ImmutableSet.of(),
- addStartFunction,
- addFinishFunction,
- pCollectionIdsToConsumers);
- return runner;
+ public FnApiDoFnRunner<InputT, OutputT> createRunner(Context<InputT, OutputT> context) {
+ return new FnApiDoFnRunner<>(context);
}
}
- static class NewFactory<InputT, OutputT>
- implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
- @Override
- public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
- PipelineOptions pipelineOptions,
- BeamFnDataClient beamFnDataClient,
- BeamFnStateClient beamFnStateClient,
- String ptransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Map<String, RunnerApi.Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
- Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
-
- DoFn<InputT, OutputT> doFn;
- TupleTag<OutputT> mainOutputTag;
- Coder<InputT> inputCoder;
- WindowingStrategy<InputT, ?> windowingStrategy;
-
- ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap =
- ImmutableMap.builder();
- ParDoPayload parDoPayload;
- try {
- RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(
- RunnerApi.Components.newBuilder()
- .putAllCoders(coders).putAllWindowingStrategies(windowingStrategies).build());
- parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
- doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
- mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);
- String mainInputTag = Iterables.getOnlyElement(Sets.difference(
- pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
- RunnerApi.PCollection mainInput =
- pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
- inputCoder = (Coder<InputT>) rehydratedComponents.getCoder(
- mainInput.getCoderId());
- windowingStrategy = (WindowingStrategy) rehydratedComponents.getWindowingStrategy(
- mainInput.getWindowingStrategyId());
-
- // Build the map from tag id to side input specification
- for (Map.Entry<String, RunnerApi.SideInput> entry
- : parDoPayload.getSideInputsMap().entrySet()) {
- String sideInputTag = entry.getKey();
- RunnerApi.SideInput sideInput = entry.getValue();
- checkArgument(
- Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
- sideInput.getAccessPattern().getUrn()),
- "This SDK is only capable of dealing with %s materializations "
- + "but was asked to handle %s for PCollectionView with tag %s.",
- Materializations.MULTIMAP_MATERIALIZATION_URN,
- sideInput.getAccessPattern().getUrn(),
- sideInputTag);
-
- RunnerApi.PCollection sideInputPCollection =
- pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
- WindowingStrategy sideInputWindowingStrategy =
- rehydratedComponents.getWindowingStrategy(
- sideInputPCollection.getWindowingStrategyId());
- tagToSideInputSpecMap.put(
- new TupleTag<>(entry.getKey()),
- SideInputSpec.create(
- rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
- sideInputWindowingStrategy.getWindowFn().windowCoder(),
- PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
- PCollectionViewTranslation.windowMappingFnFromProto(
- entry.getValue().getWindowMappingFn())));
- }
- } catch (InvalidProtocolBufferException exn) {
- throw new IllegalArgumentException("Malformed ParDoPayload", exn);
- } catch (IOException exn) {
- throw new IllegalArgumentException("Malformed ParDoPayload", exn);
- }
-
- ImmutableListMultimap.Builder<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
- tagToConsumerBuilder = ImmutableListMultimap.builder();
- for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
- tagToConsumerBuilder.putAll(
- new TupleTag<>(entry.getKey()), pCollectionIdsToConsumers.get(entry.getValue()));
- }
- ListMultimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> tagToConsumer =
- tagToConsumerBuilder.build();
-
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>(
- pipelineOptions,
- beamFnStateClient,
- ptransformId,
- processBundleInstructionId,
- doFn,
- inputCoder,
- (Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
- tagToConsumer.get(mainOutputTag),
- tagToConsumer,
- tagToSideInputSpecMap.build(),
- windowingStrategy);
- registerHandlers(
- runner,
- pTransform,
- parDoPayload.getSideInputsMap().keySet(),
- addStartFunction,
- addFinishFunction,
- pCollectionIdsToConsumers);
- return runner;
- }
- }
-
- private static <InputT, OutputT> void registerHandlers(
- DoFnRunner<InputT, OutputT> runner,
- RunnerApi.PTransform pTransform,
- Set<String> sideInputLocalNames,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction,
- Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers) {
- // Register the appropriate handlers.
- addStartFunction.accept(runner::startBundle);
- for (String localInputName
- : Sets.difference(pTransform.getInputsMap().keySet(), sideInputLocalNames)) {
- pCollectionIdsToConsumers.put(
- pTransform.getInputsOrThrow(localInputName),
- (FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>) runner::processElement);
- }
- addFinishFunction.accept(runner::finishBundle);
- }
-
//////////////////////////////////////////////////////////////////////////////////////////////////
- private final PipelineOptions pipelineOptions;
- private final BeamFnStateClient beamFnStateClient;
- private final String ptransformId;
- private final Supplier<String> processBundleInstructionId;
- private final DoFn<InputT, OutputT> doFn;
- private final Coder<InputT> inputCoder;
+ private final Context<InputT, OutputT> context;
private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
- private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap;
- private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
- private final Map<StateKey, Object> stateKeyObjectCache;
- private final WindowingStrategy windowingStrategy;
+ private FnApiStateAccessor stateAccessor;
private final DoFnSignature doFnSignature;
private final DoFnInvoker<InputT, OutputT> doFnInvoker;
- private final StateBinder stateBinder;
- private final Collection<ThrowingRunnable> stateFinalizers;
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and is null otherwise.
- */
+ private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+ private final ProcessBundleContext processContext;
+ private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+ /** Only valid during {@link #processElement}, null otherwise. */
private WindowedValue<InputT> currentElement;
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and is null otherwise.
- */
private BoundedWindow currentWindow;
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and only when processing a {@link KV} and is null otherwise.
- */
- private ByteString encodedCurrentKey;
-
- /**
- * The lifetime of this member is only valid during {@link #processElement}
- * and is null otherwise.
- */
- private ByteString encodedCurrentWindow;
-
- FnApiDoFnRunner(
- PipelineOptions pipelineOptions,
- BeamFnStateClient beamFnStateClient,
- String ptransformId,
- Supplier<String> processBundleInstructionId,
- DoFn<InputT, OutputT> doFn,
- Coder<InputT> inputCoder,
- Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
- Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
- Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
- WindowingStrategy windowingStrategy) {
- this.pipelineOptions = pipelineOptions;
- this.beamFnStateClient = beamFnStateClient;
- this.ptransformId = ptransformId;
- this.processBundleInstructionId = processBundleInstructionId;
- this.doFn = doFn;
- this.inputCoder = inputCoder;
- this.mainOutputConsumers = mainOutputConsumers;
- this.outputMap = outputMap;
- this.sideInputSpecMap = sideInputSpecMap;
- this.stateKeyObjectCache = new HashMap<>();
- this.windowingStrategy = windowingStrategy;
- this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
- this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
+ FnApiDoFnRunner(Context<InputT, OutputT> context) {
+ this.context = context;
+ this.mainOutputConsumers =
+ (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+ (Collection) context.tagToConsumer.get(context.mainOutputTag);
+ this.doFnSignature = DoFnSignatures.signatureForDoFn(context.doFn);
+ this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
this.doFnInvoker.invokeSetup();
- this.stateBinder = new BeamFnStateBinder();
- this.stateFinalizers = new ArrayList<>();
- this.startBundleContext = doFn.new StartBundleContext() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
- };
+ this.startBundleContext =
+ this.context.doFn.new StartBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
+ };
this.processContext = new ProcessBundleContext();
- finishBundleContext = doFn.new FinishBundleContext() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
+ finishBundleContext =
+ this.context.doFn.new FinishBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
- @Override
- public void output(OutputT output, Instant timestamp, BoundedWindow window) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
- }
+ @Override
+ public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+ outputTo(
+ mainOutputConsumers,
+ WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+ }
- @Override
- public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
- Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
- if (consumers == null) {
- throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
- }
- outputTo(consumers,
- WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
- }
- };
+ @Override
+ public <T> void output(
+ TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
+ if (consumers == null) {
+ throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
+ }
+ outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
+ }
+ };
}
@Override
public void startBundle() {
+ this.stateAccessor =
+ new FnApiStateAccessor(
+ context.pipelineOptions,
+ context.ptransformId,
+ context.processBundleInstructionId,
+ context.tagToSideInputSpecMap,
+ context.beamFnStateClient,
+ context.keyCoder,
+ (Coder<BoundedWindow>) context.windowCoder);
+
doFnInvoker.invokeStartBundle(startBundleContext);
}
@Override
public void processElement(WindowedValue<InputT> elem) {
currentElement = elem;
+ stateAccessor.setCurrentElement(elem);
try {
Iterator<BoundedWindow> windowIterator =
(Iterator<BoundedWindow>) elem.getWindows().iterator();
while (windowIterator.hasNext()) {
currentWindow = windowIterator.next();
+ stateAccessor.setCurrentWindow(currentWindow);
doFnInvoker.invokeProcessElement(processContext);
}
} finally {
currentElement = null;
currentWindow = null;
- encodedCurrentKey = null;
- encodedCurrentWindow = null;
+ stateAccessor.setCurrentElement(null);
+ stateAccessor.setCurrentWindow(null);
}
}
@Override
- public void onTimer(
- String timerId,
- BoundedWindow window,
- Instant timestamp,
- TimeDomain timeDomain) {
- throw new UnsupportedOperationException("TODO: Add support for timers");
- }
-
- @Override
public void finishBundle() {
doFnInvoker.invokeFinishBundle(finishBundleContext);
- // Persist all dirty state cells
- try {
- for (ThrowingRunnable runnable : stateFinalizers) {
- runnable.run();
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new IllegalStateException(e);
- } catch (Exception e) {
- throw new IllegalStateException(e);
- }
-
// TODO: Support caching state data across bundle boundaries.
- stateKeyObjectCache.clear();
+ this.stateAccessor.finalizeState();
+ this.stateAccessor = null;
}
- @Override
- public DoFn<InputT, OutputT> getFn() {
- return doFnInvoker.getFn();
- }
-
- /**
- * Outputs the given element to the specified set of consumers wrapping any exceptions.
- */
+ /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
private <T> void outputTo(
- Collection<FnDataReceiver<WindowedValue<T>>> consumers,
- WindowedValue<T> output) {
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
try {
for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
consumer.accept(output);
@@ -499,12 +198,11 @@
/**
* Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}.
*/
- private class ProcessBundleContext
- extends DoFn<InputT, OutputT>.ProcessContext
+ private class ProcessBundleContext extends DoFn<InputT, OutputT>.ProcessContext
implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
private ProcessBundleContext() {
- doFn.super();
+ context.doFn.super();
}
@Override
@@ -577,11 +275,11 @@
checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
StateSpec<?> spec;
try {
- spec = (StateSpec<?>) stateDeclaration.field().get(doFn);
+ spec = (StateSpec<?>) stateDeclaration.field().get(context.doFn);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
- return spec.bind(stateId, stateBinder);
+ return spec.bind(stateId, stateAccessor);
}
@Override
@@ -591,60 +289,51 @@
@Override
public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
+ return context.pipelineOptions;
}
@Override
public PipelineOptions pipelineOptions() {
- return pipelineOptions;
+ return context.pipelineOptions;
}
@Override
public void output(OutputT output) {
- outputTo(mainOutputConsumers,
+ outputTo(
+ mainOutputConsumers,
WindowedValue.of(
- output,
- currentElement.getTimestamp(),
- currentWindow,
- currentElement.getPane()));
+ output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
}
@Override
public void outputWithTimestamp(OutputT output, Instant timestamp) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(
- output,
- timestamp,
- currentWindow,
- currentElement.getPane()));
+ outputTo(
+ mainOutputConsumers,
+ WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
}
@Override
public <T> void output(TupleTag<T> tag, T output) {
- Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
- outputTo(consumers,
+ outputTo(
+ consumers,
WindowedValue.of(
- output,
- currentElement.getTimestamp(),
- currentWindow,
- currentElement.getPane()));
+ output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
}
@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
- Collection<FnDataReceiver<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
if (consumers == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
- outputTo(consumers,
- WindowedValue.of(
- output,
- timestamp,
- currentWindow,
- currentElement.getPane()));
+ outputTo(
+ consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
}
@Override
@@ -654,7 +343,7 @@
@Override
public <T> T sideInput(PCollectionView<T> view) {
- return (T) bindSideInputView(view.getTagInternal());
+ return stateAccessor.get(view, currentWindow);
}
@Override
@@ -672,359 +361,4 @@
throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
}
}
-
- /**
- * A {@link StateBinder} that uses the Beam Fn State API to read and write user state.
- *
- * <p>TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that
- * {@link #bindWatermark} should never be implemented.
- */
- private class BeamFnStateBinder implements StateBinder {
- @Override
- public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
- return (ValueState<T>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- new Function<StateKey, Object>() {
- @Override
- public Object apply(StateKey key) {
- return new ValueState<T>() {
- private final BagUserState<T> impl = createBagUserState(id, coder);
-
- @Override
- public void clear() {
- impl.clear();
- }
-
- @Override
- public void write(T input) {
- impl.clear();
- impl.append(input);
- }
-
- @Override
- public T read() {
- Iterator<T> value = impl.get().iterator();
- if (value.hasNext()) {
- return value.next();
- } else {
- return null;
- }
- }
-
- @Override
- public ValueState<T> readLater() {
- // TODO: Support prefetching.
- return this;
- }
- };
- }
- });
- }
-
- @Override
- public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
- return (BagState<T>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- new Function<StateKey, Object>() {
- @Override
- public Object apply(StateKey key) {
- return new BagState<T>() {
- private final BagUserState<T> impl = createBagUserState(id, elemCoder);
-
- @Override
- public void add(T value) {
- impl.append(value);
- }
-
- @Override
- public ReadableState<Boolean> isEmpty() {
- return ReadableStates.immediate(!impl.get().iterator().hasNext());
- }
-
- @Override
- public Iterable<T> read() {
- return impl.get();
- }
-
- @Override
- public BagState<T> readLater() {
- // TODO: Support prefetching.
- return this;
- }
-
- @Override
- public void clear() {
- impl.clear();
- }
- };
- }
- });
- }
-
- @Override
- public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
- throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
- }
-
- @Override
- public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(String id,
- StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder,
- Coder<ValueT> mapValueCoder) {
- throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
- }
-
- @Override
- public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
- String id,
- StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec, Coder<AccumT> accumCoder,
- CombineFn<ElementT, AccumT, ResultT> combineFn) {
- return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- new Function<StateKey, Object>() {
- @Override
- public Object apply(StateKey key) {
- // TODO: Support squashing accumulators depending on whether we know of all
- // remote accumulators and local accumulators or just local accumulators.
- return new CombiningState<ElementT, AccumT, ResultT>() {
- private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
-
- @Override
- public AccumT getAccum() {
- Iterator<AccumT> iterator = impl.get().iterator();
- if (iterator.hasNext()) {
- return iterator.next();
- }
- return combineFn.createAccumulator();
- }
-
- @Override
- public void addAccum(AccumT accum) {
- Iterator<AccumT> iterator = impl.get().iterator();
-
- // Only merge if there was a prior value
- if (iterator.hasNext()) {
- accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
- // Since there was a prior value, we need to clear.
- impl.clear();
- }
-
- impl.append(accum);
- }
-
- @Override
- public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
- return combineFn.mergeAccumulators(accumulators);
- }
-
- @Override
- public CombiningState<ElementT, AccumT, ResultT> readLater() {
- return this;
- }
-
- @Override
- public ResultT read() {
- Iterator<AccumT> iterator = impl.get().iterator();
- if (iterator.hasNext()) {
- return combineFn.extractOutput(iterator.next());
- }
- return combineFn.defaultValue();
- }
-
- @Override
- public void add(ElementT value) {
- AccumT newAccumulator = combineFn.addInput(getAccum(), value);
- impl.clear();
- impl.append(newAccumulator);
- }
-
- @Override
- public ReadableState<Boolean> isEmpty() {
- return ReadableStates.immediate(!impl.get().iterator().hasNext());
- }
-
- @Override
- public void clear() {
- impl.clear();
- }
- };
- }
- });
- }
-
- @Override
- public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT>
- bindCombiningWithContext(
- String id,
- StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
- Coder<AccumT> accumCoder,
- CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
- return (CombiningState<ElementT, AccumT, ResultT>) stateKeyObjectCache.computeIfAbsent(
- createBagUserStateKey(id),
- key -> bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn,
- new StateContext<BoundedWindow>() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public <T> T sideInput(PCollectionView<T> view) {
- return (T) bindSideInputView(view.getTagInternal());
- }
-
- @Override
- public BoundedWindow window() {
- return currentWindow;
- }
- })));
- }
-
- /**
- * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing
- * and is waiting on resolution of BEAM-2535.
- */
- @Override
- @Deprecated
- public WatermarkHoldState bindWatermark(String id, StateSpec<WatermarkHoldState> spec,
- TimestampCombiner timestampCombiner) {
- throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
- }
-
- private <T> BagUserState<T> createBagUserState(
- String stateId, Coder<T> valueCoder) {
- BagUserState<T> rval = new BagUserState<>(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- stateId,
- encodedCurrentWindow,
- encodedCurrentKey,
- valueCoder);
- stateFinalizers.add(rval::asyncClose);
- return rval;
- }
- }
-
- private StateKey createBagUserStateKey(String stateId) {
- cacheEncodedKeyAndWindowForKeyedContext();
- StateKey.Builder builder = StateKey.newBuilder();
- builder.getBagUserStateBuilder()
- .setWindow(encodedCurrentWindow)
- .setKey(encodedCurrentKey)
- .setPtransformId(ptransformId)
- .setUserStateId(stateId);
- return builder.build();
- }
-
- /**
- * Memoizes an encoded key and window for the current element being processed saving on the
- * encoding cost of the key and window across multiple state cells for the lifetime of
- * {@link #processElement}.
- *
- * <p>This should only be called during {@link #processElement}.
- */
- private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
- if (encodedCurrentKey == null) {
- checkState(currentElement.getValue() instanceof KV,
- "Accessing state in unkeyed context. Current element is not a KV: %s.",
- currentElement);
- checkState(
- // TODO: Stop passing windowed value coders within PCollections.
- inputCoder instanceof KvCoder
- || (inputCoder instanceof WindowedValueCoder
- && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder)),
- "Accessing state in unkeyed context. Keyed coder expected but found %s.",
- inputCoder);
-
- ByteString.Output encodedKeyOut = ByteString.newOutput();
-
- Coder<K> keyCoder = inputCoder instanceof WindowedValueCoder
- ? ((KvCoder<K, ?>) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()
- : ((KvCoder<K, ?>) inputCoder).getKeyCoder();
- try {
- keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- encodedCurrentKey = encodedKeyOut.toByteString();
- }
-
- if (encodedCurrentWindow == null) {
- ByteString.Output encodedWindowOut = ByteString.newOutput();
- try {
- windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut);
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- encodedCurrentWindow = encodedWindowOut.toByteString();
- }
- }
-
- /**
- * A specification for side inputs containing a value {@link Coder},
- * the window {@link Coder}, {@link ViewFn}, and the {@link WindowMappingFn}.
- * @param <W>
- */
- @AutoValue
- abstract static class SideInputSpec<W extends BoundedWindow> {
- static <W extends BoundedWindow> SideInputSpec create(
- Coder<?> coder,
- Coder<W> windowCoder,
- ViewFn<?, ?> viewFn,
- WindowMappingFn<W> windowMappingFn) {
- return new AutoValue_FnApiDoFnRunner_SideInputSpec<>(
- coder, windowCoder, viewFn, windowMappingFn);
- }
-
- abstract Coder<?> getCoder();
-
- abstract Coder<W> getWindowCoder();
-
- abstract ViewFn<?, ?> getViewFn();
-
- abstract WindowMappingFn<W> getWindowMappingFn();
- }
-
- private <K, V> Object bindSideInputView(TupleTag<?> view) {
- SideInputSpec sideInputSpec = sideInputSpecMap.get(view);
- checkArgument(sideInputSpec != null,
- "Attempting to access unknown side input %s.",
- view);
- KvCoder<K, V> kvCoder = (KvCoder) sideInputSpec.getCoder();
-
- ByteString.Output encodedWindowOut = ByteString.newOutput();
- try {
- sideInputSpec.getWindowCoder().encode(
- sideInputSpec.getWindowMappingFn().getSideInputWindow(currentWindow), encodedWindowOut);
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- ByteString encodedWindow = encodedWindowOut.toByteString();
-
- StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
- cacheKeyBuilder.getMultimapSideInputBuilder()
- .setPtransformId(ptransformId)
- .setSideInputId(view.getId())
- .setWindow(encodedWindow);
- return stateKeyObjectCache.computeIfAbsent(
- cacheKeyBuilder.build(),
- key -> sideInputSpec.getViewFn().apply(createMultimapSideInput(
- view.getId(), encodedWindow, kvCoder.getKeyCoder(), kvCoder.getValueCoder())));
- }
-
- private <K, V> MultimapSideInput<K, V> createMultimapSideInput(
- String sideInputId,
- ByteString encodedWindow,
- Coder<K> keyCoder,
- Coder<V> valueCoder) {
-
- return new MultimapSideInput<>(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- sideInputId,
- encodedWindow,
- keyCoder,
- valueCoder);
- }
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
index 7819726..e2c61de 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java
@@ -27,6 +27,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.MultiplexingFnDataReceiver;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
@@ -109,7 +110,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException {
Collection<FnDataReceiver<WindowedValue<OutputT>>> consumers =
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
index c130d4d..efd4a38 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java
@@ -22,6 +22,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -37,7 +38,6 @@
* A factory able to instantiate an appropriate handler for a given PTransform.
*/
public interface PTransformRunnerFactory<T> {
-
/**
* Creates and returns a handler for a given PTransform. Note that the handler must support
* processing multiple bundles. The handler will be discarded if an error is thrown during element
@@ -60,6 +60,7 @@
* registered within this multimap.
* @param addStartFunction A consumer to register a start bundle handler with.
* @param addFinishFunction A consumer to register a finish bundle handler with.
+ * @param splitListener A listener to be invoked when the PTransform splits itself.
*/
T createRunnerForPTransform(
PipelineOptions pipelineOptions,
@@ -73,7 +74,8 @@
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException;
/**
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
new file mode 100644
index 0000000..c8cac40
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
+import org.apache.beam.fn.harness.state.FnApiStateAccessor;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+import org.apache.beam.runners.core.OutputAndTimeBoundedSplittableProcessElementInvoker;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.SplittableProcessElementInvoker;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/** Runs the {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} transform. */
+public class SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>
+ implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<KV<InputT, RestrictionT>> {
+ /** A registrar which provides a factory to handle Java {@link DoFn}s. */
+ @AutoService(PTransformRunnerFactory.Registrar.class)
+ public static class Registrar implements PTransformRunnerFactory.Registrar {
+ @Override
+ public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+ return ImmutableMap.of(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN, new Factory());
+ }
+ }
+
+ static class Factory<InputT, RestrictionT, OutputT>
+ extends DoFnPTransformRunnerFactory<
+ KV<InputT, RestrictionT>,
+ InputT,
+ OutputT,
+ SplittableProcessElementsRunner<InputT, RestrictionT, OutputT>> {
+
+ @Override
+ SplittableProcessElementsRunner<InputT, RestrictionT, OutputT> createRunner(
+ Context<InputT, OutputT> context) {
+ Coder<WindowedValue<KV<InputT, RestrictionT>>> windowedCoder =
+ FullWindowedValueCoder.of(
+ (Coder<KV<InputT, RestrictionT>>) context.inputCoder, context.windowCoder);
+
+ return new SplittableProcessElementsRunner<>(
+ context,
+ windowedCoder,
+ (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
+ (Collection) context.tagToConsumer.get(context.mainOutputTag),
+ Iterables.getOnlyElement(context.pTransform.getInputsMap().keySet()));
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ private final Context<InputT, OutputT> context;
+ private final String mainInputId;
+ private final Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder;
+ private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
+ private final DoFnInvoker<InputT, OutputT> doFnInvoker;
+ private final ScheduledExecutorService executor;
+
+ private FnApiStateAccessor stateAccessor;
+
+ private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
+ private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;
+
+ SplittableProcessElementsRunner(
+ Context<InputT, OutputT> context,
+ Coder<WindowedValue<KV<InputT, RestrictionT>>> inputCoder,
+ Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
+ String mainInputId) {
+ this.context = context;
+ this.mainInputId = mainInputId;
+ this.inputCoder = inputCoder;
+ this.mainOutputConsumers = mainOutputConsumers;
+ this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
+ this.doFnInvoker.invokeSetup();
+ this.executor = Executors.newSingleThreadScheduledExecutor();
+
+ this.startBundleContext =
+ context.doFn.new StartBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
+ };
+ this.finishBundleContext =
+ context.doFn.new FinishBundleContext() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return context.pipelineOptions;
+ }
+
+ @Override
+ public void output(OutputT output, Instant timestamp, BoundedWindow window) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> void output(
+ TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ @Override
+ public void startBundle() {
+ this.stateAccessor =
+ new FnApiStateAccessor(
+ context.pipelineOptions,
+ context.ptransformId,
+ context.processBundleInstructionId,
+ context.tagToSideInputSpecMap,
+ context.beamFnStateClient,
+ context.keyCoder,
+ (Coder<BoundedWindow>) context.windowCoder);
+ doFnInvoker.invokeStartBundle(startBundleContext);
+ }
+
+ @Override
+ public void processElement(WindowedValue<KV<InputT, RestrictionT>> elem) {
+ processElementTyped(elem);
+ }
+
+ private <PositionT, TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
+ void processElementTyped(WindowedValue<KV<InputT, RestrictionT>> elem) {
+ checkArgument(
+ elem.getWindows().size() == 1,
+ "SPLITTABLE_PROCESS_ELEMENTS expects its input to be in 1 window, but got %s windows",
+ elem.getWindows().size());
+ this.stateAccessor.setCurrentWindow(elem.getWindows().iterator().next());
+ WindowedValue<InputT> element = elem.withValue(elem.getValue().getKey());
+ TrackerT tracker = doFnInvoker.invokeNewTracker(elem.getValue().getValue());
+ OutputAndTimeBoundedSplittableProcessElementInvoker<
+ InputT, OutputT, RestrictionT, PositionT, TrackerT>
+ processElementInvoker =
+ new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
+ context.doFn,
+ context.pipelineOptions,
+ new OutputWindowedValue<OutputT>() {
+ @Override
+ public void outputWindowedValue(
+ OutputT output,
+ Instant timestamp,
+ Collection<? extends BoundedWindow> windows,
+ PaneInfo pane) {
+ outputTo(
+ mainOutputConsumers, WindowedValue.of(output, timestamp, windows, pane));
+ }
+
+ @Override
+ public <AdditionalOutputT> void outputWindowedValue(
+ TupleTag<AdditionalOutputT> tag,
+ AdditionalOutputT output,
+ Instant timestamp,
+ Collection<? extends BoundedWindow> windows,
+ PaneInfo pane) {
+ Collection<FnDataReceiver<WindowedValue<AdditionalOutputT>>> consumers =
+ (Collection) context.tagToConsumer.get(tag);
+ if (consumers == null) {
+ throw new IllegalArgumentException(
+ String.format("Unknown output tag %s", tag));
+ }
+ outputTo(consumers, WindowedValue.of(output, timestamp, windows, pane));
+ }
+ },
+ stateAccessor,
+ executor,
+ 10000,
+ Duration.standardSeconds(10));
+
+ SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, TrackerT>.Result result =
+ processElementInvoker.invokeProcessElement(doFnInvoker, element, tracker);
+ if (result.getContinuation().shouldResume()) {
+ WindowedValue<KV<InputT, RestrictionT>> primary =
+ element.withValue(KV.of(element.getValue(), tracker.currentRestriction()));
+ WindowedValue<KV<InputT, RestrictionT>> residual =
+ element.withValue(KV.of(element.getValue(), result.getResidualRestriction()));
+ ByteString.Output primaryBytes = ByteString.newOutput();
+ ByteString.Output residualBytes = ByteString.newOutput();
+ try {
+ inputCoder.encode(primary, primaryBytes);
+ inputCoder.encode(residual, residualBytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ Application primaryApplication =
+ Application.newBuilder()
+ .setPtransformId(context.ptransformId)
+ .setInputId(mainInputId)
+ .setElement(primaryBytes.toByteString())
+ .build();
+ Application residualApplication =
+ Application.newBuilder()
+ .setPtransformId(context.ptransformId)
+ .setInputId(mainInputId)
+ .setElement(residualBytes.toByteString())
+ .build();
+ context.splitListener.split(
+ ImmutableList.of(primaryApplication),
+ ImmutableList.of(
+ DelayedApplication.newBuilder()
+ .setApplication(residualApplication)
+ .setDelaySec(0.001 * result.getContinuation().resumeDelay().getMillis())
+ .build()));
+ }
+ }
+
+ @Override
+ public void finishBundle() {
+ doFnInvoker.invokeFinishBundle(finishBundleContext);
+ }
+
+ /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
+ private <T> void outputTo(
+ Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
+ try {
+ for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
+ consumer.accept(output);
+ }
+ } catch (Throwable t) {
+ throw UserCodeException.wrap(t);
+ }
+ }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
new file mode 100644
index 0000000..db73447
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.control;
+
+import java.util.List;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
+
+/**
+ * Listens to splits happening to a single bundle. See <a
+ * href="https://s.apache.org/beam-breaking-fusion">Breaking the Fusion Barrier</a> for a
+ * discussion of the design.
+ */
+public interface BundleSplitListener {
+ /**
+ * Signals that the current bundle should be split into the given set of primary and residual
+ * roots.
+ *
+ * <p>Primary roots are the new decomposition of the bundle's work into transform applications
+ * that have happened or will happen as part of this bundle (modulo future splits). Residual roots
+ * are a decomposition of work that has been given away by the bundle, so the runner must delegate
+ * it for someone else to execute.
+ */
+ void split(List<Application> primaryRoots, List<DelayedApplication> residualRoots);
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 64d713e..ee907a8 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -46,8 +46,12 @@
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.Application;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleSplit.DelayedApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
@@ -144,7 +148,8 @@
SetMultimap<String, String> pCollectionIdsToConsumingPTransforms,
ListMultimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction)
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener)
throws IOException {
// Recursively ensure that all consumers of the output PCollection have been created.
@@ -166,7 +171,8 @@
pCollectionIdsToConsumingPTransforms,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction);
+ addFinishFunction,
+ splitListener);
}
}
@@ -196,15 +202,12 @@
processBundleDescriptor.getWindowingStrategiesMap(),
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction);
+ addFinishFunction,
+ splitListener);
}
public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request)
throws Exception {
- BeamFnApi.InstructionResponse.Builder response =
- BeamFnApi.InstructionResponse.newBuilder()
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
-
String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference();
BeamFnApi.ProcessBundleDescriptor bundleDescriptor =
(BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId);
@@ -223,6 +226,8 @@
}
}
+ ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder();
+
// Instantiate a State API call handler depending on whether a State Api service descriptor
// was specified.
try (HandleStateCallsForBundle beamFnStateClient =
@@ -230,6 +235,23 @@
? new BlockTillStateCallsFinish(beamFnStateGrpcClientCache.forApiServiceDescriptor(
bundleDescriptor.getStateApiServiceDescriptor()))
: new FailAllStateCallsForBundle(request.getProcessBundle())) {
+ Multimap<String, Application> allPrimaries = ArrayListMultimap.create();
+ Multimap<String, DelayedApplication> allResiduals = ArrayListMultimap.create();
+ BundleSplitListener splitListener =
+ (List<Application> primaries, List<DelayedApplication> residuals) -> {
+ // Reset primaries and accumulate residuals.
+ Multimap<String, Application> newPrimaries = ArrayListMultimap.create();
+ for (Application primary : primaries) {
+ newPrimaries.put(primary.getPtransformId(), primary);
+ }
+ allPrimaries.clear();
+ allPrimaries.putAll(newPrimaries);
+
+ for (DelayedApplication residual : residuals) {
+ allResiduals.put(residual.getApplication().getPtransformId(), residual);
+ }
+ };
+
// Create a BeamFnStateClient
for (Map.Entry<String, RunnerApi.PTransform> entry
: bundleDescriptor.getTransformsMap().entrySet()) {
@@ -251,7 +273,8 @@
pCollectionIdsToConsumingPTransforms,
pCollectionIdsToConsumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ splitListener);
}
// Already in reverse topological order so we don't need to do anything.
@@ -265,9 +288,16 @@
LOG.debug("Finishing function {}", finishFunction);
finishFunction.run();
}
+ if (!allPrimaries.isEmpty()) {
+ response.setSplit(
+ BundleSplit.newBuilder()
+ .addAllPrimaryRoots(allPrimaries.values())
+ .addAllResidualRoots(allResiduals.values())
+ .build());
+ }
}
- return response;
+ return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response);
}
/**
@@ -355,7 +385,8 @@
Map<String, WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) {
String message =
String.format(
"No factory registered for %s, known factories %s",
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
new file mode 100644
index 0000000..e7a043b
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.MapState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import org.apache.beam.sdk.state.SetState;
+import org.apache.beam.sdk.state.StateBinder;
+import org.apache.beam.sdk.state.StateContext;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.CombineFnUtil;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** Provides access to side inputs and state via a {@link BeamFnStateClient}. */
+public class FnApiStateAccessor implements SideInputReader, StateBinder {
+ private final PipelineOptions pipelineOptions;
+ private final Map<StateKey, Object> stateKeyObjectCache;
+ private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
+ private final BeamFnStateClient beamFnStateClient;
+ private final String ptransformId;
+ private final Supplier<String> processBundleInstructionId;
+ private final Collection<ThrowingRunnable> stateFinalizers;
+
+ private final Coder<?> keyCoder;
+ private final Coder<BoundedWindow> windowCoder;
+
+ private WindowedValue<?> currentElement;
+ private BoundedWindow currentWindow;
+ private ByteString encodedCurrentKey;
+ private ByteString encodedCurrentWindow;
+
+ public FnApiStateAccessor(
+ PipelineOptions pipelineOptions,
+ String ptransformId,
+ Supplier<String> processBundleInstructionId,
+ Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
+ BeamFnStateClient beamFnStateClient,
+ Coder<?> keyCoder,
+ Coder<BoundedWindow> windowCoder) {
+ this.pipelineOptions = pipelineOptions;
+ this.stateKeyObjectCache = Maps.newHashMap();
+ this.sideInputSpecMap = sideInputSpecMap;
+ this.beamFnStateClient = beamFnStateClient;
+ this.ptransformId = ptransformId;
+ this.processBundleInstructionId = processBundleInstructionId;
+ this.stateFinalizers = new ArrayList<>();
+
+ this.keyCoder = keyCoder;
+ this.windowCoder = windowCoder;
+ }
+
+ public void setCurrentElement(WindowedValue<?> currentElement) {
+ this.currentElement = currentElement;
+ this.encodedCurrentKey = null;
+ }
+
+ public void setCurrentWindow(BoundedWindow currentWindow) {
+ this.currentWindow = currentWindow;
+ this.encodedCurrentWindow = null;
+ }
+
+ @Override
+ @Nullable
+ public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+ TupleTag<?> tag = view.getTagInternal();
+
+ SideInputSpec sideInputSpec = sideInputSpecMap.get(tag);
+ checkArgument(sideInputSpec != null, "Attempting to access unknown side input %s.", view);
+ KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
+
+ ByteString.Output encodedWindowOut = ByteString.newOutput();
+ try {
+ sideInputSpec
+ .getWindowCoder()
+ .encode(sideInputSpec.getWindowMappingFn().getSideInputWindow(window), encodedWindowOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ ByteString encodedWindow = encodedWindowOut.toByteString();
+
+ StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
+ cacheKeyBuilder
+ .getMultimapSideInputBuilder()
+ .setPtransformId(ptransformId)
+ .setSideInputId(tag.getId())
+ .setWindow(encodedWindow);
+ return (T)
+ stateKeyObjectCache.computeIfAbsent(
+ cacheKeyBuilder.build(),
+ key ->
+ sideInputSpec
+ .getViewFn()
+ .apply(
+ new MultimapSideInput<>(
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ ptransformId,
+ tag.getId(),
+ encodedWindow,
+ kvCoder.getKeyCoder(),
+ kvCoder.getValueCoder())));
+ }
+
+ @Override
+ public <T> boolean contains(PCollectionView<T> view) {
+ return sideInputSpecMap.containsKey(view.getTagInternal());
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return sideInputSpecMap.isEmpty();
+ }
+
+ @Override
+ public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
+ return (ValueState<T>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey key) {
+ return new ValueState<T>() {
+ private final BagUserState<T> impl = createBagUserState(id, coder);
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+
+ @Override
+ public void write(T input) {
+ impl.clear();
+ impl.append(input);
+ }
+
+ @Override
+ public T read() {
+ Iterator<T> value = impl.get().iterator();
+ if (value.hasNext()) {
+ return value.next();
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public ValueState<T> readLater() {
+ // TODO: Support prefetching.
+ return this;
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
+ return (BagState<T>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey key) {
+ return new BagState<T>() {
+ private final BagUserState<T> impl = createBagUserState(id, elemCoder);
+
+ @Override
+ public void add(T value) {
+ impl.append(value);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return ReadableStates.immediate(!impl.get().iterator().hasNext());
+ }
+
+ @Override
+ public Iterable<T> read() {
+ return impl.get();
+ }
+
+ @Override
+ public BagState<T> readLater() {
+ // TODO: Support prefetching.
+ return this;
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
+ throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+ }
+
+ @Override
+ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+ String id,
+ StateSpec<MapState<KeyT, ValueT>> spec,
+ Coder<KeyT> mapKeyCoder,
+ Coder<ValueT> mapValueCoder) {
+ throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API.");
+ }
+
+ @Override
+ public <ElementT, AccumT, ResultT> CombiningState<ElementT, AccumT, ResultT> bindCombining(
+ String id,
+ StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+ Coder<AccumT> accumCoder,
+ CombineFn<ElementT, AccumT, ResultT> combineFn) {
+ return (CombiningState<ElementT, AccumT, ResultT>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey key) {
+ // TODO: Support squashing accumulators depending on whether we know of all
+ // remote accumulators and local accumulators or just local accumulators.
+ return new CombiningState<ElementT, AccumT, ResultT>() {
+ private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder);
+
+ @Override
+ public AccumT getAccum() {
+ Iterator<AccumT> iterator = impl.get().iterator();
+ if (iterator.hasNext()) {
+ return iterator.next();
+ }
+ return combineFn.createAccumulator();
+ }
+
+ @Override
+ public void addAccum(AccumT accum) {
+ Iterator<AccumT> iterator = impl.get().iterator();
+
+ // Only merge if there was a prior value
+ if (iterator.hasNext()) {
+ accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+ // Since there was a prior value, we need to clear.
+ impl.clear();
+ }
+
+ impl.append(accum);
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ return combineFn.mergeAccumulators(accumulators);
+ }
+
+ @Override
+ public CombiningState<ElementT, AccumT, ResultT> readLater() {
+ return this;
+ }
+
+ @Override
+ public ResultT read() {
+ Iterator<AccumT> iterator = impl.get().iterator();
+ if (iterator.hasNext()) {
+ return combineFn.extractOutput(iterator.next());
+ }
+ return combineFn.defaultValue();
+ }
+
+ @Override
+ public void add(ElementT value) {
+ AccumT newAccumulator = combineFn.addInput(getAccum(), value);
+ impl.clear();
+ impl.append(newAccumulator);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return ReadableStates.immediate(!impl.get().iterator().hasNext());
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ public <ElementT, AccumT, ResultT>
+ CombiningState<ElementT, AccumT, ResultT> bindCombiningWithContext(
+ String id,
+ StateSpec<CombiningState<ElementT, AccumT, ResultT>> spec,
+ Coder<AccumT> accumCoder,
+ CombineFnWithContext<ElementT, AccumT, ResultT> combineFn) {
+ return (CombiningState<ElementT, AccumT, ResultT>)
+ stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ key ->
+ bindCombining(
+ id,
+ spec,
+ accumCoder,
+ CombineFnUtil.bindContext(
+ combineFn,
+ new StateContext<BoundedWindow>() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return pipelineOptions;
+ }
+
+ @Override
+ public <T> T sideInput(PCollectionView<T> view) {
+ return get(view, currentWindow);
+ }
+
+ @Override
+ public BoundedWindow window() {
+ return currentWindow;
+ }
+ })));
+ }
+
+ /**
+ * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing and is
+ * waiting on resolution of BEAM-2535.
+ */
+ @Override
+ @Deprecated
+ public WatermarkHoldState bindWatermark(
+ String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
+ throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API.");
+ }
+
+ private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> valueCoder) {
+ BagUserState<T> rval =
+ new BagUserState<>(
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ ptransformId,
+ stateId,
+ encodedCurrentWindow,
+ encodedCurrentKey,
+ valueCoder);
+ stateFinalizers.add(rval::asyncClose);
+ return rval;
+ }
+
+ private StateKey createBagUserStateKey(String stateId) {
+ cacheEncodedKeyAndWindowForKeyedContext();
+ StateKey.Builder builder = StateKey.newBuilder();
+ builder
+ .getBagUserStateBuilder()
+ .setWindow(encodedCurrentWindow)
+ .setKey(encodedCurrentKey)
+ .setPtransformId(ptransformId)
+ .setUserStateId(stateId);
+ return builder.build();
+ }
+
+ /**
+ * Memoizes an encoded key and window for the current element being processed saving on the
+ * encoding cost of the key and window across multiple state cells for the lifetime of {@link
+ * DoFn.ProcessElement}
+ *
+ * <p>This should only be called during {@link DoFn.ProcessElement}.
+ */
+ private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
+ if (encodedCurrentKey == null) {
+ checkState(
+ currentElement.getValue() instanceof KV,
+ "Accessing state in unkeyed context. Current element is not a KV: %s.",
+ currentElement);
+ checkState(keyCoder != null, "Accessing state in unkeyed context, no key coder available");
+
+ ByteString.Output encodedKeyOut = ByteString.newOutput();
+ try {
+ ((Coder) keyCoder).encode(((KV<?, ?>) currentElement.getValue()).getKey(), encodedKeyOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ encodedCurrentKey = encodedKeyOut.toByteString();
+ }
+
+ if (encodedCurrentWindow == null) {
+ ByteString.Output encodedWindowOut = ByteString.newOutput();
+ try {
+ windowCoder.encode(currentWindow, encodedWindowOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ encodedCurrentWindow = encodedWindowOut.toByteString();
+ }
+ }
+
+ public void finalizeState() {
+ // Persist all dirty state cells
+ try {
+ for (ThrowingRunnable runnable : stateFinalizers) {
+ runnable.run();
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException(e);
+ } catch (Exception e) {
+ throw new IllegalStateException(e);
+ }
+ }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
new file mode 100644
index 0000000..8aec3a1
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/SideInputSpec.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness.state;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
+
+/**
+ * A specification for side inputs containing a value {@link Coder}, the window {@link Coder},
+ * {@link ViewFn}, and the {@link WindowMappingFn}.
+ *
+ * @param <W>
+ */
+@AutoValue
+public abstract class SideInputSpec<W extends BoundedWindow> {
+ public static <W extends BoundedWindow> SideInputSpec create(
+ Coder<?> coder,
+ Coder<W> windowCoder,
+ ViewFn<?, ?> viewFn,
+ WindowMappingFn<W> windowMappingFn) {
+ return new AutoValue_SideInputSpec<>(
+ coder, windowCoder, viewFn, windowMappingFn);
+ }
+
+ abstract Coder<?> getCoder();
+
+ abstract Coder<W> getWindowCoder();
+
+ abstract ViewFn<?, ?> getViewFn();
+
+ abstract WindowMappingFn<W> getWindowMappingFn();
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
index d72753b..9871c37 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java
@@ -201,7 +201,8 @@
null /* windowingStrategies */,
receivers,
null /* addStartFunction */,
- null /* addFinishFunction */);
+ null, /* addFinishFunction */
+ null /* splitListener */);
WindowedValue<Integer> value =
WindowedValue.of(
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index f138677..4b4c64f 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -152,7 +152,8 @@
COMPONENTS.getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
verifyZeroInteractions(mockBeamFnDataClient);
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
index 12540c9..a70aedd 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
@@ -141,7 +141,8 @@
COMPONENTS.getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
verifyZeroInteractions(mockBeamFnDataClient);
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
index 044de10..e438bd5 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
@@ -151,7 +151,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
// This is testing a deprecated way of running sources and should be removed
// once all source definitions are instead propagated along the input edge.
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
index 9b1bd75..3b58517 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java
@@ -87,7 +87,8 @@
Collections.emptyMap(),
consumers,
null /* addStartFunction */,
- null /* addFinishFunction */);
+ null, /* addFinishFunction */
+ null /* splitListener */);
mainOutputValues.clear();
assertThat(consumers.keySet(), containsInAnyOrder(
@@ -149,7 +150,8 @@
Collections.emptyMap(),
consumers,
null /* addStartFunction */,
- null /* addFinishFunction */);
+ null, /* addFinishFunction */
+ null /* splitListener */);
mainOutputValues.clear();
assertThat(consumers.keySet(), containsInAnyOrder("inputATarget", "mainOutputTarget"));
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 85aa564..ee04cf8 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -18,8 +18,6 @@
package org.apache.beam.fn.harness;
-import static com.google.common.base.Preconditions.checkState;
-import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -27,29 +25,23 @@
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.protobuf.ByteString;
import java.io.IOException;
+import java.io.Serializable;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
-import java.util.ServiceLoader;
-import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -60,8 +52,6 @@
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
-import org.apache.beam.sdk.transforms.CombineWithContext.Context;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
@@ -73,15 +63,13 @@
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.CoderUtils;
-import org.apache.beam.sdk.util.DoFnInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.hamcrest.collection.IsMapContaining;
+import org.apache.beam.sdk.values.TupleTagList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
@@ -90,137 +78,10 @@
/** Tests for {@link FnApiDoFnRunner}. */
@RunWith(JUnit4.class)
-public class FnApiDoFnRunnerTest {
+public class FnApiDoFnRunnerTest implements Serializable {
public static final String TEST_PTRANSFORM_ID = "pTransformId";
- private static class TestDoFn extends DoFn<String, String> {
- private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
- private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
-
- private enum State {
- NOT_SET_UP,
- OUTSIDE_BUNDLE,
- INSIDE_BUNDLE,
- }
-
- private State state = State.NOT_SET_UP;
-
- private BoundedWindow window;
-
- @Setup
- public void setUp() {
- checkState(State.NOT_SET_UP.equals(state), "Unexpected state: %s", state);
- state = State.OUTSIDE_BUNDLE;
- }
-
- // No testing for TearDown - it's currently not supported by FnHarness.
-
- @StartBundle
- public void startBundle() {
- checkState(State.OUTSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
- state = State.INSIDE_BUNDLE;
- }
-
- @ProcessElement
- public void processElement(ProcessContext context, BoundedWindow window) {
- checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
- context.output("MainOutput" + context.element());
- context.output(additionalOutput, "AdditionalOutput" + context.element());
- this.window = window;
- }
-
- @FinishBundle
- public void finishBundle(FinishBundleContext context) {
- checkState(State.INSIDE_BUNDLE.equals(state), "Unexpected state: %s", state);
- state = State.OUTSIDE_BUNDLE;
- if (window != null) {
- context.output("FinishBundle", window.maxTimestamp(), window);
- window = null;
- }
- }
- }
-
- /**
- * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
- * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
- * are directed to the correct consumers.
- */
- @Test
- public void testCreatingAndProcessingDoFn() throws Exception {
- String pTransformId = "pTransformId";
-
- DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
- new TestDoFn(),
- WindowingStrategy.globalDefault(),
- ImmutableList.of(),
- StringUtf8Coder.of(),
- TestDoFn.mainOutput);
- RunnerApi.FunctionSpec functionSpec =
- RunnerApi.FunctionSpec.newBuilder()
- .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
- .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
- .build();
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs("inputA", "inputATarget")
- .putInputs("inputB", "inputBTarget")
- .putOutputs(TestDoFn.mainOutput.getId(), "mainOutputTarget")
- .putOutputs(TestDoFn.additionalOutput.getId(), "additionalOutputTarget")
- .build();
-
- List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
- List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
- Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("mainOutputTarget",
- (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
- consumers.put("additionalOutputTarget",
- (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
- PipelineOptionsFactory.create(),
- null /* beamFnDataClient */,
- null /* beamFnStateClient */,
- pTransformId,
- pTransform,
- Suppliers.ofInstance("57L")::get,
- Collections.emptyMap(),
- Collections.emptyMap(),
- Collections.emptyMap(),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- Iterables.getOnlyElement(startFunctions).run();
- mainOutputValues.clear();
-
- assertThat(consumers.keySet(), containsInAnyOrder(
- "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
-
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
- Iterables.getOnlyElement(consumers.get("inputBTarget")).accept(valueInGlobalWindow("B"));
- assertThat(mainOutputValues, contains(
- valueInGlobalWindow("MainOutputA1"),
- valueInGlobalWindow("MainOutputA2"),
- valueInGlobalWindow("MainOutputB")));
- assertThat(additionalOutputValues, contains(
- valueInGlobalWindow("AdditionalOutputA1"),
- valueInGlobalWindow("AdditionalOutputA2"),
- valueInGlobalWindow("AdditionalOutputB")));
- mainOutputValues.clear();
- additionalOutputValues.clear();
-
- Iterables.getOnlyElement(finishFunctions).run();
- assertThat(
- mainOutputValues,
- contains(
- timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
- mainOutputValues.clear();
- }
-
private static class ConcatCombineFn extends CombineFn<String, String, String> {
@Override
public String createAccumulator() {
@@ -247,33 +108,6 @@
}
}
- private static class ConcatCombineFnWithContext
- extends CombineFnWithContext<String, String, String> {
- @Override
- public String createAccumulator(Context c) {
- return "";
- }
-
- @Override
- public String addInput(String accumulator, String input, Context c) {
- return accumulator.concat(input);
- }
-
- @Override
- public String mergeAccumulators(Iterable<String> accumulators, Context c) {
- StringBuilder builder = new StringBuilder();
- for (String value : accumulators) {
- builder.append(value);
- }
- return builder.toString();
- }
-
- @Override
- public String extractOutput(String accumulator, Context c) {
- return accumulator;
- }
- }
-
private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
@@ -287,17 +121,12 @@
@StateId("combine")
private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
- @StateId("combineWithContext")
- private final StateSpec<CombiningState<String, String, String>> combiningWithContextStateSpec =
- StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext());
@ProcessElement
public void processElement(ProcessContext context,
@StateId("value") ValueState<String> valueState,
@StateId("bag") BagState<String> bagState,
- @StateId("combine") CombiningState<String, String, String> combiningState,
- @StateId("combineWithContext")
- CombiningState<String, String, String> combiningWithContextState) {
+ @StateId("combine") CombiningState<String, String, String> combiningState) {
context.output("value:" + valueState.read());
valueState.write(context.element().getValue());
@@ -306,41 +135,34 @@
context.output("combine:" + combiningState.read());
combiningState.add(context.element().getValue());
-
- context.output("combineWithContext:" + combiningWithContextState.read());
- combiningWithContextState.add(context.element().getValue());
}
}
@Test
public void testUsingUserState() throws Exception {
- DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
- new TestStatefulDoFn(),
- WindowingStrategy.globalDefault(),
- ImmutableList.of(),
- KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
- new TupleTag<>("mainOutput"));
- RunnerApi.FunctionSpec functionSpec =
- RunnerApi.FunctionSpec.newBuilder()
- .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
- .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
- .build();
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs("input", "inputTarget")
- .putOutputs("mainOutput", "mainOutputTarget")
- .build();
+ Pipeline p = Pipeline.create();
+ PCollection<KV<String, String>> valuePCollection =
+ p.apply(Create.of(KV.of("unused", "unused")));
+ PCollection<String> outputPCollection =
+ valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(new TestStatefulDoFn()));
+
+ SdkComponents sdkComponents = SdkComponents.create();
+ RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+ String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+ String outputPCollectionId =
+ sdkComponents.registerPCollection(outputPCollection);
+ RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
+ pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of(
bagUserStateKey("value", "X"), encode("X0"),
bagUserStateKey("bag", "X"), encode("X0"),
- bagUserStateKey("combine", "X"), encode("X0"),
- bagUserStateKey("combineWithContext", "X"), encode("X0")
+ bagUserStateKey("combine", "X"), encode("X0")
));
List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("mainOutputTarget",
+ consumers.put(outputPCollectionId,
(FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
List<ThrowingRunnable> startFunctions = new ArrayList<>();
List<ThrowingRunnable> finishFunctions = new ArrayList<>();
@@ -352,22 +174,23 @@
TEST_PTRANSFORM_ID,
pTransform,
Suppliers.ofInstance("57L")::get,
- Collections.emptyMap(),
- Collections.emptyMap(),
- Collections.emptyMap(),
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
Iterables.getOnlyElement(startFunctions).run();
mainOutputValues.clear();
- assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget"));
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
// Ensure that bag user state that is initially empty or populated works.
// Ensure that the key order does not matter when we traverse over KV pairs.
FnDataReceiver<WindowedValue<?>> mainInput =
- Iterables.getOnlyElement(consumers.get("inputTarget"));
+ Iterables.getOnlyElement(consumers.get(inputPCollectionId));
mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
@@ -376,19 +199,15 @@
valueInGlobalWindow("value:X0"),
valueInGlobalWindow("bag:[X0]"),
valueInGlobalWindow("combine:X0"),
- valueInGlobalWindow("combineWithContext:X0"),
valueInGlobalWindow("value:null"),
valueInGlobalWindow("bag:[]"),
valueInGlobalWindow("combine:"),
- valueInGlobalWindow("combineWithContext:"),
valueInGlobalWindow("value:X1"),
valueInGlobalWindow("bag:[X0, X1]"),
valueInGlobalWindow("combine:X0X1"),
- valueInGlobalWindow("combineWithContext:X0X1"),
valueInGlobalWindow("value:Y1"),
valueInGlobalWindow("bag:[Y1]"),
- valueInGlobalWindow("combine:Y1"),
- valueInGlobalWindow("combineWithContext:Y1")));
+ valueInGlobalWindow("combine:Y1")));
mainOutputValues.clear();
Iterables.getOnlyElement(finishFunctions).run();
@@ -399,11 +218,9 @@
.put(bagUserStateKey("value", "X"), encode("X2"))
.put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
.put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
- .put(bagUserStateKey("combineWithContext", "X"), encode("X0X1X2"))
.put(bagUserStateKey("value", "Y"), encode("Y2"))
.put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
.put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
- .put(bagUserStateKey("combineWithContext", "Y"), encode("Y1Y2"))
.build(),
fakeClient.getData());
mainOutputValues.clear();
@@ -425,13 +242,17 @@
private final PCollectionView<String> defaultSingletonSideInput;
private final PCollectionView<String> singletonSideInput;
private final PCollectionView<Iterable<String>> iterableSideInput;
+ private final TupleTag<String> additionalOutput;
+
private TestSideInputDoFn(
PCollectionView<String> defaultSingletonSideInput,
PCollectionView<String> singletonSideInput,
- PCollectionView<Iterable<String>> iterableSideInput) {
+ PCollectionView<Iterable<String>> iterableSideInput,
+ TupleTag<String> additionalOutput) {
this.defaultSingletonSideInput = defaultSingletonSideInput;
this.singletonSideInput = singletonSideInput;
this.iterableSideInput = iterableSideInput;
+ this.additionalOutput = additionalOutput;
}
@ProcessElement
@@ -441,6 +262,7 @@
for (String sideInputValue : context.sideInput(iterableSideInput)) {
context.output(context.element() + ":" + sideInputValue);
}
+ context.output(additionalOutput, context.element() + ":additional");
}
}
@@ -453,21 +275,31 @@
PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
PCollectionView<Iterable<String>> iterableSideInputView =
valuePCollection.apply(View.asIterable());
- PCollection<String> outputPCollection = valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
- new TestSideInputDoFn(
- defaultSingletonSideInputView,
- singletonSideInputView,
- iterableSideInputView))
- .withSideInputs(
- defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView));
+ TupleTag<String> mainOutput = new TupleTag<String>("main") {};
+ TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
+ PCollectionTuple outputPCollection =
+ valuePCollection.apply(
+ TEST_PTRANSFORM_ID,
+ ParDo.of(
+ new TestSideInputDoFn(
+ defaultSingletonSideInputView,
+ singletonSideInputView,
+ iterableSideInputView,
+ additionalOutput))
+ .withSideInputs(
+ defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView)
+ .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
SdkComponents sdkComponents = SdkComponents.create();
RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
- String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+ String outputPCollectionId =
+ sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
+ String additionalPCollectionId =
+ sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
- RunnerApi.PTransform pTransform = pProto.getComponents().getTransformsOrThrow(
- pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID);
ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
@@ -478,13 +310,16 @@
FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+ consumers.put(outputPCollectionId,
(FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+ consumers.put(additionalPCollectionId,
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
List<ThrowingRunnable> startFunctions = new ArrayList<>();
List<ThrowingRunnable> finishFunctions = new ArrayList<>();
- new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
PipelineOptionsFactory.create(),
null /* beamFnDataClient */,
fakeClient,
@@ -496,12 +331,15 @@
pProto.getComponents().getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
Iterables.getOnlyElement(startFunctions).run();
mainOutputValues.clear();
- assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+ assertThat(
+ consumers.keySet(),
+ containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
// Ensure that bag user state that is initially empty or populated works.
// Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
@@ -520,6 +358,9 @@
valueInGlobalWindow("Y:iterableValue1"),
valueInGlobalWindow("Y:iterableValue2"),
valueInGlobalWindow("Y:iterableValue3")));
+ assertThat(
+ additionalOutputValues,
+ contains(valueInGlobalWindow("X:additional"), valueInGlobalWindow("Y:additional")));
mainOutputValues.clear();
Iterables.getOnlyElement(finishFunctions).run();
@@ -589,7 +430,7 @@
List<ThrowingRunnable> startFunctions = new ArrayList<>();
List<ThrowingRunnable> finishFunctions = new ArrayList<>();
- new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
PipelineOptionsFactory.create(),
null /* beamFnDataClient */,
fakeClient,
@@ -601,7 +442,8 @@
pProto.getComponents().getWindowingStrategiesMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
Iterables.getOnlyElement(startFunctions).run();
mainOutputValues.clear();
@@ -658,17 +500,4 @@
}
return out.toByteString();
}
-
- @Test
- public void testRegistration() {
- for (Registrar registrar :
- ServiceLoader.load(Registrar.class)) {
- if (registrar instanceof FnApiDoFnRunner.Registrar) {
- assertThat(registrar.getPTransformRunnerFactories(),
- IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN));
- return;
- }
- }
- fail("Expected registrar not found.");
- }
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
index 906bdeb..93ca808 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java
@@ -82,7 +82,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
assertThat(startFunctions, empty());
assertThat(finishFunctions, empty());
@@ -116,7 +117,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
assertThat(startFunctions, empty());
assertThat(finishFunctions, empty());
@@ -150,7 +152,8 @@
Collections.emptyMap(),
consumers,
startFunctions::add,
- finishFunctions::add);
+ finishFunctions::add,
+ null /* splitListener */);
assertThat(startFunctions, empty());
assertThat(finishFunctions, empty());
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 48d56fd..eafef99 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -50,6 +50,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.function.ThrowingConsumer;
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
@@ -115,7 +116,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
assertThat(processBundleInstructionId.get(), equalTo("999L"));
transformsProcessed.add(pTransform);
@@ -176,7 +178,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
throw new IllegalStateException("TestException");
@@ -217,7 +220,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
addStartFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -261,7 +265,8 @@
windowingStrategies,
pCollectionIdsToConsumers,
addStartFunction,
- addFinishFunction) -> {
+ addFinishFunction,
+ splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
addFinishFunction.accept(ProcessBundleHandlerTest::throwException);
@@ -335,10 +340,11 @@
Supplier<String> processBundleInstructionId,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Map<String, WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
return null;
}
@@ -386,10 +392,11 @@
Supplier<String> processBundleInstructionId,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
- Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
+ Map<String, WindowingStrategy> windowingStrategies,
Multimap<String, FnDataReceiver<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ Consumer<ThrowingRunnable> addFinishFunction,
+ BundleSplitListener splitListener) throws IOException {
addStartFunction.accept(() -> doStateCalls(beamFnStateClient));
return null;
}