blob: 82cf753832d697faebe7f32a10ddff1744166a02 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.construction.PTransformReplacements;
import org.apache.beam.runners.core.construction.ReplacementOutputs;
import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly;
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.ParDo.SingleOutput;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Instant;
/**
* {@link PTransformOverrideFactory PTransformOverrideFactories} that expands to correctly implement
* stateful {@link ParDo} using window-unaware {@link GroupByKeyAndSortValuesOnly} to linearize
* processing per key.
*
* <p>For the Fn API, the {@link PTransformOverrideFactory} is only required to perform per key
* grouping and expansion.
*
* <p>This implementation relies on implementation details of the Dataflow runner, specifically
* standard fusion behavior of {@link ParDo} tranforms following a {@link GroupByKey}.
*/
public class BatchStatefulParDoOverrides {
/**
* Returns a {@link PTransformOverrideFactory} that replaces a single-output {@link ParDo} with a
* composite transform specialized for the {@link DataflowRunner}.
*/
public static <K, InputT, OutputT>
PTransformOverrideFactory<
PCollection<KV<K, InputT>>,
PCollection<OutputT>,
ParDo.SingleOutput<KV<K, InputT>, OutputT>>
singleOutputOverrideFactory(DataflowPipelineOptions options) {
return new SingleOutputOverrideFactory<>(isFnApi(options));
}
/**
* Returns a {@link PTransformOverrideFactory} that replaces a multi-output {@link ParDo} with a
* composite transform specialized for the {@link DataflowRunner}.
*/
public static <K, InputT, OutputT>
PTransformOverrideFactory<
PCollection<KV<K, InputT>>,
PCollectionTuple,
ParDo.MultiOutput<KV<K, InputT>, OutputT>>
multiOutputOverrideFactory(DataflowPipelineOptions options) {
return new MultiOutputOverrideFactory<>(isFnApi(options));
}
private static boolean isFnApi(DataflowPipelineOptions options) {
List<String> experiments = options.getExperiments();
return experiments != null && experiments.contains("beam_fn_api");
}
private static class SingleOutputOverrideFactory<K, InputT, OutputT>
implements PTransformOverrideFactory<
PCollection<KV<K, InputT>>,
PCollection<OutputT>,
ParDo.SingleOutput<KV<K, InputT>, OutputT>> {
private final boolean isFnApi;
private SingleOutputOverrideFactory(boolean isFnApi) {
this.isFnApi = isFnApi;
}
@Override
public PTransformReplacement<PCollection<KV<K, InputT>>, PCollection<OutputT>>
getReplacementTransform(
AppliedPTransform<
PCollection<KV<K, InputT>>,
PCollection<OutputT>,
SingleOutput<KV<K, InputT>, OutputT>>
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
new StatefulSingleOutputParDo<>(transform.getTransform(), isFnApi));
}
@Override
public Map<PValue, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PValue> outputs, PCollection<OutputT> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}
private static class MultiOutputOverrideFactory<K, InputT, OutputT>
implements PTransformOverrideFactory<
PCollection<KV<K, InputT>>, PCollectionTuple, ParDo.MultiOutput<KV<K, InputT>, OutputT>> {
private final boolean isFnApi;
private MultiOutputOverrideFactory(boolean isFnApi) {
this.isFnApi = isFnApi;
}
@Override
public PTransformReplacement<PCollection<KV<K, InputT>>, PCollectionTuple>
getReplacementTransform(
AppliedPTransform<
PCollection<KV<K, InputT>>,
PCollectionTuple,
MultiOutput<KV<K, InputT>, OutputT>>
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
new StatefulMultiOutputParDo<>(transform.getTransform(), isFnApi));
}
@Override
public Map<PValue, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PValue> outputs, PCollectionTuple newOutput) {
return ReplacementOutputs.tagged(outputs, newOutput);
}
}
static class StatefulSingleOutputParDo<K, InputT, OutputT>
extends PTransform<PCollection<KV<K, InputT>>, PCollection<OutputT>> {
private final ParDo.SingleOutput<KV<K, InputT>, OutputT> originalParDo;
private final boolean isFnApi;
StatefulSingleOutputParDo(
ParDo.SingleOutput<KV<K, InputT>, OutputT> originalParDo, boolean isFnApi) {
this.originalParDo = originalParDo;
this.isFnApi = isFnApi;
}
ParDo.SingleOutput<KV<K, InputT>, OutputT> getOriginalParDo() {
return originalParDo;
}
@Override
public PCollection<OutputT> expand(PCollection<KV<K, InputT>> input) {
DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn();
verifyFnIsStateful(fn);
DataflowRunner.verifyStateSupported(fn);
DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy());
if (isFnApi) {
return input
.apply(GroupByKey.create())
.apply(ParDo.of(new ExpandGbkFn<>()))
.apply(originalParDo);
}
PTransform<
PCollection<? extends KV<K, Iterable<KV<Instant, WindowedValue<KV<K, InputT>>>>>>,
PCollection<OutputT>>
statefulParDo =
ParDo.of(new BatchStatefulDoFn<>(fn)).withSideInputs(originalParDo.getSideInputs());
return input.apply(new GbkBeforeStatefulParDo<>()).apply(statefulParDo);
}
}
static class StatefulMultiOutputParDo<K, InputT, OutputT>
extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> {
private final ParDo.MultiOutput<KV<K, InputT>, OutputT> originalParDo;
private final boolean isFnApi;
StatefulMultiOutputParDo(
ParDo.MultiOutput<KV<K, InputT>, OutputT> originalParDo, boolean isFnApi) {
this.originalParDo = originalParDo;
this.isFnApi = isFnApi;
}
@Override
public PCollectionTuple expand(PCollection<KV<K, InputT>> input) {
DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn();
verifyFnIsStateful(fn);
DataflowRunner.verifyStateSupported(fn);
DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy());
if (isFnApi) {
return input
.apply(GroupByKey.create())
.apply(ParDo.of(new ExpandGbkFn<>()))
.apply(originalParDo);
}
PTransform<
PCollection<? extends KV<K, Iterable<KV<Instant, WindowedValue<KV<K, InputT>>>>>>,
PCollectionTuple>
statefulParDo =
ParDo.of(new BatchStatefulDoFn<>(fn))
.withSideInputs(originalParDo.getSideInputs())
.withOutputTags(
originalParDo.getMainOutputTag(), originalParDo.getAdditionalOutputTags());
return input.apply(new GbkBeforeStatefulParDo<>()).apply(statefulParDo);
}
public ParDo.MultiOutput<KV<K, InputT>, OutputT> getOriginalParDo() {
return originalParDo;
}
}
static class GbkBeforeStatefulParDo<K, V>
extends PTransform<
PCollection<KV<K, V>>,
PCollection<KV<K, Iterable<KV<Instant, WindowedValue<KV<K, V>>>>>>> {
@Override
public PCollection<KV<K, Iterable<KV<Instant, WindowedValue<KV<K, V>>>>>> expand(
PCollection<KV<K, V>> input) {
WindowingStrategy<?, ?> inputWindowingStrategy = input.getWindowingStrategy();
// A KvCoder is required since this goes through GBK. Further, WindowedValueCoder
// is not registered by default, so we explicitly set the relevant coders.
checkState(
input.getCoder() instanceof KvCoder,
"Input to a %s using state requires a %s, but the coder was %s",
ParDo.class.getSimpleName(),
KvCoder.class.getSimpleName(),
input.getCoder());
KvCoder<K, V> kvCoder = (KvCoder<K, V>) input.getCoder();
Coder<K> keyCoder = kvCoder.getKeyCoder();
Coder<? extends BoundedWindow> windowCoder =
inputWindowingStrategy.getWindowFn().windowCoder();
return input
// Stash the original timestamps, etc, for when it is fed to the user's DoFn
.apply("ReifyWindows", ParDo.of(new ReifyWindowedValueFn<>()))
.setCoder(
KvCoder.of(
keyCoder,
KvCoder.of(InstantCoder.of(), WindowedValue.getFullCoder(kvCoder, windowCoder))))
// Group by key and sort by timestamp, dropping windows as they are reified
.apply("PartitionKeys", new GroupByKeyAndSortValuesOnly<>())
// The GBKO sets the windowing strategy to the global default
.setWindowingStrategyInternal(inputWindowingStrategy);
}
}
/** A key-preserving {@link DoFn} that reifies a windowed value. */
static class ReifyWindowedValueFn<K, V>
extends DoFn<KV<K, V>, KV<K, KV<Instant, WindowedValue<KV<K, V>>>>> {
@ProcessElement
public void processElement(final ProcessContext c, final BoundedWindow window) {
c.output(
KV.of(
c.element().getKey(),
KV.of(
c.timestamp(), WindowedValue.of(c.element(), c.timestamp(), window, c.pane()))));
}
}
/**
* A key preserving {@link DoFn} that expands the output of a GBK {@code KV<K, Iterable<V>>} into
* individual KVs.
*/
static class ExpandGbkFn<K, V> extends DoFn<KV<K, Iterable<V>>, KV<K, V>> {
@ProcessElement
public void processElement(ProcessContext c) {
K k = c.element().getKey();
for (V v : c.element().getValue()) {
c.output(KV.of(k, v));
}
}
}
/**
* A key-preserving {@link DoFn} that explodes an iterable that has been grouped by key and
* window.
*/
public static class BatchStatefulDoFn<K, V, OutputT>
extends DoFn<KV<K, Iterable<KV<Instant, WindowedValue<KV<K, V>>>>>, OutputT> {
private final DoFn<KV<K, V>, OutputT> underlyingDoFn;
BatchStatefulDoFn(DoFn<KV<K, V>, OutputT> underlyingDoFn) {
this.underlyingDoFn = underlyingDoFn;
}
public DoFn<KV<K, V>, OutputT> getUnderlyingDoFn() {
return underlyingDoFn;
}
@Setup
public void setup() {
DoFnInvokers.invokerFor(underlyingDoFn).invokeSetup();
}
@ProcessElement
public void processElement(final ProcessContext c, final BoundedWindow window) {
throw new UnsupportedOperationException(
"BatchStatefulDoFn.ProcessElement should never be invoked");
}
@Teardown
public void teardown() {
DoFnInvokers.invokerFor(underlyingDoFn).invokeTeardown();
}
@Override
public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
return underlyingDoFn.getOutputTypeDescriptor();
}
}
private static <InputT, OutputT> void verifyFnIsStateful(DoFn<InputT, OutputT> fn) {
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
// It is still correct to use this without state or timers, but a bad idea.
// Since it is internal it should never be used wrong, so it is OK to crash.
checkState(
signature.usesState() || signature.usesTimers(),
"%s used for %s that does not use state or timers.",
BatchStatefulParDoOverrides.class.getSimpleName(),
ParDo.class.getSimpleName());
}
}