blob: ad699afb793b670306ec4dbb47f1349e9b368c84 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.samza.translation;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.runners.samza.runtime.DoFnOp;
import org.apache.beam.runners.samza.runtime.Op;
import org.apache.beam.runners.samza.runtime.OpAdapter;
import org.apache.beam.runners.samza.runtime.OpEmitter;
import org.apache.beam.runners.samza.runtime.OpMessage;
import org.apache.beam.runners.samza.runtime.SamzaDoFnInvokerRegistrar;
import org.apache.beam.runners.samza.util.SamzaPipelineTranslatorUtils;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterators;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.functions.FlatMapFunction;
import org.apache.samza.operators.functions.WatermarkFunction;
import org.joda.time.Instant;
/**
* Translates {@link org.apache.beam.sdk.transforms.ParDo.MultiOutput} or ExecutableStage in
* portable api to Samza {@link DoFnOp}.
*/
class ParDoBoundMultiTranslator<InT, OutT>
implements TransformTranslator<ParDo.MultiOutput<InT, OutT>>,
TransformConfigGenerator<ParDo.MultiOutput<InT, OutT>> {
private final SamzaDoFnInvokerRegistrar doFnInvokerRegistrar;
ParDoBoundMultiTranslator() {
final Iterator<SamzaDoFnInvokerRegistrar> invokerReg =
ServiceLoader.load(SamzaDoFnInvokerRegistrar.class).iterator();
doFnInvokerRegistrar = invokerReg.hasNext() ? Iterators.getOnlyElement(invokerReg) : null;
}
@Override
public void translate(
ParDo.MultiOutput<InT, OutT> transform,
TransformHierarchy.Node node,
TranslationContext ctx) {
doTranslate(transform, node, ctx);
}
// static for serializing anonymous functions
private static <InT, OutT> void doTranslate(
ParDo.MultiOutput<InT, OutT> transform,
TransformHierarchy.Node node,
TranslationContext ctx) {
final PCollection<? extends InT> input = ctx.getInput(transform);
final Map<TupleTag<?>, Coder<?>> outputCoders =
ctx.getCurrentTransform().getOutputs().entrySet().stream()
.filter(e -> e.getValue() instanceof PCollection)
.collect(
Collectors.toMap(e -> e.getKey(), e -> ((PCollection<?>) e.getValue()).getCoder()));
final DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
final Coder<?> keyCoder =
signature.usesState() ? ((KvCoder<?, ?>) input.getCoder()).getKeyCoder() : null;
if (signature.processElement().isSplittable()) {
throw new UnsupportedOperationException("Splittable DoFn is not currently supported");
}
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStream(input);
final List<MessageStream<OpMessage<InT>>> sideInputStreams =
transform.getSideInputs().stream()
.map(ctx::<InT>getViewStream)
.collect(Collectors.toList());
final ArrayList<Map.Entry<TupleTag<?>, PValue>> outputs =
new ArrayList<>(node.getOutputs().entrySet());
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, PCollection<?>> indexToPCollectionMap = new HashMap<>();
for (int index = 0; index < outputs.size(); ++index) {
final Map.Entry<TupleTag<?>, PValue> taggedOutput = outputs.get(index);
tagToIndexMap.put(taggedOutput.getKey(), index);
if (!(taggedOutput.getValue() instanceof PCollection)) {
throw new IllegalArgumentException(
"Expected side output to be PCollection, but was: " + taggedOutput.getValue());
}
final PCollection<?> sideOutputCollection = (PCollection<?>) taggedOutput.getValue();
indexToPCollectionMap.put(index, sideOutputCollection);
}
final HashMap<String, PCollectionView<?>> idToPValueMap = new HashMap<>();
for (PCollectionView<?> view : transform.getSideInputs()) {
idToPValueMap.put(ctx.getViewId(view), view);
}
DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(ctx.getCurrentTransform());
final DoFnOp<InT, OutT, RawUnionValue> op =
new DoFnOp<>(
transform.getMainOutputTag(),
transform.getFn(),
keyCoder,
(Coder<InT>) input.getCoder(),
outputCoders,
transform.getSideInputs(),
transform.getAdditionalOutputTags().getAll(),
input.getWindowingStrategy(),
idToPValueMap,
new DoFnOp.MultiOutputManagerFactory(tagToIndexMap),
node.getFullName(),
// TODO: infer a fixed id from the name
String.valueOf(ctx.getCurrentTopologicalId()),
input.isBounded(),
false,
null,
Collections.emptyMap(),
doFnSchemaInformation);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams =
MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream =
mergedStreams.flatMap(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked")
final MessageStream<OpMessage<OutT>> outputStream =
taggedOutputStream
.filter(
message ->
message.getType() != OpMessage.Type.ELEMENT
|| message.getElement().getValue().getUnionTag() == outputIndex)
.flatMap(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToPCollectionMap.get(outputIndex), outputStream);
}
}
/*
* We reuse ParDo translator to translate ExecutableStage
*/
@Override
public void translatePortable(
PipelineNode.PTransformNode transform,
QueryablePipeline pipeline,
PortableTranslationContext ctx) {
doTranslatePortable(transform, pipeline, ctx);
}
// static for serializing anonymous functions
private static <InT, OutT> void doTranslatePortable(
PipelineNode.PTransformNode transform,
QueryablePipeline pipeline,
PortableTranslationContext ctx) {
Map<String, String> outputs = transform.getTransform().getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload =
RunnerApi.ExecutableStagePayload.parseFrom(
transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputId = stagePayload.getInput();
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStreamById(inputId);
// TODO: support side input
final List<MessageStream<OpMessage<InT>>> sideInputStreams = Collections.emptyList();
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<String, TupleTag<?>> idToTupleTagMap = new HashMap<>();
// first output as the main output
final TupleTag<OutT> mainOutputTag =
outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
AtomicInteger index = new AtomicInteger(0);
outputs
.keySet()
.iterator()
.forEachRemaining(
outputName -> {
TupleTag<?> tupleTag = new TupleTag<>(outputName);
tagToIndexMap.put(tupleTag, index.get());
index.incrementAndGet();
String collectionId = outputs.get(outputName);
idToTupleTagMap.put(collectionId, tupleTag);
});
WindowedValue.WindowedValueCoder<InT> windowedInputCoder =
SamzaPipelineTranslatorUtils.instantiateCoder(inputId, pipeline.getComponents());
final String nodeFullname = transform.getTransform().getUniqueName();
final DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(transform.getTransform());
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);
final DoFnOp<InT, OutT, RawUnionValue> op =
new DoFnOp<>(
mainOutputTag,
new NoOpDoFn<>(),
null, // key coder not in use
windowedInputCoder.getValueCoder(), // input coder not in use
Collections.emptyMap(), // output coders not in use
Collections.emptyList(), // sideInputs not in use until side input support
new ArrayList<>(idToTupleTagMap.values()), // used by java runner only
SamzaPipelineTranslatorUtils.getPortableWindowStrategy(transform, pipeline),
Collections.emptyMap(), // idToViewMap not in use until side input support
new DoFnOp.MultiOutputManagerFactory(tagToIndexMap),
nodeFullname,
// TODO: infer a fixed id from the name
String.valueOf(ctx.getCurrentTopologicalId()),
isBounded,
true,
stagePayload,
idToTupleTagMap,
doFnSchemaInformation);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams =
MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream =
mergedStreams.flatMap(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
final MessageStream<OpMessage<OutT>> outputStream =
taggedOutputStream
.filter(
message ->
message.getType() != OpMessage.Type.ELEMENT
|| message.getElement().getValue().getUnionTag() == outputIndex)
.flatMap(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(ctx.getOutputId(transform), outputStream);
}
}
@Override
public Map<String, String> createConfig(
ParDo.MultiOutput<InT, OutT> transform, TransformHierarchy.Node node, ConfigContext ctx) {
final Map<String, String> config = new HashMap<>();
final DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
final SamzaPipelineOptions options = ctx.getPipelineOptions();
if (signature.usesState()) {
// set up user state configs
for (DoFnSignature.StateDeclaration state : signature.stateDeclarations().values()) {
final String storeId = state.id();
config.put(
"stores." + storeId + ".factory",
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory");
config.put("stores." + storeId + ".key.serde", "byteSerde");
config.put("stores." + storeId + ".msg.serde", "byteSerde");
if (options.getStateDurable()) {
config.put(
"stores." + storeId + ".changelog",
ConfigBuilder.getChangelogTopic(options, storeId));
}
}
}
if (doFnInvokerRegistrar != null) {
config.putAll(doFnInvokerRegistrar.configFor(transform.getFn()));
}
return config;
}
private static class SideInputWatermarkFn<InT>
implements FlatMapFunction<OpMessage<InT>, OpMessage<InT>>,
WatermarkFunction<OpMessage<InT>> {
@Override
public Collection<OpMessage<InT>> apply(OpMessage<InT> message) {
return Collections.singletonList(message);
}
@Override
public Collection<OpMessage<InT>> processWatermark(long watermark) {
return Collections.singletonList(OpMessage.ofSideInputWatermark(new Instant(watermark)));
}
@Override
public Long getOutputWatermark() {
// Always return max so the side input watermark will not be aggregated with main inputs.
return Long.MAX_VALUE;
}
}
private static class RawUnionValueToValue<OutT> implements Op<RawUnionValue, OutT, Void> {
@Override
public void processElement(WindowedValue<RawUnionValue> inputElement, OpEmitter<OutT> emitter) {
@SuppressWarnings("unchecked")
final OutT value = (OutT) inputElement.getValue().getValue();
emitter.emitElement(inputElement.withValue(value));
}
}
private static class NoOpDoFn<InT, OutT> extends DoFn<InT, OutT> {
@ProcessElement
public void doNothing(ProcessContext context) {}
}
}