blob: aed287ba6d568dfc15c5a6d208bac62cf4677660 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.structuredstreaming.translation;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.spark.api.java.function.ForeachFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.ForeachWriter;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.streaming.DataStreamWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Base class that gives a context for {@link PTransform} translation: keeping track of the
* datasets, the {@link SparkSession}, the current transform being translated.
*/
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public abstract class AbstractTranslationContext {
private static final Logger LOG = LoggerFactory.getLogger(AbstractTranslationContext.class);
/** All the datasets of the DAG. */
private final Map<PValue, Dataset<?>> datasets;
/** datasets that are not used as input to other datasets (leaves of the DAG). */
private final Set<Dataset<?>> leaves;
private final SerializablePipelineOptions serializablePipelineOptions;
@SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy
private AppliedPTransform<?, ?, ?> currentTransform;
@SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy
private final SparkSession sparkSession;
private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
public AbstractTranslationContext(SparkStructuredStreamingPipelineOptions options) {
this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
this.serializablePipelineOptions = new SerializablePipelineOptions(options);
this.datasets = new HashMap<>();
this.leaves = new HashSet<>();
this.broadcastDataSets = new HashMap<>();
}
public SparkSession getSparkSession() {
return sparkSession;
}
public SerializablePipelineOptions getSerializableOptions() {
return serializablePipelineOptions;
}
// --------------------------------------------------------------------------------------------
// Transforms methods
// --------------------------------------------------------------------------------------------
public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) {
this.currentTransform = currentTransform;
}
public AppliedPTransform<?, ?, ?> getCurrentTransform() {
return currentTransform;
}
// --------------------------------------------------------------------------------------------
// Datasets methods
// --------------------------------------------------------------------------------------------
@SuppressWarnings("unchecked")
public <T> Dataset<T> emptyDataset() {
return (Dataset<T>) sparkSession.emptyDataset(EncoderHelpers.fromBeamCoder(VoidCoder.of()));
}
@SuppressWarnings("unchecked")
public <T> Dataset<WindowedValue<T>> getDataset(PValue value) {
Dataset<?> dataset = datasets.get(value);
// assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore
leaves.remove(dataset);
return (Dataset<WindowedValue<T>>) dataset;
}
/**
* TODO: All these 3 methods (putDataset*) are temporary and they are used only for generics type
* checking. We should unify them in the future.
*/
public void putDatasetWildcard(PValue value, Dataset<WindowedValue<?>> dataset) {
if (!datasets.containsKey(value)) {
datasets.put(value, dataset);
leaves.add(dataset);
}
}
public <T> void putDataset(PValue value, Dataset<WindowedValue<T>> dataset) {
if (!datasets.containsKey(value)) {
datasets.put(value, dataset);
leaves.add(dataset);
}
}
public <ViewT, ElemT> void setSideInputDataset(
PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
if (!broadcastDataSets.containsKey(value)) {
broadcastDataSets.put(value, set);
}
}
@SuppressWarnings("unchecked")
public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
return (Dataset<T>) broadcastDataSets.get(value);
}
// --------------------------------------------------------------------------------------------
// PCollections methods
// --------------------------------------------------------------------------------------------
public PValue getInput() {
return Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform));
}
public Map<TupleTag<?>, PCollection<?>> getInputs() {
return currentTransform.getInputs();
}
public PValue getOutput() {
return Iterables.getOnlyElement(currentTransform.getOutputs().values());
}
public Map<TupleTag<?>, PCollection<?>> getOutputs() {
return currentTransform.getOutputs();
}
@SuppressWarnings("unchecked")
public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
return currentTransform.getOutputs().entrySet().stream()
.filter(e -> e.getValue() instanceof PCollection)
.collect(Collectors.toMap(Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder()));
}
// --------------------------------------------------------------------------------------------
// Pipeline methods
// --------------------------------------------------------------------------------------------
/** Starts the pipeline. */
public void startPipeline() {
SparkStructuredStreamingPipelineOptions options =
serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class);
int datasetIndex = 0;
for (Dataset<?> dataset : leaves) {
if (options.isStreaming()) {
// TODO: deal with Beam Discarding, Accumulating and Accumulating & Retracting outputmodes
// with DatastreamWriter.outputMode
DataStreamWriter<?> dataStreamWriter = dataset.writeStream();
// spark sets a default checkpoint dir if not set.
if (options.getCheckpointDir() != null) {
dataStreamWriter =
dataStreamWriter.option("checkpointLocation", options.getCheckpointDir());
}
launchStreaming(dataStreamWriter.foreach(new NoOpForeachWriter<>()));
} else {
if (options.getTestMode()) {
LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex);
dataset.explain(true);
}
// apply a dummy fn just to apply foreach action that will trigger the pipeline run in
// spark
dataset.foreach((ForeachFunction) t -> {});
}
}
}
public abstract void launchStreaming(DataStreamWriter<?> dataStreamWriter);
public static void printDatasetContent(Dataset<WindowedValue> dataset) {
// cannot use dataset.show because dataset schema is binary so it will print binary
// code.
List<WindowedValue> windowedValues = dataset.collectAsList();
for (WindowedValue windowedValue : windowedValues) {
LOG.debug("**** dataset content {} ****", windowedValue.toString());
}
}
private static class NoOpForeachWriter<T> extends ForeachWriter<T> {
@Override
public boolean open(long partitionId, long epochId) {
return false;
}
@Override
public void process(T value) {
// do nothing
}
@Override
public void close(Throwable errorOrNull) {
// do nothing
}
}
}