blob: 24ea8ad29f1410a9a2b7c8fe1087d22b331f8046 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;
import com.google.common.base.Optional;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
/** A set of group/combine functions to apply to Spark {@link org.apache.spark.rdd.RDD}s. */
public class GroupCombineFunctions {
/**
* An implementation of {@link
* org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly} for the Spark runner.
*/
public static <K, V> JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupByKeyOnly(
JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) {
// we use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle.
JavaPairRDD<ByteArray, byte[]> pairRDD =
rdd.map(new ReifyTimestampsAndWindowsFunction<>())
.map(WindowingHelpers.unwindowFunction())
.mapToPair(TranslationUtils.toPairFunction())
.mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder));
// use a default parallelism HashPartitioner.
Partitioner partitioner = new HashPartitioner(rdd.rdd().sparkContext().defaultParallelism());
// using mapPartitions allows to preserve the partitioner
// and avoid unnecessary shuffle downstream.
return pairRDD
.groupByKey(partitioner)
.mapPartitionsToPair(
TranslationUtils.pairFunctionToPairFlatMapFunction(
CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)),
true)
.mapPartitions(TranslationUtils.fromPairFlatMapFunction(), true)
.mapPartitions(
TranslationUtils.functionToFlatMapFunction(WindowingHelpers.windowFunction()), true);
}
/** Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} transformation. */
public static <InputT, AccumT> Optional<Iterable<WindowedValue<AccumT>>> combineGlobally(
JavaRDD<WindowedValue<InputT>> rdd,
final SparkGlobalCombineFn<InputT, AccumT, ?> sparkCombineFn,
final Coder<InputT> iCoder,
final Coder<AccumT> aCoder,
final WindowingStrategy<?, ?> windowingStrategy) {
// coders.
final WindowedValue.FullWindowedValueCoder<InputT> wviCoder =
WindowedValue.FullWindowedValueCoder.of(
iCoder, windowingStrategy.getWindowFn().windowCoder());
final WindowedValue.FullWindowedValueCoder<AccumT> wvaCoder =
WindowedValue.FullWindowedValueCoder.of(
aCoder, windowingStrategy.getWindowFn().windowCoder());
final IterableCoder<WindowedValue<AccumT>> iterAccumCoder = IterableCoder.of(wvaCoder);
// Use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle.
// for readability, we add comments with actual type next to byte[].
// to shorten line length, we use:
//---- WV: WindowedValue
//---- Iterable: Itr
//---- AccumT: A
//---- InputT: I
JavaRDD<byte[]> inputRDDBytes = rdd.map(CoderHelpers.toByteFunction(wviCoder));
if (inputRDDBytes.isEmpty()) {
return Optional.absent();
}
/*Itr<WV<A>>*/
/*Itr<WV<A>>>*/
/*Itr<WV<A>>>*/
/*Itr<WV<A>>>*/
/*Itr<WV<A>>>*/
/*Itr<WV<A>>>*/
/*Itr<WV<A>>>*/
/*A*/
/*I*/
/*A*/
/*Itr<WV<A>>*/
/*Itr<WV<A>>*/
/*WV<I>*/
byte[] accumulatedBytes =
inputRDDBytes.aggregate(
CoderHelpers.toByteArray(sparkCombineFn.zeroValue(), iterAccumCoder),
(ab, ib) -> {
Iterable<WindowedValue<AccumT>> a = CoderHelpers.fromByteArray(ab, iterAccumCoder);
WindowedValue<InputT> i = CoderHelpers.fromByteArray(ib, wviCoder);
return CoderHelpers.toByteArray(sparkCombineFn.seqOp(a, i), iterAccumCoder);
},
(a1b, a2b) -> {
Iterable<WindowedValue<AccumT>> a1 = CoderHelpers.fromByteArray(a1b, iterAccumCoder);
Iterable<WindowedValue<AccumT>> a2 = CoderHelpers.fromByteArray(a2b, iterAccumCoder);
Iterable<WindowedValue<AccumT>> merged = sparkCombineFn.combOp(a1, a2);
return CoderHelpers.toByteArray(merged, iterAccumCoder);
});
return Optional.of(CoderHelpers.fromByteArray(accumulatedBytes, iterAccumCoder));
}
/**
* Apply a composite {@link org.apache.beam.sdk.transforms.Combine.PerKey} transformation.
*
* <p>This aggregation will apply Beam's {@link org.apache.beam.sdk.transforms.Combine.CombineFn}
* via Spark's {@link JavaPairRDD#combineByKey(Function, Function2, Function2)} aggregation. For
* streaming, this will be called from within a serialized context (DStream's transform callback),
* so passed arguments need to be Serializable.
*/
public static <K, InputT, AccumT>
JavaPairRDD<K, Iterable<WindowedValue<KV<K, AccumT>>>> combinePerKey(
JavaRDD<WindowedValue<KV<K, InputT>>> rdd,
final SparkKeyedCombineFn<K, InputT, AccumT, ?> sparkCombineFn,
final Coder<K> keyCoder,
final Coder<InputT> iCoder,
final Coder<AccumT> aCoder,
final WindowingStrategy<?, ?> windowingStrategy) {
// coders.
final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder =
WindowedValue.FullWindowedValueCoder.of(
KvCoder.of(keyCoder, iCoder), windowingStrategy.getWindowFn().windowCoder());
final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder =
WindowedValue.FullWindowedValueCoder.of(
KvCoder.of(keyCoder, aCoder), windowingStrategy.getWindowFn().windowCoder());
final IterableCoder<WindowedValue<KV<K, AccumT>>> iterAccumCoder = IterableCoder.of(wkvaCoder);
// We need to duplicate K as both the key of the JavaPairRDD as well as inside the value,
// since the functions passed to combineByKey don't receive the associated key of each
// value, and we need to map back into methods in Combine.KeyedCombineFn, which each
// require the key in addition to the InputT's and AccumT's being merged/accumulated.
// Once Spark provides a way to include keys in the arguments of combine/merge functions,
// we won't need to duplicate the keys anymore.
// Key has to bw windowed in order to group by window as well.
JavaPairRDD<K, WindowedValue<KV<K, InputT>>> inRddDuplicatedKeyPair =
rdd.mapToPair(TranslationUtils.toPairByKeyInWindowedValue());
// Use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle.
// for readability, we add comments with actual type next to byte[].
// to shorten line length, we use:
//---- WV: WindowedValue
//---- Iterable: Itr
//---- AccumT: A
//---- InputT: I
JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes =
inRddDuplicatedKeyPair.mapToPair(CoderHelpers.toByteFunction(keyCoder, wkviCoder));
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*WV<KV<K, I>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*WV<KV<K, I>>*/
/*WV<KV<K, I>>*/
/*Itr<WV<KV<K, A>>>*/
/*Itr<WV<KV<K, A>>>*/
/*WV<KV<K, I>>*/
JavaPairRDD</*K*/ ByteArray, /*Itr<WV<KV<K, A>>>*/ byte[]> accumulatedBytes =
inRddDuplicatedKeyPairBytes.combineByKey(
input -> {
WindowedValue<KV<K, InputT>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder);
return CoderHelpers.toByteArray(sparkCombineFn.createCombiner(wkvi), iterAccumCoder);
},
(acc, input) -> {
Iterable<WindowedValue<KV<K, AccumT>>> wkvas =
CoderHelpers.fromByteArray(acc, iterAccumCoder);
WindowedValue<KV<K, InputT>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder);
return CoderHelpers.toByteArray(
sparkCombineFn.mergeValue(wkvi, wkvas), iterAccumCoder);
},
(acc1, acc2) -> {
Iterable<WindowedValue<KV<K, AccumT>>> wkvas1 =
CoderHelpers.fromByteArray(acc1, iterAccumCoder);
Iterable<WindowedValue<KV<K, AccumT>>> wkvas2 =
CoderHelpers.fromByteArray(acc2, iterAccumCoder);
return CoderHelpers.toByteArray(
sparkCombineFn.mergeCombiners(wkvas1, wkvas2), iterAccumCoder);
});
return accumulatedBytes.mapToPair(CoderHelpers.fromByteFunction(keyCoder, iterAccumCoder));
}
/** An implementation of {@link Reshuffle} for the Spark runner. */
public static <K, V> JavaRDD<WindowedValue<KV<K, V>>> reshuffle(
JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) {
// Use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle.
return rdd.map(new ReifyTimestampsAndWindowsFunction<>())
.map(WindowingHelpers.unwindowFunction())
.mapToPair(TranslationUtils.toPairFunction())
.mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder))
.repartition(rdd.getNumPartitions())
.mapToPair(CoderHelpers.fromByteFunction(keyCoder, wvCoder))
.map(TranslationUtils.fromPairFunction())
.map(TranslationUtils.toKVByWindowInValue());
}
}