blob: e1800d4ea84dceb1e583e376d1241e310f69b957 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.Util;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.dag.Edge;
import org.apache.nemo.common.dag.Vertex;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DataFlowProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.stream.Collectors;
/**
* A pass for assigning each stages in schedule groups.
* <p>
* TODO #347: IRDAG#partitionAcyclically
* This code can be greatly simplified...
*
* <h3>Rules</h3>
* <ul>
* <li>Vertices connected with push edges must be assigned same ScheduleGroup.</li>
* <li>For pull edges,
* <ul>
* <li>if the destination of the edge depends on multiple ScheduleGroups, split ScheduleGroup by the edge.</li>
* <li>if the edge is broadcast type and {@code allowBroadcastWithinScheduleGroup} is {@code false},
* split ScheduleGroup by the edge.</li>
* <li>if the edge is shuffle type and {@code allowShuffleWithinScheduleGroup} is {@code false},
* split ScheduleGroup by the edge.</li>
* <li>if the destination of the edge has multiple inEdges, split ScheduleGroup by the edge.</li>
* <li>Otherwise, the source and the destination of the edge should be assigned same ScheduleGroup.</li>
* </ul>
* </li>
* </ul>
*/
@Annotates(ScheduleGroupProperty.class)
@Requires({CommunicationPatternProperty.class, DataFlowProperty.class})
public final class DefaultScheduleGroupPass extends AnnotatingPass {
private static final Logger LOG = LoggerFactory.getLogger(DefaultScheduleGroupPass.class.getName());
private final boolean allowBroadcastWithinScheduleGroup;
private final boolean allowShuffleWithinScheduleGroup;
private final boolean allowMultipleInEdgesWithinScheduleGroup;
/**
* Default constructor.
*/
public DefaultScheduleGroupPass() {
this(true, true, true);
}
/**
* Constructor.
*
* @param allowBroadcastWithinScheduleGroup whether to allow Broadcast edges within a ScheduleGroup or not
* @param allowShuffleWithinScheduleGroup whether to allow Shuffle edges within a ScheduleGroup or not
* @param allowMultipleInEdgesWithinScheduleGroup whether to allow vertices with multiple dependencies or not
*/
public DefaultScheduleGroupPass(final boolean allowBroadcastWithinScheduleGroup,
final boolean allowShuffleWithinScheduleGroup,
final boolean allowMultipleInEdgesWithinScheduleGroup) {
super(DefaultScheduleGroupPass.class);
this.allowBroadcastWithinScheduleGroup = allowBroadcastWithinScheduleGroup;
this.allowShuffleWithinScheduleGroup = allowShuffleWithinScheduleGroup;
this.allowMultipleInEdgesWithinScheduleGroup = allowMultipleInEdgesWithinScheduleGroup;
}
@Override
public IRDAG apply(final IRDAG dag) {
final Map<IRVertex, Integer> irVertexToGroupIdMap = new HashMap<>();
final Map<Integer, List<IRVertex>> groupIdToVertices = new HashMap<>();
// Step 1: Compute schedule groups
final MutableInt lastGroupId = new MutableInt(0);
dag.topologicalDo(irVertex -> {
final int curId;
if (!irVertexToGroupIdMap.containsKey(irVertex)) {
lastGroupId.increment();
irVertexToGroupIdMap.put(irVertex, lastGroupId.intValue());
curId = lastGroupId.intValue();
} else {
curId = irVertexToGroupIdMap.get(irVertex);
}
groupIdToVertices.putIfAbsent(curId, new ArrayList<>());
groupIdToVertices.get(curId).add(irVertex);
final List<IREdge> allOutEdges = dag.getOutgoingEdgesOf(irVertex);
final List<IREdge> noCycleOutEdges = allOutEdges.stream().filter(curEdge -> {
final List<IREdge> outgoingEdgesWithoutCurEdge = new ArrayList<>(allOutEdges);
outgoingEdgesWithoutCurEdge.remove(curEdge);
return outgoingEdgesWithoutCurEdge.stream()
.map(IREdge::getDst)
.flatMap(dst -> dag.getDescendants(dst.getId()).stream())
.noneMatch(descendant -> descendant.equals(curEdge.getDst()));
}).collect(Collectors.toList());
final List<IRVertex> pushNoCycleOutEdgeDsts = noCycleOutEdges.stream()
.filter(e -> DataFlowProperty.Value.Push.equals(e.getPropertyValue(DataFlowProperty.class).get()))
.map(IREdge::getDst)
.collect(Collectors.toList());
pushNoCycleOutEdgeDsts.forEach(dst -> irVertexToGroupIdMap.put(dst, curId));
});
// Step 2: Topologically sort schedule groups
final Map<Integer, List<Pair>> vIdTogId = irVertexToGroupIdMap.entrySet().stream()
.map(entry -> Pair.of(entry.getKey().getId(), entry.getValue()))
.collect(Collectors.groupingBy(p -> (Integer) ((Pair) p).right()));
final DAGBuilder<ScheduleGroup, ScheduleGroupEdge> builder = new DAGBuilder<>();
final Map<Integer, ScheduleGroup> idToGroup = new HashMap<>();
// ScheduleGroups
groupIdToVertices.forEach((groupId, vertices) -> {
final ScheduleGroup sg = new ScheduleGroup(groupId);
idToGroup.put(groupId, sg);
sg.vertices.addAll(vertices);
builder.addVertex(sg);
});
// ScheduleGroupEdges
irVertexToGroupIdMap.forEach((vertex, groupId) -> {
dag.getIncomingEdgesOf(vertex).stream()
.filter(inEdge -> !groupIdToVertices.get(groupId).contains(inEdge.getSrc()))
.map(inEdge -> new ScheduleGroupEdge(
idToGroup.get(irVertexToGroupIdMap.get(inEdge.getSrc())),
idToGroup.get(irVertexToGroupIdMap.get(inEdge.getDst()))))
.forEach(builder::connectVertices);
});
// Step 3: Actually set new schedule group properties based on topological ordering
final MutableInt actualScheduleGroup = new MutableInt(0);
final DAG<ScheduleGroup, ScheduleGroupEdge> sgDAG = builder.buildWithoutSourceSinkCheck();
final List<ScheduleGroup> sorted = sgDAG.getTopologicalSort();
sorted.stream()
.map(sg -> groupIdToVertices.get(sg.getScheduleGroupId()))
.forEach(vertices -> {
vertices.forEach(vertex -> {
vertex.setPropertyPermanently(ScheduleGroupProperty.of(actualScheduleGroup.intValue()));
});
actualScheduleGroup.increment();
});
return dag;
}
/**
* Vertex in ScheduleGroup DAG.
*/
private static final class ScheduleGroup extends Vertex {
private final Set<IRVertex> vertices = new HashSet<>();
private final Set<ScheduleGroup> scheduleGroupsTo = new HashSet<>();
private final Set<ScheduleGroup> scheduleGroupsFrom = new HashSet<>();
private final int scheduleGroupId;
/**
* Constructor.
*/
ScheduleGroup(final int groupId) {
super(String.format("ScheduleGroup%d", groupId));
this.scheduleGroupId = groupId;
}
public int getScheduleGroupId() {
return scheduleGroupId;
}
@Override
public ObjectNode getPropertiesAsJsonNode() {
final ObjectMapper mapper = new ObjectMapper();
final ObjectNode node = mapper.createObjectNode();
node.put("transform", Util.stringifyIRVertexIds(vertices));
return node;
}
}
/**
* Edge in ScheduleGroup DAG.
*/
private static final class ScheduleGroupEdge extends Edge<ScheduleGroup> {
private static int nextScheduleGroupEdgeId = 0;
/**
* Constructor.
*
* @param src source vertex.
* @param dst destination vertex.
*/
ScheduleGroupEdge(final ScheduleGroup src, final ScheduleGroup dst) {
super(String.format("ScheduleGroupEdge-%d", nextScheduleGroupEdgeId++), src, dst);
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final ScheduleGroupEdge that = (ScheduleGroupEdge) o;
return this.getSrc().equals(that.getSrc()) && this.getDst().equals(that.getDst());
}
@Override
public int hashCode() {
return getSrc().hashCode() + 31 * getDst().hashCode();
}
}
}