blob: 49ebd6b4a4e92a45eb3e4f87afa7e406a0646f8a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.runners.spark.util.SparkSideInputReader;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.spark.api.java.function.Function;
import org.joda.time.Instant;
/**
* A {@link org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn} with a {@link
* org.apache.beam.sdk.transforms.CombineWithContext.Context} for the SparkRunner.
*/
public class SparkCombineFn<InputT, ValueT, AccumT, OutputT> implements Serializable {
/** Accumulator of WindowedValues holding values for different windows. */
public interface WindowedAccumulator<
InputT, ValueT, AccumT, ImplT extends WindowedAccumulator<InputT, ValueT, AccumT, ImplT>> {
/**
* Type of the accumulator. The type depends (mostly) on {@link WindowingStrategy}, and more
* specialized versions enable more optimizations.
*/
enum Type {
MERGING,
NON_MERGING,
EXPLODE_WINDOWS,
SINGLE_WINDOW;
boolean isMapBased() {
return this == NON_MERGING || this == MERGING;
}
}
/**
* Check if this accumulator is empty.
*
* @return {@code true} if this accumulator is empty
*/
boolean isEmpty();
/** Add value with unexploded windows into the accumulator. */
void add(WindowedValue<InputT> value, SparkCombineFn<InputT, ValueT, AccumT, ?> context)
throws Exception;
/**
* Merge other acccumulator into this one.
*
* @param other the other accumulator to merge
*/
void merge(ImplT other, SparkCombineFn<?, ?, AccumT, ?> context) throws Exception;
/** Extract output. */
Collection<WindowedValue<AccumT>> extractOutput();
/** Create concrete accumulator for given type. */
static <InputT, ValueT, AccumT> WindowedAccumulator<InputT, ValueT, AccumT, ?> create(
SparkCombineFn<InputT, ValueT, AccumT, ?> context,
Function<InputT, ValueT> toValue,
WindowingStrategy<?, ?> windowingStrategy,
Comparator<BoundedWindow> windowComparator) {
return create(toValue, context.getType(windowingStrategy), windowComparator);
}
static <InputT, ValueT, AccumT> WindowedAccumulator<InputT, ValueT, AccumT, ?> create(
Function<InputT, ValueT> toValue, Type type, Comparator<BoundedWindow> windowComparator) {
switch (type) {
case MERGING:
return MergingWindowedAccumulator.create(toValue, windowComparator);
case NON_MERGING:
return NonMergingWindowedAccumulator.create(toValue);
case SINGLE_WINDOW:
case EXPLODE_WINDOWS:
return SingleWindowWindowedAccumulator.create(toValue);
default:
throw new IllegalArgumentException("Unknown type: " + type);
}
}
/** Create concrete accumulator for given type. */
static <InputT, ValueT, AccumT> WindowedAccumulator<InputT, ValueT, AccumT, ?> create(
Function<InputT, ValueT> toValue,
Type type,
Iterable<WindowedValue<AccumT>> values,
Comparator<BoundedWindow> windowComparator) {
switch (type) {
case MERGING:
return MergingWindowedAccumulator.from(toValue, values, windowComparator);
case NON_MERGING:
return NonMergingWindowedAccumulator.from(toValue, values);
case SINGLE_WINDOW:
case EXPLODE_WINDOWS:
Iterator<WindowedValue<AccumT>> iter = values.iterator();
if (iter.hasNext()) {
return SingleWindowWindowedAccumulator.create(toValue, iter.next());
}
return SingleWindowWindowedAccumulator.create(toValue);
default:
throw new IllegalArgumentException("Unknown type: " + type);
}
}
}
/**
* Accumulator for {@link WindowFn WindowFns} which create only single window per key, and
* therefore don't need any map to hold different windows per key.
*
* @param <AccumT> type of accumulator
*/
static class SingleWindowWindowedAccumulator<InputT, ValueT, AccumT>
implements WindowedAccumulator<
InputT, ValueT, AccumT, SingleWindowWindowedAccumulator<InputT, ValueT, AccumT>> {
static <InputT, ValueT, AccumT> SingleWindowWindowedAccumulator<InputT, ValueT, AccumT> create(
Function<InputT, ValueT> toValue) {
return new SingleWindowWindowedAccumulator<>(toValue);
}
static <InputT, ValueT, AccumT> WindowedAccumulator<InputT, ValueT, AccumT, ?> create(
Function<InputT, ValueT> toValue, WindowedValue<AccumT> accumulator) {
return new SingleWindowWindowedAccumulator<>(toValue, accumulator);
}
final Function<InputT, ValueT> toValue;
AccumT windowAccumulator = null;
Instant accTimestamp = null;
BoundedWindow accWindow = null;
SingleWindowWindowedAccumulator(Function<InputT, ValueT> toValue) {
this.toValue = toValue;
}
SingleWindowWindowedAccumulator(
Function<InputT, ValueT> toValue, WindowedValue<AccumT> accumulator) {
this.toValue = toValue;
this.windowAccumulator = accumulator.getValue();
this.accTimestamp =
accumulator.getTimestamp().equals(BoundedWindow.TIMESTAMP_MIN_VALUE)
? null
: accumulator.getTimestamp();
this.accWindow = getWindow(accumulator);
}
@Override
public void add(WindowedValue<InputT> value, SparkCombineFn<InputT, ValueT, AccumT, ?> context)
throws Exception {
BoundedWindow window = getWindow(value);
SparkCombineContext ctx = context.ctxtForValue(value);
TimestampCombiner combiner = context.windowingStrategy.getTimestampCombiner();
Instant windowTimestamp =
combiner.assign(
window,
context.windowingStrategy.getWindowFn().getOutputTime(value.getTimestamp(), window));
final AccumT acc;
final Instant timestamp;
if (windowAccumulator == null) {
acc = context.combineFn.createAccumulator(ctx);
timestamp = windowTimestamp;
} else {
acc = windowAccumulator;
timestamp = accTimestamp;
}
AccumT result = context.combineFn.addInput(acc, toValue(value), ctx);
Instant timestampCombined = combiner.combine(windowTimestamp, timestamp);
windowAccumulator = result;
accTimestamp = timestampCombined;
accWindow = window;
}
@Override
public void merge(
SingleWindowWindowedAccumulator<InputT, ValueT, AccumT> other,
SparkCombineFn<?, ?, AccumT, ?> context) {
if (windowAccumulator != null && other.windowAccumulator != null) {
List<AccumT> accumulators = Arrays.asList(windowAccumulator, other.windowAccumulator);
AccumT merged =
context.combineFn.mergeAccumulators(
accumulators, context.ctxtForWindows(Arrays.asList(accWindow)));
Instant combined =
context
.windowingStrategy
.getTimestampCombiner()
.combine(accTimestamp, other.accTimestamp);
windowAccumulator = merged;
accTimestamp = combined;
} else if (windowAccumulator == null) {
windowAccumulator = other.windowAccumulator;
accTimestamp = other.accTimestamp;
accWindow = other.accWindow;
}
}
@Override
public Collection<WindowedValue<AccumT>> extractOutput() {
if (windowAccumulator != null) {
return Arrays.asList(
WindowedValue.of(
windowAccumulator, accTimestamp, accWindow, PaneInfo.ON_TIME_AND_ONLY_FIRING));
}
return Collections.emptyList();
}
private ValueT toValue(WindowedValue<InputT> input) {
try {
return toValue.call(input.getValue());
} catch (Exception ex) {
throw UserCodeException.wrap(ex);
}
}
@Override
public boolean isEmpty() {
return windowAccumulator == null;
}
}
/**
* Abstract base class for {@link WindowedAccumulator WindowedAccumulators} which hold
* accumulators in map.
*
* @param <AccumT> type of accumulator
* @param <ImplT> type of final subclass
*/
abstract static class MapBasedWindowedAccumulator<
InputT,
ValueT,
AccumT,
ImplT extends MapBasedWindowedAccumulator<InputT, ValueT, AccumT, ImplT>>
implements WindowedAccumulator<InputT, ValueT, AccumT, ImplT> {
final Function<InputT, ValueT> toValue;
final Map<BoundedWindow, WindowedValue<AccumT>> map;
MapBasedWindowedAccumulator(
Function<InputT, ValueT> toValue, Map<BoundedWindow, WindowedValue<AccumT>> map) {
this.toValue = toValue;
this.map = map;
}
@Override
public void add(WindowedValue<InputT> value, SparkCombineFn<InputT, ValueT, AccumT, ?> context)
throws Exception {
for (WindowedValue<InputT> v : value.explodeWindows()) {
SparkCombineContext ctx = context.ctxtForValue(v);
BoundedWindow window = getWindow(v);
TimestampCombiner combiner = context.windowingStrategy.getTimestampCombiner();
Instant windowTimestamp =
combiner.assign(
window,
context.windowingStrategy.getWindowFn().getOutputTime(v.getTimestamp(), window));
map.compute(
window,
(w, windowAccumulator) -> {
final AccumT acc;
final Instant timestamp;
if (windowAccumulator == null) {
acc = context.combineFn.createAccumulator(ctx);
timestamp = windowTimestamp;
} else {
acc = windowAccumulator.getValue();
timestamp = windowAccumulator.getTimestamp();
}
AccumT result = context.combineFn.addInput(acc, toValue(v), ctx);
Instant timestampCombined = combiner.combine(windowTimestamp, timestamp);
return WindowedValue.of(result, timestampCombined, window, PaneInfo.NO_FIRING);
});
}
mergeWindows(context);
}
@Override
public void merge(ImplT other, SparkCombineFn<?, ?, AccumT, ?> context) throws Exception {
other.map.forEach(
(window, acc) -> {
WindowedValue<AccumT> thisAcc = this.map.get(window);
if (thisAcc == null) {
// just copy
this.map.put(window, acc);
} else {
// merge
this.map.put(
window,
WindowedValue.of(
context.combineFn.mergeAccumulators(
Lists.newArrayList(thisAcc.getValue(), acc.getValue()),
context.ctxtForValue(acc)),
context
.windowingStrategy
.getTimestampCombiner()
.combine(acc.getTimestamp(), thisAcc.getTimestamp()),
window,
PaneInfo.NO_FIRING));
}
});
mergeWindows(context);
}
@Override
public Collection<WindowedValue<AccumT>> extractOutput() {
return map.values();
}
@Override
public boolean isEmpty() {
return map.isEmpty();
}
void mergeWindows(SparkCombineFn<?, ?, AccumT, ?> fn) throws Exception {}
private ValueT toValue(WindowedValue<InputT> value) {
try {
return toValue.call(value.getValue());
} catch (Exception ex) {
throw UserCodeException.wrap(ex);
}
}
}
/** Accumulator for non-merging windows. */
static class NonMergingWindowedAccumulator<InputT, ValueT, AccumT>
extends MapBasedWindowedAccumulator<
InputT, ValueT, AccumT, NonMergingWindowedAccumulator<InputT, ValueT, AccumT>> {
static <InputT, ValueT, AccumT> NonMergingWindowedAccumulator<InputT, ValueT, AccumT> create(
Function<InputT, ValueT> toValue) {
return new NonMergingWindowedAccumulator<>(toValue);
}
static <InputT, ValueT, AccumT> NonMergingWindowedAccumulator<InputT, ValueT, AccumT> from(
Function<InputT, ValueT> toValue, Iterable<WindowedValue<AccumT>> values) {
return new NonMergingWindowedAccumulator<>(toValue, values);
}
@SuppressWarnings("unchecked")
private NonMergingWindowedAccumulator(Function<InputT, ValueT> toValue) {
super(toValue, new HashMap<>());
}
private NonMergingWindowedAccumulator(
Function<InputT, ValueT> toValue, Iterable<WindowedValue<AccumT>> values) {
super(toValue, asMap(values, new HashMap<>()));
}
}
/** Accumulator for merging windows. */
static class MergingWindowedAccumulator<InputT, ValueT, AccumT>
extends MapBasedWindowedAccumulator<
InputT, ValueT, AccumT, MergingWindowedAccumulator<InputT, ValueT, AccumT>> {
static <InputT, ValueT, AccumT> MergingWindowedAccumulator<InputT, ValueT, AccumT> create(
Function<InputT, ValueT> toValue, Comparator<BoundedWindow> windowComparator) {
return new MergingWindowedAccumulator<>(toValue, windowComparator);
}
static <InputT, ValueT, AccumT> MergingWindowedAccumulator<InputT, ValueT, AccumT> from(
Function<InputT, ValueT> toValue,
Iterable<WindowedValue<AccumT>> values,
Comparator<BoundedWindow> windowComparator) {
return new MergingWindowedAccumulator<>(toValue, values, windowComparator);
}
@SuppressWarnings("unchecked")
private MergingWindowedAccumulator(
Function<InputT, ValueT> toValue, Comparator<BoundedWindow> windowComparator) {
super(toValue, new TreeMap<>(windowComparator));
}
private MergingWindowedAccumulator(
Function<InputT, ValueT> toValue,
Iterable<WindowedValue<AccumT>> values,
Comparator<BoundedWindow> windowComparator) {
super(toValue, asMap(values, new TreeMap<>(windowComparator)));
}
@Override
void mergeWindows(SparkCombineFn<?, ?, AccumT, ?> fn) throws Exception {
SparkCombineContext ctx = fn.ctxtForWindows(this.map.keySet());
@SuppressWarnings("unchecked")
WindowFn<Object, BoundedWindow> windowFn = (WindowFn) fn.windowingStrategy.getWindowFn();
windowFn.mergeWindows(
asMergeContext(
windowFn,
(a, b) -> fn.combineFn.mergeAccumulators(Lists.newArrayList(a, b), ctx),
(toBeMerged, mergeResult) -> {
Instant mergedInstant =
fn.windowingStrategy
.getTimestampCombiner()
.merge(
mergeResult.getKey(),
toBeMerged.stream()
.map(w -> map.get(w).getTimestamp())
.collect(Collectors.toList()));
toBeMerged.forEach(this.map::remove);
this.map.put(
mergeResult.getKey(),
WindowedValue.of(
mergeResult.getValue(),
mergedInstant,
mergeResult.getKey(),
PaneInfo.NO_FIRING));
},
map));
}
private WindowFn<Object, BoundedWindow>.MergeContext asMergeContext(
WindowFn<Object, BoundedWindow> windowFn,
BiFunction<AccumT, AccumT, AccumT> mergeFn,
BiConsumer<Collection<BoundedWindow>, KV<BoundedWindow, AccumT>> afterMerge,
Map<BoundedWindow, WindowedValue<AccumT>> map) {
return windowFn.new MergeContext() {
@Override
public Collection<BoundedWindow> windows() {
return map.keySet();
}
@Override
public void merge(Collection<BoundedWindow> toBeMerged, BoundedWindow mergeResult) {
AccumT accumulator = null;
for (BoundedWindow w : toBeMerged) {
WindowedValue<AccumT> windowAccumulator = Objects.requireNonNull(map.get(w));
if (accumulator == null) {
accumulator = windowAccumulator.getValue();
} else {
accumulator = mergeFn.apply(accumulator, windowAccumulator.getValue());
}
}
afterMerge.accept(toBeMerged, KV.of(mergeResult, accumulator));
}
};
}
@Override
public String toString() {
return "MergingWindowedAccumulator(" + this.map + ")";
}
}
static class WindowedAccumulatorCoder<InputT, ValueT, AccumT>
extends Coder<WindowedAccumulator<InputT, ValueT, AccumT, ?>> {
private final Function<InputT, ValueT> toValue;
private final IterableCoder<WindowedValue<AccumT>> wrap;
private final Coder<WindowedValue<AccumT>> accumCoder;
private final Comparator<BoundedWindow> windowComparator;
private final WindowedAccumulator.Type type;
@SuppressWarnings("unchecked")
WindowedAccumulatorCoder(
Function<InputT, ValueT> toValue,
Coder<BoundedWindow> windowCoder,
Comparator<BoundedWindow> windowComparator,
Coder<AccumT> accumCoder,
WindowedAccumulator.Type type) {
this.toValue = toValue;
this.accumCoder = WindowedValue.FullWindowedValueCoder.of(accumCoder, windowCoder);
this.windowComparator = windowComparator;
this.wrap = IterableCoder.of(this.accumCoder);
this.type = type;
}
@Override
public void encode(WindowedAccumulator<InputT, ValueT, AccumT, ?> value, OutputStream outStream)
throws CoderException, IOException {
if (type.isMapBased()) {
wrap.encode(((MapBasedWindowedAccumulator<?, ?, AccumT, ?>) value).map.values(), outStream);
} else {
SingleWindowWindowedAccumulator<?, ?, AccumT> swwa =
(SingleWindowWindowedAccumulator<?, ?, AccumT>) value;
if (swwa.isEmpty()) {
outStream.write(0);
} else {
outStream.write(1);
accumCoder.encode(
WindowedValue.of(
swwa.windowAccumulator, swwa.accTimestamp, swwa.accWindow, PaneInfo.NO_FIRING),
outStream);
}
}
}
@Override
public WindowedAccumulator<InputT, ValueT, AccumT, ?> decode(InputStream inStream)
throws CoderException, IOException {
if (type.isMapBased()) {
return WindowedAccumulator.create(toValue, type, wrap.decode(inStream), windowComparator);
}
boolean empty = inStream.read() == 0;
if (empty) {
return WindowedAccumulator.create(toValue, type, windowComparator);
}
return WindowedAccumulator.create(
toValue, type, Arrays.asList(accumCoder.decode(inStream)), windowComparator);
}
@Override
public List<? extends Coder<?>> getCoderArguments() {
return wrap.getComponents();
}
@Override
public void verifyDeterministic() throws NonDeterministicException {}
}
/** An implementation of {@link CombineWithContext.Context} for the SparkRunner. */
static class SparkCombineContext extends CombineWithContext.Context {
private final PipelineOptions pipelineOptions;
private final SideInputReader sideInputReader;
SparkCombineContext(PipelineOptions pipelineOptions, SideInputReader sideInputReader) {
this.pipelineOptions = pipelineOptions;
this.sideInputReader = sideInputReader;
}
Collection<? extends BoundedWindow> windows = null;
SparkCombineContext forInput(final Collection<? extends BoundedWindow> windows) {
this.windows = Objects.requireNonNull(windows);
return this;
}
@Override
public PipelineOptions getPipelineOptions() {
return pipelineOptions;
}
@Override
public <T> T sideInput(PCollectionView<T> view) {
// validate element window.
Preconditions.checkState(
windows.size() == 1,
"sideInput can only be called when the main " + "input element is in exactly one window");
return sideInputReader.get(view, windows.iterator().next());
}
}
@VisibleForTesting
static <K, V, AccumT, OutputT> SparkCombineFn<KV<K, V>, V, AccumT, OutputT> keyed(
CombineWithContext.CombineFnWithContext<V, AccumT, OutputT> combineFn,
SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy,
WindowedAccumulator.Type nonMergingStrategy) {
return new SparkCombineFn<>(
false, KV::getValue, combineFn, options, sideInputs, windowingStrategy, nonMergingStrategy);
}
public static <K, V, AccumT, OutputT> SparkCombineFn<KV<K, V>, V, AccumT, OutputT> keyed(
CombineWithContext.CombineFnWithContext<V, AccumT, OutputT> combineFn,
SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy) {
return new SparkCombineFn<>(
false, KV::getValue, combineFn, options, sideInputs, windowingStrategy);
}
public static <InputT, AccumT, OutputT> SparkCombineFn<InputT, InputT, AccumT, OutputT> globally(
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn,
SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy) {
return new SparkCombineFn<>(true, e -> e, combineFn, options, sideInputs, windowingStrategy);
}
private static <T> Map<BoundedWindow, WindowedValue<T>> asMap(
Iterable<WindowedValue<T>> values, Map<BoundedWindow, WindowedValue<T>> res) {
for (WindowedValue<T> v : values) {
res.put(getWindow(v), v);
}
return res;
}
/**
* Create comparator for given type descriptor. Note that the returned {@link Comparator} has to
* be {@link Serializable}.
*/
@SuppressWarnings({"unchecked", "rawtypes"})
private static Comparator<BoundedWindow> asWindowComparator(
@Nullable TypeDescriptor<?> windowType) {
final Comparator<BoundedWindow> comparator;
if (windowType != null
&& StreamSupport.stream(windowType.getInterfaces().spliterator(), false)
.anyMatch(t -> t.isSubtypeOf(TypeDescriptor.of(Comparable.class)))) {
comparator = (Comparator) Comparator.naturalOrder();
} else {
java.util.function.Function<BoundedWindow, Instant> keyExtractor =
(java.util.function.Function<BoundedWindow, Instant> & Serializable)
BoundedWindow::maxTimestamp;
comparator = Comparator.comparing(keyExtractor);
}
return comparator;
}
private final boolean globalCombine;
private final SerializablePipelineOptions options;
private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs;
final WindowingStrategy<?, BoundedWindow> windowingStrategy;
private final Function<InputT, ValueT> toValue;
private final WindowedAccumulator.Type defaultNonMergingCombineStrategy;
private final CombineWithContext.CombineFnWithContext<ValueT, AccumT, OutputT> combineFn;
private final Comparator<BoundedWindow> windowComparator;
// each Spark task should get it's own copy of this SparkCombineFn, and since Spark tasks
// are single-threaded, it is safe to reuse the context.
// the combine context is not Serializable so we'll use lazy initialization.
// ** DO NOT attempt to turn this into a Singleton as Spark may run multiple tasks in parallel
// in the same JVM (Executor). **
// ** DO NOT use combineContext directly inside this class, use ctxtForValue instead. **
private transient SparkCombineContext combineContext;
SparkCombineFn(
boolean global,
Function<InputT, ValueT> toValue,
CombineWithContext.CombineFnWithContext<ValueT, AccumT, OutputT> combineFn,
SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy) {
this(
global,
toValue,
combineFn,
options,
sideInputs,
windowingStrategy,
WindowedAccumulator.Type.EXPLODE_WINDOWS);
}
@VisibleForTesting
SparkCombineFn(
boolean global,
Function<InputT, ValueT> toValue,
CombineWithContext.CombineFnWithContext<ValueT, AccumT, OutputT> combineFn,
SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy,
WindowedAccumulator.Type defaultNonMergingCombineStrategy) {
this.globalCombine = global;
this.options = options;
this.sideInputs = sideInputs;
@SuppressWarnings("unchecked")
WindowingStrategy<?, BoundedWindow> castStrategy = (WindowingStrategy) windowingStrategy;
this.windowingStrategy = castStrategy;
this.toValue = toValue;
this.defaultNonMergingCombineStrategy = defaultNonMergingCombineStrategy;
this.combineFn = combineFn;
@SuppressWarnings("unchecked")
TypeDescriptor<BoundedWindow> untyped =
(TypeDescriptor<BoundedWindow>) windowingStrategy.getWindowFn().getWindowTypeDescriptor();
this.windowComparator = asWindowComparator(untyped);
}
/** Create empty combiner. Implements Spark's zeroValue for aggregateFn. */
WindowedAccumulator<InputT, ValueT, AccumT, ?> createCombiner() {
return WindowedAccumulator.create(this, toValue, windowingStrategy, windowComparator);
}
/**
* Implements Spark's createCombiner function in:
*
* <p>{@link org.apache.spark.rdd.PairRDDFunctions#combineByKey}.
*/
WindowedAccumulator<InputT, ValueT, AccumT, ?> createCombiner(WindowedValue<InputT> value) {
try {
WindowedAccumulator<InputT, ValueT, AccumT, ?> accumulator =
WindowedAccumulator.create(this, toValue, windowingStrategy, windowComparator);
accumulator.add(value, this);
return accumulator;
} catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
/**
* Implements Spark's mergeValue function in:
*
* <p>{@link org.apache.spark.rdd.PairRDDFunctions#combineByKey}.
*/
WindowedAccumulator<InputT, ValueT, AccumT, ?> mergeValue(
WindowedAccumulator<InputT, ValueT, AccumT, ?> accumulator, WindowedValue<InputT> value) {
try {
accumulator.add(value, this);
return accumulator;
} catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
/**
* Implements Spark's mergeCombiners function in:
*
* <p>{@link org.apache.spark.rdd.PairRDDFunctions#combineByKey}.
*/
@SuppressWarnings({"unchecked", "rawtypes"})
WindowedAccumulator<InputT, ValueT, AccumT, ?> mergeCombiners(
WindowedAccumulator ac1, WindowedAccumulator ac2) {
try {
ac1.merge(ac2, this);
return ac1;
} catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
Iterable<WindowedValue<OutputT>> extractOutput(WindowedAccumulator<?, ?, AccumT, ?> accumulator) {
return extractOutputStream(accumulator).collect(Collectors.toList());
}
/** Extracts the stream of accumulated values. */
@Internal
public Stream<WindowedValue<OutputT>> extractOutputStream(
WindowedAccumulator<?, ?, AccumT, ?> accumulator) {
return accumulator.extractOutput().stream()
.filter(Objects::nonNull)
.map(
windowAcc ->
windowAcc.withValue(
combineFn.extractOutput(windowAcc.getValue(), ctxtForValue(windowAcc))));
}
WindowedAccumulatorCoder<InputT, ValueT, AccumT> accumulatorCoder(
Coder<BoundedWindow> windowCoder,
Coder<AccumT> accumulatorCoder,
WindowingStrategy<?, ?> windowingStrategy) {
return new WindowedAccumulatorCoder<>(
toValue, windowCoder, windowComparator, accumulatorCoder, getType(windowingStrategy));
}
CombineWithContext.CombineFnWithContext<ValueT, AccumT, OutputT> getCombineFn() {
return combineFn;
}
boolean mustBringWindowToKey() {
return !getType(windowingStrategy).isMapBased();
}
private WindowedAccumulator.Type getType(WindowingStrategy<?, ?> windowingStrategy) {
if (windowingStrategy.getWindowFn().isNonMerging()) {
if (globalCombine) {
/* global combine must use map-based accumulator to incorporate multiple windows */
return WindowedAccumulator.Type.NON_MERGING;
}
if (windowingStrategy.getWindowFn().assignsToOneWindow()
&& GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
return WindowedAccumulator.Type.SINGLE_WINDOW;
}
return defaultNonMergingCombineStrategy;
}
return WindowedAccumulator.Type.MERGING;
}
private static BoundedWindow getWindow(WindowedValue<?> value) {
if (value.isSingleWindowedValue()) {
return ((WindowedValue.SingleWindowedValue) value).getWindow();
}
return Iterables.getOnlyElement(value.getWindows());
}
@SuppressWarnings("unchecked")
SparkCombineContext ctxtForValue(WindowedValue<?> input) {
return ctxtForWindows((Collection) input.getWindows());
}
SparkCombineContext ctxtForWindows(Collection<BoundedWindow> windows) {
if (combineContext == null) {
combineContext = new SparkCombineContext(options.get(), new SparkSideInputReader(sideInputs));
}
return combineContext.forInput(windows);
}
}