| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.samza.translation; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.ServiceLoader; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.stream.Collectors; |
| import org.apache.beam.model.pipeline.v1.RunnerApi; |
| import org.apache.beam.runners.core.construction.ParDoTranslation; |
| import org.apache.beam.runners.core.construction.graph.PipelineNode; |
| import org.apache.beam.runners.core.construction.graph.QueryablePipeline; |
| import org.apache.beam.runners.samza.SamzaPipelineOptions; |
| import org.apache.beam.runners.samza.runtime.DoFnOp; |
| import org.apache.beam.runners.samza.runtime.Op; |
| import org.apache.beam.runners.samza.runtime.OpAdapter; |
| import org.apache.beam.runners.samza.runtime.OpEmitter; |
| import org.apache.beam.runners.samza.runtime.OpMessage; |
| import org.apache.beam.runners.samza.runtime.SamzaDoFnInvokerRegistrar; |
| import org.apache.beam.runners.samza.util.SamzaPipelineTranslatorUtils; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.runners.TransformHierarchy; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.DoFnSchemaInformation; |
| import org.apache.beam.sdk.transforms.ParDo; |
| import org.apache.beam.sdk.transforms.join.RawUnionValue; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; |
| import org.apache.beam.sdk.util.WindowedValue; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionView; |
| import org.apache.beam.sdk.values.PValue; |
| import org.apache.beam.sdk.values.TupleTag; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterators; |
| import org.apache.samza.operators.MessageStream; |
| import org.apache.samza.operators.functions.FlatMapFunction; |
| import org.apache.samza.operators.functions.WatermarkFunction; |
| import org.joda.time.Instant; |
| |
| /** |
| * Translates {@link org.apache.beam.sdk.transforms.ParDo.MultiOutput} or ExecutableStage in |
| * portable api to Samza {@link DoFnOp}. |
| */ |
| class ParDoBoundMultiTranslator<InT, OutT> |
| implements TransformTranslator<ParDo.MultiOutput<InT, OutT>>, |
| TransformConfigGenerator<ParDo.MultiOutput<InT, OutT>> { |
| |
| private final SamzaDoFnInvokerRegistrar doFnInvokerRegistrar; |
| |
| ParDoBoundMultiTranslator() { |
| final Iterator<SamzaDoFnInvokerRegistrar> invokerReg = |
| ServiceLoader.load(SamzaDoFnInvokerRegistrar.class).iterator(); |
| doFnInvokerRegistrar = invokerReg.hasNext() ? Iterators.getOnlyElement(invokerReg) : null; |
| } |
| |
| @Override |
| public void translate( |
| ParDo.MultiOutput<InT, OutT> transform, |
| TransformHierarchy.Node node, |
| TranslationContext ctx) { |
| doTranslate(transform, node, ctx); |
| } |
| |
| // static for serializing anonymous functions |
| private static <InT, OutT> void doTranslate( |
| ParDo.MultiOutput<InT, OutT> transform, |
| TransformHierarchy.Node node, |
| TranslationContext ctx) { |
| final PCollection<? extends InT> input = ctx.getInput(transform); |
| final Map<TupleTag<?>, Coder<?>> outputCoders = |
| ctx.getCurrentTransform().getOutputs().entrySet().stream() |
| .filter(e -> e.getValue() instanceof PCollection) |
| .collect( |
| Collectors.toMap(e -> e.getKey(), e -> ((PCollection<?>) e.getValue()).getCoder())); |
| |
| final DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); |
| final Coder<?> keyCoder = |
| signature.usesState() ? ((KvCoder<?, ?>) input.getCoder()).getKeyCoder() : null; |
| |
| if (signature.processElement().isSplittable()) { |
| throw new UnsupportedOperationException("Splittable DoFn is not currently supported"); |
| } |
| |
| final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStream(input); |
| final List<MessageStream<OpMessage<InT>>> sideInputStreams = |
| transform.getSideInputs().stream() |
| .map(ctx::<InT>getViewStream) |
| .collect(Collectors.toList()); |
| final ArrayList<Map.Entry<TupleTag<?>, PValue>> outputs = |
| new ArrayList<>(node.getOutputs().entrySet()); |
| |
| final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>(); |
| final Map<Integer, PCollection<?>> indexToPCollectionMap = new HashMap<>(); |
| |
| for (int index = 0; index < outputs.size(); ++index) { |
| final Map.Entry<TupleTag<?>, PValue> taggedOutput = outputs.get(index); |
| tagToIndexMap.put(taggedOutput.getKey(), index); |
| |
| if (!(taggedOutput.getValue() instanceof PCollection)) { |
| throw new IllegalArgumentException( |
| "Expected side output to be PCollection, but was: " + taggedOutput.getValue()); |
| } |
| final PCollection<?> sideOutputCollection = (PCollection<?>) taggedOutput.getValue(); |
| indexToPCollectionMap.put(index, sideOutputCollection); |
| } |
| |
| final HashMap<String, PCollectionView<?>> idToPValueMap = new HashMap<>(); |
| for (PCollectionView<?> view : transform.getSideInputs()) { |
| idToPValueMap.put(ctx.getViewId(view), view); |
| } |
| |
| DoFnSchemaInformation doFnSchemaInformation; |
| doFnSchemaInformation = ParDoTranslation.getSchemaInformation(ctx.getCurrentTransform()); |
| |
| final DoFnOp<InT, OutT, RawUnionValue> op = |
| new DoFnOp<>( |
| transform.getMainOutputTag(), |
| transform.getFn(), |
| keyCoder, |
| (Coder<InT>) input.getCoder(), |
| outputCoders, |
| transform.getSideInputs(), |
| transform.getAdditionalOutputTags().getAll(), |
| input.getWindowingStrategy(), |
| idToPValueMap, |
| new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), |
| node.getFullName(), |
| // TODO: infer a fixed id from the name |
| String.valueOf(ctx.getCurrentTopologicalId()), |
| input.isBounded(), |
| false, |
| null, |
| Collections.emptyMap(), |
| doFnSchemaInformation); |
| |
| final MessageStream<OpMessage<InT>> mergedStreams; |
| if (sideInputStreams.isEmpty()) { |
| mergedStreams = inputStream; |
| } else { |
| MessageStream<OpMessage<InT>> mergedSideInputStreams = |
| MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn()); |
| mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams)); |
| } |
| |
| final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = |
| mergedStreams.flatMap(OpAdapter.adapt(op)); |
| |
| for (int outputIndex : tagToIndexMap.values()) { |
| @SuppressWarnings("unchecked") |
| final MessageStream<OpMessage<OutT>> outputStream = |
| taggedOutputStream |
| .filter( |
| message -> |
| message.getType() != OpMessage.Type.ELEMENT |
| || message.getElement().getValue().getUnionTag() == outputIndex) |
| .flatMap(OpAdapter.adapt(new RawUnionValueToValue())); |
| |
| ctx.registerMessageStream(indexToPCollectionMap.get(outputIndex), outputStream); |
| } |
| } |
| |
| /* |
| * We reuse ParDo translator to translate ExecutableStage |
| */ |
| @Override |
| public void translatePortable( |
| PipelineNode.PTransformNode transform, |
| QueryablePipeline pipeline, |
| PortableTranslationContext ctx) { |
| doTranslatePortable(transform, pipeline, ctx); |
| } |
| |
| // static for serializing anonymous functions |
| private static <InT, OutT> void doTranslatePortable( |
| PipelineNode.PTransformNode transform, |
| QueryablePipeline pipeline, |
| PortableTranslationContext ctx) { |
| Map<String, String> outputs = transform.getTransform().getOutputsMap(); |
| |
| final RunnerApi.ExecutableStagePayload stagePayload; |
| try { |
| stagePayload = |
| RunnerApi.ExecutableStagePayload.parseFrom( |
| transform.getTransform().getSpec().getPayload()); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| String inputId = stagePayload.getInput(); |
| final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStreamById(inputId); |
| // TODO: support side input |
| final List<MessageStream<OpMessage<InT>>> sideInputStreams = Collections.emptyList(); |
| |
| final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>(); |
| final Map<String, TupleTag<?>> idToTupleTagMap = new HashMap<>(); |
| |
| // first output as the main output |
| final TupleTag<OutT> mainOutputTag = |
| outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next()); |
| |
| AtomicInteger index = new AtomicInteger(0); |
| outputs |
| .keySet() |
| .iterator() |
| .forEachRemaining( |
| outputName -> { |
| TupleTag<?> tupleTag = new TupleTag<>(outputName); |
| tagToIndexMap.put(tupleTag, index.get()); |
| index.incrementAndGet(); |
| String collectionId = outputs.get(outputName); |
| idToTupleTagMap.put(collectionId, tupleTag); |
| }); |
| |
| WindowedValue.WindowedValueCoder<InT> windowedInputCoder = |
| SamzaPipelineTranslatorUtils.instantiateCoder(inputId, pipeline.getComponents()); |
| final String nodeFullname = transform.getTransform().getUniqueName(); |
| |
| final DoFnSchemaInformation doFnSchemaInformation; |
| doFnSchemaInformation = ParDoTranslation.getSchemaInformation(transform.getTransform()); |
| |
| final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId); |
| final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input); |
| |
| final DoFnOp<InT, OutT, RawUnionValue> op = |
| new DoFnOp<>( |
| mainOutputTag, |
| new NoOpDoFn<>(), |
| null, // key coder not in use |
| windowedInputCoder.getValueCoder(), // input coder not in use |
| Collections.emptyMap(), // output coders not in use |
| Collections.emptyList(), // sideInputs not in use until side input support |
| new ArrayList<>(idToTupleTagMap.values()), // used by java runner only |
| SamzaPipelineTranslatorUtils.getPortableWindowStrategy(transform, pipeline), |
| Collections.emptyMap(), // idToViewMap not in use until side input support |
| new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), |
| nodeFullname, |
| // TODO: infer a fixed id from the name |
| String.valueOf(ctx.getCurrentTopologicalId()), |
| isBounded, |
| true, |
| stagePayload, |
| idToTupleTagMap, |
| doFnSchemaInformation); |
| |
| final MessageStream<OpMessage<InT>> mergedStreams; |
| if (sideInputStreams.isEmpty()) { |
| mergedStreams = inputStream; |
| } else { |
| MessageStream<OpMessage<InT>> mergedSideInputStreams = |
| MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn()); |
| mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams)); |
| } |
| |
| final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = |
| mergedStreams.flatMap(OpAdapter.adapt(op)); |
| |
| for (int outputIndex : tagToIndexMap.values()) { |
| final MessageStream<OpMessage<OutT>> outputStream = |
| taggedOutputStream |
| .filter( |
| message -> |
| message.getType() != OpMessage.Type.ELEMENT |
| || message.getElement().getValue().getUnionTag() == outputIndex) |
| .flatMap(OpAdapter.adapt(new RawUnionValueToValue())); |
| |
| ctx.registerMessageStream(ctx.getOutputId(transform), outputStream); |
| } |
| } |
| |
| @Override |
| public Map<String, String> createConfig( |
| ParDo.MultiOutput<InT, OutT> transform, TransformHierarchy.Node node, ConfigContext ctx) { |
| final Map<String, String> config = new HashMap<>(); |
| final DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); |
| final SamzaPipelineOptions options = ctx.getPipelineOptions(); |
| |
| if (signature.usesState()) { |
| // set up user state configs |
| for (DoFnSignature.StateDeclaration state : signature.stateDeclarations().values()) { |
| final String storeId = state.id(); |
| config.put( |
| "stores." + storeId + ".factory", |
| "org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory"); |
| config.put("stores." + storeId + ".key.serde", "byteSerde"); |
| config.put("stores." + storeId + ".msg.serde", "byteSerde"); |
| |
| if (options.getStateDurable()) { |
| config.put( |
| "stores." + storeId + ".changelog", |
| ConfigBuilder.getChangelogTopic(options, storeId)); |
| } |
| } |
| } |
| |
| if (doFnInvokerRegistrar != null) { |
| config.putAll(doFnInvokerRegistrar.configFor(transform.getFn())); |
| } |
| |
| return config; |
| } |
| |
| private static class SideInputWatermarkFn<InT> |
| implements FlatMapFunction<OpMessage<InT>, OpMessage<InT>>, |
| WatermarkFunction<OpMessage<InT>> { |
| |
| @Override |
| public Collection<OpMessage<InT>> apply(OpMessage<InT> message) { |
| return Collections.singletonList(message); |
| } |
| |
| @Override |
| public Collection<OpMessage<InT>> processWatermark(long watermark) { |
| return Collections.singletonList(OpMessage.ofSideInputWatermark(new Instant(watermark))); |
| } |
| |
| @Override |
| public Long getOutputWatermark() { |
| // Always return max so the side input watermark will not be aggregated with main inputs. |
| return Long.MAX_VALUE; |
| } |
| } |
| |
| private static class RawUnionValueToValue<OutT> implements Op<RawUnionValue, OutT, Void> { |
| @Override |
| public void processElement(WindowedValue<RawUnionValue> inputElement, OpEmitter<OutT> emitter) { |
| @SuppressWarnings("unchecked") |
| final OutT value = (OutT) inputElement.getValue().getValue(); |
| emitter.emitElement(inputElement.withValue(value)); |
| } |
| } |
| |
| private static class NoOpDoFn<InT, OutT> extends DoFn<InT, OutT> { |
| @ProcessElement |
| public void doNothing(ProcessContext context) {} |
| } |
| } |