blob: 549bd30a11aa643c91fa550f80e40c393a49d89d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.stateful;
import com.google.common.base.Stopwatch;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import java.io.Closeable;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.io.EmptyCheckpointMark;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.io.SparkUnboundedSource.Metadata;
import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.metrics.MetricsContainer;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.spark.streaming.State;
import org.apache.spark.streaming.StateSpec;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Tuple2;
import scala.runtime.AbstractFunction3;
/**
* A class containing {@link org.apache.spark.streaming.StateSpec} mappingFunctions.
*/
public class StateSpecFunctions {
private static final Logger LOG = LoggerFactory.getLogger(StateSpecFunctions.class);
/**
* A helper class that is essentially a {@link Serializable} {@link AbstractFunction3}.
*/
private abstract static class SerializableFunction3<T1, T2, T3, T4>
extends AbstractFunction3<T1, T2, T3, T4> implements Serializable {
}
/**
* A {@link org.apache.spark.streaming.StateSpec} function to support reading from
* an {@link UnboundedSource}.
*
* <p>This StateSpec function expects the following:
* <ul>
* <li>Key: The (partitioned) Source to read from.</li>
* <li>Value: An optional {@link UnboundedSource.CheckpointMark} to start from.</li>
* <li>State: A byte representation of the (previously) persisted CheckpointMark.</li>
* </ul>
* And returns an iterator over all read values (for the micro-batch).
*
* <p>This stateful operation could be described as a flatMap over a single-element stream, which
* outputs all the elements read from the {@link UnboundedSource} for this micro-batch.
* Since micro-batches are bounded, the provided UnboundedSource is wrapped by a
* {@link MicrobatchSource} that applies bounds in the form of duration and max records
* (per micro-batch).
*
*
* <p>In order to avoid using Spark Guava's classes which pollute the
* classpath, we use the {@link StateSpec#function(scala.Function3)} signature which employs
* scala's native {@link scala.Option}, instead of the
* {@link StateSpec#function(org.apache.spark.api.java.function.Function3)} signature,
* which employs Guava's {@link com.google.common.base.Optional}.
*
* <p>See also <a href="https://issues.apache.org/jira/browse/SPARK-4819">SPARK-4819</a>.</p>
*
* @param runtimeContext A serializable {@link SparkRuntimeContext}.
* @param <T> The type of the input stream elements.
* @param <CheckpointMarkT> The type of the {@link UnboundedSource.CheckpointMark}.
* @return The appropriate {@link org.apache.spark.streaming.StateSpec} function.
*/
public static <T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
scala.Function3<Source<T>, scala.Option<CheckpointMarkT>, State<Tuple2<byte[], Instant>>,
Tuple2<Iterable<byte[]>, Metadata>> mapSourceFunction(
final SparkRuntimeContext runtimeContext, final String stepName) {
return new SerializableFunction3<Source<T>, Option<CheckpointMarkT>,
State<Tuple2<byte[], Instant>>, Tuple2<Iterable<byte[]>, Metadata>>() {
@Override
public Tuple2<Iterable<byte[]>, Metadata> apply(
Source<T> source,
scala.Option<CheckpointMarkT> startCheckpointMark,
State<Tuple2<byte[], Instant>> state) {
MetricsContainerStepMap metricsContainers = new MetricsContainerStepMap();
MetricsContainer metricsContainer = metricsContainers.getContainer(stepName);
// Add metrics container to the scope of org.apache.beam.sdk.io.Source.Reader methods
// since they may report metrics.
try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
// source as MicrobatchSource
MicrobatchSource<T, CheckpointMarkT> microbatchSource =
(MicrobatchSource<T, CheckpointMarkT>) source;
// Initial high/low watermarks.
Instant lowWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
final Instant highWatermark;
// if state exists, use it, otherwise it's first time so use the startCheckpointMark.
// startCheckpointMark may be EmptyCheckpointMark (the Spark Java API tries to apply
// Optional(null)), which is handled by the UnboundedSource implementation.
Coder<CheckpointMarkT> checkpointCoder = microbatchSource.getCheckpointMarkCoder();
CheckpointMarkT checkpointMark;
if (state.exists()) {
// previous (output) watermark is now the low watermark.
lowWatermark = state.get()._2();
checkpointMark = CoderHelpers.fromByteArray(state.get()._1(), checkpointCoder);
LOG.info("Continue reading from an existing CheckpointMark.");
} else if (startCheckpointMark.isDefined()
&& !startCheckpointMark.get().equals(EmptyCheckpointMark.get())) {
checkpointMark = startCheckpointMark.get();
LOG.info("Start reading from a provided CheckpointMark.");
} else {
checkpointMark = null;
LOG.info("No CheckpointMark provided, start reading from default.");
}
// create reader.
final MicrobatchSource.Reader/*<T>*/ microbatchReader;
final Stopwatch stopwatch = Stopwatch.createStarted();
long readDurationMillis = 0;
try {
microbatchReader =
(MicrobatchSource.Reader)
microbatchSource.getOrCreateReader(runtimeContext.getPipelineOptions(),
checkpointMark);
} catch (IOException e) {
throw new RuntimeException(e);
}
// read microbatch as a serialized collection.
final List<byte[]> readValues = new ArrayList<>();
WindowedValue.FullWindowedValueCoder<T> coder =
WindowedValue.FullWindowedValueCoder.of(
source.getDefaultOutputCoder(),
GlobalWindow.Coder.INSTANCE);
try {
// measure how long a read takes per-partition.
boolean finished = !microbatchReader.start();
while (!finished) {
final WindowedValue<T> wv =
WindowedValue.of((T) microbatchReader.getCurrent(),
microbatchReader.getCurrentTimestamp(),
GlobalWindow.INSTANCE,
PaneInfo.NO_FIRING);
readValues.add(CoderHelpers.toByteArray(wv, coder));
finished = !microbatchReader.advance();
}
// end-of-read watermark is the high watermark, but don't allow decrease.
final Instant sourceWatermark = microbatchReader.getWatermark();
highWatermark = sourceWatermark.isAfter(lowWatermark) ? sourceWatermark : lowWatermark;
readDurationMillis = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);
LOG.info(
"Source id {} spent {} millis on reading.",
microbatchSource.getId(),
readDurationMillis);
// if the Source does not supply a CheckpointMark skip updating the state.
@SuppressWarnings("unchecked")
final CheckpointMarkT finishedReadCheckpointMark =
(CheckpointMarkT) microbatchReader.getCheckpointMark();
byte[] codedCheckpoint = new byte[0];
if (finishedReadCheckpointMark != null) {
codedCheckpoint = CoderHelpers.toByteArray(finishedReadCheckpointMark, checkpointCoder);
} else {
LOG.info("Skipping checkpoint marking because the reader failed to supply one.");
}
// persist the end-of-read (high) watermark for following read, where it will become
// the next low watermark.
state.update(new Tuple2<>(codedCheckpoint, highWatermark));
} catch (IOException e) {
throw new RuntimeException("Failed to read from reader.", e);
}
final ArrayList<byte[]> payload =
Lists.newArrayList(Iterators.unmodifiableIterator(readValues.iterator()));
return new Tuple2<>(
(Iterable<byte[]>) payload,
new Metadata(
readValues.size(),
lowWatermark,
highWatermark,
readDurationMillis,
metricsContainers));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
};
}
}