blob: 9773301f7255fc20d701f49d5ed6b4b839650e06 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.beam.runners.core.construction.PTransformReplacements;
import org.apache.beam.runners.core.construction.ReplacementOutputs;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
@SuppressWarnings({
"rawtypes" // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
})
public class GroupIntoBatchesOverride {
static class BatchGroupIntoBatchesOverrideFactory<K, V>
implements PTransformOverrideFactory<
PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupIntoBatches<K, V>> {
@Override
public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>
getReplacementTransform(
AppliedPTransform<
PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupIntoBatches<K, V>>
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
new BatchGroupIntoBatches<>(transform.getTransform().getBatchingParams().getBatchSize()));
}
@Override
public Map<PCollection<?>, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PCollection<?>> outputs, PCollection<KV<K, Iterable<V>>> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}
/** Specialized implementation of {@link GroupIntoBatches} for bounded Dataflow pipelines. */
static class BatchGroupIntoBatches<K, V>
extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> {
private final long batchSize;
private BatchGroupIntoBatches(long batchSize) {
this.batchSize = batchSize;
}
@Override
public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
return input
.apply("GroupAll", GroupByKey.create())
.apply(
"SplitIntoBatches",
ParDo.of(
new DoFn<KV<K, Iterable<V>>, KV<K, Iterable<V>>>() {
@ProcessElement
public void process(ProcessContext c) {
// Iterators.partition lazily creates the partitions as they are accessed
// allowing it to partition very large iterators.
Iterator<List<V>> iterator =
Iterators.partition(c.element().getValue().iterator(), (int) batchSize);
// Note that GroupIntoBatches only outputs when the batch is non-empty.
while (iterator.hasNext()) {
c.output(KV.of(c.element().getKey(), iterator.next()));
}
}
}));
}
}
static class BatchGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>
implements PTransformOverrideFactory<
PCollection<KV<K, V>>,
PCollection<KV<ShardedKey<K>, Iterable<V>>>,
GroupIntoBatches<K, V>.WithShardedKey> {
@Override
public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>>
getReplacementTransform(
AppliedPTransform<
PCollection<KV<K, V>>,
PCollection<KV<ShardedKey<K>, Iterable<V>>>,
GroupIntoBatches<K, V>.WithShardedKey>
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
new BatchGroupIntoBatchesWithShardedKey<>(
transform.getTransform().getBatchingParams().getBatchSize()));
}
@Override
public Map<PCollection<?>, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PCollection<?>> outputs,
PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}
/**
* Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for bounded Dataflow
* pipelines.
*/
static class BatchGroupIntoBatchesWithShardedKey<K, V>
extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>> {
private final long batchSize;
private BatchGroupIntoBatchesWithShardedKey(long batchSize) {
this.batchSize = batchSize;
}
@Override
public PCollection<KV<ShardedKey<K>, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
return shardKeys(input).apply(new BatchGroupIntoBatches<>(batchSize));
}
}
static class StreamingGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>
implements PTransformOverrideFactory<
PCollection<KV<K, V>>,
PCollection<KV<ShardedKey<K>, Iterable<V>>>,
GroupIntoBatches<K, V>.WithShardedKey> {
private final DataflowRunner runner;
StreamingGroupIntoBatchesWithShardedKeyOverrideFactory(DataflowRunner runner) {
this.runner = runner;
}
@Override
public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>>
getReplacementTransform(
AppliedPTransform<
PCollection<KV<K, V>>,
PCollection<KV<ShardedKey<K>, Iterable<V>>>,
GroupIntoBatches<K, V>.WithShardedKey>
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
new StreamingGroupIntoBatchesWithShardedKey<>(
runner,
transform.getTransform(),
PTransformReplacements.getSingletonMainOutput(transform)));
}
@Override
public Map<PCollection<?>, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PCollection<?>> outputs,
PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}
/**
* Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for unbounded Dataflow
* pipelines. The override does the same thing as the original transform but additionally records
* the output in order to append required step properties during the graph translation.
*/
static class StreamingGroupIntoBatchesWithShardedKey<K, V>
extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>> {
private final transient DataflowRunner runner;
private final GroupIntoBatches<K, V>.WithShardedKey originalTransform;
private final transient PCollection<KV<ShardedKey<K>, Iterable<V>>> originalOutput;
public StreamingGroupIntoBatchesWithShardedKey(
DataflowRunner runner,
GroupIntoBatches<K, V>.WithShardedKey original,
PCollection<KV<ShardedKey<K>, Iterable<V>>> output) {
this.runner = runner;
this.originalTransform = original;
this.originalOutput = output;
}
@Override
public PCollection<KV<ShardedKey<K>, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
// Record the output PCollection of the original transform since the new output will be
// replaced by the original one when the replacement transform is wired to other nodes in the
// graph, although the old and the new outputs are effectively the same.
runner.maybeRecordPCollectionWithAutoSharding(originalOutput);
return input.apply(originalTransform);
}
}
private static final UUID workerUuid = UUID.randomUUID();
private static <K, V> PCollection<KV<ShardedKey<K>, V>> shardKeys(PCollection<KV<K, V>> input) {
KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
org.apache.beam.sdk.coders.Coder<K> keyCoder =
(org.apache.beam.sdk.coders.Coder<K>) inputCoder.getCoderArguments().get(0);
org.apache.beam.sdk.coders.Coder<V> valueCoder =
(org.apache.beam.sdk.coders.Coder<V>) inputCoder.getCoderArguments().get(1);
return input
.apply(
"Shard Keys",
MapElements.via(
new SimpleFunction<KV<K, V>, KV<ShardedKey<K>, V>>() {
@Override
public KV<ShardedKey<K>, V> apply(KV<K, V> input) {
long tid = Thread.currentThread().getId();
ByteBuffer buffer = ByteBuffer.allocate(3 * Long.BYTES);
buffer.putLong(workerUuid.getMostSignificantBits());
buffer.putLong(workerUuid.getLeastSignificantBits());
buffer.putLong(tid);
return KV.of(ShardedKey.of(input.getKey(), buffer.array()), input.getValue());
}
}))
.setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder));
}
}