blob: 0af0df93d19d33d0796feeff483fdc784eaa5a68 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.samza.translation;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.runners.core.KeyedWorkItemCoder;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.runners.samza.runtime.DoFnOp;
import org.apache.beam.runners.samza.runtime.GroupByKeyOp;
import org.apache.beam.runners.samza.runtime.KvToKeyedWorkItemOp;
import org.apache.beam.runners.samza.runtime.OpAdapter;
import org.apache.beam.runners.samza.runtime.OpMessage;
import org.apache.beam.runners.samza.transforms.GroupWithoutRepartition;
import org.apache.beam.runners.samza.util.SamzaCoders;
import org.apache.beam.runners.samza.util.SamzaPipelineTranslatorUtils;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.AppliedCombineFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.serializers.KVSerde;
/** Translates {@link GroupByKey} to Samza {@link GroupByKeyOp}. */
class GroupByKeyTranslator<K, InputT, OutputT>
implements TransformTranslator<
PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
@Override
public void translate(
PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform,
TransformHierarchy.Node node,
TranslationContext ctx) {
doTranslate(transform, node, ctx);
}
private static <K, InputT, OutputT> void doTranslate(
PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform,
TransformHierarchy.Node node,
TranslationContext ctx) {
final PCollection<KV<K, InputT>> input = ctx.getInput(transform);
final PCollection<KV<K, OutputT>> output = ctx.getOutput(transform);
final TupleTag<KV<K, OutputT>> outputTag = ctx.getOutputTag(transform);
@SuppressWarnings("unchecked")
final WindowingStrategy<?, BoundedWindow> windowingStrategy =
(WindowingStrategy<?, BoundedWindow>) input.getWindowingStrategy();
final MessageStream<OpMessage<KV<K, InputT>>> inputStream = ctx.getMessageStream(input);
final KvCoder<K, InputT> kvInputCoder = (KvCoder<K, InputT>) input.getCoder();
final Coder<WindowedValue<KV<K, InputT>>> elementCoder = SamzaCoders.of(input);
final SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> reduceFn =
getSystemReduceFn(transform, input.getPipeline(), kvInputCoder);
final MessageStream<OpMessage<KV<K, OutputT>>> outputStream =
doTranslateGBK(
inputStream,
needRepartition(node, ctx),
reduceFn,
windowingStrategy,
kvInputCoder,
elementCoder,
ctx.getCurrentTopologicalId(),
node.getFullName(),
outputTag,
input.isBounded());
ctx.registerMessageStream(output, outputStream);
}
@Override
public void translatePortable(
PipelineNode.PTransformNode transform,
QueryablePipeline pipeline,
PortableTranslationContext ctx) {
doTranslatePortable(transform, pipeline, ctx);
}
private static <K, InputT, OutputT> void doTranslatePortable(
PipelineNode.PTransformNode transform,
QueryablePipeline pipeline,
PortableTranslationContext ctx) {
final MessageStream<OpMessage<KV<K, InputT>>> inputStream =
ctx.getOneInputMessageStream(transform);
final boolean needRepartition = ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1;
final WindowingStrategy<?, BoundedWindow> windowingStrategy =
ctx.getPortableWindowStrategy(transform, pipeline);
final Coder<BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
final String inputId = ctx.getInputId(transform);
final WindowedValue.WindowedValueCoder<KV<K, InputT>> windowedInputCoder =
ctx.instantiateCoder(inputId, pipeline.getComponents());
final KvCoder<K, InputT> kvInputCoder = (KvCoder<K, InputT>) windowedInputCoder.getValueCoder();
final Coder<WindowedValue<KV<K, InputT>>> elementCoder =
WindowedValue.FullWindowedValueCoder.of(kvInputCoder, windowCoder);
final int topologyId = ctx.getCurrentTopologicalId();
final String nodeFullname = transform.getTransform().getUniqueName();
final TupleTag<KV<K, OutputT>> outputTag =
new TupleTag<>(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().keySet()));
@SuppressWarnings("unchecked")
final SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> reduceFn =
(SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow>)
SystemReduceFn.buffering(kvInputCoder.getValueCoder());
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);
final MessageStream<OpMessage<KV<K, OutputT>>> outputStream =
doTranslateGBK(
inputStream,
needRepartition,
reduceFn,
windowingStrategy,
kvInputCoder,
elementCoder,
topologyId,
nodeFullname,
outputTag,
isBounded);
ctx.registerMessageStream(ctx.getOutputId(transform), outputStream);
}
private static <K, InputT, OutputT> MessageStream<OpMessage<KV<K, OutputT>>> doTranslateGBK(
MessageStream<OpMessage<KV<K, InputT>>> inputStream,
boolean needRepartition,
SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> reduceFn,
WindowingStrategy<?, BoundedWindow> windowingStrategy,
KvCoder<K, InputT> kvInputCoder,
Coder<WindowedValue<KV<K, InputT>>> elementCoder,
int topologyId,
String nodeFullname,
TupleTag<KV<K, OutputT>> outputTag,
PCollection.IsBounded isBounded) {
final MessageStream<OpMessage<KV<K, InputT>>> filteredInputStream =
inputStream.filter(msg -> msg.getType() == OpMessage.Type.ELEMENT);
final MessageStream<OpMessage<KV<K, InputT>>> partitionedInputStream;
if (!needRepartition) {
partitionedInputStream = filteredInputStream;
} else {
partitionedInputStream =
filteredInputStream
.partitionBy(
msg -> msg.getElement().getValue().getKey(),
msg -> msg.getElement(),
KVSerde.of(
SamzaCoders.toSerde(kvInputCoder.getKeyCoder()),
SamzaCoders.toSerde(elementCoder)),
// TODO: infer a fixed id from the name
"gbk-" + topologyId)
.map(kv -> OpMessage.ofElement(kv.getValue()));
}
final Coder<KeyedWorkItem<K, InputT>> keyedWorkItemCoder =
KeyedWorkItemCoder.of(
kvInputCoder.getKeyCoder(),
kvInputCoder.getValueCoder(),
windowingStrategy.getWindowFn().windowCoder());
final MessageStream<OpMessage<KV<K, OutputT>>> outputStream =
partitionedInputStream
.flatMap(OpAdapter.adapt(new KvToKeyedWorkItemOp<>()))
.flatMap(
OpAdapter.adapt(
new GroupByKeyOp<>(
outputTag,
keyedWorkItemCoder,
reduceFn,
windowingStrategy,
new DoFnOp.SingleOutputManagerFactory<>(),
nodeFullname,
// TODO: infer a fixed id from the name
outputTag.getId(),
isBounded)));
return outputStream;
}
@SuppressWarnings("unchecked")
private static <K, InputT, OutputT>
SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> getSystemReduceFn(
PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform,
Pipeline pipeline,
KvCoder<K, InputT> kvInputCoder) {
if (transform instanceof GroupByKey) {
return (SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow>)
SystemReduceFn.buffering(kvInputCoder.getValueCoder());
} else if (transform instanceof Combine.PerKey) {
final CombineFnBase.GlobalCombineFn<? super InputT, ?, OutputT> combineFn =
((Combine.PerKey) transform).getFn();
return SystemReduceFn.combining(
kvInputCoder.getKeyCoder(),
AppliedCombineFn.withInputCoder(combineFn, pipeline.getCoderRegistry(), kvInputCoder));
} else {
throw new RuntimeException("Transform " + transform + " cannot be translated as GroupByKey.");
}
}
private static boolean needRepartition(TransformHierarchy.Node node, TranslationContext ctx) {
if (ctx.getPipelineOptions().getMaxSourceParallelism() == 1) {
// Only one task will be created, no need for repartition
return false;
}
if (node == null) {
return true;
}
if (node.getTransform() instanceof GroupWithoutRepartition) {
return false;
} else {
return needRepartition(node.getEnclosingNode(), ctx);
}
}
}