[NEMO-390] Address SonarCloud issues for the IR package (#217)
JIRA: [NEMO-390: Address SonarCloud issues for the IR package](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-390)
**Major changes:**
- Fixes transient/serializable issues raised by SonarCloud
**Minor changes to note:**
- Renames MessageBarrierVertex to TriggerVertex
- Renames StreamVertex to RelayVertex
Closes #217
diff --git a/common/src/main/java/org/apache/nemo/common/Util.java b/common/src/main/java/org/apache/nemo/common/Util.java
index 7a84491..0688da9 100644
--- a/common/src/main/java/org/apache/nemo/common/Util.java
+++ b/common/src/main/java/org/apache/nemo/common/Util.java
@@ -23,9 +23,9 @@
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
-import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
+import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import java.io.IOException;
import java.lang.instrument.Instrumentation;
@@ -189,8 +189,8 @@
public static boolean isUtilityVertex(final IRVertex v) {
return v instanceof SamplingVertex
|| v instanceof MessageAggregatorVertex
- || v instanceof MessageBarrierVertex
- || v instanceof StreamVertex;
+ || v instanceof TriggerVertex
+ || v instanceof RelayVertex;
}
/**
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index dfc1103..f480d41 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -36,9 +36,9 @@
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
-import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -70,7 +70,7 @@
private DAG<IRVertex, IREdge> modifiedDAG; // the DAG that is being updated.
// To remember original encoders/decoders, and etc
- private final Map<StreamVertex, IREdge> streamVertexToOriginalEdge;
+ private final Map<RelayVertex, IREdge> streamVertexToOriginalEdge;
// To remember sampling vertex groups
private final Map<SamplingVertex, Set<SamplingVertex>> samplingVertexToGroup;
@@ -128,7 +128,7 @@
/**
* Deletes a previously inserted utility vertex.
- * (e.g., MessageBarrierVertex, StreamVertex, SamplingVertex)
+ * (e.g., TriggerVertex, RelayVertex, SamplingVertex)
* <p>
* Notice that the actual number of vertices that will be deleted after this call returns can be more than one.
* We roll back the changes made with the previous insert(), while preserving application semantics.
@@ -144,7 +144,7 @@
}
private Set<IRVertex> getVertexGroupToDelete(final IRVertex vertexToDelete) {
- if (vertexToDelete instanceof StreamVertex) {
+ if (vertexToDelete instanceof RelayVertex) {
return Sets.newHashSet(vertexToDelete);
} else if (vertexToDelete instanceof SamplingVertex) {
final Set<SamplingVertex> samplingVertexGroup = samplingVertexToGroup.get(vertexToDelete);
@@ -153,7 +153,7 @@
converted.add(sv); // explicit conversion to IRVertex is needed.. otherwise the compiler complains :(
}
return converted;
- } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof MessageBarrierVertex) {
+ } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof TriggerVertex) {
return messageVertexToGroup.get(vertexToDelete);
} else {
throw new IllegalArgumentException(vertexToDelete.getId());
@@ -200,7 +200,7 @@
Sets.difference(utilityParents, vertexGroupToDelete).forEach(ptd -> deleteRecursively(ptd, visited));
// STEP 2: Delete the specified vertex(vertices)
- if (vertexToDelete instanceof StreamVertex) {
+ if (vertexToDelete instanceof RelayVertex) {
final DAGBuilder<IRVertex, IREdge> builder = rebuildExcluding(modifiedDAG, vertexGroupToDelete);
// Add a new edge that directly connects the src of the stream vertex to its dst
@@ -214,7 +214,7 @@
.forEach(srcVertex -> builder.connectVertices(
Util.cloneEdge(streamVertexToOriginalEdge.get(vertexToDelete), srcVertex, dstVertex))));
modifiedDAG = builder.buildWithoutSourceSinkCheck();
- } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof MessageBarrierVertex) {
+ } else if (vertexToDelete instanceof MessageAggregatorVertex || vertexToDelete instanceof TriggerVertex) {
modifiedDAG = rebuildExcluding(modifiedDAG, vertexGroupToDelete).buildWithoutSourceSinkCheck();
final int deletedMessageId = vertexGroupToDelete.stream()
.filter(vtd -> vtd instanceof MessageAggregatorVertex)
@@ -245,16 +245,16 @@
* Inserts a new vertex that streams data.
* <p>
* Before: src - edgeToStreamize - dst
- * After: src - edgeToStreamizeWithNewDestination - streamVertex - oneToOneEdge - dst
+ * After: src - edgeToStreamizeWithNewDestination - relayVertex - oneToOneEdge - dst
* (replaces the "Before" relationships)
* <p>
- * This preserves semantics as the streamVertex simply forwards data elements from the input edge to the output edge.
+ * This preserves semantics as the relayVertex simply forwards data elements from the input edge to the output edge.
*
- * @param streamVertex to insert.
+ * @param relayVertex to insert.
* @param edgeToStreamize to modify.
*/
- public void insert(final StreamVertex streamVertex, final IREdge edgeToStreamize) {
- assertNonExistence(streamVertex);
+ public void insert(final RelayVertex relayVertex, final IREdge edgeToStreamize) {
+ assertNonExistence(relayVertex);
assertNonControlEdge(edgeToStreamize);
// Create a completely new DAG with the vertex inserted.
@@ -267,7 +267,7 @@
}
// Insert the vertex.
- final IRVertex vertexToInsert = wrapSamplingVertexIfNeeded(streamVertex, edgeToStreamize.getSrc());
+ final IRVertex vertexToInsert = wrapSamplingVertexIfNeeded(relayVertex, edgeToStreamize.getSrc());
builder.addVertex(vertexToInsert);
edgeToStreamize.getSrc().getPropertyValue(ParallelismProperty.class)
.ifPresent(p -> vertexToInsert.setProperty(ParallelismProperty.of(p)));
@@ -280,14 +280,14 @@
if (edge.equals(edgeToStreamize)) {
// MATCH!
- // Edge to the streamVertex
+ // Edge to the relayVertex
final IREdge toSV = new IREdge(
edgeToStreamize.getPropertyValue(CommunicationPatternProperty.class).get(),
edgeToStreamize.getSrc(),
vertexToInsert);
edgeToStreamize.copyExecutionPropertiesTo(toSV);
- // Edge from the streamVertex.
+ // Edge from the relayVertex.
final IREdge fromSV = new IREdge(CommunicationPatternProperty.Value.OneToOne, vertexToInsert, v);
fromSV.setProperty(EncoderProperty.of(edgeToStreamize.getPropertyValue(EncoderProperty.class).get()));
fromSV.setProperty(DecoderProperty.of(edgeToStreamize.getPropertyValue(DecoderProperty.class).get()));
@@ -313,12 +313,12 @@
}
});
- if (edgeToStreamize.getSrc() instanceof StreamVertex) {
- streamVertexToOriginalEdge.put(streamVertex, streamVertexToOriginalEdge.get(edgeToStreamize.getSrc()));
- } else if (edgeToStreamize.getDst() instanceof StreamVertex) {
- streamVertexToOriginalEdge.put(streamVertex, streamVertexToOriginalEdge.get(edgeToStreamize.getDst()));
+ if (edgeToStreamize.getSrc() instanceof RelayVertex) {
+ streamVertexToOriginalEdge.put(relayVertex, streamVertexToOriginalEdge.get(edgeToStreamize.getSrc()));
+ } else if (edgeToStreamize.getDst() instanceof RelayVertex) {
+ streamVertexToOriginalEdge.put(relayVertex, streamVertexToOriginalEdge.get(edgeToStreamize.getDst()));
} else {
- streamVertexToOriginalEdge.put(streamVertex, edgeToStreamize);
+ streamVertexToOriginalEdge.put(relayVertex, edgeToStreamize);
}
modifiedDAG = builder.build(); // update the DAG.
}
@@ -329,28 +329,28 @@
* For each edge in edgesToGetStatisticsOf...
* <p>
* Before: src - edge - dst
- * After: src - oneToOneEdge(a clone of edge) - messageBarrierVertex -
+ * After: src - oneToOneEdge(a clone of edge) - triggerVertex -
* shuffleEdge - messageAggregatorVertex - broadcastEdge - dst
* (the "Before" relationships are unmodified)
* <p>
* This preserves semantics as the results of the inserted message vertices are never consumed by the original IRDAG.
* <p>
- * TODO #345: Simplify insert(MessageBarrierVertex)
+ * TODO #345: Simplify insert(TriggerVertex)
*
- * @param messageBarrierVertex to insert.
+ * @param triggerVertex to insert.
* @param messageAggregatorVertex to insert.
- * @param mbvOutputEncoder to use.
- * @param mbvOutputDecoder to use.
+ * @param triggerOutputEncoder to use.
+ * @param triggerOutputDecoder to use.
* @param edgesToGetStatisticsOf to examine.
* @param edgesToOptimize to optimize.
*/
- public void insert(final MessageBarrierVertex messageBarrierVertex,
+ public void insert(final TriggerVertex triggerVertex,
final MessageAggregatorVertex messageAggregatorVertex,
- final EncoderProperty mbvOutputEncoder,
- final DecoderProperty mbvOutputDecoder,
+ final EncoderProperty triggerOutputEncoder,
+ final DecoderProperty triggerOutputDecoder,
final Set<IREdge> edgesToGetStatisticsOf,
final Set<IREdge> edgesToOptimize) {
- assertNonExistence(messageBarrierVertex);
+ assertNonExistence(triggerVertex);
assertNonExistence(messageAggregatorVertex);
edgesToGetStatisticsOf.forEach(this::assertNonControlEdge);
edgesToOptimize.forEach(this::assertNonControlEdge);
@@ -371,43 +371,43 @@
modifiedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
});
- ////////////////////////////////// STEP 1: Insert new vertices and edges (src - mbv - mav - dst)
+ ////////////////////////////////// STEP 1: Insert new vertices and edges (src - trigger - agg - dst)
- // From src to mbv
- final List<IRVertex> mbvList = new ArrayList<>();
+ // From src to trigger
+ final List<IRVertex> triggerList = new ArrayList<>();
for (final IREdge edge : edgesToGetStatisticsOf) {
- final IRVertex mbvToAdd = wrapSamplingVertexIfNeeded(
- new MessageBarrierVertex<>(messageBarrierVertex.getMessageFunction()), edge.getSrc());
- builder.addVertex(mbvToAdd);
- mbvList.add(mbvToAdd);
+ final IRVertex triggerToAdd = wrapSamplingVertexIfNeeded(
+ new TriggerVertex<>(triggerVertex.getMessageFunction()), edge.getSrc());
+ builder.addVertex(triggerToAdd);
+ triggerList.add(triggerToAdd);
edge.getSrc().getPropertyValue(ParallelismProperty.class)
- .ifPresent(p -> mbvToAdd.setProperty(ParallelismProperty.of(p)));
+ .ifPresent(p -> triggerToAdd.setProperty(ParallelismProperty.of(p)));
final IREdge edgeToClone;
- if (edge.getSrc() instanceof StreamVertex) {
+ if (edge.getSrc() instanceof RelayVertex) {
edgeToClone = streamVertexToOriginalEdge.get(edge.getSrc());
- } else if (edge.getDst() instanceof StreamVertex) {
+ } else if (edge.getDst() instanceof RelayVertex) {
edgeToClone = streamVertexToOriginalEdge.get(edge.getDst());
} else {
edgeToClone = edge;
}
final IREdge clone = Util.cloneEdge(
- CommunicationPatternProperty.Value.OneToOne, edgeToClone, edge.getSrc(), mbvToAdd);
+ CommunicationPatternProperty.Value.OneToOne, edgeToClone, edge.getSrc(), triggerToAdd);
builder.connectVertices(clone);
}
- // Add mav (no need to wrap inside sampling vertices)
+ // Add agg (no need to wrap inside sampling vertices)
builder.addVertex(messageAggregatorVertex);
- // From mbv to mav
- for (final IRVertex mbv : mbvList) {
+ // From trigger to agg
+ for (final IRVertex trigger : triggerList) {
final IREdge edgeToMav = edgeToMessageAggregator(
- mbv, messageAggregatorVertex, mbvOutputEncoder, mbvOutputDecoder);
+ trigger, messageAggregatorVertex, triggerOutputEncoder, triggerOutputDecoder);
builder.connectVertices(edgeToMav);
}
- // From mav to dst
+ // From agg to dst
// Add a control dependency (no output) from the messageAggregatorVertex to the destination.
builder.connectVertices(
Util.createControlEdge(messageAggregatorVertex, edgesToGetStatisticsOf.iterator().next().getDst()));
@@ -426,9 +426,9 @@
});
final Set<IRVertex> insertedVertices = new HashSet<>();
- insertedVertices.addAll(mbvList);
+ insertedVertices.addAll(triggerList);
insertedVertices.add(messageAggregatorVertex);
- mbvList.forEach(mbv -> messageVertexToGroup.put(mbv, insertedVertices));
+ triggerList.forEach(trigger -> messageVertexToGroup.put(trigger, insertedVertices));
messageVertexToGroup.put(messageAggregatorVertex, insertedVertices);
modifiedDAG = builder.build(); // update the DAG.
@@ -574,17 +574,17 @@
}
/**
- * @param mbv src.
- * @param mav dst.
+ * @param trigger src.
+ * @param agg dst.
* @param encoder src-dst encoder.
* @param decoder src-dst decoder.
* @return the edge.
*/
- private IREdge edgeToMessageAggregator(final IRVertex mbv,
- final IRVertex mav,
+ private IREdge edgeToMessageAggregator(final IRVertex trigger,
+ final IRVertex agg,
final EncoderProperty encoder,
final DecoderProperty decoder) {
- final IREdge newEdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, mbv, mav);
+ final IREdge newEdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, trigger, agg);
newEdge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
newEdge.setProperty(DataPersistenceProperty.of(DataPersistenceProperty.Value.Keep));
newEdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Push));
@@ -592,7 +592,7 @@
newEdge.setPropertyPermanently(decoder);
newEdge.setPropertyPermanently(KeyExtractorProperty.of(new PairKeyExtractor()));
- // TODO #345: Simplify insert(MessageBarrierVertex)
+ // TODO #345: Simplify insert(TriggerVertex)
// these are obviously wrong, but hacks for now...
newEdge.setPropertyPermanently(KeyEncoderProperty.of(encoder.getValue()));
newEdge.setPropertyPermanently(KeyDecoderProperty.of(decoder.getValue()));
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
index dafb9cd..cd14d97 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
@@ -32,7 +32,7 @@
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.*;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
+import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -68,7 +68,7 @@
addShuffleEdgeCheckers();
addPartitioningCheckers();
addEncodingCompressionCheckers();
- addMessageBarrierVertexCheckers();
+ addTriggerVertexCheckers();
addStreamVertexCheckers();
addLoopVertexCheckers();
addScheduleGroupCheckers();
@@ -305,7 +305,7 @@
neighborCheckerList.add(shuffleChecker);
}
- void addMessageBarrierVertexCheckers() {
+ void addTriggerVertexCheckers() {
final GlobalDAGChecker messageIds = (dag -> {
final long numMessageAggregatorVertices = dag.getVertices()
.stream()
@@ -468,7 +468,7 @@
///////////////////////////// Private helper methods
private boolean isConnectedToStreamVertex(final IREdge irEdge) {
- return irEdge.getDst() instanceof StreamVertex || irEdge.getSrc() instanceof StreamVertex;
+ return irEdge.getDst() instanceof RelayVertex || irEdge.getSrc() instanceof RelayVertex;
}
private Map<Optional<String>, List<IREdge>> groupOutEdgesByAdditionalOutputTag(final List<IREdge> outEdges) {
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java
index e33900f..ada8af3 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregatorTransform.java
@@ -20,13 +20,12 @@
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.OutputCollector;
+import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.function.BiFunction;
-
/**
- * A {@link Transform} that aggregates statistics generated by the {@link MessageBarrierTransform}.
+ * A {@link Transform} that aggregates statistics generated by the {@link TriggerTransform}.
*
* @param <K> input key type.
* @param <V> input value type.
@@ -34,35 +33,38 @@
*/
public final class MessageAggregatorTransform<K, V, O> extends NoWatermarkEmitTransform<Pair<K, V>, O> {
private static final Logger LOG = LoggerFactory.getLogger(MessageAggregatorTransform.class.getName());
- private OutputCollector<O> outputCollector;
- private O aggregatedDynOptData;
- private final BiFunction<Pair<K, V>, O, O> dynOptDataAggregator;
+
+ private transient O state;
+ private transient OutputCollector<O> outputCollector;
+
+ private final MessageAggregatorVertex.InitialStateSupplier<O> initialStateSupplier;
+ private final MessageAggregatorVertex.MessageAggregatorFunction<K, V, O> aggregator;
/**
* Default constructor.
- *
- * @param aggregatedDynOptData per-stage aggregated dynamic optimization data.
- * @param dynOptDataAggregator aggregator to use.
+ * @param initialStateSupplier to use.
+ * @param aggregator to use.
*/
- public MessageAggregatorTransform(final O aggregatedDynOptData,
- final BiFunction<Pair<K, V>, O, O> dynOptDataAggregator) {
- this.aggregatedDynOptData = aggregatedDynOptData;
- this.dynOptDataAggregator = dynOptDataAggregator;
+ public MessageAggregatorTransform(final MessageAggregatorVertex.InitialStateSupplier<O> initialStateSupplier,
+ final MessageAggregatorVertex.MessageAggregatorFunction<K, V, O> aggregator) {
+ this.initialStateSupplier = initialStateSupplier;
+ this.aggregator = aggregator;
}
@Override
public void prepare(final Context context, final OutputCollector<O> oc) {
+ this.state = initialStateSupplier.get();
this.outputCollector = oc;
}
@Override
public void onData(final Pair<K, V> element) {
- aggregatedDynOptData = dynOptDataAggregator.apply(element, aggregatedDynOptData);
+ state = aggregator.apply(element, state);
}
@Override
public void close() {
- outputCollector.emit(aggregatedDynOptData);
+ outputCollector.emit(state);
}
@Override
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/TriggerTransform.java
similarity index 74%
rename from common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/transform/TriggerTransform.java
index fe2e523..5c982ef 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/TriggerTransform.java
@@ -20,33 +20,34 @@
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.OutputCollector;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
-import java.util.function.BiFunction;
/**
- * A {@link Transform} that collects task-level statistics used for dynamic optimization.
+ * A {@link Transform} for the trigger vertex.
*
* @param <I> input type.
* @param <K> output key type.
* @param <V> output value type.
*/
-public final class MessageBarrierTransform<I, K, V> extends NoWatermarkEmitTransform<I, Pair<K, V>> {
- private static final Logger LOG = LoggerFactory.getLogger(MessageBarrierTransform.class.getName());
- private final BiFunction<I, Map<K, V>, Map<K, V>> userFunction;
+public final class TriggerTransform<I, K, V> extends NoWatermarkEmitTransform<I, Pair<K, V>> {
+ private static final Logger LOG = LoggerFactory.getLogger(TriggerTransform.class.getName());
- private OutputCollector<Pair<K, V>> outputCollector;
- private Map<K, V> holder;
+ private transient OutputCollector<Pair<K, V>> outputCollector;
+ private transient Map<K, V> holder;
+
+ private final TriggerVertex.MessageGeneratorFunction<I, K, V> userFunction;
/**
- * MessageBarrierTransform constructor.
+ * TriggerTransform constructor.
*
* @param userFunction that analyzes the data.
*/
- public MessageBarrierTransform(final BiFunction<I, Map<K, V>, Map<K, V>> userFunction) {
+ public TriggerTransform(final TriggerVertex.MessageGeneratorFunction<I, K, V> userFunction) {
this.userFunction = userFunction;
}
@@ -72,7 +73,7 @@
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
- sb.append(MessageBarrierTransform.class);
+ sb.append(TriggerTransform.class);
sb.append(":");
sb.append(super.toString());
return sb.toString();
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
index 2e2cdb6..cc58f27 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
@@ -26,8 +26,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.io.Serializable;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
+import java.util.function.Supplier;
/**
* Aggregates upstream messages.
@@ -41,13 +43,29 @@
private static final AtomicInteger MESSAGE_ID_GENERATOR = new AtomicInteger(0);
/**
- * @param initialState to use.
+ * @param initialStateSupplier for producing the initial state.
* @param userFunction for aggregating the messages.
*/
- public MessageAggregatorVertex(final O initialState,
- final BiFunction<Pair<K, V>, O, O> userFunction) {
- super(new MessageAggregatorTransform<>(initialState, userFunction));
+ public MessageAggregatorVertex(final InitialStateSupplier<O> initialStateSupplier,
+ final MessageAggregatorFunction<K, V, O> userFunction) {
+ super(new MessageAggregatorTransform<>(initialStateSupplier, userFunction));
this.setPropertyPermanently(MessageIdVertexProperty.of(MESSAGE_ID_GENERATOR.incrementAndGet()));
this.setProperty(ParallelismProperty.of(1));
}
+
+ /**
+ * Creates the initial aggregated message.
+ * @param <O> of the output aggregated message.
+ */
+ public interface InitialStateSupplier<O> extends Supplier<O>, Serializable {
+ }
+
+ /**
+ * Aggregates incoming messages.
+ * @param <K> of the input pair.
+ * @param <V> of the input pair.
+ * @param <O> of the output aggregated message.
+ */
+ public interface MessageAggregatorFunction<K, V, O> extends BiFunction<Pair<K, V>, O, O>, Serializable {
+ }
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/StreamVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/RelayVertex.java
similarity index 93%
rename from common/src/main/java/org/apache/nemo/common/ir/vertex/utility/StreamVertex.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/utility/RelayVertex.java
index 58d98a7..cf3d6c3 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/StreamVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/RelayVertex.java
@@ -24,11 +24,11 @@
/**
* Relays input data from upstream vertex to downstream vertex promptly.
*/
-public final class StreamVertex extends OperatorVertex {
+public final class RelayVertex extends OperatorVertex {
/**
* Constructor.
*/
- public StreamVertex() {
+ public RelayVertex() {
super(new StreamTransform());
}
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
index b614022..6201575 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
@@ -38,9 +38,9 @@
*/
public SamplingVertex(final IRVertex originalVertex, final float desiredSampleRate) {
super();
- if (!(originalVertex instanceof MessageBarrierVertex) && (Util.isUtilityVertex(originalVertex))) {
+ if (!(originalVertex instanceof TriggerVertex) && (Util.isUtilityVertex(originalVertex))) {
throw new IllegalArgumentException(
- "Cannot sample non-MessageBarrier utility vertices: " + originalVertex.toString());
+ "Cannot sample non-Trigger utility vertices: " + originalVertex.toString());
}
if (desiredSampleRate > 1 || desiredSampleRate <= 0) {
throw new IllegalArgumentException(String.valueOf(desiredSampleRate));
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageBarrierVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TriggerVertex.java
similarity index 61%
rename from common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageBarrierVertex.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TriggerVertex.java
index 2647794..bb2337d 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageBarrierVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/TriggerVertex.java
@@ -19,30 +19,41 @@
package org.apache.nemo.common.ir.vertex.utility;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
-import org.apache.nemo.common.ir.vertex.transform.MessageBarrierTransform;
+import org.apache.nemo.common.ir.vertex.transform.TriggerTransform;
+import java.io.Serializable;
import java.util.Map;
import java.util.function.BiFunction;
/**
- * Generates messages.
+ * Produces a message and triggers a run-time pass.
*
* @param <I> input type
* @param <K> of the output pair.
* @param <V> of the output pair.
*/
-public final class MessageBarrierVertex<I, K, V> extends OperatorVertex {
- private final BiFunction<I, Map<K, V>, Map<K, V>> messageFunction;
+public final class TriggerVertex<I, K, V> extends OperatorVertex {
+ private final MessageGeneratorFunction<I, K, V> messageFunction;
/**
* @param messageFunction for producing a message.
*/
- public MessageBarrierVertex(final BiFunction<I, Map<K, V>, Map<K, V>> messageFunction) {
- super(new MessageBarrierTransform<>(messageFunction));
+ public TriggerVertex(final MessageGeneratorFunction<I, K, V> messageFunction) {
+ super(new TriggerTransform<>(messageFunction));
this.messageFunction = messageFunction;
}
- public BiFunction<I, Map<K, V>, Map<K, V>> getMessageFunction() {
+ public MessageGeneratorFunction<I, K, V> getMessageFunction() {
return messageFunction;
}
+
+ /**
+ * Applied on the input data elements to produce a message.
+ *
+ * @param <I> input type
+ * @param <K> of the output pair.
+ * @param <V> of the output pair.
+ */
+ public interface MessageGeneratorFunction<I, K, V> extends BiFunction<I, Map<K, V>, Map<K, V>>, Serializable {
+ }
}
diff --git a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
index 9444c33..c33f751 100644
--- a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
+++ b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
@@ -31,9 +31,9 @@
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.*;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
+import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
-import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
import org.apache.nemo.common.test.EmptyComponents;
import org.junit.Before;
import org.junit.Test;
@@ -244,11 +244,11 @@
@Test
public void testStreamVertex() {
- final StreamVertex svOne = new StreamVertex();
+ final RelayVertex svOne = new RelayVertex();
irdag.insert(svOne, oneToOneEdge);
mustPass();
- final StreamVertex svTwo = new StreamVertex();
+ final RelayVertex svTwo = new RelayVertex();
irdag.insert(svTwo, shuffleEdge);
mustPass();
@@ -260,11 +260,11 @@
}
@Test
- public void testMessageBarrierVertex() {
- final MessageAggregatorVertex maOne = insertNewMessageBarrierVertex(irdag, oneToOneEdge);
+ public void testTriggerVertex() {
+ final MessageAggregatorVertex maOne = insertNewTriggerVertex(irdag, oneToOneEdge);
mustPass();
- final MessageAggregatorVertex maTwo = insertNewMessageBarrierVertex(irdag, shuffleEdge);
+ final MessageAggregatorVertex maTwo = insertNewTriggerVertex(irdag, shuffleEdge);
mustPass();
irdag.delete(maTwo);
@@ -292,9 +292,9 @@
mustPass();
}
- private MessageAggregatorVertex insertNewMessageBarrierVertex(final IRDAG dag, final IREdge edgeToGetStatisticsOf) {
- final MessageBarrierVertex mb = new MessageBarrierVertex<>((l, r) -> null);
- final MessageAggregatorVertex ma = new MessageAggregatorVertex<>(new Object(), (l, r) -> null);
+ private MessageAggregatorVertex insertNewTriggerVertex(final IRDAG dag, final IREdge edgeToGetStatisticsOf) {
+ final TriggerVertex mb = new TriggerVertex<>((l, r) -> null);
+ final MessageAggregatorVertex ma = new MessageAggregatorVertex<>(() -> new Object(), (l, r) -> null);
dag.insert(
mb,
ma,
@@ -344,15 +344,15 @@
// Reshaping methods
case 7:
- final StreamVertex streamVertex = new StreamVertex();
+ final RelayVertex relayVertex = new RelayVertex();
final IREdge edgeToStreamize = selectRandomEdge();
if (!(edgeToStreamize.getPropertyValue(MessageIdEdgeProperty.class).isPresent()
&& !edgeToStreamize.getPropertyValue(MessageIdEdgeProperty.class).get().isEmpty())) {
- irdag.insert(streamVertex, edgeToStreamize);
+ irdag.insert(relayVertex, edgeToStreamize);
}
break;
case 8:
- insertNewMessageBarrierVertex(irdag, selectRandomEdge());
+ insertNewTriggerVertex(irdag, selectRandomEdge());
break;
case 9:
final IRVertex vertexToSample = selectRandomNonUtilityVertex();
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
index 120c598..431b225 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
@@ -24,11 +24,11 @@
import org.apache.nemo.common.ir.edge.executionproperty.DataPersistenceProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DataStoreProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
-import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
+import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
/**
- * This pass assumes that a StreamVertex was previously inserted to receive each shuffle edge.
+ * This pass assumes that a RelayVertex was previously inserted to receive each shuffle edge.
* <p>
* src - shuffle-edge - streamvertex - one-to-one-edge - dst
* <p>
@@ -57,7 +57,7 @@
public IRDAG apply(final IRDAG dag) {
dag.topologicalDo(irVertex ->
dag.getIncomingEdgesOf(irVertex).forEach(edge -> {
- if (edge.getDst().getClass().equals(StreamVertex.class)) {
+ if (edge.getDst().getClass().equals(RelayVertex.class)) {
// CASE #1: To a stream vertex
// Data transfers
@@ -67,7 +67,7 @@
// Resource slots
edge.getDst().setPropertyPermanently(ResourceSlotProperty.of(false));
- } else if (edge.getSrc().getClass().equals(StreamVertex.class)) {
+ } else if (edge.getSrc().getClass().equals(RelayVertex.class)) {
// CASE #2: From a stream vertex
// Data transfers
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleReshapingPass.java
index 442c018..4088fb8 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleReshapingPass.java
@@ -20,11 +20,11 @@
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
-import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
+import org.apache.nemo.common.ir.vertex.utility.RelayVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
/**
- * Inserts the StreamVertex for each shuffle edge.
+ * Inserts the RelayVertex for each shuffle edge.
*/
@Requires(CommunicationPatternProperty.class)
public final class LargeShuffleReshapingPass extends ReshapingPass {
@@ -43,7 +43,7 @@
dag.getIncomingEdgesOf(vertex).forEach(edge -> {
if (CommunicationPatternProperty.Value.Shuffle
.equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
- dag.insert(new StreamVertex(), edge);
+ dag.insert(new RelayVertex(), edge);
}
});
});
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
index efb9416..50412a9 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
@@ -27,7 +27,7 @@
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
@@ -55,7 +55,7 @@
* (P3 is not cloned here because it is a sink partition, and none of the outgoing edges of its vertices needs to be
* optimized)
* <p>
- * For each Px' this pass also inserts a MessageBarrierVertex, to use its data statistics for dynamically optimizing
+ * For each Px' this pass also inserts a TriggerVertex, to use its data statistics for dynamically optimizing
* the execution behaviors of Px.
*/
@Requires(CommunicationPatternProperty.class)
@@ -112,8 +112,8 @@
final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
dag.insert(
- new MessageBarrierVertex<>(SkewHandlingUtil.getDynOptCollector(keyExtractor)),
- new MessageAggregatorVertex(new HashMap(), SkewHandlingUtil.getDynOptAggregator()),
+ new TriggerVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor)),
+ new MessageAggregatorVertex(() -> new HashMap<>(), SkewHandlingUtil.getMessageAggregator()),
SkewHandlingUtil.getEncoder(e),
SkewHandlingUtil.getDecoder(e),
new HashSet<>(Arrays.asList(clonedShuffleEdge)), // this works although the clone is not in the dag
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
index c176ad6..b724bdd 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
@@ -19,7 +19,6 @@
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
import org.apache.nemo.common.KeyExtractor;
-import org.apache.nemo.common.Pair;
import org.apache.nemo.common.coder.LongDecoderFactory;
import org.apache.nemo.common.coder.LongEncoderFactory;
import org.apache.nemo.common.coder.PairDecoderFactory;
@@ -29,10 +28,10 @@
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyDecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyEncoderProperty;
+import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
-import java.io.Serializable;
import java.util.Map;
-import java.util.function.BiFunction;
/**
* A utility class for skew handling passes.
@@ -41,31 +40,30 @@
private SkewHandlingUtil() {
}
- static BiFunction<Object, Map<Object, Long>, Map<Object, Long>> getDynOptCollector(final KeyExtractor keyExtractor) {
- return (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
- (element, dynOptData) -> {
- Object key = keyExtractor.extractKey(element);
- if (dynOptData.containsKey(key)) {
- dynOptData.compute(key, (existingKey, existingCount) -> (long) existingCount + 1L);
- } else {
- dynOptData.put(key, 1L);
- }
- return dynOptData;
- };
+ static TriggerVertex.MessageGeneratorFunction<Object, Object, Long> getMessageGenerator(
+ final KeyExtractor keyExtractor) {
+ return (element, dynOptData) -> {
+ Object key = keyExtractor.extractKey(element);
+ if (dynOptData.containsKey(key)) {
+ dynOptData.compute(key, (existingKey, existingCount) -> (long) existingCount + 1L);
+ } else {
+ dynOptData.put(key, 1L);
+ }
+ return dynOptData;
+ };
}
- static BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> getDynOptAggregator() {
- return (BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> & Serializable)
- (element, aggregatedDynOptData) -> {
- final Object key = element.left();
- final Long count = element.right();
- if (aggregatedDynOptData.containsKey(key)) {
- aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
- } else {
- aggregatedDynOptData.put(key, count);
- }
- return aggregatedDynOptData;
- };
+ static MessageAggregatorVertex.MessageAggregatorFunction<Object, Long, Map<Object, Long>> getMessageAggregator() {
+ return (element, aggregatedDynOptData) -> {
+ final Object key = element.left();
+ final Long count = element.right();
+ if (aggregatedDynOptData.containsKey(key)) {
+ aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
+ } else {
+ aggregatedDynOptData.put(key, count);
+ }
+ return aggregatedDynOptData;
+ };
}
static EncoderProperty getEncoder(final IREdge irEdge) {
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index 0b843dd..8985b92 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -25,7 +25,7 @@
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
-import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.TriggerVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -38,7 +38,7 @@
/**
* Pass to reshape the IR DAG for skew handling.
- * We insert a {@link MessageBarrierVertex} for each shuffle edge,
+ * We insert a {@link TriggerVertex} for each shuffle edge,
* and aggregate messages for multiple same-destination shuffle edges.
*/
@Requires(CommunicationPatternProperty.class)
@@ -74,10 +74,10 @@
final KeyExtractor keyExtractor = representativeEdge.getPropertyValue(KeyExtractorProperty.class).get();
// Insert the vertices
- final MessageBarrierVertex mbv = new MessageBarrierVertex<>(SkewHandlingUtil.getDynOptCollector(keyExtractor));
+ final TriggerVertex trigger = new TriggerVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor));
final MessageAggregatorVertex mav =
- new MessageAggregatorVertex(new HashMap(), SkewHandlingUtil.getDynOptAggregator());
- dag.insert(mbv, mav, SkewHandlingUtil.getEncoder(representativeEdge),
+ new MessageAggregatorVertex(() -> new HashMap(), SkewHandlingUtil.getMessageAggregator());
+ dag.insert(trigger, mav, SkewHandlingUtil.getEncoder(representativeEdge),
SkewHandlingUtil.getDecoder(representativeEdge), shuffleEdgeGroup, shuffleEdgeGroup);
}
});
diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
index c2d51d1..8728803 100644
--- a/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
+++ b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
@@ -26,7 +26,7 @@
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ResourceAntiAffinityProperty;
import org.apache.nemo.common.ir.vertex.transform.MessageAggregatorTransform;
-import org.apache.nemo.common.ir.vertex.transform.MessageBarrierTransform;
+import org.apache.nemo.common.ir.vertex.transform.TriggerTransform;
import org.apache.nemo.compiler.CompilerTestUtil;
import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.AnnotatingPass;
import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass;
@@ -74,7 +74,7 @@
/**
* Test for {@link SkewCompositePass} with MR workload.
- * It should have inserted vertex with {@link MessageBarrierTransform}
+ * It should have inserted vertex with {@link TriggerTransform}
* and vertex with {@link MessageAggregatorTransform} for each shuffle edge.
*
* @throws Exception exception on the way.
@@ -95,7 +95,7 @@
assertEquals(originalVerticesNum + numOfShuffleEdges * 2, processedDAG.getVertices().size());
processedDAG.filterVertices(v -> v instanceof OperatorVertex
- && ((OperatorVertex) v).getTransform() instanceof MessageBarrierTransform)
+ && ((OperatorVertex) v).getTransform() instanceof TriggerTransform)
.forEach(metricV -> {
final List<IRVertex> reducerV = processedDAG.getChildren(metricV.getId());
reducerV.forEach(rV -> {
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
index b812136..a37fa52 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -170,7 +170,7 @@
.map(edge -> edge.getExecutionProperties().get(MessageIdEdgeProperty.class).get())
.findFirst().get();
// Here we simply use findFirst() for now...
- // TODO #345: Simplify insert(MessageBarrierVertex)
+ // TODO #345: Simplify insert
return messageIds.iterator().next();
}