blob: cf71efb815d0b75ce61b6673a8089005634b977c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
/**
* The EvaluationContext allows us to define pipeline instructions and translate between {@code
* PObject<T>}s or {@code PCollection<T>}s and Ts or DStreams/RDDs of Ts.
*/
public class EvaluationContext {
private final JavaSparkContext jsc;
private JavaStreamingContext jssc;
private final Pipeline pipeline;
private final Map<PValue, Dataset> datasets = new LinkedHashMap<>();
private final Map<PValue, Dataset> pcollections = new LinkedHashMap<>();
private final Set<Dataset> leaves = new LinkedHashSet<>();
private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
private AppliedPTransform<?, ?, ?> currentTransform;
private final SparkPCollectionView pviews = new SparkPCollectionView();
private final Map<PCollection, Long> cacheCandidates = new HashMap<>();
private final PipelineOptions options;
private final SerializablePipelineOptions serializableOptions;
public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, PipelineOptions options) {
this.jsc = jsc;
this.pipeline = pipeline;
this.options = options;
this.serializableOptions = new SerializablePipelineOptions(options);
}
public EvaluationContext(
JavaSparkContext jsc, Pipeline pipeline, PipelineOptions options, JavaStreamingContext jssc) {
this(jsc, pipeline, options);
this.jssc = jssc;
}
public JavaSparkContext getSparkContext() {
return jsc;
}
public JavaStreamingContext getStreamingContext() {
return jssc;
}
public Pipeline getPipeline() {
return pipeline;
}
public PipelineOptions getOptions() {
return options;
}
public SerializablePipelineOptions getSerializableOptions() {
return serializableOptions;
}
public void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) {
this.currentTransform = transform;
}
public AppliedPTransform<?, ?, ?> getCurrentTransform() {
return currentTransform;
}
public <T extends PValue> T getInput(PTransform<T, ?> transform) {
@SuppressWarnings("unchecked")
T input =
(T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform()));
return input;
}
public <T> Map<TupleTag<?>, PValue> getInputs(PTransform<?, ?> transform) {
checkArgument(currentTransform != null, "can only be called with non-null currentTransform");
checkArgument(
currentTransform.getTransform() == transform, "can only be called with current transform");
return currentTransform.getInputs();
}
public <T extends PValue> T getOutput(PTransform<?, T> transform) {
@SuppressWarnings("unchecked")
T output = (T) Iterables.getOnlyElement(getOutputs(transform).values());
return output;
}
public Map<TupleTag<?>, PValue> getOutputs(PTransform<?, ?> transform) {
checkArgument(currentTransform != null, "can only be called with non-null currentTransform");
checkArgument(
currentTransform.getTransform() == transform, "can only be called with current transform");
return currentTransform.getOutputs();
}
public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
return currentTransform.getOutputs().entrySet().stream()
.filter(e -> e.getValue() instanceof PCollection)
.collect(Collectors.toMap(Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder()));
}
/**
* Cache PCollection if SparkPipelineOptions.isCacheDisabled is false or transform isn't
* GroupByKey transformation and PCollection is used more then once in Pipeline.
*
* <p>PCollection is not cached in GroupByKey transformation, because Spark automatically persists
* some intermediate data in shuffle operations, even without users calling persist.
*
* @param pvalue output of transform
* @param transform the transform to check
* @return if PCollection will be cached
*/
public boolean shouldCache(PTransform<?, ? extends PValue> transform, PValue pvalue) {
if (serializableOptions.get().as(SparkPipelineOptions.class).isCacheDisabled()
|| transform instanceof GroupByKey) {
return false;
}
return pvalue instanceof PCollection && cacheCandidates.getOrDefault(pvalue, 0L) > 1;
}
/**
* Add single output of transform to context map and possibly cache if it conforms {@link
* #shouldCache(PTransform, PValue)}.
*
* @param transform from which Dataset was created
* @param dataset created Dataset from transform
*/
public void putDataset(PTransform<?, ? extends PValue> transform, Dataset dataset) {
putDataset(transform, getOutput(transform), dataset);
}
/**
* Add output of transform to context map and possibly cache if it conforms {@link
* #shouldCache(PTransform, PValue)}. Used when PTransform has multiple outputs.
*
* @param pvalue one of multiple outputs of transform
* @param dataset created Dataset from transform
*/
public void putDataset(PValue pvalue, Dataset dataset) {
putDataset(null, pvalue, dataset);
}
/**
* Add output of transform to context map and possibly cache if it conforms {@link
* #shouldCache(PTransform, PValue)}.
*
* @param transform from which Dataset was created
* @param pvalue output of transform
* @param dataset created Dataset from transform
*/
private void putDataset(
@Nullable PTransform<?, ? extends PValue> transform, PValue pvalue, Dataset dataset) {
try {
dataset.setName(pvalue.getName());
} catch (IllegalStateException e) {
// name not set, ignore
}
if (shouldCache(transform, pvalue)) {
// we cache only PCollection
Coder<?> coder = ((PCollection<?>) pvalue).getCoder();
Coder<? extends BoundedWindow> wCoder =
((PCollection<?>) pvalue).getWindowingStrategy().getWindowFn().windowCoder();
dataset.cache(storageLevel(), WindowedValue.getFullCoder(coder, wCoder));
}
datasets.put(pvalue, dataset);
leaves.add(dataset);
}
public Dataset borrowDataset(PTransform<? extends PValue, ?> transform) {
return borrowDataset(getInput(transform));
}
public Dataset borrowDataset(PValue pvalue) {
Dataset dataset = datasets.get(pvalue);
leaves.remove(dataset);
return dataset;
}
/**
* Computes the outputs for all RDDs that are leaves in the DAG and do not have any actions (like
* saving to a file) registered on them (i.e. they are performed for side effects).
*/
public void computeOutputs() {
for (Dataset dataset : leaves) {
dataset.action(); // force computation.
}
}
/**
* Retrieve an object of Type T associated with the PValue passed in.
*
* @param value PValue to retrieve associated data for.
* @param <T> Type of object to return.
* @return Native object.
*/
@SuppressWarnings("TypeParameterUnusedInFormals")
public <T> T get(PValue value) {
if (pobjects.containsKey(value)) {
return (T) pobjects.get(value);
}
if (pcollections.containsKey(value)) {
JavaRDD<?> rdd = ((BoundedDataset) pcollections.get(value)).getRDD();
T res = (T) Iterables.getOnlyElement(rdd.collect());
pobjects.put(value, res);
return res;
}
throw new IllegalStateException("Cannot resolve un-known PObject: " + value);
}
/**
* Return the current views creates in the pipeline.
*
* @return SparkPCollectionView
*/
public SparkPCollectionView getPViews() {
return pviews;
}
/**
* Adds/Replaces a view to the current views creates in the pipeline.
*
* @param view - Identifier of the view
* @param value - Actual value of the view
* @param coder - Coder of the value
*/
public void putPView(
PCollectionView<?> view,
Iterable<WindowedValue<?>> value,
Coder<Iterable<WindowedValue<?>>> coder) {
pviews.putPView(view, value, coder);
}
/**
* Get the map of cache candidates hold by the evaluation context.
*
* @return The current {@link Map} of cache candidates.
*/
public Map<PCollection, Long> getCacheCandidates() {
return this.cacheCandidates;
}
<T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
@SuppressWarnings("unchecked")
BoundedDataset<T> boundedDataset = (BoundedDataset<T>) datasets.get(pcollection);
leaves.remove(boundedDataset);
return boundedDataset.getValues(pcollection);
}
public String storageLevel() {
return serializableOptions.get().as(SparkPipelineOptions.class).getStorageLevel();
}
}