[NEMO-411] Bug in ScheduleGroupPass, OutputTag, DuplicateEdgeGroup (#232)
JIRA: [NEMO-411: Bug in ScheduleGroupPass, OutputTag, DuplicateEdgeGroup](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-411)
**Major changes:**
- Fixes the bugs described in [NEMO-411](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-411)
> When trying to run ALS with TransientResourcePass, I've faced a bug regarding 1. OutputTag, which groups edges based on the output tag, but currently groups all of the edges without an output tag into a single group, which is undesirable, 2. DuplicateEdgeGroup, which does not consider the first edge that points to the first iteration of the loop, but obviously is a part of the duplicate edge group, 3. ScheduleGroupPass, which does not consider vertices with multiple outgoing edges pointing outside when looking for a cycle in the graph.
**Minor changes to note:**
- None
**Tests for the changes:**
- Added a integration test for ALS, the case where it was failing initially.
**Other comments:**
- None
Closes #232
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
index b72f7ae..4484505 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
@@ -38,6 +38,8 @@
import java.io.Serializable;
import java.util.*;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntSupplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -349,8 +351,8 @@
void addLoopVertexCheckers() {
final NeighborChecker duplicateEdgeGroupId = ((v, inEdges, outEdges) -> {
- final Map<Optional<String>, List<IREdge>> tagToOutEdges = groupOutEdgesByAdditionalOutputTag(outEdges);
- for (final List<IREdge> sameTagOutEdges : tagToOutEdges.values()) {
+ // In loop vertices, different edges with empty output tag must be distinguished separately.
+ for (final List<IREdge> sameTagOutEdges : groupOutEdgesByAdditionalOutputTag(outEdges, true)) {
if (sameTagOutEdges.stream()
.map(e -> e.getPropertyValue(DuplicateEdgeGroupProperty.class)
.map(DuplicateEdgeGroupPropertyValue::getGroupId))
@@ -429,7 +431,7 @@
void addEncodingCompressionCheckers() {
final NeighborChecker additionalOutputEncoder = ((irVertex, inEdges, outEdges) -> {
- for (final List<IREdge> sameTagOutEdges : groupOutEdgesByAdditionalOutputTag(outEdges).values()) {
+ for (final List<IREdge> sameTagOutEdges : groupOutEdgesByAdditionalOutputTag(outEdges, false)) {
final List<IREdge> nonStreamVertexEdge = sameTagOutEdges.stream()
.filter(stoe -> !isConnectedToStreamVertex(stoe))
.collect(Collectors.toList());
@@ -464,6 +466,23 @@
singleEdgeCheckerList.add(compressAndDecompress);
}
+ /**
+ * Group outgoing edges by the additional output tag property.
+ * @param outEdges the outedges to group.
+ * @param distinguishEmpty whether or not to distinguish empty tags separately or not.
+ * @return the edges grouped by the additional output tag property value.
+ */
+ private Collection<List<IREdge>> groupOutEdgesByAdditionalOutputTag(final List<IREdge> outEdges,
+ final boolean distinguishEmpty) {
+ final AtomicInteger distinctIntegerForEmptyOutputTag = new AtomicInteger(0);
+ final IntSupplier tagValueSupplier = distinguishEmpty
+ ? distinctIntegerForEmptyOutputTag::getAndIncrement : distinctIntegerForEmptyOutputTag::get;
+
+ return outEdges.stream().collect(Collectors.groupingBy(
+ outEdge -> outEdge.getPropertyValue(AdditionalOutputTagProperty.class)
+ .orElse(String.valueOf(tagValueSupplier.getAsInt())),
+ Collectors.toList())).values();
+ }
///////////////////////////// Private helper methods
@@ -471,12 +490,6 @@
return irEdge.getDst() instanceof RelayVertex || irEdge.getSrc() instanceof RelayVertex;
}
- private Map<Optional<String>, List<IREdge>> groupOutEdgesByAdditionalOutputTag(final List<IREdge> outEdges) {
- return outEdges.stream().collect(Collectors.groupingBy(
- (outEdge -> outEdge.getPropertyValue(AdditionalOutputTagProperty.class)),
- Collectors.toList()));
- }
-
private Set<Integer> getZeroToNSet(final int n) {
return IntStream.range(0, n)
.boxed()
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
index e0dd0e3..b9af41b 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
@@ -34,6 +34,7 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntPredicate;
/**
@@ -41,13 +42,13 @@
*/
public final class LoopVertex extends IRVertex {
- private static int duplicateEdgeGroupId = 0;
+ private final AtomicInteger duplicateEdgeGroupId = new AtomicInteger(0);
// Contains DAG information
private final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
private final String compositeTransformFullName;
// for the initial iteration
private final Map<IRVertex, Set<IREdge>> dagIncomingEdges = new HashMap<>();
- // Edges from previous iterations connected internal.
+ // Edges from previous iterations connected internally.
private final Map<IRVertex, Set<IREdge>> iterativeIncomingEdges = new HashMap<>();
// Edges from outside previous iterations.
private final Map<IRVertex, Set<IREdge>> nonIterativeIncomingEdges = new HashMap<>();
@@ -210,15 +211,20 @@
* Marks duplicate edges with DuplicateEdgeGroupProperty.
*/
public void markDuplicateEdges() {
- nonIterativeIncomingEdges.forEach(((irVertex, irEdges) -> irEdges.forEach(irEdge -> {
- irEdge.setProperty(
- DuplicateEdgeGroupProperty.of(new DuplicateEdgeGroupPropertyValue(String.valueOf(duplicateEdgeGroupId))));
- duplicateEdgeGroupId++;
+ nonIterativeIncomingEdges.forEach(((irVertex, inEdges) -> inEdges.forEach(inEdge -> {
+ final DuplicateEdgeGroupPropertyValue value =
+ new DuplicateEdgeGroupPropertyValue(String.valueOf(duplicateEdgeGroupId.getAndIncrement()));
+ inEdge.setProperty(DuplicateEdgeGroupProperty.of(value));
+ getDagIncomingEdges().getOrDefault(irVertex, new HashSet<>()).stream()
+ .filter(irEdge -> irEdge.getSrc().equals(inEdge.getSrc()))
+ .forEach(irEdge -> irEdge.setProperty(DuplicateEdgeGroupProperty.of(value)));
})));
}
/**
* Method for unrolling an iteration of the LoopVertex.
+ * So basically, in the original place of a LoopVertex, it puts a clone of the sub-DAG that iterates, and
+ * appends the LoopVertex after that, until the termination condition has been met.
*
* @param dagBuilder DAGBuilder to add the unrolled iteration to.
* @return a LoopVertex with one less maximum iteration.
@@ -244,7 +250,7 @@
});
});
- // process DAG incoming edges.
+ // process the initial DAG incoming edges for the first loop.
getDagIncomingEdges().forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
edge.getSrc(), originalToNewIRVertex.get(dstVertex));
@@ -253,7 +259,7 @@
}));
if (loopTerminationConditionMet()) {
- // if termination condition met, we process the DAG outgoing edge.
+ // if termination condition met, we process the last DAG outgoing edges for the final loop. Otherwise, we leave it
getDagOutgoingEdges().forEach((srcVertex, irEdges) -> irEdges.forEach(edge -> {
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
originalToNewIRVertex.get(srcVertex), edge.getDst());
@@ -262,7 +268,8 @@
}));
}
- // process next iteration's DAG incoming edges
+ // process next iteration's DAG incoming edges, and add them as the next loop's incoming edges:
+ // clear, as we're done with the current loop and need to prepare it for the next one.
this.getDagIncomingEdges().clear();
this.nonIterativeIncomingEdges.forEach((dstVertex, irEdges) -> irEdges.forEach(this::addDagIncomingEdge));
this.iterativeIncomingEdges.forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
@@ -287,7 +294,7 @@
* @return whether or not the loop termination condition has been met.
*/
public Boolean loopTerminationConditionMet(final Integer intPredicateInput) {
- return maxNumberOfIterations <= 0 || terminationCondition.test(intPredicateInput);
+ return maxNumberOfIterations <= 0 || (terminationCondition != null && terminationCondition.test(intPredicateInput));
}
/**
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
index f69f57c..ed614f1 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
@@ -21,7 +21,6 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.lang.mutable.MutableInt;
-import org.apache.nemo.common.Pair;
import org.apache.nemo.common.Util;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
@@ -113,9 +112,13 @@
groupIdToVertices.putIfAbsent(curId, new ArrayList<>());
groupIdToVertices.get(curId).add(irVertex);
- final List<IREdge> allOutEdges = dag.getOutgoingEdgesOf(irVertex);
- final List<IREdge> noCycleOutEdges = allOutEdges.stream().filter(curEdge -> {
- final List<IREdge> outgoingEdgesWithoutCurEdge = new ArrayList<>(allOutEdges);
+ final List<IRVertex> verticesOfGroup = groupIdToVertices.get(curId);
+ final List<IREdge> allOutEdgesOfGroup = groupIdToVertices.get(curId).stream()
+ .flatMap(vtx -> dag.getOutgoingEdgesOf(vtx).stream())
+ .filter(edge -> !verticesOfGroup.contains(edge.getDst())) // We don't count the group-internal edges.
+ .collect(Collectors.toList());
+ final List<IREdge> noCycleOutEdges = allOutEdgesOfGroup.stream().filter(curEdge -> {
+ final List<IREdge> outgoingEdgesWithoutCurEdge = new ArrayList<>(allOutEdgesOfGroup);
outgoingEdgesWithoutCurEdge.remove(curEdge);
return outgoingEdgesWithoutCurEdge.stream()
.map(IREdge::getDst)
@@ -132,10 +135,6 @@
});
// Step 2: Topologically sort schedule groups
- final Map<Integer, List<Pair>> vIdTogId = irVertexToGroupIdMap.entrySet().stream()
- .map(entry -> Pair.of(entry.getKey().getId(), entry.getValue()))
- .collect(Collectors.groupingBy(p -> (Integer) ((Pair) p).right()));
-
final DAGBuilder<ScheduleGroup, ScheduleGroupEdge> builder = new DAGBuilder<>();
final Map<Integer, ScheduleGroup> idToGroup = new HashMap<>();
@@ -158,14 +157,11 @@
// Step 3: Actually set new schedule group properties based on topological ordering
final MutableInt actualScheduleGroup = new MutableInt(0);
final DAG<ScheduleGroup, ScheduleGroupEdge> sgDAG = builder.buildWithoutSourceSinkCheck();
- final List<ScheduleGroup> sorted = sgDAG.getTopologicalSort();
- sorted.stream()
- .map(sg -> groupIdToVertices.get(sg.getScheduleGroupId()))
- .forEach(vertices -> {
- vertices.forEach(vertex -> vertex.setPropertyPermanently(
- ScheduleGroupProperty.of(actualScheduleGroup.intValue())));
- actualScheduleGroup.increment();
- });
+ sgDAG.topologicalDo(sg -> {
+ sg.vertices.forEach(vertex ->
+ vertex.setPropertyPermanently(ScheduleGroupProperty.of(actualScheduleGroup.intValue())));
+ actualScheduleGroup.increment();
+ });
return dag;
}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/PolicyImpl.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/PolicyImpl.java
index 0e8bcb2..9059891 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/PolicyImpl.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/PolicyImpl.java
@@ -96,9 +96,9 @@
final IRDAGChecker.CheckerResult integrity = processedDAG.checkIntegrity();
if (!integrity.isPassed()) {
final long curTime = System.currentTimeMillis();
- processedDAG.storeJSON("debug", String.valueOf(curTime), "integrity failure");
+ processedDAG.storeJSON(dagDirectory, String.valueOf(curTime), "integrity failure");
throw new CompileTimeOptimizationException(integrity.getFailReason()
- + " / For DAG visualization, check out debug/" + curTime + ".json");
+ + " / For DAG visualization, check out " + dagDirectory + curTime + ".json");
}
// Save the processed JSON DAG.
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java
index 8452bb5..5e0fb1a 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java
@@ -22,6 +22,7 @@
import org.apache.nemo.common.test.ArgBuilder;
import org.apache.nemo.common.test.ExampleTestArgs;
import org.apache.nemo.common.test.ExampleTestUtil;
+import org.apache.nemo.compiler.optimizer.policy.TransientResourcePolicy;
import org.apache.nemo.examples.beam.policy.DefaultPolicyParallelismFive;
import org.junit.After;
import org.junit.Before;
@@ -72,6 +73,15 @@
.build());
}
+ @Test(timeout = ExampleTestArgs.TIMEOUT)
+ public void testTransient() throws Exception {
+ JobLauncher.main(builder
+ .addResourceJson(noPoisonResources)
+ .addJobId(AlternatingLeastSquareITCase.class.getSimpleName() + "_transient")
+ .addOptimizationPolicy(TransientResourcePolicy.class.getCanonicalName())
+ .build());
+ }
+
// TODO #137: Retry parent task(s) upon task INPUT_READ_FAILURE
// @Test (timeout = TIMEOUT)
// public void testTransientResourceWithPoison() throws Exception {