blob: 50412a98ba2f916ebb68523cc173efb1cd4ffc96 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
import org.apache.nemo.common.KeyExtractor;
import org.apache.nemo.common.dag.Edge;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DataStoreProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Optimizes the PartitionSet property of shuffle edges to handle data skews using the SamplingVertex.
* <p>
* This pass effectively partitions the IRDAG by non-oneToOne edges, clones each subDAG partition using SamplingVertex
* to process sampled data, and executes each cloned partition prior to executing the corresponding original partition.
* <p>
* Suppose the IRDAG is partitioned into three sub-DAG partitions with shuffle dependencies as follows:
* P1 - P2 - P3
* <p>
* Then, this pass will produce something like:
* P1' - P1
* - P2' - P2 - P3
* where Px' consists of SamplingVertex objects that clone the execution of Px.
* (P3 is not cloned here because it is a sink partition, and none of the outgoing edges of its vertices needs to be
* optimized)
* <p>
* For each Px' this pass also inserts a TriggerVertex, to use its data statistics for dynamically optimizing
* the execution behaviors of Px.
*/
@Requires(CommunicationPatternProperty.class)
public final class SamplingSkewReshapingPass extends ReshapingPass {
private static final Logger LOG = LoggerFactory.getLogger(SamplingSkewReshapingPass.class.getName());
private static final float SAMPLE_RATE = 0.1f;
/**
* Default constructor.
*/
public SamplingSkewReshapingPass() {
super(SamplingSkewReshapingPass.class);
}
@Override
public IRDAG apply(final IRDAG dag) {
dag.topologicalDo(v -> {
for (final IREdge e : dag.getIncomingEdgesOf(v)) {
if (CommunicationPatternProperty.Value.Shuffle.equals(
e.getPropertyValue(CommunicationPatternProperty.class).get())) {
// Compute the partition and its source vertices
final IRVertex shuffleWriter = e.getSrc();
final Set<IRVertex> partitionAll = recursivelyBuildPartition(shuffleWriter, dag);
final Set<IRVertex> partitionSources = partitionAll.stream().filter(vertexInPartition ->
!dag.getIncomingEdgesOf(vertexInPartition).stream()
.map(Edge::getSrc)
.anyMatch(partitionAll::contains)
).collect(Collectors.toSet());
// Check if the partition is a sink, in which case we do not create sampling vertices
final boolean isSinkPartition = partitionAll.stream()
.flatMap(vertexInPartition -> dag.getOutgoingEdgesOf(vertexInPartition).stream())
.map(Edge::getDst)
.allMatch(partitionAll::contains);
if (isSinkPartition) {
break;
}
// Insert sampling vertices.
final Set<SamplingVertex> samplingVertices = partitionAll
.stream()
.map(vertexInPartition -> new SamplingVertex(vertexInPartition, SAMPLE_RATE))
.collect(Collectors.toSet());
dag.insert(samplingVertices, partitionSources);
// Insert the message vertex.
// We first obtain a clonedShuffleEdge to analyze the data statistics of the shuffle outputs of
// the sampling vertex right before shuffle.
final SamplingVertex rightBeforeShuffle = samplingVertices.stream()
.filter(sv -> sv.getOriginalVertexId().equals(e.getSrc().getId()))
.findFirst()
.orElseThrow(() -> new IllegalStateException());
final IREdge clonedShuffleEdge = rightBeforeShuffle.getCloneOfOriginalEdge(e);
final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
dag.insert(
new TriggerVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor)),
new MessageAggregatorVertex(() -> new HashMap<>(), SkewHandlingUtil.getMessageAggregator()),
SkewHandlingUtil.getEncoder(e),
SkewHandlingUtil.getDecoder(e),
new HashSet<>(Arrays.asList(clonedShuffleEdge)), // this works although the clone is not in the dag
new HashSet<>(Arrays.asList(e))); // we want to optimize the original edge, not the clone
}
}
});
return dag;
}
private Set<IRVertex> recursivelyBuildPartition(final IRVertex curVertex, final IRDAG dag) {
final Set<IRVertex> unionSet = new HashSet<>();
unionSet.add(curVertex);
for (final IREdge inEdge : dag.getIncomingEdgesOf(curVertex)) {
if (CommunicationPatternProperty.Value.OneToOne
.equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())
&& DataStoreProperty.Value.MemoryStore
.equals(inEdge.getPropertyValue(DataStoreProperty.class).get())
&& dag.getIncomingEdgesOf(curVertex).size() == 1) {
unionSet.addAll(recursivelyBuildPartition(inEdge.getSrc(), dag));
}
}
return unionSet;
}
}