[NEMO-434] Logical DAG modification for Dynamic sampling of task metrics during the execution of a stage (#292)
JIRA: [NEMO-434: Logical DAG modification for Dynamic sampling of task metrics during the execution of a stage](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-434)
**Major changes:**
- Added new mechanism for launching runtime pass
- For runtime passes which need to gather information by itself, use MessageGeneratorVertex and MessageAggregatorVertex
- For runtime passes which need Metric information, use SignalVertex
- Added SignalVertex and SignalTransform to launch runtime pass without collecting any runtime information
- New Compile time pass for DTS which changes DAG
- New TaskSizeSplitterVertex to distinguish and manage sampled data and the rest in vertex level(extended from LoopVertex)
- Added inserting / deleting TaskSizeSplitterVertex and SignalVertex method in IRDAG
**Minor changes to note:**
- Changed the name of TriggerVertex to MessageGeneratorVertex. Since the combination of TriggerVertex and MAV triggers runtime pass, not trigger vertex alone, the name 'Trigger' can be misleading.
- erased duplicate assignments in EmptyComponents class
- new methods in LoopVertex
**Tests for the changes:**
- IRDAGTest and SkewCompositePassTest has been modified by the class name refactoring
- Added new test methods in IRDAGTest regarding TaskSizeSplitterVertex and SignalVertex
**Other comments:**
Closes #292
diff --git a/common/src/main/java/org/apache/nemo/common/Util.java b/common/src/main/java/org/apache/nemo/common/Util.java
index 4bf4593..fdd862c 100644
--- a/common/src/main/java/org/apache/nemo/common/Util.java
+++ b/common/src/main/java/org/apache/nemo/common/Util.java
@@ -27,10 +27,12 @@
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.executionproperty.ResourceSpecification;
import org.apache.nemo.common.ir.vertex.IRVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.SignalVertex;
import java.io.IOException;
import java.lang.instrument.Instrumentation;
@@ -196,8 +198,10 @@
public static boolean isUtilityVertex(final IRVertex v) {
return v instanceof SamplingVertex
|| v instanceof MessageAggregatorVertex
- || v instanceof TriggerVertex
- || v instanceof RelayVertex;
+ || v instanceof MessageGeneratorVertex
+ || v instanceof RelayVertex
+ || v instanceof SignalVertex
+ || v instanceof TaskSizeSplitterVertex;
}
/**
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
index 1d6af69..7ec85b7 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
@@ -18,15 +18,21 @@
*/
package org.apache.nemo.common.dag;
+import org.apache.nemo.common.Util;
import org.apache.nemo.common.exception.CompileTimeOptimizationException;
import org.apache.nemo.common.exception.IllegalVertexOperationException;
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.util.*;
@@ -42,6 +48,7 @@
* @param <E> the edge type.
*/
public final class DAGBuilder<V extends Vertex, E extends Edge<V>> implements Serializable {
+ private static final Logger LOG = LoggerFactory.getLogger(IRDAG.class.getName());
private final Set<V> vertices;
private final Map<V, Set<E>> incomingEdges;
private final Map<V, Set<E>> outgoingEdges;
@@ -169,6 +176,132 @@
return this;
}
+ // The below similar two methods are for connecting SplitterVertex in DAG
+
+ /**
+ * This method replaces current SplitterVertex's LoopEdge - InternalEdge relationship with the new relationship
+ * and connects the Edge.
+ * The changes which invokes this method should not be caused by SplitterVertex itself. Therefore, this method
+ * should be used when there are changes in vertices before / after SplitterVertex.
+ *
+ * CAUTION: TaskSizeSplitterVertex must only appear in IRDAG.
+ * {@param originalEdge} and {@param edgeToInsert} should have same source and destination.
+ *
+ * Relation to be Erased: originalEdge - internalEdge
+ * Relation to insert: edgeToInsert - newInternalEdge
+ *
+ * @param originalEdge edge connected to SplitterVertex, and is to be replaced.
+ * @param edgeToInsert edge connected to SplitterVertex, and is to be inserted.
+ * @return itself.
+ */
+ public DAGBuilder<V, E> connectSplitterVertexWithReplacing(final E originalEdge, final E edgeToInsert) {
+ final V src = edgeToInsert.getSrc();
+ final V dst = edgeToInsert.getDst();
+
+ if (vertices.contains(src) && vertices.contains(dst)) {
+ // integrity check: TaskSizeSplitterVertex should only appear in IRDAG.
+ if (!(edgeToInsert instanceof IREdge)) {
+ return this;
+ }
+
+ if (!originalEdge.getSrc().equals(src)) {
+ throw new IllegalVertexOperationException(originalEdge.getId()
+ + " and" + edgeToInsert.getId() + " should have same source, but founded\n edge : source"
+ + originalEdge.getId() + " : " + originalEdge.getSrc().getId()
+ + edgeToInsert.getId() + " : " + edgeToInsert.getSrc().getId());
+ }
+
+ if (!originalEdge.getDst().equals(dst)) {
+ throw new IllegalVertexOperationException(originalEdge.getId()
+ + " and" + edgeToInsert.getId() + " should have same destination, but founded\n edge : dest"
+ + originalEdge.getId() + " : " + originalEdge.getDst().getId()
+ + edgeToInsert.getId() + " : " + edgeToInsert.getDst().getId());
+ }
+
+ if (src instanceof TaskSizeSplitterVertex) {
+ TaskSizeSplitterVertex spSrc = (TaskSizeSplitterVertex) src;
+ IREdge internalEdge = spSrc.getEdgeWithInternalVertex((IREdge) originalEdge);
+ IREdge newInternalEdge = Util.cloneEdge(internalEdge, internalEdge.getSrc(), (IRVertex) dst);
+ spSrc.mapEdgeWithLoop((IREdge) originalEdge, newInternalEdge);
+ spSrc.mapEdgeWithLoop((IREdge) edgeToInsert, newInternalEdge);
+ }
+ if (dst instanceof TaskSizeSplitterVertex) {
+ TaskSizeSplitterVertex spDst = (TaskSizeSplitterVertex) dst;
+ IREdge internalEdge = spDst.getEdgeWithInternalVertex((IREdge) originalEdge);
+ IREdge newInternalEdge = Util.cloneEdge(internalEdge, (IRVertex) src, internalEdge.getDst());
+ spDst.mapEdgeWithLoop((IREdge) originalEdge, newInternalEdge);
+ spDst.mapEdgeWithLoop((IREdge) edgeToInsert, newInternalEdge);
+ }
+ incomingEdges.get(dst).add(edgeToInsert);
+ outgoingEdges.get(src).add(edgeToInsert);
+ } else {
+ this.buildWithoutSourceSinkCheck().storeJSON("debug", "errored_ir", "Errored IR");
+ throw new IllegalVertexOperationException("The DAG does not contain"
+ + (vertices.contains(src) ? "" : " [source]") + (vertices.contains(dst) ? "" : " [destination]")
+ + " of the edge: [" + (src == null ? null : src.getId())
+ + "]->[" + (dst == null ? null : dst.getId()) + "] in "
+ + vertices.stream().map(V::getId).collect(Collectors.toSet()));
+ }
+ return this;
+ }
+
+ /**
+ * This method adds a information in SplitterVertex's LoopEdge - InternalEdge relationship and connects the Edge
+ * without replacing existing mapping relationships.
+ * The changes which invokes this method should not be caused by SplitterVertex itself. Therefore, this method
+ * should be used when there are changes in vertices before / after SplitterVertex.
+ * Since {@param edgeToInsert} should also have a mapping relationship to originalVertices of SplitterVertex,
+ * we give {@param edgeToReference} together to copy the mapping information. Therefore, these two parameters must
+ * have at least one common source or destination.
+ *
+ * Relation to reference: edgeToReference - internalEdge
+ * Relation to add: edgeToInsert - newInternalEdge
+ *
+ * CAUTION: TaskSizeSplitterVertex must only appear in IRDAG.
+ *
+ * Use case example: when inserting trigger vertices before / after splitterVertex.
+ *
+ * @param edgeToReference edge connected to SplitterVertex, and to reference.
+ * @param edgeToInsert edge connected to SplitterVertex, and to insert.
+ * @return itself.
+ */
+ public DAGBuilder<V, E> connectSplitterVertexWithoutReplacing(final E edgeToReference, final E edgeToInsert) {
+ final V src = edgeToInsert.getSrc();
+ final V dst = edgeToInsert.getDst();
+
+ if (vertices.contains(src) && vertices.contains(dst)) {
+ // integrity check: TaskSizeSplitterVertex should only appear in IRDAG.
+ if (!(edgeToInsert instanceof IREdge)) {
+ return this;
+ }
+
+ if (src instanceof TaskSizeSplitterVertex && edgeToReference.getSrc().equals(src)) {
+ TaskSizeSplitterVertex spSrc = (TaskSizeSplitterVertex) src;
+ IREdge internalEdge = spSrc.getEdgeWithInternalVertex((IREdge) edgeToReference);
+ IREdge newInternalEdge = Util.cloneEdge((IREdge) edgeToInsert, internalEdge.getSrc(), (IRVertex) dst);
+ spSrc.mapEdgeWithLoop((IREdge) edgeToInsert, newInternalEdge);
+ }
+ if (dst instanceof TaskSizeSplitterVertex && edgeToReference.getDst().equals(dst)) {
+ TaskSizeSplitterVertex spDst = (TaskSizeSplitterVertex) dst;
+ IREdge internalEdge = spDst.getEdgeWithInternalVertex((IREdge) edgeToReference);
+ IREdge newInternalEdge = Util.cloneEdge(internalEdge,
+ (IRVertex) src,
+ internalEdge.getDst());
+ spDst.mapEdgeWithLoop((IREdge) edgeToInsert, newInternalEdge);
+ }
+ incomingEdges.get(dst).add(edgeToInsert);
+ outgoingEdges.get(src).add(edgeToInsert);
+ } else {
+ this.buildWithoutSourceSinkCheck().storeJSON("debug", "errored_ir", "Errored IR");
+ throw new IllegalVertexOperationException("The DAG does not contain"
+ + (vertices.contains(src) ? "" : " [source]") + (vertices.contains(dst) ? "" : " [destination]")
+ + " of the edge: [" + (src == null ? null : src.getId())
+ + "]->[" + (dst == null ? null : dst.getId()) + "] in "
+ + vertices.stream().map(V::getId).collect(Collectors.toSet()));
+ }
+ return this;
+ }
+
/**
* Checks whether the DAGBuilder is empty.
*
@@ -231,7 +364,10 @@
.filter(v -> v instanceof IRVertex);
// They should all match SourceVertex
if (!(verticesToObserve.get().allMatch(v -> (v instanceof SourceVertex)
- || (v instanceof SamplingVertex && ((SamplingVertex) v).getCloneOfOriginalVertex() instanceof SourceVertex)))) {
+ || (v instanceof SamplingVertex && ((SamplingVertex) v).getCloneOfOriginalVertex() instanceof SourceVertex)
+ || (v instanceof TaskSizeSplitterVertex && ((TaskSizeSplitterVertex) v).getOriginalVertices().stream()
+ .anyMatch(irVertex -> irVertex instanceof SourceVertex))
+ ))) {
final String problematicVertices = verticesToObserve.get()
.filter(v -> !(v instanceof SourceVertex))
.map(V::getId)
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index c95bcf0..ca4ac5f 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -38,10 +38,12 @@
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.SignalVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -200,8 +202,12 @@
converted.add(sv); // explicit conversion to IRVertex is needed.. otherwise the compiler complains :(
}
return converted;
- } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof TriggerVertex) {
+ } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof MessageGeneratorVertex) {
return messageVertexToGroup.get(vertexToDelete);
+ } else if (vertexToDelete instanceof SignalVertex) {
+ return Sets.newHashSet(vertexToDelete);
+ } else if (vertexToDelete instanceof TaskSizeSplitterVertex) {
+ return Sets.newHashSet(vertexToDelete);
} else {
throw new IllegalArgumentException(vertexToDelete.getId());
}
@@ -261,7 +267,7 @@
.forEach(srcVertex -> builder.connectVertices(
Util.cloneEdge(streamVertexToOriginalEdge.get(vertexToDelete), srcVertex, dstVertex))));
modifiedDAG = builder.buildWithoutSourceSinkCheck();
- } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof TriggerVertex) {
+ } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof MessageGeneratorVertex) {
modifiedDAG = rebuildExcluding(modifiedDAG, vertexGroupToDelete).buildWithoutSourceSinkCheck();
final Optional<Integer> deletedMessageIdOptional = vertexGroupToDelete.stream()
.filter(vtd -> vtd instanceof MessageAggregatorVertex)
@@ -275,6 +281,19 @@
hashSet -> hashSet.remove(deletedMessageId))));
} else if (vertexToDelete instanceof SamplingVertex) {
modifiedDAG = rebuildExcluding(modifiedDAG, vertexGroupToDelete).buildWithoutSourceSinkCheck();
+ } else if (vertexToDelete instanceof SignalVertex) {
+ modifiedDAG = rebuildExcluding(modifiedDAG, vertexGroupToDelete).buildWithoutSourceSinkCheck();
+ final Optional<Integer> deletedMessageIdOptional = vertexGroupToDelete.stream()
+ .map(vtd -> vtd.getPropertyValue(MessageIdVertexProperty.class).<IllegalArgumentException>orElseThrow(
+ () -> new IllegalArgumentException(
+ "SignalVertex " + vtd.getId() + " does not have MessageIdVertexProperty.")))
+ .findAny();
+ deletedMessageIdOptional.ifPresent(deletedMessageId ->
+ modifiedDAG.getEdges().forEach(e ->
+ e.getPropertyValue(MessageIdEdgeProperty.class).ifPresent(
+ hashSet -> hashSet.remove(deletedMessageId))));
+ } else if (vertexToDelete instanceof TaskSizeSplitterVertex) {
+ modifiedDAG = rebuildExcludingSplitter(modifiedDAG, vertexGroupToDelete).buildWithoutSourceSinkCheck();
} else {
throw new IllegalArgumentException(vertexToDelete.getId());
}
@@ -292,6 +311,64 @@
}
/**
+ * helper method in deleting splitter vertex.
+ * @param dag dag to get information.
+ * @param excluded Set of Splitter vertex to delete. Always a singleton set.
+ * @return
+ */
+ private DAGBuilder<IRVertex, IREdge> rebuildExcludingSplitter(final DAG<IRVertex, IREdge> dag,
+ final Set<IRVertex> excluded) {
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+ dag.getVertices().stream().filter(v -> !excluded.contains(v)).forEach(builder::addVertex);
+ dag.getEdges().stream()
+ .filter(e -> !(excluded.contains(e.getSrc()) || excluded.contains(e.getDst())))
+ .forEach(builder::connectVertices);
+
+ for (IRVertex vertex : excluded) {
+ if (!(vertex instanceof TaskSizeSplitterVertex)) {
+ break;
+ }
+ final TaskSizeSplitterVertex splitter = (TaskSizeSplitterVertex) vertex;
+ //first, restore original vertices
+ DAG<IRVertex, IREdge> internalDag = splitter.getDAG();
+ internalDag.getVertices().stream().filter(v -> !(v instanceof SignalVertex)).forEach(builder::addVertex);
+ internalDag.getEdges().stream()
+ .filter(e -> !(e.getSrc() instanceof SignalVertex || e.getDst() instanceof SignalVertex))
+ .forEach(builder::connectVertices);
+
+ //second, take care of edges connected to splitter vertex
+ for (IREdge edgeToSplitter : dag.getIncomingEdgesOf(splitter)) {
+ if (edgeToSplitter.getSrc() instanceof TaskSizeSplitterVertex) {
+ final TaskSizeSplitterVertex prevSp = (TaskSizeSplitterVertex) edgeToSplitter.getSrc();
+ final IREdge internalEdge = prevSp.getEdgeWithInternalVertex(edgeToSplitter);
+ final IREdge newEdgeToPrevSp = Util.cloneEdge(internalEdge, prevSp, internalEdge.getDst());
+ prevSp.mapEdgeWithLoop(newEdgeToPrevSp, internalEdge);
+
+ builder.connectVertices(newEdgeToPrevSp);
+ } else {
+ final IREdge internalEdge = splitter.getEdgeWithInternalVertex(edgeToSplitter);
+ builder.connectVertices(internalEdge);
+ }
+ }
+
+ for (IREdge edgeFromSplitter : dag.getOutgoingEdgesOf(splitter)) {
+ if (edgeFromSplitter.getDst() instanceof TaskSizeSplitterVertex) {
+ final TaskSizeSplitterVertex nextSp = (TaskSizeSplitterVertex) edgeFromSplitter.getDst();
+ final IREdge internalEdge = nextSp.getEdgeWithInternalVertex(edgeFromSplitter);
+ final IREdge newEdgeToNextSp = Util.cloneEdge(internalEdge, internalEdge.getSrc(), nextSp);
+ nextSp.mapEdgeWithLoop(newEdgeToNextSp, internalEdge);
+
+ builder.connectVertices(newEdgeToNextSp);
+ } else {
+ final IREdge internalEdge = splitter.getEdgeWithInternalVertex(edgeFromSplitter);
+ builder.connectVertices(internalEdge);
+ }
+ }
+ }
+ return builder;
+ }
+
+ /**
* Inserts a new vertex that streams data.
* <p>
* Before: src - edgeToStreamize - dst
@@ -316,6 +393,11 @@
throw new CompileTimeOptimizationException(edgeToStreamize.getId() + " has a MessageId, and cannot be removed");
}
+ // RelayVertex should not be inserted before SplitterVertex.
+ if (edgeToStreamize.getDst() instanceof TaskSizeSplitterVertex) {
+ return;
+ }
+
// Insert the vertex.
final IRVertex vertexToInsert = wrapSamplingVertexIfNeeded(relayVertex, edgeToStreamize.getSrc());
builder.addVertex(vertexToInsert);
@@ -387,20 +469,21 @@
* <p>
* TODO #345: Simplify insert(TriggerVertex)
*
- * @param triggerVertex to insert.
+ * @param messageGeneratorVertex to insert.
* @param messageAggregatorVertex to insert.
* @param triggerOutputEncoder to use.
* @param triggerOutputDecoder to use.
* @param edgesToGetStatisticsOf to examine.
* @param edgesToOptimize to optimize.
*/
- public void insert(final TriggerVertex triggerVertex,
+ public void insert(final MessageGeneratorVertex messageGeneratorVertex,
final MessageAggregatorVertex messageAggregatorVertex,
final EncoderProperty triggerOutputEncoder,
final DecoderProperty triggerOutputDecoder,
final Set<IREdge> edgesToGetStatisticsOf,
final Set<IREdge> edgesToOptimize) {
- assertNonExistence(triggerVertex);
+ //edge case: when the destination of mav is splitter, do not insert!
+ assertNonExistence(messageGeneratorVertex);
assertNonExistence(messageAggregatorVertex);
edgesToGetStatisticsOf.forEach(this::assertNonControlEdge);
edgesToOptimize.forEach(this::assertNonControlEdge);
@@ -427,7 +510,7 @@
final List<IRVertex> triggerList = new ArrayList<>();
for (final IREdge edge : edgesToGetStatisticsOf) {
final IRVertex triggerToAdd = wrapSamplingVertexIfNeeded(
- new TriggerVertex<>(triggerVertex.getMessageFunction()), edge.getSrc());
+ new MessageGeneratorVertex<>(messageGeneratorVertex.getMessageFunction()), edge.getSrc());
builder.addVertex(triggerToAdd);
triggerList.add(triggerToAdd);
edge.getSrc().getPropertyValue(ParallelismProperty.class)
@@ -444,7 +527,11 @@
final IREdge clone = Util.cloneEdge(
CommunicationPatternProperty.Value.ONE_TO_ONE, edgeToClone, edge.getSrc(), triggerToAdd);
- builder.connectVertices(clone);
+ if (edge.getSrc() instanceof TaskSizeSplitterVertex) {
+ builder.connectSplitterVertexWithoutReplacing(edgeToClone, clone);
+ } else {
+ builder.connectVertices(clone);
+ }
}
// Add agg (no need to wrap inside sampling vertices)
@@ -459,8 +546,14 @@
// From agg to dst
// Add a control dependency (no output) from the messageAggregatorVertex to the destination.
- builder.connectVertices(
- Util.createControlEdge(messageAggregatorVertex, edgesToGetStatisticsOf.iterator().next().getDst()));
+ IREdge aggToDst = Util.createControlEdge(
+ messageAggregatorVertex, edgesToGetStatisticsOf.iterator().next().getDst());
+ if (edgesToGetStatisticsOf.iterator().next().getDst() instanceof TaskSizeSplitterVertex) {
+ builder.connectSplitterVertexWithoutReplacing(edgesToGetStatisticsOf.iterator().next(), aggToDst);
+ } else {
+ builder.connectVertices(aggToDst);
+ }
+
////////////////////////////////// STEP 2: Annotate the MessageId on optimization target edges
@@ -485,6 +578,60 @@
}
/**
+ * Inserts new vertex which calls for runtime pass.
+ *
+ * e.g) suppose that we want to change vertex 2's property by using runtime pass, but the related data is not gained
+ * directly from the incoming edge of vertex 2 (for example, the data is gained from using simulation).
+ * In this case, it is unnecessary to insert message generator vertex and message aggregator vertex to launch runtime
+ * pass.
+ *
+ * Original case: (vertex1) -- shuffle edge -- (vertex 2)
+ *
+ * After inserting signal Vertex:
+ * (vertex 1) -------------------- shuffle edge ------------------- (vertex 2)
+ * -- control edge -- (signal vertex) -- control edge --
+ *
+ * Therefore, the shuffle edge to vertex 2 is executed after signal vertex is executed.
+ * Since signal vertex only 'signals' the launch of runtime pass, its parallelism is sufficient to be only 1.
+ * @param toInsert Signal vertex to optimize.
+ * @param edgeToOptimize Original edge to optimize(in the above example, shuffle edge).
+ */
+ public void insert(final SignalVertex toInsert,
+ final IREdge edgeToOptimize) {
+
+ // Create a completely new DAG with the vertex inserted.
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+
+ // All of the existing vertices and edges remain intact
+ modifiedDAG.topologicalDo(v -> {
+ builder.addVertex(v);
+ modifiedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
+ });
+
+ // insert Signal Vertex in DAG.
+ builder.addVertex(toInsert);
+
+ final IREdge controlEdgeToSV = Util.createControlEdge(edgeToOptimize.getSrc(), toInsert);
+ final IREdge controlEdgeFromSV = Util.createControlEdge(toInsert, edgeToOptimize.getDst());
+
+ builder.connectVertices(controlEdgeToSV);
+ builder.connectVertices(controlEdgeFromSV);
+
+ modifiedDAG.topologicalDo(v ->
+ modifiedDAG.getIncomingEdgesOf(v).forEach(inEdge -> {
+ if (edgeToOptimize.equals(inEdge)) {
+ final HashSet<Integer> msgEdgeIds =
+ inEdge.getPropertyValue(MessageIdEdgeProperty.class).orElse(new HashSet<>(0));
+ msgEdgeIds.add(toInsert.getPropertyValue(MessageIdVertexProperty.class).get());
+ inEdge.setProperty(MessageIdEdgeProperty.of(msgEdgeIds));
+ }
+ })
+ );
+ // update the DAG.
+ modifiedDAG = builder.build();
+ }
+
+ /**
* Inserts a set of samplingVertices that process sampled data.
* <p>
* This method automatically inserts the following three types of edges.
@@ -577,6 +724,82 @@
}
/**
+ * Insert TaskSizeSplitterVertex in dag.
+ * @param toInsert TaskSizeSplitterVertex to insert.
+ */
+ public void insert(final TaskSizeSplitterVertex toInsert) {
+ final Set<IRVertex> originalVertices = toInsert.getOriginalVertices();
+
+ final Set<IREdge> incomingEdgesOfOriginalVertices = originalVertices
+ .stream()
+ .flatMap(ov -> modifiedDAG.getIncomingEdgesOf(ov).stream())
+ .collect(Collectors.toSet());
+
+ final Set<IREdge> outgoingEdgesOfOriginalVertices = originalVertices
+ .stream()
+ .flatMap(ov -> modifiedDAG.getOutgoingEdgesOf(ov).stream())
+ .collect(Collectors.toSet());
+
+ final Set<IREdge> fromOutsideToOriginal = toInsert.getEdgesFromOutsideToOriginal(modifiedDAG);
+ final Set<IREdge> fromOriginalToOutside = toInsert.getEdgesFromOriginalToOutside(modifiedDAG);
+
+ // make edges connected to splitter vertex
+ final Set<IREdge> fromOutsideToSplitter = toInsert.getEdgesFromOutsideToSplitter(modifiedDAG);
+ final Set<IREdge> fromSplitterToOutside = toInsert.getEdgesFromSplitterToOutside(modifiedDAG);
+
+ //map splitter vertex connection to corresponding internal vertex connection
+ for (IREdge splitterEdge : fromSplitterToOutside) {
+ for (IREdge internalEdge : fromOriginalToOutside) {
+ if (splitterEdge.getDst() instanceof TaskSizeSplitterVertex) {
+ TaskSizeSplitterVertex nextSplitter = (TaskSizeSplitterVertex) splitterEdge.getDst();
+ if (nextSplitter.getOriginalVertices().contains(internalEdge.getDst())) {
+ toInsert.mapEdgeWithLoop(splitterEdge, internalEdge);
+ }
+ } else {
+ if (splitterEdge.getDst().equals(internalEdge.getDst())) {
+ toInsert.mapEdgeWithLoop(splitterEdge, internalEdge);
+ }
+ }
+ }
+ }
+
+ for (IREdge splitterEdge : fromOutsideToSplitter) {
+ for (IREdge internalEdge : fromOutsideToOriginal) {
+ if (splitterEdge.getSrc().equals(internalEdge.getSrc())) {
+ toInsert.mapEdgeWithLoop(splitterEdge, internalEdge);
+ }
+ }
+ }
+
+ fromOutsideToOriginal.forEach(toInsert::addDagIncomingEdge);
+ fromOutsideToOriginal.forEach(toInsert::addNonIterativeIncomingEdge);
+ fromOriginalToOutside.forEach(toInsert::addDagOutgoingEdge);
+
+ // All preparation done. Insert splitter vertex.
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+
+ //insert vertex and edges irrelevant to splitter vertex
+ modifiedDAG.topologicalDo(v -> {
+ if (!originalVertices.contains(v)) {
+ builder.addVertex(v);
+ for (IREdge edge : modifiedDAG.getIncomingEdgesOf(v)) {
+ if (!incomingEdgesOfOriginalVertices.contains(edge) && !outgoingEdgesOfOriginalVertices.contains(edge)) {
+ builder.connectVertices(edge);
+ }
+ }
+ }
+ });
+ //insert splitter vertices
+ builder.addVertex(toInsert);
+
+ //connect splitter to outside world
+ fromOutsideToSplitter.forEach(builder::connectVertices);
+ fromSplitterToOutside.forEach(builder::connectVertices);
+
+ modifiedDAG = builder.build();
+ }
+
+ /**
* Reshape unsafely, without guarantees on preserving application semantics.
* TODO #330: Refactor Unsafe Reshaping Passes
*
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
index bdca93c..ed7a444 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
@@ -29,9 +29,11 @@
import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.*;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.transform.SignalTransform;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -311,7 +313,8 @@
final GlobalDAGChecker messageIds = (dag -> {
final long numMessageAggregatorVertices = dag.getVertices()
.stream()
- .filter(v -> v instanceof MessageAggregatorVertex)
+ .filter(v -> v instanceof MessageAggregatorVertex
+ || (v instanceof OperatorVertex && ((OperatorVertex) v).getTransform() instanceof SignalTransform))
.count();
// Triggering ids, must be unique
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IdManager.java b/common/src/main/java/org/apache/nemo/common/ir/IdManager.java
index 2ff5255..8d4c98d 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IdManager.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IdManager.java
@@ -40,6 +40,7 @@
private static AtomicInteger vertexId = new AtomicInteger(1);
private static AtomicInteger edgeId = new AtomicInteger(1);
private static AtomicLong resourceSpecIdGenerator = new AtomicLong(0);
+ private static AtomicInteger messageId = new AtomicInteger(1);
private static volatile boolean isDriver = false;
// Vertex ID Map to be used upon cloning in loop vertices.
@@ -99,6 +100,9 @@
return "ResourceSpec" + resourceSpecIdGenerator.getAndIncrement();
}
+ public static Integer generateMessageId() {
+ return messageId.getAndIncrement();
+ }
/**
* Set the realm of the loaded class as REEF driver.
*/
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
index b9af41b..5780c6f 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
@@ -28,6 +28,8 @@
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.util.HashMap;
@@ -40,8 +42,9 @@
/**
* IRVertex that contains a partial DAG that is iterative.
*/
-public final class LoopVertex extends IRVertex {
-
+//TODO 454: Change dependency between LoopVertex and TaskSizeSplitterVertex.
+public class LoopVertex extends IRVertex {
+ private static final Logger LOG = LoggerFactory.getLogger(LoopVertex.class.getName());
private final AtomicInteger duplicateEdgeGroupId = new AtomicInteger(0);
// Contains DAG information
private final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
@@ -95,7 +98,7 @@
}
@Override
- public LoopVertex getClone() {
+ public final LoopVertex getClone() {
return new LoopVertex(this);
}
@@ -127,6 +130,16 @@
* @param edgeWithInternalVertex the corresponding edge from/to internal vertex
*/
public void mapEdgeWithLoop(final IREdge edgeWithLoop, final IREdge edgeWithInternalVertex) {
+ if (this.edgeWithLoopToEdgeWithInternalVertex.containsKey(edgeWithLoop)
+ && !this.edgeWithInternalVertexToEdgeWithLoop.containsKey(edgeWithInternalVertex)) {
+ // A B to A B'
+ this.edgeWithInternalVertexToEdgeWithLoop.remove(this.edgeWithLoopToEdgeWithInternalVertex.get(edgeWithLoop));
+ } else if (this.edgeWithInternalVertexToEdgeWithLoop.containsKey(edgeWithInternalVertex)
+ && !this.edgeWithLoopToEdgeWithInternalVertex.containsKey(edgeWithLoop)) {
+ // A B to A' B
+ this.edgeWithLoopToEdgeWithInternalVertex.remove(
+ this.edgeWithInternalVertexToEdgeWithLoop.get(edgeWithInternalVertex));
+ }
this.edgeWithLoopToEdgeWithInternalVertex.put(edgeWithLoop, edgeWithInternalVertex);
this.edgeWithInternalVertexToEdgeWithLoop.put(edgeWithInternalVertex, edgeWithLoop);
}
@@ -140,6 +153,29 @@
}
/**
+ * @param edgeWithLoop an edge with loop
+ * @return the corresponding edge with internal vertex for the specified edge with loop
+ */
+ public IREdge getEdgeWithInternalVertex(final IREdge edgeWithLoop) {
+ return this.edgeWithLoopToEdgeWithInternalVertex.getOrDefault(edgeWithLoop,
+ new HashMap<>(this.edgeWithLoopToEdgeWithInternalVertex).get(edgeWithLoop));
+ }
+
+ /**
+ * Getter method for edgeWithLoopToEdgeWithInternalVertex.
+ */
+ public Map<IREdge, IREdge> getEdgeWithLoopToEdgeWithInternalVertex() {
+ return this.edgeWithLoopToEdgeWithInternalVertex;
+ }
+
+ /**
+ * Getter method for edgeWithInternalVertexToEdgeWithLoop.
+ */
+ public Map<IREdge, IREdge> getEdgeWithInternalVertexToEdgeWithLoop() {
+ return this.edgeWithInternalVertexToEdgeWithLoop;
+ }
+
+ /**
* Adds the incoming edge of the contained DAG.
*
* @param edge edge to add.
@@ -157,6 +193,17 @@
}
/**
+ * Removes the incoming edge of the contained DAG.
+ *
+ * @param edge edge to remove
+ */
+ public void removeDagIncomingEdge(final IREdge edge) {
+ if (this.dagIncomingEdges.containsKey(edge.getDst())) {
+ this.dagIncomingEdges.get(edge.getDst()).remove(edge);
+ }
+ }
+
+ /**
* Adds an iterative incoming edge, from the previous iteration, but connection internally.
*
* @param edge edge to add.
@@ -174,6 +221,17 @@
}
/**
+ * Remove an iterative incoming edge.
+ *
+ * @param edge edge to remove
+ */
+ public void removeIterativeIncomingEdge(final IREdge edge) {
+ if (this.iterativeIncomingEdges.containsKey(edge.getDst())) {
+ this.iterativeIncomingEdges.get(edge.getDst()).remove(edge);
+ }
+ }
+
+ /**
* Adds a non-iterative incoming edge, from outside the previous iteration.
*
* @param edge edge to add.
@@ -191,6 +249,16 @@
}
/**
+ * Removes non iterative incoming edge.
+ * @param edge edge to remove.
+ */
+ public void removeNonIterativeIncomingEdge(final IREdge edge) {
+ if (this.nonIterativeIncomingEdges.containsKey(edge.getDst())) {
+ this.nonIterativeIncomingEdges.get(edge.getDst()).remove(edge);
+ }
+ }
+
+ /**
* Adds and outgoing edge of the contained DAG.
*
* @param edge edge to add.
@@ -208,6 +276,17 @@
}
/**
+ * Removes a dag outgoing edge.
+ *
+ * @param edge edge to remove.
+ */
+ public void removeDagOutgoingEdge(final IREdge edge) {
+ if (this.dagOutgoingEdges.containsKey(edge.getSrc())) {
+ this.dagOutgoingEdges.get(edge.getSrc()).remove(edge);
+ }
+ }
+
+ /**
* Marks duplicate edges with DuplicateEdgeGroupProperty.
*/
public void markDuplicateEdges() {
@@ -330,7 +409,7 @@
/**
* decrease the value of maximum number of iterations by 1.
*/
- private void decreaseMaxNumberOfIterations() {
+ protected void decreaseMaxNumberOfIterations() {
this.maxNumberOfIterations--;
}
@@ -359,6 +438,9 @@
}
@Override
+ /**
+ * Parse Properties to JsonNode.
+ */
public ObjectNode getPropertiesAsJsonNode() {
final ObjectNode node = getIRVertexPropertiesAsJsonNode();
node.put("remainingIteration", maxNumberOfIterations);
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java
index ada8af3..f706a8e 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java
@@ -20,12 +20,12 @@
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.OutputCollector;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
- * A {@link Transform} that aggregates statistics generated by the {@link TriggerTransform}.
+ * A {@link Transform} that aggregates statistics generated by the {@link MessageGeneratorTransform}.
*
* @param <K> input key type.
* @param <V> input value type.
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/TriggerTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageGeneratorTransform.java
similarity index 77%
rename from common/src/main/java/org/apache/nemo/common/ir/vertex/transform/TriggerTransform.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageGeneratorTransform.java
index 5c982ef..ec198ef 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/TriggerTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageGeneratorTransform.java
@@ -20,7 +20,7 @@
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.OutputCollector;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -28,26 +28,26 @@
import java.util.Map;
/**
- * A {@link Transform} for the trigger vertex.
+ * A {@link Transform} for the message generator vertex.
*
* @param <I> input type.
* @param <K> output key type.
* @param <V> output value type.
*/
-public final class TriggerTransform<I, K, V> extends NoWatermarkEmitTransform<I, Pair<K, V>> {
- private static final Logger LOG = LoggerFactory.getLogger(TriggerTransform.class.getName());
+public final class MessageGeneratorTransform<I, K, V> extends NoWatermarkEmitTransform<I, Pair<K, V>> {
+ private static final Logger LOG = LoggerFactory.getLogger(MessageGeneratorTransform.class.getName());
private transient OutputCollector<Pair<K, V>> outputCollector;
private transient Map<K, V> holder;
- private final TriggerVertex.MessageGeneratorFunction<I, K, V> userFunction;
+ private final MessageGeneratorVertex.MessageGeneratorFunction<I, K, V> userFunction;
/**
* TriggerTransform constructor.
*
* @param userFunction that analyzes the data.
*/
- public TriggerTransform(final TriggerVertex.MessageGeneratorFunction<I, K, V> userFunction) {
+ public MessageGeneratorTransform(final MessageGeneratorVertex.MessageGeneratorFunction<I, K, V> userFunction) {
this.userFunction = userFunction;
}
@@ -73,7 +73,7 @@
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
- sb.append(TriggerTransform.class);
+ sb.append(MessageGeneratorTransform.class);
sb.append(":");
sb.append(super.toString());
return sb.toString();
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/SignalTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/SignalTransform.java
new file mode 100644
index 0000000..ae711f0
--- /dev/null
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/SignalTransform.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.common.ir.vertex.transform;
+
+import org.apache.nemo.common.ir.OutputCollector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class evokes run-time pass when there's no need to transfer any run-time information.
+ */
+public final class SignalTransform extends NoWatermarkEmitTransform<Void, Void> {
+ private static final Logger LOG = LoggerFactory.getLogger(SignalTransform.class.getName());
+ private transient Void elementHolder;
+ private transient OutputCollector<Void> outputCollector;
+
+ @Override
+ public void prepare(final Context context, final OutputCollector<Void> oc) {
+ this.outputCollector = oc;
+ }
+
+ @Override
+ public void onData(final Void element) {
+ elementHolder = element;
+ }
+
+ @Override
+ public void close() {
+ outputCollector.emit(elementHolder);
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder();
+ sb.append(SignalTransform.class);
+ sb.append(":");
+ sb.append(super.toString());
+ return sb.toString();
+ }
+}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
index 6201575..77671f9 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
@@ -22,6 +22,7 @@
import org.apache.nemo.common.Util;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
/**
* Executes the original IRVertex using a subset of input data partitions.
@@ -38,7 +39,7 @@
*/
public SamplingVertex(final IRVertex originalVertex, final float desiredSampleRate) {
super();
- if (!(originalVertex instanceof TriggerVertex) && (Util.isUtilityVertex(originalVertex))) {
+ if (!(originalVertex instanceof MessageGeneratorVertex) && (Util.isUtilityVertex(originalVertex))) {
throw new IllegalArgumentException(
"Cannot sample non-Trigger utility vertices: " + originalVertex.toString());
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TaskSizeSplitterVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TaskSizeSplitterVertex.java
new file mode 100644
index 0000000..fad3170
--- /dev/null
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TaskSizeSplitterVertex.java
@@ -0,0 +1,484 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.common.ir.vertex.utility;
+
+import org.apache.commons.lang.mutable.MutableInt;
+import org.apache.nemo.common.HashRange;
+import org.apache.nemo.common.KeyRange;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.dag.DAG;
+import org.apache.nemo.common.dag.DAGBuilder;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.SubPartitionSetProperty;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.common.ir.vertex.LoopVertex;
+import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
+import org.apache.nemo.common.ir.vertex.transform.SignalTransform;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.SignalVertex;
+import org.apache.nemo.common.test.EmptyComponents;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+
+/**
+ * This vertex works as a partition-based sampling vertex of dynamic task sizing pass.
+ * It covers both sampling vertices and optimized vertices known from sampling by iterating same vertices, giving
+ * different properties in each iteration.
+ */
+//TODO 454: Change dependency between LoopVertex and TaskSizeSplitterVertex.
+public final class TaskSizeSplitterVertex extends LoopVertex {
+ // Information about original(before splitting) vertices
+ private static final Logger LOG = LoggerFactory.getLogger(TaskSizeSplitterVertex.class.getName());
+ private final Set<IRVertex> originalVertices;
+ // Vertex which has incoming edge from other groups. Guaranteed to be only one in each group by stage partitioner
+ private final Set<IRVertex> groupStartingVertices;
+ // vertices which has outgoing edge to other groups. Can be more than one in one groups
+ private final Set<IRVertex> verticesWithGroupOutgoingEdges;
+ // vertices which does not have any outgoing edge to vertices in same group
+ private final Set<IRVertex> groupEndingVertices;
+
+ // Information about partition sizes
+ private final int partitionerProperty;
+
+ // Information about splitter vertex's iteration
+ private final MutableInt testingTrial;
+
+ private final Map<IRVertex, IRVertex> mapOfOriginalVertexToClone = new HashMap<>();
+
+ /**
+ * Default constructor of TaskSizeSplitterVertex.
+ * @param splitterVertexName for now, this doesn't do anything. This is inserted to enable extension
+ * from LoopVertex.
+ * @param originalVertices Set of vertices which form one stage and which splitter will wrap up.
+ * @param groupStartingVertices The first vertex in stage. Although it is given as a Set, we assert that
+ * this set has only one element (guaranteed by stage partitioner logic)
+ * @param verticesWithGroupOutgoingEdges Vertices which has outgoing edges to other stage.
+ * @param groupEndingVertices Vertices which has only outgoing edges to other stage.
+ * @param edgesBetweenOriginalVertices Edges which connects original vertices.
+ * @param partitionerProperty PartitionerProperty of incoming stage edge regarding to job data size.
+ * For more information, check
+ */
+ public TaskSizeSplitterVertex(final String splitterVertexName,
+ final Set<IRVertex> originalVertices,
+ final Set<IRVertex> groupStartingVertices,
+ final Set<IRVertex> verticesWithGroupOutgoingEdges,
+ final Set<IRVertex> groupEndingVertices,
+ final Set<IREdge> edgesBetweenOriginalVertices,
+ final int partitionerProperty) {
+ super(splitterVertexName); // need to take care of here
+ testingTrial = new MutableInt(0);
+ this.originalVertices = originalVertices;
+ this.partitionerProperty = partitionerProperty;
+ for (IRVertex original : originalVertices) {
+ mapOfOriginalVertexToClone.putIfAbsent(original, original.getClone());
+ }
+ this.groupStartingVertices = groupStartingVertices;
+ this.verticesWithGroupOutgoingEdges = verticesWithGroupOutgoingEdges;
+ this.groupEndingVertices = groupEndingVertices;
+
+ insertWorkingVertices(originalVertices, edgesBetweenOriginalVertices);
+ //insertSignalVertex(new SignalVertex());
+ }
+
+ // Getters of attributes
+ public Set<IRVertex> getOriginalVertices() {
+ return originalVertices;
+ }
+
+ public Set<IRVertex> getGroupStartingVertices() {
+ return groupStartingVertices;
+ }
+
+ public Set<IRVertex> getVerticesWithGroupOutgoingEdges() {
+ return verticesWithGroupOutgoingEdges;
+ }
+
+ public Set<IRVertex> getGroupEndingVertices() {
+ return groupEndingVertices;
+ }
+
+ /**
+ * Insert vertices from original dag. This does not harm their topological order.
+ * @param stageVertices vertices to insert. can be same as OriginalVertices.
+ * @param edgesInBetween edges connecting stageVertices. This stage does not contain any edge
+ * that are connected to vertices other than those in stageVertices.
+ * (Both ends need to be the element of stageVertices)
+ */
+ private void insertWorkingVertices(final Set<IRVertex> stageVertices, final Set<IREdge> edgesInBetween) {
+ stageVertices.forEach(vertex -> getBuilder().addVertex(vertex));
+ edgesInBetween.forEach(edge -> getBuilder().connectVertices(edge));
+ }
+
+ /**
+ * Inserts signal Vertex at the end of the iteration. Last iteration does not contain any signal vertex.
+ * (stage finishing vertices) - dummyShuffleEdge - SignalVertex
+ * SignalVertex - ControlEdge - (stage starting vertices)
+ * @param toInsert SignalVertex to insert.
+ */
+ private void insertSignalVertex(final SignalVertex toInsert) {
+ getBuilder().addVertex(toInsert);
+ for (IRVertex lastVertex : groupEndingVertices) {
+ IREdge edgeToSignal = EmptyComponents.newDummyShuffleEdge(lastVertex, toInsert);
+ getBuilder().connectVertices(edgeToSignal);
+ for (IRVertex firstVertex : groupStartingVertices) {
+ IREdge controlEdgeToBeginning = Util.createControlEdge(toInsert, firstVertex);
+ addIterativeIncomingEdge(controlEdgeToBeginning);
+ }
+ }
+ }
+
+ public void increaseTestingTrial() {
+ testingTrial.add(1);
+ }
+
+ /**
+ * Need to be careful about Signal Vertex, because they do not appear in the last iteration.
+ * @param dagBuilder DAGBuilder to add the unrolled iteration to.
+ * @return Modified this object
+ */
+ public TaskSizeSplitterVertex unRollIteration(final DAGBuilder<IRVertex, IREdge> dagBuilder) {
+ final HashMap<IRVertex, IRVertex> originalToNewIRVertex = new HashMap<>();
+ final HashSet<IRVertex> originalUtilityVertices = new HashSet<>();
+ final HashSet<IREdge> edgesToOptimize = new HashSet<>();
+
+ if (testingTrial.intValue() == 0) {
+ insertSignalVertex(new SignalVertex());
+ }
+
+ final List<OperatorVertex> previousSignalVertex = new ArrayList<>(1);
+ final DAG<IRVertex, IREdge> dagToAdd = getDAG();
+
+ decreaseMaxNumberOfIterations();
+
+ // add the working vertex and its incoming edges to the dagBuilder.
+ dagToAdd.topologicalDo(irVertex -> {
+ if (!(irVertex instanceof SignalVertex)) {
+ final IRVertex newIrVertex = irVertex.getClone();
+ setParallelismPropertyByTestingTrial(newIrVertex);
+ originalToNewIRVertex.putIfAbsent(irVertex, newIrVertex);
+ dagBuilder.addVertex(newIrVertex, dagToAdd);
+ dagToAdd.getIncomingEdgesOf(irVertex).forEach(edge -> {
+ final IRVertex newSrc = originalToNewIRVertex.get(edge.getSrc());
+ final IREdge newIrEdge =
+ new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), newSrc, newIrVertex);
+ edge.copyExecutionPropertiesTo(newIrEdge);
+ setSubPartitionSetPropertyByTestingTrial(newIrEdge);
+ edgesToOptimize.add(newIrEdge);
+ dagBuilder.connectVertices(newIrEdge);
+ });
+ } else {
+ originalUtilityVertices.add(irVertex);
+ }
+ });
+
+ // process the initial DAG incoming edges for the first loop.
+ getDagIncomingEdges().forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
+ final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
+ edge.getSrc(), originalToNewIRVertex.get(dstVertex));
+ edge.copyExecutionPropertiesTo(newIrEdge);
+ setSubPartitionSetPropertyByTestingTrial(newIrEdge);
+ if (edge.getSrc() instanceof OperatorVertex
+ && ((OperatorVertex) edge.getSrc()).getTransform() instanceof SignalTransform) {
+ previousSignalVertex.add((OperatorVertex) edge.getSrc());
+ } else {
+ edgesToOptimize.add(newIrEdge);
+ }
+ dagBuilder.connectVertices(newIrEdge);
+ }));
+
+ getDagOutgoingEdges().forEach((srcVertex, irEdges) -> irEdges.forEach(edgeFromOriginal -> {
+ for (Map.Entry<IREdge, IREdge> entry : this.getEdgeWithInternalVertexToEdgeWithLoop().entrySet()) {
+ if (entry.getKey().getId().equals(edgeFromOriginal.getId())) {
+ final IREdge correspondingEdge = entry.getValue(); // edge to next splitter vertex
+ if (correspondingEdge.getDst() instanceof TaskSizeSplitterVertex) {
+ TaskSizeSplitterVertex nextSplitter = (TaskSizeSplitterVertex) correspondingEdge.getDst();
+ IRVertex dstVertex = edgeFromOriginal.getDst(); // vertex inside of next splitter vertex
+ List<IREdge> edgesToDelete = new ArrayList<>();
+ List<IREdge> edgesToAdd = new ArrayList<>();
+ for (IREdge edgeToDst : nextSplitter.getDagIncomingEdges().get(dstVertex)) {
+ if (edgeToDst.getSrc().getId().equals(srcVertex.getId())) {
+ final IREdge newIrEdge = new IREdge(
+ edgeFromOriginal.getPropertyValue(CommunicationPatternProperty.class).get(),
+ originalToNewIRVertex.get(srcVertex),
+ edgeFromOriginal.getDst());
+ edgeToDst.copyExecutionPropertiesTo(newIrEdge);
+ edgesToDelete.add(edgeToDst);
+ edgesToAdd.add(newIrEdge);
+ final IREdge newLoopEdge = Util.cloneEdge(
+ correspondingEdge, newIrEdge.getSrc(), correspondingEdge.getDst());
+ nextSplitter.mapEdgeWithLoop(newLoopEdge, newIrEdge);
+ }
+ }
+ if (loopTerminationConditionMet()) {
+ for (IREdge edgeToDelete : edgesToDelete) {
+ nextSplitter.removeDagIncomingEdge(edgeToDelete);
+ nextSplitter.removeNonIterativeIncomingEdge(edgeToDelete);
+ }
+ }
+ for (IREdge edgeToAdd : edgesToAdd) {
+ nextSplitter.addDagIncomingEdge(edgeToAdd);
+ nextSplitter.addNonIterativeIncomingEdge(edgeToAdd);
+ }
+ } else {
+ final IREdge newIrEdge = new IREdge(
+ edgeFromOriginal.getPropertyValue(CommunicationPatternProperty.class).get(),
+ originalToNewIRVertex.get(srcVertex), edgeFromOriginal.getDst());
+ edgeFromOriginal.copyExecutionPropertiesTo(newIrEdge);
+ dagBuilder.addVertex(edgeFromOriginal.getDst()).connectVertices(newIrEdge);
+ }
+ }
+ }
+ }));
+
+ // if loop termination condition is false, add signal vertex
+ if (!loopTerminationConditionMet()) {
+ for (IRVertex helper : originalUtilityVertices) {
+ final IRVertex newHelper = helper.getClone();
+ originalToNewIRVertex.putIfAbsent(helper, newHelper);
+ setParallelismPropertyByTestingTrial(newHelper);
+ dagBuilder.addVertex(newHelper, dagToAdd);
+ dagToAdd.getIncomingEdgesOf(helper).forEach(edge -> {
+ final IRVertex newSrc = originalToNewIRVertex.get(edge.getSrc());
+ final IREdge newIrEdge =
+ new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), newSrc, newHelper);
+ edge.copyExecutionPropertiesTo(newIrEdge);
+ dagBuilder.connectVertices(newIrEdge);
+ });
+ }
+ }
+
+ // assign signal vertex of n-th iteration with nonIterativeIncomingEdges of (n+1)th iteration
+ markEdgesToOptimize(previousSignalVertex, edgesToOptimize);
+
+ // process next iteration's DAG incoming edges, and add them as the next loop's incoming edges:
+ // clear, as we're done with the current loop and need to prepare it for the next one.
+ this.getDagIncomingEdges().clear();
+ this.getNonIterativeIncomingEdges().forEach((dstVertex, irEdges) -> irEdges.forEach(this::addDagIncomingEdge));
+ if (!loopTerminationConditionMet()) {
+ this.getIterativeIncomingEdges().forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
+ final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
+ originalToNewIRVertex.get(edge.getSrc()), dstVertex);
+ edge.copyExecutionPropertiesTo(newIrEdge);
+ this.addDagIncomingEdge(newIrEdge);
+ }));
+ }
+
+ increaseTestingTrial();
+ return this;
+ }
+
+ // private helper methods
+
+ /**
+ * Set Parallelism Property of internal vertices by unroll iteration.
+ * @param irVertex vertex to set parallelism property.
+ */
+ private void setParallelismPropertyByTestingTrial(final IRVertex irVertex) {
+ if (testingTrial.intValue() == 0 && !(irVertex instanceof OperatorVertex
+ && ((OperatorVertex) irVertex).getTransform() instanceof SignalTransform)) {
+ irVertex.setPropertyPermanently(ParallelismProperty.of(32));
+ } else {
+ irVertex.setProperty(ParallelismProperty.of(1));
+ }
+ }
+
+ /**
+ * Set SubPartitionSetProperty of given edge by unroll iteration.
+ * @param edge edge to set subPartitionSetProperty
+ */
+ private void setSubPartitionSetPropertyByTestingTrial(final IREdge edge) {
+ final ArrayList<KeyRange> partitionSet = new ArrayList<>();
+ int taskIndex = 0;
+ if (testingTrial.intValue() == 0) {
+ for (int i = 0; i < 4; i++) {
+ partitionSet.add(taskIndex, HashRange.of(i, i + 1));
+ taskIndex++;
+ }
+ for (int groupStartingIndex = 4; groupStartingIndex < 512; groupStartingIndex *= 2) {
+ int growingFactor = groupStartingIndex / 4;
+ for (int startIndex = groupStartingIndex; startIndex < groupStartingIndex * 2; startIndex += growingFactor) {
+ partitionSet.add(taskIndex, HashRange.of(startIndex, startIndex + growingFactor));
+ taskIndex++;
+ }
+ }
+ edge.setProperty(SubPartitionSetProperty.of(partitionSet));
+ } else {
+ partitionSet.add(0, HashRange.of(512, partitionerProperty)); // 31+testingTrial
+ edge.setProperty(SubPartitionSetProperty.of(partitionSet));
+ }
+ }
+
+ /**
+ * Mark edges for DTS (i.e. incoming edges of second iteration vertices).
+ *
+ * @param toAssign Signal Vertex to get MessageIdVertexProperty
+ * @param edgesToOptimize Edges to mark for DTS
+ */
+ private void markEdgesToOptimize(final List<OperatorVertex> toAssign, final Set<IREdge> edgesToOptimize) {
+ if (testingTrial.intValue() > 0) {
+ edgesToOptimize.forEach(edge -> {
+ if (!edge.getDst().getPropertyValue(ParallelismProperty.class).get().equals(1)) {
+ throw new IllegalArgumentException("Target edges should begin with Parallelism of 1.");
+ }
+ final HashSet<Integer> msgEdgeIds =
+ edge.getPropertyValue(MessageIdEdgeProperty.class).orElse(new HashSet<>(0));
+ msgEdgeIds.add(toAssign.get(0).getPropertyValue(MessageIdVertexProperty.class).get());
+ edge.setProperty(MessageIdEdgeProperty.of(msgEdgeIds));
+ });
+ }
+ }
+
+ // These similar four methods are for inserting TaskSizeSplitterVertex in DAG
+
+ /**
+ * Get edges which come to original vertices from outer sources by observing the dag. This will be the
+ * 'dagIncomingEdges' in Splitter vertex.
+ * Edge case: Happens when previous vertex(i.e. outer source) is also a splitter vertex. In this case, we need to get
+ * original edges which is invisible from the dag by hacking into previous splitter vertex.
+ *
+ * @param dag dag to insert Splitter Vertex.
+ * @return a set of edges from outside to original vertices.
+ */
+ public Set<IREdge> getEdgesFromOutsideToOriginal(final DAG<IRVertex, IREdge> dag) {
+ // if previous vertex is splitter vertex, add the last vertex of that splitter vertex in map
+ Set<IREdge> fromOutsideToOriginal = new HashSet<>();
+ for (IRVertex startingVertex : this.groupStartingVertices) {
+ for (IREdge edge : dag.getIncomingEdgesOf(startingVertex)) {
+ if (edge.getSrc() instanceof TaskSizeSplitterVertex) {
+ for (IRVertex originalInnerSource : ((TaskSizeSplitterVertex) edge.getSrc())
+ .getVerticesWithGroupOutgoingEdges()) {
+ Set<IREdge> candidates = ((TaskSizeSplitterVertex) edge.getSrc()).
+ getDagOutgoingEdges().get(originalInnerSource);
+ candidates.stream().filter(edge2 -> edge2.getDst().equals(startingVertex))
+ .forEach(fromOutsideToOriginal::add);
+ }
+ } else {
+ fromOutsideToOriginal.add(edge);
+ }
+ }
+ }
+ return fromOutsideToOriginal;
+ }
+
+ /**
+ * Get edges which come from original vertices to outer destinations by observing the dag. This will be the
+ * 'dagOutgoingEdges' in Splitter vertex.
+ * Edge case: Happens when the vertex to be executed after the splitter vertex (i.e. outer destination)
+ * is also a splitter vertex. In this case, we need to get original edges which is invisible from the dag
+ * by hacking into next splitter vertex.
+ *
+ * @param dag dag to insert Splitter Vertex.
+ * @return a set of edges from original vertices to outside.
+ */
+ public Set<IREdge> getEdgesFromOriginalToOutside(final DAG<IRVertex, IREdge> dag) {
+ Set<IREdge> fromOriginalToOutside = new HashSet<>();
+ for (IRVertex vertex : verticesWithGroupOutgoingEdges) {
+ for (IREdge edge : dag.getOutgoingEdgesOf(vertex)) {
+ if (edge.getDst() instanceof TaskSizeSplitterVertex) {
+ Set<IRVertex> originalInnerDstVertices = ((TaskSizeSplitterVertex) edge.getDst()).getGroupStartingVertices();
+ for (IRVertex innerVertex : originalInnerDstVertices) {
+ Set<IREdge> candidates = ((TaskSizeSplitterVertex) edge.getDst()).
+ getDagIncomingEdges().get(innerVertex);
+ candidates.stream().filter(candidate -> candidate.getSrc().equals(vertex))
+ .forEach(fromOriginalToOutside::add);
+ }
+ } else if (!originalVertices.contains(edge.getDst())) {
+ fromOriginalToOutside.add(edge);
+ }
+ }
+ }
+ return fromOriginalToOutside;
+ }
+
+ /**
+ * Get edges which come to splitter from outside sources. These edges have a one-to-one relationship with
+ * edgesFromOutsideToOriginal.
+ * Edge case: Happens when previous vertex(i.e. outer source) is also a splitter vertex.
+ * In this case, we need to modify the prevSplitter's LoopEdge - InternalEdge mapping relationship,
+ * since inserting this Splitter Vertex changes the destination of prevSplitter's LoopEdge
+ * from the original vertex to this Splitter Vertex
+ *
+ * @param dag dag to insert Splitter Vertex
+ * @return a set of edges pointing at Splitter Vertex
+ */
+ public Set<IREdge> getEdgesFromOutsideToSplitter(final DAG<IRVertex, IREdge> dag) {
+ HashSet<IREdge> fromOutsideToSplitter = new HashSet<>();
+ for (IRVertex groupStartingVertex : groupStartingVertices) {
+ for (IREdge incomingEdge : dag.getIncomingEdgesOf(groupStartingVertex)) {
+ if (incomingEdge.getSrc() instanceof TaskSizeSplitterVertex) {
+ TaskSizeSplitterVertex prevSplitter = (TaskSizeSplitterVertex) incomingEdge.getSrc();
+ IREdge internalEdge = prevSplitter.getEdgeWithInternalVertex(incomingEdge);
+ IREdge newIrEdge = Util.cloneEdge(incomingEdge, incomingEdge.getSrc(), this);
+ prevSplitter.mapEdgeWithLoop(newIrEdge, internalEdge);
+ fromOutsideToSplitter.add(newIrEdge);
+ } else {
+ IREdge cloneOfIncomingEdge = Util.cloneEdge(incomingEdge, incomingEdge.getSrc(), this);
+ fromOutsideToSplitter.add(cloneOfIncomingEdge);
+ }
+ }
+ }
+ return fromOutsideToSplitter;
+ }
+
+ /**
+ * Get edges which come out from splitter to outside destinations. These edges have a one-to-one relationship with
+ * edgesFromOriginalToOutside.
+ * Edge case: Happens when vertex to be executed after this Splitter Vertex(i.e. outer destination)
+ * is also a Splitter Vertex. In this case, we need to modify the nextSplitter's LoopEdge - InternalEdge
+ * mapping relationship, since inserting this Splitter Vertex changes the source of prevSplitter's
+ * LoopEdge from the original vertex to this Splitter Vertex.
+ *
+ * @param dag dag to insert Splitter Vertex.
+ * @return a set of edges coming out from Splitter Vertex.
+ */
+ public Set<IREdge> getEdgesFromSplitterToOutside(final DAG<IRVertex, IREdge> dag) {
+ HashSet<IREdge> fromSplitterToOutside = new HashSet<>();
+ for (IRVertex vertex : verticesWithGroupOutgoingEdges) {
+ for (IREdge outgoingEdge : dag.getOutgoingEdgesOf(vertex)) {
+ if (outgoingEdge.getDst() instanceof TaskSizeSplitterVertex) {
+ TaskSizeSplitterVertex nextSplitter = (TaskSizeSplitterVertex) outgoingEdge.getDst();
+ IREdge internalEdge = nextSplitter.getEdgeWithInternalVertex(outgoingEdge);
+ IREdge newIrEdge = Util.cloneEdge(outgoingEdge, this, outgoingEdge.getDst());
+ nextSplitter.mapEdgeWithLoop(newIrEdge, internalEdge);
+ fromSplitterToOutside.add(newIrEdge);
+ } else if (!originalVertices.contains(outgoingEdge.getDst())) {
+ IREdge cloneOfOutgoingEdge = Util.cloneEdge(outgoingEdge, this, outgoingEdge.getDst());
+ fromSplitterToOutside.add(cloneOfOutgoingEdge);
+ }
+ }
+ }
+ return fromSplitterToOutside;
+ }
+
+ public void printLogs() {
+ LOG.error("[Vertex] this is splitter {}", this.getId());
+ LOG.error("[Vertex] get dag incoming edges: {}", this.getDagIncomingEdges().entrySet());
+ LOG.error("[Vertex] get dag iterative incoming edges: {}", this.getIterativeIncomingEdges().entrySet());
+ LOG.error("[Vertex] get dag nonIterative incoming edges: {}", this.getNonIterativeIncomingEdges().entrySet());
+ LOG.error("[Vertex] get dag outgoing edges: {}", this.getDagOutgoingEdges().entrySet());
+ LOG.error("[Vertex] get edge map with loop {}", this.getEdgeWithLoopToEdgeWithInternalVertex().entrySet());
+ LOG.error("[Vertex] get edge map with internal vertex {}",
+ this.getEdgeWithInternalVertexToEdgeWithLoop().entrySet());
+ }
+}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageAggregatorVertex.java
similarity index 89%
rename from common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageAggregatorVertex.java
index cc58f27..31351e8 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageAggregatorVertex.java
@@ -16,9 +16,10 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.nemo.common.ir.vertex.utility;
+package org.apache.nemo.common.ir.vertex.utility.runtimepass;
import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.ir.IdManager;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
@@ -27,7 +28,6 @@
import org.slf4j.LoggerFactory;
import java.io.Serializable;
-import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Supplier;
@@ -40,7 +40,6 @@
*/
public final class MessageAggregatorVertex<K, V, O> extends OperatorVertex {
private static final Logger LOG = LoggerFactory.getLogger(MessageAggregatorVertex.class.getName());
- private static final AtomicInteger MESSAGE_ID_GENERATOR = new AtomicInteger(0);
/**
* @param initialStateSupplier for producing the initial state.
@@ -49,7 +48,7 @@
public MessageAggregatorVertex(final InitialStateSupplier<O> initialStateSupplier,
final MessageAggregatorFunction<K, V, O> userFunction) {
super(new MessageAggregatorTransform<>(initialStateSupplier, userFunction));
- this.setPropertyPermanently(MessageIdVertexProperty.of(MESSAGE_ID_GENERATOR.incrementAndGet()));
+ this.setPropertyPermanently(MessageIdVertexProperty.of(IdManager.generateMessageId()));
this.setProperty(ParallelismProperty.of(1));
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TriggerVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageGeneratorVertex.java
similarity index 79%
rename from common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TriggerVertex.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageGeneratorVertex.java
index bb2337d..e7aceef 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TriggerVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageGeneratorVertex.java
@@ -16,30 +16,30 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.nemo.common.ir.vertex.utility;
+package org.apache.nemo.common.ir.vertex.utility.runtimepass;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
-import org.apache.nemo.common.ir.vertex.transform.TriggerTransform;
+import org.apache.nemo.common.ir.vertex.transform.MessageGeneratorTransform;
import java.io.Serializable;
import java.util.Map;
import java.util.function.BiFunction;
/**
- * Produces a message and triggers a run-time pass.
+ * Produces a message for run-time pass.
*
* @param <I> input type
* @param <K> of the output pair.
* @param <V> of the output pair.
*/
-public final class TriggerVertex<I, K, V> extends OperatorVertex {
+public final class MessageGeneratorVertex<I, K, V> extends OperatorVertex {
private final MessageGeneratorFunction<I, K, V> messageFunction;
/**
* @param messageFunction for producing a message.
*/
- public TriggerVertex(final MessageGeneratorFunction<I, K, V> messageFunction) {
- super(new TriggerTransform<>(messageFunction));
+ public MessageGeneratorVertex(final MessageGeneratorFunction<I, K, V> messageFunction) {
+ super(new MessageGeneratorTransform<>(messageFunction));
this.messageFunction = messageFunction;
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/SignalVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/SignalVertex.java
new file mode 100644
index 0000000..f300511
--- /dev/null
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/runtimepass/SignalVertex.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.common.ir.vertex.utility.runtimepass;
+
+import org.apache.nemo.common.ir.IdManager;
+import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
+import org.apache.nemo.common.ir.vertex.transform.SignalTransform;
+
+
+/**
+ * Signal vertex holding signal transform.
+ * It triggers runtime pass without examining related edge's data.
+ */
+public final class SignalVertex extends OperatorVertex {
+
+ public SignalVertex() {
+ super(new SignalTransform());
+ this.setPropertyPermanently(MessageIdVertexProperty.of(IdManager.generateMessageId()));
+ this.setPropertyPermanently(ParallelismProperty.of(1));
+ }
+}
diff --git a/common/src/main/java/org/apache/nemo/common/test/EmptyComponents.java b/common/src/main/java/org/apache/nemo/common/test/EmptyComponents.java
index 1f93a81..650a2fa 100644
--- a/common/src/main/java/org/apache/nemo/common/test/EmptyComponents.java
+++ b/common/src/main/java/org/apache/nemo/common/test/EmptyComponents.java
@@ -58,8 +58,6 @@
edge.setProperty(KeyDecoderProperty.of(new DecoderFactory.DummyDecoderFactory()));
edge.setProperty(EncoderProperty.of(new EncoderFactory.DummyEncoderFactory()));
edge.setProperty(DecoderProperty.of(new DecoderFactory.DummyDecoderFactory()));
- edge.setProperty(KeyEncoderProperty.of(new EncoderFactory.DummyEncoderFactory()));
- edge.setProperty(KeyDecoderProperty.of(new DecoderFactory.DummyDecoderFactory()));
return edge;
}
diff --git a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
index f1d6a9e..6113c85 100644
--- a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
+++ b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
@@ -24,16 +24,19 @@
import org.apache.nemo.common.coder.DecoderFactory;
import org.apache.nemo.common.coder.EncoderFactory;
import org.apache.nemo.common.dag.DAGBuilder;
+import org.apache.nemo.common.dag.Edge;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.*;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.SignalVertex;
import org.apache.nemo.common.test.EmptyComponents;
import org.junit.Before;
import org.junit.Test;
@@ -275,13 +278,27 @@
}
@Test
+ public void testSignalVertex() {
+ final SignalVertex sg1 = insertNewSignalVertex(irdag, oneToOneEdge);
+ mustPass();
+
+ final SignalVertex sg2 = insertNewSignalVertex(irdag, shuffleEdge);
+ mustPass();
+
+ irdag.delete(sg1);
+ mustPass();
+
+ irdag.delete(sg2);
+ mustPass();
+ }
+
+ @Test
public void testSamplingVertex() {
final SamplingVertex svOne = new SamplingVertex(sourceVertex, 0.1f);
irdag.insert(Sets.newHashSet(svOne), Sets.newHashSet(sourceVertex));
mustPass();
final SamplingVertex svTwo = new SamplingVertex(firstOperatorVertex, 0.1f);
- ;
irdag.insert(Sets.newHashSet(svTwo), Sets.newHashSet(firstOperatorVertex));
mustPass();
@@ -292,8 +309,26 @@
mustPass();
}
+ @Test
+ public void testSplitterVertex() {
+ final TaskSizeSplitterVertex sp = new TaskSizeSplitterVertex(
+ "splitter_1",
+ Sets.newHashSet(secondOperatorVertex),
+ Sets.newHashSet(secondOperatorVertex),
+ Sets.newHashSet(),
+ Sets.newHashSet(secondOperatorVertex),
+ Sets.newHashSet(),
+ 1024
+ );
+ irdag.insert(sp);
+ mustPass();
+
+ irdag.delete(sp);
+ mustPass();
+ }
+
private MessageAggregatorVertex insertNewTriggerVertex(final IRDAG dag, final IREdge edgeToGetStatisticsOf) {
- final TriggerVertex mb = new TriggerVertex<>((l, r) -> null);
+ final MessageGeneratorVertex mb = new MessageGeneratorVertex<>((l, r) -> null);
final MessageAggregatorVertex ma = new MessageAggregatorVertex<>(() -> new Object(), (l, r) -> null);
dag.insert(
mb,
@@ -305,6 +340,53 @@
return ma;
}
+ private Optional<TaskSizeSplitterVertex> insertNewSplitterVertex(final IRDAG dag,
+ final IREdge edgeToSplitterVertex) {
+ final Set<IRVertex> vertexGroup = getVertexGroupToInsertSplitter(irdag, edgeToSplitterVertex);
+ if (vertexGroup.isEmpty()) {
+ return Optional.empty();
+ }
+ Set<IRVertex> verticesWithGroupOutgoingEdges = new HashSet<>();
+ for (IRVertex vertex : vertexGroup) {
+ Set<IRVertex> nextVertices = irdag.getOutgoingEdgesOf(vertex).stream().map(Edge::getDst)
+ .collect(Collectors.toSet());
+ for (IRVertex nextVertex : nextVertices) {
+ if (!vertexGroup.contains(nextVertex)) {
+ verticesWithGroupOutgoingEdges.add(vertex);
+ }
+ }
+ }
+ Set<IRVertex> groupEndingVertices = vertexGroup.stream()
+ .filter(stageVertex -> irdag.getOutgoingEdgesOf(stageVertex).isEmpty()
+ || !irdag.getOutgoingEdgesOf(stageVertex).stream().map(Edge::getDst).anyMatch(vertexGroup::contains))
+ .collect(Collectors.toSet());
+
+ final Set<IREdge> edgesBetweenOriginalVertices = vertexGroup
+ .stream()
+ .flatMap(ov -> dag.getIncomingEdgesOf(ov).stream())
+ .filter(edge -> vertexGroup.contains(edge.getSrc()))
+ .collect(Collectors.toSet());
+
+ TaskSizeSplitterVertex sp = new TaskSizeSplitterVertex(
+ "sp" + edgeToSplitterVertex.getId(),
+ vertexGroup,
+ Sets.newHashSet(edgeToSplitterVertex.getDst()),
+ verticesWithGroupOutgoingEdges,
+ groupEndingVertices,
+ edgesBetweenOriginalVertices,
+ 1024);
+
+ dag.insert(sp);
+
+ return Optional.of(sp);
+ }
+
+ private SignalVertex insertNewSignalVertex(final IRDAG dag, final IREdge edgeToOptimize) {
+ final SignalVertex sg = new SignalVertex();
+ dag.insert(sg, edgeToOptimize);
+ return sg;
+ }
+
////////////////////////////////////////////////////// Random generative tests
private Random random = new Random(0); // deterministic seed for reproducibility
@@ -314,8 +396,8 @@
// Thousand random configurations (some duplicate configurations possible)
final int thousandConfigs = 1000;
for (int i = 0; i < thousandConfigs; i++) {
- // LOG.info("Doing {}", i);
- final int numOfTotalMethods = 11;
+ //LOG.info("Doing {}", i);
+ final int numOfTotalMethods = 13;
final int methodIndex = random.nextInt(numOfTotalMethods);
switch (methodIndex) {
// Annotation methods
@@ -360,6 +442,12 @@
irdag.insert(Sets.newHashSet(samplingVertex), Sets.newHashSet(vertexToSample));
break;
case 10:
+ insertNewSignalVertex(irdag, selectRandomEdge());
+ break;
+ case 11:
+ insertNewSplitterVertex(irdag, selectRandomEdge());
+ break;
+ case 12:
// the last index must be (numOfTotalMethods - 1)
selectRandomUtilityVertex().ifPresent(irdag::delete);
break;
@@ -374,7 +462,7 @@
if (methodIndex >= 7) {
// Uncomment to visualize DAG snapshots after reshaping (insert, delete)
- // irdag.storeJSON("test_reshaping_snapshots", i + "(methodIndex_" + methodIndex + ")", "test");
+ //irdag.storeJSON("test_reshaping_snapshots", i + "(methodIndex_" + methodIndex + ")", "test");
}
// Must always pass
@@ -411,6 +499,105 @@
: Optional.of(utilityVertices.get(random.nextInt(utilityVertices.size())));
}
+ /**
+ * Private helper method to check if the parameter observingEdge is appropriate for inserting Splitter Vertex.
+ * Specifically, this edge is considered to be the incoming edge of splitter vertex.
+ * This edge should have communication property of shuffle, and should be the only edge coming out from/coming in to
+ * its source/dest.
+ * @param dag dag to observe.
+ * @param observingEdge observing edge.
+ * @return true if this edge is appropriate for inserting splitter vertex.
+ */
+ private boolean isThisEdgeAppropriateForInsertingSplitterVertex(IRDAG dag, IREdge observingEdge) {
+ // If communication property of observing Edge is not shuffle, return false.
+ if (!CommunicationPatternProperty.Value.SHUFFLE.equals(
+ observingEdge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ return false;
+ }
+
+ // If destination of observingEdge has multiple incoming edges, return false.
+ if (dag.getIncomingEdgesOf(observingEdge.getDst()).size() > 1) {
+ return false;
+ }
+
+ // If source of observingEdge has multiple outgoing edges, return false.
+ if (dag.getOutgoingEdgesOf(observingEdge.getSrc()).size() > 1) {
+ return false;
+ }
+ return true;
+ }
+
+ private Set<IRVertex> getVertexGroupToInsertSplitter(IRDAG dag, IREdge observingEdge) {
+ final Set<IRVertex> vertexGroup = new HashSet<>();
+
+ // If this edge is not appropriate to be the incoming edge of splitter vertex, return empty set.
+ if (!isThisEdgeAppropriateForInsertingSplitterVertex(dag, observingEdge)) {
+ return new HashSet<>();
+ }
+
+ if (observingEdge.getDst() instanceof MessageGeneratorVertex
+ || observingEdge.getDst() instanceof MessageAggregatorVertex) {
+ return new HashSet<>();
+ }
+ // Get the vertex group.
+ vertexGroup.add(observingEdge.getDst());
+ for (IREdge edge : dag.getOutgoingEdgesOf(observingEdge.getDst())) {
+ vertexGroup.addAll(recursivelyAddVertexGroup(dag, edge, vertexGroup));
+ }
+
+ // Check if this vertex group is appropriate for inserting splitter vertex
+ Set<IREdge> stageOutgoingEdges = vertexGroup
+ .stream()
+ .flatMap(vertex -> dag.getOutgoingEdgesOf(vertex).stream())
+ .filter(edge -> !vertexGroup.contains(edge.getDst()))
+ .collect(Collectors.toSet());
+ if (stageOutgoingEdges.isEmpty()) {
+ return vertexGroup;
+ } else {
+ for (IREdge edge : stageOutgoingEdges) {
+ if (CommunicationPatternProperty.Value.ONE_TO_ONE.equals(
+ edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ return new HashSet<>();
+ }
+ }
+ }
+ return vertexGroup;
+ }
+
+ /**
+ * Check of the destination of the observing edge can be added in vertex group.
+ * @param dag dag to observe.
+ * @param observingEdge edge to observe.
+ * @param vertexGroup vertex group to add.
+ * @return updated vertex group.
+ */
+ private Set<IRVertex> recursivelyAddVertexGroup(IRDAG dag, IREdge observingEdge, Set<IRVertex> vertexGroup) {
+ // do not update.
+ if (dag.getIncomingEdgesOf(observingEdge.getDst()).size() > 1) {
+ return vertexGroup;
+ }
+ // do not update.
+ if (observingEdge.getPropertyValue(CommunicationPatternProperty.class).orElseThrow(IllegalStateException::new)
+ != CommunicationPatternProperty.Value.ONE_TO_ONE) {
+ return vertexGroup;
+ }
+ // do not update.
+ if (!observingEdge.getSrc().getExecutionProperties().equals(observingEdge.getDst().getExecutionProperties())) {
+ return vertexGroup;
+ }
+ // do not update.
+ if (observingEdge.getDst() instanceof MessageGeneratorVertex
+ || observingEdge.getDst() instanceof MessageAggregatorVertex) {
+ return vertexGroup;
+ }
+ // do update.
+ vertexGroup.add(observingEdge.getDst());
+ for (IREdge edge : dag.getOutgoingEdgesOf(observingEdge.getDst())) {
+ vertexGroup.addAll(recursivelyAddVertexGroup(dag, edge, vertexGroup));
+ }
+ return vertexGroup;
+ }
+
///////////////// Random vertex EP
private ClonedSchedulingProperty randomCSP() {
diff --git a/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java b/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
index 550699a..1298868 100644
--- a/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
+++ b/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
@@ -23,7 +23,7 @@
import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
import org.apache.nemo.common.ir.executionproperty.ExecutionPropertyMap;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.apache.nemo.compiler.optimizer.NemoOptimizer;
import org.apache.nemo.compiler.optimizer.pass.runtime.Message;
import org.apache.nemo.runtime.common.comm.ControlMessage;
diff --git a/compiler/optimizer/pom.xml b/compiler/optimizer/pom.xml
index 9b9c4c1..c6dbb79 100644
--- a/compiler/optimizer/pom.xml
+++ b/compiler/optimizer/pom.xml
@@ -20,6 +20,18 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <source>${java.version}</source>
+ <target>${java.version}</target>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
<parent>
<groupId>org.apache.nemo</groupId>
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/SkewAnnotatingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/SkewAnnotatingPass.java
index b87ca83..18ca85b 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/SkewAnnotatingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/SkewAnnotatingPass.java
@@ -22,7 +22,7 @@
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.PartitionerProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
index 4c0cf29..686f904 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
@@ -26,8 +26,8 @@
import org.apache.nemo.common.ir.edge.executionproperty.DataStoreProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
@@ -112,7 +112,7 @@
final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
dag.insert(
- new TriggerVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor)),
+ new MessageGeneratorVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor)),
new MessageAggregatorVertex(HashMap::new, SkewHandlingUtil.getMessageAggregator()),
SkewHandlingUtil.getEncoder(e),
SkewHandlingUtil.getDecoder(e),
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingTaskSizingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingTaskSizingPass.java
new file mode 100644
index 0000000..d85b7c6
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingTaskSizingPass.java
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+
+import org.apache.nemo.common.dag.Edge;
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.EnableDynamicTaskSizingProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
+import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.Annotates;
+import org.apache.nemo.runtime.common.plan.StagePartitioner;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * Compiler pass for dynamic task size optimization. Happens only when the edge property is SHUFFLE.
+ * If (size of given job) >= 1GB: enable dynamic task sizing optimization.
+ * else: break.
+ *
+ *
+ * @Attributes
+ * PARTITIONER_PROPERTY_FOR_SMALL_JOB: PartitionerProperty for jobs in range of [1GB, 10GB) size.
+ * PARTITIONER_PROPERTY_FOR_MEDIUM_JOB: PartitionerProperty for jobs in range of [10GB, 100GB) size.
+ * PARTITIONER_PROPERTY_FOR_BIG_JOB: PartitionerProperty for jobs in range of [100GB, - ) size(No upper limit).
+ *
+ * source stage - shuffle edge - current stage - next stage
+ * -> source stage - [curr stage - signal vertex] - next stage
+ * where [] is a splitter vertex
+ */
+@Annotates({EnableDynamicTaskSizingProperty.class, PartitionerProperty.class, SubPartitionSetProperty.class,
+ ParallelismProperty.class})
+public final class SamplingTaskSizingPass extends ReshapingPass {
+ private static final Logger LOG = LoggerFactory.getLogger(SamplingTaskSizingPass.class.getName());
+
+ private static final int PARTITIONER_PROPERTY_FOR_SMALL_JOB = 1024;
+ private static final int PARTITIONER_PROPERTY_FOR_MEDIUM_JOB = 2048;
+ private static final int PARTITIONER_PROPERTY_FOR_LARGE_JOB = 4096;
+ private final StagePartitioner stagePartitioner = new StagePartitioner();
+
+ /**
+ * Default constructor.
+ */
+ public SamplingTaskSizingPass() {
+ super(SamplingTaskSizingPass.class);
+ }
+
+ @Override
+ public IRDAG apply(final IRDAG dag) {
+ /* Step 1. check DTS launch by job size */
+ boolean enableDynamicTaskSizing = isDTSEnabledByJobSize(dag);
+ if (!enableDynamicTaskSizing) {
+ return dag;
+ } else {
+ dag.topologicalDo(v -> v.setProperty(EnableDynamicTaskSizingProperty.of(enableDynamicTaskSizing)));
+ }
+
+ final int partitionerProperty = getPartitionerPropertyByJobSize(dag);
+
+ /* Step 2-1. Group vertices by stage using stage merging logic */
+ final Map<IRVertex, Integer> vertexToStageId = stagePartitioner.apply(dag);
+ final Map<Integer, Set<IRVertex>> stageIdToStageVertices = new HashMap<>();
+ vertexToStageId.forEach((vertex, stageId) -> {
+ if (!stageIdToStageVertices.containsKey(stageId)) {
+ stageIdToStageVertices.put(stageId, new HashSet<>());
+ }
+ stageIdToStageVertices.get(stageId).add(vertex);
+ });
+
+ /* Step 2-2. Mark stages to insert splitter vertex and get target edges of DTS */
+ Set<Integer> stageIdsToInsertSplitter = new HashSet<>();
+ Set<IREdge> shuffleEdgesForDTS = new HashSet<>();
+ dag.topologicalDo(v -> {
+ for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
+ if (isAppropriateForInsertingSplitterVertex(dag, v, edge, vertexToStageId, stageIdToStageVertices)) {
+ stageIdsToInsertSplitter.add(vertexToStageId.get(v));
+ shuffleEdgesForDTS.add(edge);
+ }
+ }
+ });
+
+ /* Step 2-3. Change partitioner property for DTS target edges */
+ dag.topologicalDo(v -> {
+ for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
+ if (shuffleEdgesForDTS.contains(edge)) {
+ shuffleEdgesForDTS.remove(edge);
+ edge.setProperty(PartitionerProperty.of(PartitionerProperty.Type.HASH, partitionerProperty));
+ shuffleEdgesForDTS.add(edge);
+ }
+ }
+ });
+ /* Step 3. Insert Splitter Vertex */
+ List<IRVertex> reverseTopologicalOrder = dag.getTopologicalSort();
+ Collections.reverse(reverseTopologicalOrder);
+ for (IRVertex v : reverseTopologicalOrder) {
+ for (final IREdge edge : dag.getOutgoingEdgesOf(v)) {
+ if (shuffleEdgesForDTS.contains(edge)) {
+ // edge is the incoming edge of observing stage, v is the last vertex of previous stage
+ Set<IRVertex> stageVertices = stageIdToStageVertices.get(vertexToStageId.get(edge.getDst()));
+ Set<IRVertex> verticesWithStageOutgoingEdges = new HashSet<>();
+ for (IRVertex v2 : stageVertices) {
+ Set<IRVertex> nextVertices = dag.getOutgoingEdgesOf(v2).stream().map(Edge::getDst)
+ .collect(Collectors.toSet());
+ for (IRVertex v3 : nextVertices) {
+ if (!stageVertices.contains(v3)) {
+ verticesWithStageOutgoingEdges.add(v2);
+ }
+ }
+ }
+ Set<IRVertex> stageEndingVertices = stageVertices.stream()
+ .filter(stageVertex -> dag.getOutgoingEdgesOf(stageVertex).isEmpty()
+ || !dag.getOutgoingEdgesOf(stageVertex).stream().map(Edge::getDst).anyMatch(stageVertices::contains))
+ .collect(Collectors.toSet());
+ final boolean isSourcePartition = stageVertices.stream()
+ .flatMap(vertexInPartition -> dag.getIncomingEdgesOf(vertexInPartition).stream())
+ .map(Edge::getSrc)
+ .allMatch(stageVertices::contains);
+ if (isSourcePartition) {
+ break;
+ }
+ insertSplitterVertex(dag, stageVertices, Collections.singleton(edge.getDst()),
+ verticesWithStageOutgoingEdges, stageEndingVertices, partitionerProperty);
+ }
+ }
+ }
+ return dag;
+ }
+
+ private boolean isDTSEnabledByJobSize(final IRDAG dag) {
+ long jobSizeInBytes = dag.getInputSize();
+ return jobSizeInBytes >= 1024 * 1024 * 1024;
+ }
+
+ /**
+ * should be called after EnableDynamicTaskSizingProperty is declared as true.
+ * @param dag IRDAG to get job input data size from
+ * @return partitioner property regarding job size
+ */
+ private int getPartitionerPropertyByJobSize(final IRDAG dag) {
+ long jobSizeInBytes = dag.getInputSize();
+ long jobSizeInGB = jobSizeInBytes / (1024 * 1024 * 1024);
+ if (1 <= jobSizeInGB && jobSizeInGB < 10) {
+ return PARTITIONER_PROPERTY_FOR_SMALL_JOB;
+ } else if (10 <= jobSizeInGB && jobSizeInGB < 100) {
+ return PARTITIONER_PROPERTY_FOR_MEDIUM_JOB;
+ } else {
+ return PARTITIONER_PROPERTY_FOR_LARGE_JOB;
+ }
+ }
+
+ /**
+ * Check if stage containing observing Vertex is appropriate for inserting splitter vertex.
+ * @param dag dag to observe
+ * @param observingVertex observing vertex
+ * @param observingEdge incoming edge of observing vertex
+ * @param vertexToStageId maps vertex to its corresponding stage id
+ * @param stageIdToStageVertices maps stage id to its vertices
+ * @return true if we can wrap this stage with splitter vertex (i.e. appropriate for DTS)
+ */
+ private boolean isAppropriateForInsertingSplitterVertex(final IRDAG dag,
+ final IRVertex observingVertex,
+ final IREdge observingEdge,
+ final Map<IRVertex, Integer> vertexToStageId,
+ final Map<Integer, Set<IRVertex>> stageIdToStageVertices) {
+ // If communication property of observing Edge is not shuffle, return false.
+ if (!CommunicationPatternProperty.Value.SHUFFLE.equals(
+ observingEdge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ return false;
+ }
+ // if observing Vertex has multiple incoming edges, return false
+ if (dag.getIncomingEdgesOf(observingVertex).size() > 1) {
+ return false;
+ }
+ // if source vertex of observing Edge has multiple outgoing edge (that is,
+ // has outgoing edges other than observing Edge), return false
+ if (dag.getOutgoingEdgesOf(observingEdge.getSrc()).size() > 1) {
+ return false;
+ }
+ // if one of the outgoing edges of stage which contains observing Vertex has communication property of one-to-one,
+ // return false.
+ // (corner case) if this stage is a sink, return true
+ // insert to do: accumulate DTS result by changing o2o stage edge into shuffle
+ Set<IRVertex> stageVertices = stageIdToStageVertices.get(vertexToStageId.get(observingVertex));
+ Set<IREdge> stageOutgoingEdges = stageVertices
+ .stream()
+ .flatMap(vertex -> dag.getOutgoingEdgesOf(vertex).stream())
+ .filter(edge -> !stageVertices.contains(edge.getDst()))
+ .collect(Collectors.toSet());
+ if (stageOutgoingEdges.isEmpty()) {
+ return true;
+ } else {
+ for (IREdge edge : stageOutgoingEdges) {
+ if (CommunicationPatternProperty.Value.ONE_TO_ONE.equals(
+ edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ return false;
+ }
+ }
+ }
+ // all cases passed: return true
+ return true;
+ }
+
+ /**
+ * Make splitter vertex and insert it in the dag.
+ * @param dag dag to insert splitter vertex
+ * @param stageVertices stage vertices which will be grouped to be inserted into splitter vertex
+ * @param stageStartingVertices subset of stage vertices which have incoming edge from other stages
+ * @param verticesWithStageOutgoingEdges subset of stage vertices which have outgoing edge to other stages
+ * @param stageEndingVertices subset of staae vertices which does not have outgoing edge to other
+ * vertices in this stage
+ * @param partitionerProperty partitioner property
+ */
+ private void insertSplitterVertex(final IRDAG dag,
+ final Set<IRVertex> stageVertices,
+ final Set<IRVertex> stageStartingVertices,
+ final Set<IRVertex> verticesWithStageOutgoingEdges,
+ final Set<IRVertex> stageEndingVertices,
+ final int partitionerProperty) {
+
+ final Set<IREdge> edgesBetweenOriginalVertices = stageVertices
+ .stream()
+ .flatMap(ov -> dag.getIncomingEdgesOf(ov).stream())
+ .filter(edge -> stageVertices.contains(edge.getSrc()))
+ .collect(Collectors.toSet());
+
+ final TaskSizeSplitterVertex toInsert = new TaskSizeSplitterVertex(
+ "Splitter" + stageStartingVertices.iterator().next().getId(),
+ stageVertices,
+ stageStartingVertices,
+ verticesWithStageOutgoingEdges,
+ stageEndingVertices,
+ edgesBetweenOriginalVertices,
+ partitionerProperty);
+
+ // By default, set the number of iterations as 2
+ toInsert.setMaxNumberOfIterations(2);
+
+ // insert splitter vertex
+ dag.insert(toInsert);
+
+ toInsert.printLogs();
+ }
+
+ /**
+ * Changes stage outgoing edges' execution property from one-to-one to shuffle when stage incoming edge became the
+ * target of DTS.
+ * Need to be careful about referenceShuffleEdge because this code does not check whether it is a valid shuffle edge
+ * or not.
+ * @param edge edge to change execution property.
+ * @param referenceShuffleEdge reference shuffle edge to copy key related execution properties
+ * @param partitionerProperty partitioner property of shuffle
+ */
+ //TODO #452: Allow changing Communication Property of Edge from one-to-one to shuffle.
+ private IREdge changeOneToOneEdgeToShuffleEdge(final IREdge edge,
+ final IREdge referenceShuffleEdge,
+ final int partitionerProperty) {
+ //double check
+ if (!CommunicationPatternProperty.Value.ONE_TO_ONE.equals(
+ edge.getPropertyValue(CommunicationPatternProperty.class).get())
+ || !CommunicationPatternProperty.Value.SHUFFLE.equals(
+ referenceShuffleEdge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ return edge;
+ }
+
+ // properties related to data
+ edge.setProperty(CommunicationPatternProperty.of(CommunicationPatternProperty.Value.SHUFFLE));
+ edge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.PULL));
+ edge.setProperty(PartitionerProperty.of(PartitionerProperty.Type.HASH, partitionerProperty));
+ edge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LOCAL_FILE_STORE));
+
+ // properties related to key
+ if (!edge.getPropertyValue(KeyExtractorProperty.class).isPresent()) {
+ edge.setProperty(KeyExtractorProperty.of(
+ referenceShuffleEdge.getPropertyValue(KeyExtractorProperty.class).get()));
+ }
+ if (!edge.getPropertyValue(KeyEncoderProperty.class).isPresent()) {
+ edge.setProperty(KeyEncoderProperty.of(
+ referenceShuffleEdge.getPropertyValue(KeyEncoderProperty.class).get()));
+ }
+ if (!edge.getPropertyValue(KeyDecoderProperty.class).isPresent()) {
+ edge.setProperty(KeyDecoderProperty.of(
+ referenceShuffleEdge.getPropertyValue(KeyDecoderProperty.class).get()));
+ }
+ return edge;
+ }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
index fa4d728..1ee52ff 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
@@ -28,8 +28,8 @@
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyDecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyEncoderProperty;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import java.util.Map;
@@ -40,7 +40,7 @@
private SkewHandlingUtil() {
}
- static TriggerVertex.MessageGeneratorFunction<Object, Object, Long> getMessageGenerator(
+ static MessageGeneratorVertex.MessageGeneratorFunction<Object, Object, Long> getMessageGenerator(
final KeyExtractor keyExtractor) {
return (element, dynOptData) -> {
Object key = keyExtractor.extractKey(element);
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index eb0ec39..c1ae7cf 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -24,8 +24,8 @@
import org.apache.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
-import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -38,7 +38,7 @@
/**
* Pass to reshape the IR DAG for skew handling.
- * We insert a {@link TriggerVertex} for each shuffle edge,
+ * We insert a {@link MessageGeneratorVertex} for each shuffle edge,
* and aggregate messages for multiple same-destination shuffle edges.
*/
@Requires(CommunicationPatternProperty.class)
@@ -73,7 +73,8 @@
final KeyExtractor keyExtractor = representativeEdge.getPropertyValue(KeyExtractorProperty.class).get();
// Insert the vertices
- final TriggerVertex trigger = new TriggerVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor));
+ final MessageGeneratorVertex trigger = new MessageGeneratorVertex<>(
+ SkewHandlingUtil.getMessageGenerator(keyExtractor));
final MessageAggregatorVertex mav =
new MessageAggregatorVertex(HashMap::new, SkewHandlingUtil.getMessageAggregator());
dag.insert(trigger, mav, SkewHandlingUtil.getEncoder(representativeEdge),
diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
index 5d9e205..9cb30a2 100644
--- a/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
+++ b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
@@ -26,7 +26,7 @@
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ResourceAntiAffinityProperty;
import org.apache.nemo.common.ir.vertex.transform.MessageAggregatorTransform;
-import org.apache.nemo.common.ir.vertex.transform.TriggerTransform;
+import org.apache.nemo.common.ir.vertex.transform.MessageGeneratorTransform;
import org.apache.nemo.compiler.CompilerTestUtil;
import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.AnnotatingPass;
import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass;
@@ -74,7 +74,7 @@
/**
* Test for {@link SkewCompositePass} with MR workload.
- * It should have inserted vertex with {@link TriggerTransform}
+ * It should have inserted vertex with {@link MessageGeneratorTransform}
* and vertex with {@link MessageAggregatorTransform} for each shuffle edge.
*
* @throws Exception exception on the way.
@@ -95,7 +95,7 @@
assertEquals(originalVerticesNum + numOfShuffleEdges * 2, processedDAG.getVertices().size());
processedDAG.filterVertices(v -> v instanceof OperatorVertex
- && ((OperatorVertex) v).getTransform() instanceof TriggerTransform)
+ && ((OperatorVertex) v).getTransform() instanceof MessageGeneratorTransform)
.forEach(metricV -> {
final List<IRVertex> reducerV = processedDAG.getChildren(metricV.getId());
reducerV.forEach(rV -> {
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java
index 5e0fb1a..b11f1fd 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/AlternatingLeastSquareITCase.java
@@ -92,4 +92,6 @@
// .addOptimizationPolicy(TransientResourcePolicyParallelismTen.class.getCanonicalName())
// .build());
// }
+
+ // TODO #453: Add test methods related to Dynamic Task Sizing in Nemo.
}
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/MultinomialLogisticRegressionITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/MultinomialLogisticRegressionITCase.java
index e4d9425..bdfeef8 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/MultinomialLogisticRegressionITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/MultinomialLogisticRegressionITCase.java
@@ -57,4 +57,6 @@
.addResourceJson(executorResourceFileName)
.build());
}
+
+ // TODO #453: Add test methods related to Dynamic Task Sizing in Nemo.
}
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java
index d4278a8..ff5d0fc 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java
@@ -127,4 +127,6 @@
.addOptimizationPolicy(AggressiveSpeculativeCloningPolicyParallelismFive.class.getCanonicalName())
.build());
}
+
+ // TODO # 453: Add test methods related to Dynamic Task Sizing in Nemo.
}