blob: 6a4c04693faed8053cd6ebedaa3ec91cac796ee5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.frontend.beam;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.*;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.*;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.transform.Transform;
import org.apache.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
import org.apache.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
import org.apache.nemo.compiler.frontend.beam.coder.SideInputCoder;
import org.apache.nemo.compiler.frontend.beam.transform.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Stack;
/**
* A collection of translators for the Beam PTransforms.
*/
final class PipelineTranslationContext {
private final PipelineOptions pipelineOptions;
private final DAGBuilder<IRVertex, IREdge> builder;
private final Map<PValue, TransformHierarchy.Node> pValueToProducerBeamNode;
private final Map<PValue, IRVertex> pValueToProducerVertex;
private final Map<PValue, TupleTag<?>> pValueToTag;
private final Stack<LoopVertex> loopVertexStack;
private final Pipeline pipeline;
/**
* @param pipeline the pipeline to translate
* @param pipelineOptions {@link PipelineOptions}
*/
PipelineTranslationContext(final Pipeline pipeline,
final PipelineOptions pipelineOptions) {
this.pipeline = pipeline;
this.builder = new DAGBuilder<>();
this.pValueToProducerBeamNode = new HashMap<>();
this.pValueToProducerVertex = new HashMap<>();
this.pValueToTag = new HashMap<>();
this.loopVertexStack = new Stack<>();
this.pipelineOptions = pipelineOptions;
}
/**
* @param compositeTransform composite transform.
*/
void enterCompositeTransform(final TransformHierarchy.Node compositeTransform) {
if (compositeTransform.getTransform() instanceof LoopCompositeTransform) {
final LoopVertex loopVertex = new LoopVertex(compositeTransform.getFullName());
builder.addVertex(loopVertex, loopVertexStack);
builder.removeVertex(loopVertex);
loopVertexStack.push(new LoopVertex(compositeTransform.getFullName()));
}
}
/**
* @param compositeTransform composite transform.
*/
void leaveCompositeTransform(final TransformHierarchy.Node compositeTransform) {
if (compositeTransform.getTransform() instanceof LoopCompositeTransform) {
loopVertexStack.pop();
}
}
/**
* Add IR vertex to the builder.
*
* @param vertex IR vertex to add
*/
void addVertex(final IRVertex vertex) {
builder.addVertex(vertex, loopVertexStack);
}
/**
* Say the dstIRVertex consumes three views: view0, view1, and view2.
* <p>
* We translate that as the following:
* view0 -> SideInputTransform(index=0) ->
* view1 -> SideInputTransform(index=1) -> dstIRVertex(with a map from indices to PCollectionViews)
* view2 -> SideInputTransform(index=2) ->
*
* @param dstVertex vertex.
* @param sideInputs of the vertex.
*/
void addSideInputEdges(final IRVertex dstVertex, final Map<Integer, PCollectionView<?>> sideInputs) {
for (final Map.Entry<Integer, PCollectionView<?>> entry : sideInputs.entrySet()) {
final int index = entry.getKey();
final PCollectionView view = entry.getValue();
final IRVertex srcVertex = pValueToProducerVertex.get(view);
final IRVertex sideInputTransformVertex = new OperatorVertex(new SideInputTransform(index));
addVertex(sideInputTransformVertex);
final Coder viewCoder = getCoderForView(view, this);
final Coder windowCoder = view.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
// First edge: view to transform
final IREdge firstEdge =
new IREdge(CommunicationPatternProperty.Value.OneToOne, srcVertex, sideInputTransformVertex);
addEdge(firstEdge, viewCoder, windowCoder);
// Second edge: transform to the dstIRVertex
final IREdge secondEdge =
new IREdge(CommunicationPatternProperty.Value.BroadCast, sideInputTransformVertex, dstVertex);
final WindowedValue.FullWindowedValueCoder sideInputElementCoder =
WindowedValue.getFullCoder(SideInputCoder.of(viewCoder), windowCoder);
// The vertices should be Parallelism=1
srcVertex.setPropertyPermanently(ParallelismProperty.of(1));
sideInputTransformVertex.setPropertyPermanently(ParallelismProperty.of(1));
secondEdge.setProperty(EncoderProperty.of(new BeamEncoderFactory(sideInputElementCoder)));
secondEdge.setProperty(DecoderProperty.of(new BeamDecoderFactory(sideInputElementCoder)));
builder.connectVertices(secondEdge);
}
}
/**
* Add IR edge to the builder.
*
* @param dst the destination IR vertex.
* @param input the {@link PValue} {@code dst} consumes
*/
void addEdgeTo(final IRVertex dst, final PValue input) {
if (input instanceof PCollection) {
final Coder elementCoder = ((PCollection) input).getCoder();
final Coder windowCoder = ((PCollection) input).getWindowingStrategy().getWindowFn().windowCoder();
final IRVertex src = pValueToProducerVertex.get(input);
if (src == null) {
throw new IllegalStateException(String.format("Cannot find a vertex that emits pValue %s", input));
}
final CommunicationPatternProperty.Value communicationPattern = getCommPattern(src, dst);
final IREdge edge = new IREdge(communicationPattern, src, dst);
if (pValueToTag.containsKey(input)) {
edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(input).getId()));
}
addEdge(edge, elementCoder, windowCoder);
} else {
throw new IllegalStateException(input.toString());
}
}
/**
* @param edge IR edge to add.
* @param elementCoder element coder.
* @param windowCoder window coder.
*/
void addEdge(final IREdge edge, final Coder elementCoder, final Coder windowCoder) {
edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
if (elementCoder instanceof KvCoder) {
Coder keyCoder = ((KvCoder) elementCoder).getKeyCoder();
edge.setProperty(KeyEncoderProperty.of(new BeamEncoderFactory(keyCoder)));
edge.setProperty(KeyDecoderProperty.of(new BeamDecoderFactory(keyCoder)));
}
final WindowedValue.FullWindowedValueCoder coder = WindowedValue.getFullCoder(elementCoder, windowCoder);
edge.setProperty(EncoderProperty.of(new BeamEncoderFactory<>(coder)));
edge.setProperty(DecoderProperty.of(new BeamDecoderFactory<>(coder)));
builder.connectVertices(edge);
}
/**
* Registers a {@link PValue} as a m.forEach(outputFromGbk -> ain output from the specified {@link IRVertex}.
*
* @param node node
* @param irVertex the IR vertex
* @param output the {@link PValue} {@code irVertex} emits as main output
*/
void registerMainOutputFrom(final TransformHierarchy.Node node,
final IRVertex irVertex,
final PValue output) {
pValueToProducerBeamNode.put(output, node);
pValueToProducerVertex.put(output, irVertex);
}
/**
* Registers a {@link PValue} as an additional output from the specified {@link IRVertex}.
*
* @param node node
* @param irVertex the IR vertex
* @param output the {@link PValue} {@code irVertex} emits as additional output
* @param tag the {@link TupleTag} associated with this additional output
*/
void registerAdditionalOutputFrom(final TransformHierarchy.Node node,
final IRVertex irVertex,
final PValue output,
final TupleTag<?> tag) {
pValueToProducerBeamNode.put(output, node);
pValueToTag.put(output, tag);
pValueToProducerVertex.put(output, irVertex);
}
/**
* @return the pipeline.
*/
Pipeline getPipeline() {
return pipeline;
}
/**
* @return the pipeline options.
*/
PipelineOptions getPipelineOptions() {
return pipelineOptions;
}
/**
* @return the dag builder of this translation context.
*/
DAGBuilder getBuilder() {
return builder;
}
/**
* @param pValue {@link PValue}
* @return the producer beam node.
*/
TransformHierarchy.Node getProducerBeamNodeOf(final PValue pValue) {
return pValueToProducerBeamNode.get(pValue);
}
/**
* @param src source IR vertex.
* @param dst destination IR vertex.
* @return the communication pattern property value.
*/
private CommunicationPatternProperty.Value getCommPattern(final IRVertex src, final IRVertex dst) {
final Class<?> constructUnionTableFn;
try {
constructUnionTableFn = Class.forName("org.apache.beam.sdk.transforms.join.CoGroupByKey$ConstructUnionTableFn");
} catch (final ClassNotFoundException e) {
throw new RuntimeException(e);
}
final Transform srcTransform = src instanceof OperatorVertex ? ((OperatorVertex) src).getTransform() : null;
final Transform dstTransform = dst instanceof OperatorVertex ? ((OperatorVertex) dst).getTransform() : null;
final DoFn srcDoFn = srcTransform instanceof DoFnTransform ? ((DoFnTransform) srcTransform).getDoFn() : null;
if (srcDoFn != null && srcDoFn.getClass().equals(constructUnionTableFn)) {
return CommunicationPatternProperty.Value.Shuffle;
}
if (srcTransform instanceof FlattenTransform) {
return CommunicationPatternProperty.Value.OneToOne;
}
if (dstTransform instanceof GroupByKeyAndWindowDoFnTransform
|| dstTransform instanceof GroupByKeyTransform) {
return CommunicationPatternProperty.Value.Shuffle;
}
if (dstTransform instanceof CreateViewTransform) {
return CommunicationPatternProperty.Value.BroadCast;
}
return CommunicationPatternProperty.Value.OneToOne;
}
/**
* Get appropriate coder for {@link PCollectionView}.
*
* @param view {@link PCollectionView}
* @param context translation context.
* @return appropriate {@link Coder} for {@link PCollectionView}
*/
private static Coder<?> getCoderForView(final PCollectionView view, final PipelineTranslationContext context) {
final TransformHierarchy.Node src = context.getProducerBeamNodeOf(view);
final KvCoder<?, ?> inputKVCoder = (KvCoder) src.getOutputs().values().stream()
.filter(v -> v instanceof PCollection)
.map(v -> (PCollection) v)
.findFirst()
.orElseThrow(() -> new RuntimeException(String.format("No incoming PCollection to %s", src)))
.getCoder();
final ViewFn viewFn = view.getViewFn();
if (viewFn instanceof PCollectionViews.IterableViewFn) {
return IterableCoder.of(inputKVCoder.getValueCoder());
} else if (viewFn instanceof PCollectionViews.ListViewFn) {
return ListCoder.of(inputKVCoder.getValueCoder());
} else if (viewFn instanceof PCollectionViews.MapViewFn) {
final KvCoder inputValueKVCoder = (KvCoder) inputKVCoder.getValueCoder();
return MapCoder.of(inputValueKVCoder.getKeyCoder(), inputValueKVCoder.getValueCoder());
} else if (viewFn instanceof PCollectionViews.MultimapViewFn) {
final KvCoder inputValueKVCoder = (KvCoder) inputKVCoder.getValueCoder();
return MapCoder.of(inputValueKVCoder.getKeyCoder(), inputValueKVCoder.getValueCoder());
} else if (viewFn instanceof PCollectionViews.SingletonViewFn) {
return inputKVCoder;
} else {
throw new UnsupportedOperationException(String.format("Unsupported viewFn %s", viewFn.getClass()));
}
}
}