blob: 8e7796f8139e340c680e92f17f790d833c39a370 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;
import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.createOutputMap;
import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ReadTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.runners.fnexecution.wire.WireCoders;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
import org.apache.beam.runners.spark.io.SourceRDD;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.BiMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Sets;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;
/** Translates a bounded portable pipeline into a Spark job. */
public class SparkBatchPortablePipelineTranslator {
private final ImmutableMap<String, PTransformTranslator> urnToTransformTranslator;
interface PTransformTranslator {
/** Translates transformNode from Beam into the Spark context. */
void translate(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context);
}
public Set<String> knownUrns() {
// Do not expose Read as a known URN because we only want to support Read
// through the Java ExpansionService. We can't translate Reads for other
// languages.
return Sets.difference(
urnToTransformTranslator.keySet(),
ImmutableSet.of(PTransformTranslation.READ_TRANSFORM_URN));
}
public SparkBatchPortablePipelineTranslator() {
ImmutableMap.Builder<String, PTransformTranslator> translatorMap = ImmutableMap.builder();
translatorMap.put(
PTransformTranslation.IMPULSE_TRANSFORM_URN,
SparkBatchPortablePipelineTranslator::translateImpulse);
translatorMap.put(
PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN,
SparkBatchPortablePipelineTranslator::translateGroupByKey);
translatorMap.put(
ExecutableStage.URN, SparkBatchPortablePipelineTranslator::translateExecutableStage);
translatorMap.put(
PTransformTranslation.FLATTEN_TRANSFORM_URN,
SparkBatchPortablePipelineTranslator::translateFlatten);
translatorMap.put(
PTransformTranslation.READ_TRANSFORM_URN,
SparkBatchPortablePipelineTranslator::translateRead);
this.urnToTransformTranslator = translatorMap.build();
}
/** Translates pipeline from Beam into the Spark context. */
public void translate(final RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
QueryablePipeline p =
QueryablePipeline.forTransforms(
pipeline.getRootTransformIdsList(), pipeline.getComponents());
for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
urnToTransformTranslator
.getOrDefault(
transformNode.getTransform().getSpec().getUrn(),
SparkBatchPortablePipelineTranslator::urnNotFound)
.translate(transformNode, pipeline, context);
}
}
private static void urnNotFound(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
throw new IllegalArgumentException(
String.format(
"Transform %s has unknown URN %s",
transformNode.getId(), transformNode.getTransform().getSpec().getUrn()));
}
private static void translateImpulse(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
BoundedDataset<byte[]> output =
new BoundedDataset<>(
Collections.singletonList(new byte[0]), context.getSparkContext(), ByteArrayCoder.of());
context.pushDataset(getOutputId(transformNode), output);
}
private static <K, V> void translateGroupByKey(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputId = getInputId(transformNode);
PCollection inputPCollection = components.getPcollectionsOrThrow(inputId);
Dataset inputDataset = context.popDataset(inputId);
JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
PCollectionNode inputPCollectionNode = PipelineNode.pCollection(inputId, inputPCollection);
WindowedValueCoder<KV<K, V>> inputCoder;
try {
inputCoder =
(WindowedValueCoder)
WireCoders.instantiateRunnerWireCoder(inputPCollectionNode, components);
} catch (IOException e) {
throw new RuntimeException(e);
}
KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder<V> wvCoder =
WindowedValue.FullWindowedValueCoder.of(inputValueCoder, windowFn.windowCoder());
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKeyAndWindow;
if (windowingStrategy.getWindowFn().isNonMerging()
&& windowingStrategy.getTimestampCombiner() == TimestampCombiner.END_OF_WINDOW) {
// we can have a memory sensitive translation for non-merging windows
groupedByKeyAndWindow =
GroupNonMergingWindowsFunctions.groupByKeyAndWindow(
inputRdd, inputKeyCoder, inputValueCoder, windowingStrategy);
} else {
Partitioner partitioner = getPartitioner(context);
JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKeyOnly =
GroupCombineFunctions.groupByKeyOnly(inputRdd, inputKeyCoder, wvCoder, partitioner);
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKeyAndWindow =
groupedByKeyOnly.flatMap(
new SparkGroupAlsoByWindowViaOutputBufferFn<>(
windowingStrategy,
new TranslationUtils.InMemoryStateInternalsFactory<>(),
SystemReduceFn.buffering(inputValueCoder),
context.serializablePipelineOptions,
AggregatorsAccumulator.getInstance()));
}
context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(groupedByKeyAndWindow));
}
private static <InputT, OutputT, SideInputT> void translateExecutableStage(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload =
RunnerApi.ExecutableStagePayload.parseFrom(
transformNode.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
Dataset inputDataset = context.popDataset(inputPCollectionId);
JavaRDD<WindowedValue<InputT>> inputRdd = ((BoundedDataset<InputT>) inputDataset).getRDD();
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
broadcastVariablesBuilder = ImmutableMap.builder();
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
RunnerApi.Components components = stagePayload.getComponents();
String collectionId =
components
.getTransformsOrThrow(sideInputId.getTransformId())
.getInputsOrThrow(sideInputId.getLocalName());
Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
broadcastSideInput(collectionId, components, context);
broadcastVariablesBuilder.put(collectionId, tuple2);
}
SparkExecutableStageFunction<InputT, SideInputT> function =
new SparkExecutableStageFunction<>(
stagePayload,
context.jobInfo,
outputMap,
broadcastVariablesBuilder.build(),
MetricsAccumulator.getInstance());
JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
for (String outputId : outputs.values()) {
JavaRDD<WindowedValue<OutputT>> outputRdd =
staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId)));
context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
}
if (outputs.isEmpty()) {
// After pipeline translation, we traverse the set of unconsumed PCollections and add a
// no-op sink to each to make sure they are materialized by Spark. However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this case by adding a sink
// here.
JavaRDD<WindowedValue<OutputT>> outputRdd =
staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
context.pushDataset(
String.format("EmptyOutputSink_%d", context.nextSinkId()),
new BoundedDataset<>(outputRdd));
}
}
/**
* Collect and serialize the data and then broadcast the result. *This can be expensive.*
*
* @return Spark broadcast variable and coder to decode its contents
*/
private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput(
String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
PCollection collection = components.getPcollectionsOrThrow(collectionId);
@SuppressWarnings("unchecked")
BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId);
PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, collection);
WindowedValueCoder<T> coder;
try {
coder =
(WindowedValueCoder<T>) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
} catch (IOException e) {
throw new RuntimeException(e);
}
List<byte[]> bytes = dataset.getBytes(coder);
Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes);
return new Tuple2<>(broadcast, coder);
}
private static <T> void translateFlatten(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
Map<String, String> inputsMap = transformNode.getTransform().getInputsMap();
JavaRDD<WindowedValue<T>> unionRDD;
if (inputsMap.isEmpty()) {
unionRDD = context.getSparkContext().emptyRDD();
} else {
JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[inputsMap.size()];
int index = 0;
for (String inputId : inputsMap.values()) {
rdds[index] = ((BoundedDataset<T>) context.popDataset(inputId)).getRDD();
index++;
}
unionRDD = context.getSparkContext().union(rdds);
}
context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(unionRDD));
}
private static <T> void translateRead(
PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
String stepName = transformNode.getTransform().getUniqueName();
final JavaSparkContext jsc = context.getSparkContext();
BoundedSource boundedSource;
try {
boundedSource =
ReadTranslation.boundedSourceFromProto(
RunnerApi.ReadPayload.parseFrom(transformNode.getTransform().getSpec().getPayload()));
} catch (IOException e) {
throw new RuntimeException("Failed to extract BoundedSource from ReadPayload.", e);
}
// create an RDD from a BoundedSource.
JavaRDD<WindowedValue<T>> input =
new SourceRDD.Bounded<>(
jsc.sc(), boundedSource, context.serializablePipelineOptions, stepName)
.toJavaRDD();
context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(input));
}
@Nullable
private static Partitioner getPartitioner(SparkTranslationContext context) {
Long bundleSize =
context.serializablePipelineOptions.get().as(SparkPipelineOptions.class).getBundleSize();
return (bundleSize > 0)
? null
: new HashPartitioner(context.getSparkContext().defaultParallelism());
}
private static String getInputId(PTransformNode transformNode) {
return Iterables.getOnlyElement(transformNode.getTransform().getInputsMap().values());
}
private static String getOutputId(PTransformNode transformNode) {
return Iterables.getOnlyElement(transformNode.getTransform().getOutputsMap().values());
}
}