blob: 2cd9d6d5c4cba7fcbcc11f9ab35f71ade83f8303 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io;
import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Distinct;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.NameUtils;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.ValueWithRecordId;
import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.Uninterruptibles;
import org.joda.time.Duration;
import org.joda.time.Instant;
/**
* {@link PTransform} that reads a bounded amount of data from an {@link UnboundedSource}, specified
* as one or both of a maximum number of elements or a maximum period of time to read.
*/
public class BoundedReadFromUnboundedSource<T> extends PTransform<PBegin, PCollection<T>> {
private final UnboundedSource<T, ?> source;
private final long maxNumRecords;
private final @Nullable Duration maxReadTime;
private static final FluentBackoff BACKOFF_FACTORY =
FluentBackoff.DEFAULT
.withInitialBackoff(Duration.millis(10))
.withMaxBackoff(Duration.standardSeconds(10));
/**
* Returns a new {@link BoundedReadFromUnboundedSource} that reads a bounded amount of data from
* the given {@link UnboundedSource}. The bound is specified as a number of records to read.
*
* <p>This may take a long time to execute if the splits of this source are slow to read records.
*/
public BoundedReadFromUnboundedSource<T> withMaxNumRecords(long maxNumRecords) {
return new BoundedReadFromUnboundedSource<>(source, maxNumRecords, maxReadTime);
}
/**
* Returns a new {@link BoundedReadFromUnboundedSource} that reads a bounded amount of data from
* the given {@link UnboundedSource}. The bound is specified as an amount of time to read for.
* Each split of the source will read for this much time.
*/
public BoundedReadFromUnboundedSource<T> withMaxReadTime(Duration maxReadTime) {
return new BoundedReadFromUnboundedSource<>(source, maxNumRecords, maxReadTime);
}
BoundedReadFromUnboundedSource(
UnboundedSource<T, ?> source, long maxNumRecords, @Nullable Duration maxReadTime) {
this.source = source;
this.maxNumRecords = maxNumRecords;
this.maxReadTime = maxReadTime;
}
@Override
public PCollection<T> expand(PBegin input) {
Coder<Shard<T>> shardCoder = SerializableCoder.of((Class<Shard<T>>) (Class) Shard.class);
PCollection<ValueWithRecordId<T>> read =
input
.apply(
"Create",
Create.of(
new AutoValue_BoundedReadFromUnboundedSource_Shard.Builder<T>()
.setSource(source)
.setMaxNumRecords(maxNumRecords)
.setMaxReadTime(maxReadTime)
.build())
.withCoder(shardCoder))
.apply("Split", ParDo.of(new SplitFn<>()))
.setCoder(shardCoder)
.apply("Reshuffle", Reshuffle.viaRandomKey())
.apply("Read", ParDo.of(new ReadFn<>()))
.setCoder(ValueWithRecordId.ValueWithRecordIdCoder.of(source.getOutputCoder()));
if (source.requiresDeduping()) {
read =
read.apply(
Distinct.<ValueWithRecordId<T>, byte[]>withRepresentativeValueFn(
ValueWithRecordId::getId)
.withRepresentativeType(TypeDescriptor.of(byte[].class)));
}
return read.apply("StripIds", ParDo.of(new ValueWithRecordId.StripIdsDoFn<>()))
.setCoder(source.getOutputCoder());
}
@Override
public String getKindString() {
return String.format("Read(%s)", NameUtils.approximateSimpleName(source));
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
// We explicitly do not register base-class data, instead we use the delegate inner source.
builder
.add(DisplayData.item("source", source.getClass()).withLabel("Read Source"))
.addIfNotDefault(
DisplayData.item("maxRecords", maxNumRecords).withLabel("Maximum Read Records"),
Long.MAX_VALUE)
.addIfNotNull(DisplayData.item("maxReadTime", maxReadTime).withLabel("Maximum Read Time"))
.include("source", source);
}
private static class SplitFn<T> extends DoFn<Shard<T>, Shard<T>> {
/**
* Divide the given number of records into {@code numSplits} approximately equal parts that sum
* to {@code numRecords}.
*/
private static long[] splitNumRecords(long numRecords, int numSplits) {
long[] splitNumRecords = new long[numSplits];
for (int i = 0; i < numSplits; i++) {
splitNumRecords[i] = numRecords / numSplits;
}
for (int i = 0; i < numRecords % numSplits; i++) {
splitNumRecords[i] = splitNumRecords[i] + 1;
}
return splitNumRecords;
}
/** Pick a number of initial splits based on the number of records expected to be processed. */
private static int numInitialSplits(long numRecords) {
final int maxSplits = 100;
final long recordsPerSplit = 10000;
return (int) Math.min(maxSplits, numRecords / recordsPerSplit + 1);
}
@ProcessElement
public void process(
@Element Shard<T> shard, OutputReceiver<Shard<T>> out, PipelineOptions options)
throws Exception {
int numInitialSplits = numInitialSplits(shard.getMaxNumRecords());
List<? extends UnboundedSource<T, ?>> splits =
shard.getSource().split(numInitialSplits, options);
int numSplits = splits.size();
long[] numRecords = splitNumRecords(shard.getMaxNumRecords(), numSplits);
for (int i = 0; i < numSplits; i++) {
out.output(
shard
.toBuilder()
.setSource(splits.get(i))
.setMaxNumRecords(numRecords[i])
.setMaxReadTime(shard.getMaxReadTime())
.build());
}
}
}
private static class ReadFn<T> extends DoFn<Shard<T>, ValueWithRecordId<T>> {
@ProcessElement
public void process(
@Element Shard<T> shard, OutputReceiver<ValueWithRecordId<T>> out, PipelineOptions options)
throws Exception {
Instant endTime =
shard.getMaxReadTime() == null ? null : Instant.now().plus(shard.getMaxReadTime());
if (shard.getMaxNumRecords() <= 0
|| (shard.getMaxReadTime() != null && shard.getMaxReadTime().getMillis() == 0)) {
return;
}
try (UnboundedSource.UnboundedReader<T> reader =
SerializableUtils.clone(shard.getSource()).createReader(options, null)) {
for (long i = 0L; i < shard.getMaxNumRecords(); ++i) {
boolean available = (i == 0) ? reader.start() : reader.advance();
if (!available && !advanceWithBackoff(reader, endTime)) {
break;
}
out.outputWithTimestamp(
new ValueWithRecordId<T>(reader.getCurrent(), reader.getCurrentRecordId()),
reader.getCurrentTimestamp());
}
reader.getCheckpointMark().finalizeCheckpoint();
}
}
private boolean advanceWithBackoff(UnboundedReader<T> reader, Instant endTime)
throws IOException {
// Try reading from the source with exponential backoff
BackOff backoff = BACKOFF_FACTORY.backoff();
long nextSleep = backoff.nextBackOffMillis();
while (true) {
if (nextSleep == BackOff.STOP || (endTime != null && Instant.now().isAfter(endTime))) {
return false;
}
if (reader.advance()) {
return true;
}
Uninterruptibles.sleepUninterruptibly(nextSleep, TimeUnit.MILLISECONDS);
nextSleep = backoff.nextBackOffMillis();
}
}
}
/**
* Adapter that wraps the underlying {@link UnboundedSource} with the specified bounds on number
* of records and read time into a {@link BoundedSource}.
*/
@AutoValue
abstract static class Shard<T> implements Serializable {
@Nullable
abstract UnboundedSource<T, ?> getSource();
abstract long getMaxNumRecords();
@Nullable
abstract Duration getMaxReadTime();
abstract Builder<T> toBuilder();
@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setSource(UnboundedSource<T, ?> source);
abstract Builder<T> setMaxNumRecords(long maxNumRecords);
abstract Builder<T> setMaxReadTime(@Nullable Duration maxReadTime);
abstract Shard<T> build();
}
}
}