[NEMO-327] Fix skew handling for multi shuffle edge receiver (#189)

JIRA: [NEMO-327: Fix skew handling for multi shuffle edge receive](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-327)

**Major changes:**
- For the case that a vertex receives multiple shuffle edges, makes `DataSkewPolicy` collect metric data for the shuffle edges in a single metric aggregation vertex and optimize the edges at once.

**Minor changes to note:**
- Makes runtime pass receive multiple target edges.

**Tests for the changes:**
- Adds an integration test with the data skew policy for `NetworkTraceAnalysis` that has a join.

**Other comments:**
- N/A.

Closes #189
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
index 6a6ca4d..30826a8 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
@@ -24,6 +24,7 @@
 import org.apache.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
 import org.apache.nemo.common.ir.vertex.*;
 import org.apache.nemo.common.exception.IllegalVertexOperationException;
+import org.apache.nemo.common.ir.vertex.transform.AggregateMetricTransform;
 
 import java.io.Serializable;
 import java.util.*;
@@ -259,8 +260,9 @@
   private void executionPropertyCheck() {
     // DataSizeMetricCollection is not compatible with Push (All data have to be stored before the data collection)
     vertices.forEach(v -> incomingEdges.get(v).stream().filter(e -> e instanceof IREdge).map(e -> (IREdge) e)
-        .filter(e -> Optional.of(MetricCollectionProperty.Value.DataSkewRuntimePass)
-            .equals(e.getPropertyValue(MetricCollectionProperty.class)))
+        .filter(e -> e.getPropertyValue(MetricCollectionProperty.class).isPresent())
+        .filter(e -> !(e.getDst() instanceof OperatorVertex
+          && ((OperatorVertex) e.getDst()).getTransform() instanceof AggregateMetricTransform))
         .filter(e -> DataFlowProperty.Value.Push.equals(e.getPropertyValue(DataFlowProperty.class).get()))
         .forEach(e -> {
           throw new CompileTimeOptimizationException("DAG execution property check: "
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java
index c9749a8..fd7dfb6 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java
@@ -23,12 +23,12 @@
 /**
  * MetricCollection ExecutionProperty that indicates the edge of which data metric will be collected.
  */
-public final class MetricCollectionProperty extends EdgeExecutionProperty<MetricCollectionProperty.Value> {
+public final class MetricCollectionProperty extends EdgeExecutionProperty<Integer> {
   /**
    * Constructor.
    * @param value value of the execution property.
    */
-  private MetricCollectionProperty(final Value value) {
+  private MetricCollectionProperty(final Integer value) {
     super(value);
   }
 
@@ -37,14 +37,7 @@
    * @param value value of the new execution property.
    * @return the newly created execution property.
    */
-  public static MetricCollectionProperty of(final Value value) {
+  public static MetricCollectionProperty of(final Integer value) {
     return new MetricCollectionProperty(value);
   }
-
-  /**
-   * Possible values of MetricCollection ExecutionProperty.
-   */
-  public enum Value {
-    DataSkewRuntimePass
-  }
 }
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index da436d9..57a1647 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -31,6 +31,7 @@
 import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
 import org.apache.nemo.common.ir.vertex.IRVertex;
 import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
 import org.apache.nemo.common.ir.vertex.transform.AggregateMetricTransform;
 import org.apache.nemo.common.ir.vertex.transform.MetricCollectTransform;
 import org.apache.nemo.compiler.optimizer.PairKeyExtractor;
@@ -40,9 +41,7 @@
 import org.slf4j.LoggerFactory;
 
 import java.io.Serializable;
-import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.function.BiFunction;
 
@@ -68,10 +67,12 @@
 
   @Override
   public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
+    int mcCount = 0;
+    // destination vertex ID to metric aggregation vertex - ID pair map
+    final Map<String, Pair<OperatorVertex, Integer>> dstVtxIdToABV = new HashMap<>();
     final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
-    final List<OperatorVertex> metricCollectVertices = new ArrayList<>();
 
-    dag.topologicalDo(v -> {
+    for (final IRVertex v : dag.getTopologicalSort()) {
       // We care about OperatorVertices that have shuffle incoming edges with main output.
       // TODO #210: Data-aware dynamic optimization at run-time
       if (v instanceof OperatorVertex && dag.getIncomingEdgesOf(v).stream().anyMatch(irEdge ->
@@ -80,15 +81,32 @@
         && dag.getIncomingEdgesOf(v).stream().noneMatch(irEdge ->
       irEdge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())) {
 
-        dag.getIncomingEdgesOf(v).forEach(edge -> {
+        for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
           if (CommunicationPatternProperty.Value.Shuffle
-                .equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
-            final OperatorVertex abv = generateMetricAggregationVertex();
+            .equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+            final String dstId = edge.getDst().getId();
+
+            // Get or generate a metric collection vertex.
+            final int metricCollectionId;
+            final OperatorVertex abv;
+            if (!dstVtxIdToABV.containsKey(dstId)) {
+              // There is no metric aggregation vertex for this destination vertex.
+              metricCollectionId = mcCount++;
+              abv = generateMetricAggregationVertex();
+              builder.addVertex(abv);
+
+              abv.setPropertyPermanently(ResourceSlotProperty.of(false));
+              dstVtxIdToABV.put(dstId, Pair.of(abv, metricCollectionId));
+            } else {
+              // There is a metric aggregation vertex for this destination vertex already.
+              final Pair<OperatorVertex, Integer> aggrPair = dstVtxIdToABV.get(dstId);
+              metricCollectionId = aggrPair.right();
+              abv = aggrPair.left();
+            }
+
             final OperatorVertex mcv = generateMetricCollectVertex(edge);
-            metricCollectVertices.add(mcv);
             builder.addVertex(v);
             builder.addVertex(mcv);
-            builder.addVertex(abv);
 
             // We then insert the vertex with MetricCollectTransform and vertex with AggregateMetricTransform
             // between the vertex and incoming vertices.
@@ -97,8 +115,8 @@
             final IREdge edgeToOriginalDstV =
               new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), edge.getSrc(), v);
             edge.copyExecutionPropertiesTo(edgeToOriginalDstV);
-            edgeToOriginalDstV.setPropertyPermanently(
-              MetricCollectionProperty.of(MetricCollectionProperty.Value.DataSkewRuntimePass));
+            edgeToOriginalDstV.setPropertyPermanently(MetricCollectionProperty.of(metricCollectionId));
+            edgeToABV.setPropertyPermanently(MetricCollectionProperty.of(metricCollectionId));
 
             builder.connectVertices(edgeToMCV);
             builder.connectVertices(edgeToABV);
@@ -111,12 +129,12 @@
           } else {
             builder.connectVertices(edge);
           }
-        });
+        }
       } else { // Others are simply added to the builder, unless it comes from an updated vertex
         builder.addVertex(v);
         dag.getIncomingEdgesOf(v).forEach(builder::connectVertices);
       }
-    });
+    }
     final DAG<IRVertex, IREdge> newDAG = builder.build();
     return newDAG;
   }
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java
index 4005731..3c031b0 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java
@@ -22,6 +22,7 @@
 import org.apache.nemo.common.test.ArgBuilder;
 import org.apache.nemo.common.test.ExampleTestArgs;
 import org.apache.nemo.common.test.ExampleTestUtil;
+import org.apache.nemo.examples.beam.policy.DataSkewPolicyParallelismFive;
 import org.apache.nemo.examples.beam.policy.DefaultPolicyParallelismFive;
 import org.apache.nemo.examples.beam.policy.TransientResourcePolicyParallelismFive;
 import org.apache.nemo.examples.beam.policy.LargeShufflePolicyParallelismFive;
@@ -86,4 +87,16 @@
         .addOptimizationPolicy(TransientResourcePolicyParallelismFive.class.getCanonicalName())
         .build());
   }
+
+  /**
+   * Testing data skew dynamic optimization.
+   * @throws Exception exception on the way.
+   */
+  @Test (timeout = ExampleTestArgs.TIMEOUT)
+  public void testDataSkew() throws Exception {
+    JobLauncher.main(builder
+      .addJobId(NetworkTraceAnalysisITCase.class.getSimpleName() + "_skew")
+      .addOptimizationPolicy(DataSkewPolicyParallelismFive.class.getCanonicalName())
+      .build());
+  }
 }
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java
index 58d6745..3623d8e 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java
@@ -22,6 +22,8 @@
 import org.apache.nemo.runtime.common.plan.PhysicalPlan;
 import org.apache.nemo.runtime.common.plan.StageEdge;
 
+import java.util.Set;
+
 /**
  * An event for triggering dynamic optimization.
  */
@@ -30,24 +32,27 @@
   private final Object dynOptData;
   private final String taskId;
   private final String executorId;
-  private final StageEdge targetEdge;
+  private final Set<StageEdge> targetEdges;
 
   /**
    * Default constructor.
-   * @param physicalPlan physical plan to be optimized.
-   * @param taskId id of the task which triggered the dynamic optimization.
-   * @param executorId the id of executor which executes {@code taskId}
+   *
+   * @param physicalPlan the physical plan to be optimized.
+   * @param dynOptData   the metric data.
+   * @param taskId       the ID of the task which triggered the dynamic optimization.
+   * @param executorId   the ID of executor which executes {@code taskId}
+   * @param targetEdges  the target edges.
    */
   public DynamicOptimizationEvent(final PhysicalPlan physicalPlan,
                                   final Object dynOptData,
                                   final String taskId,
                                   final String executorId,
-                                  final StageEdge targetEdge) {
+                                  final Set<StageEdge> targetEdges) {
     this.physicalPlan = physicalPlan;
     this.taskId = taskId;
     this.dynOptData = dynOptData;
     this.executorId = executorId;
-    this.targetEdge = targetEdge;
+    this.targetEdges = targetEdges;
   }
 
   /**
@@ -75,7 +80,7 @@
     return this.dynOptData;
   }
 
-  public StageEdge getTargetEdge() {
-    return this.targetEdge;
+  public Set<StageEdge> getTargetEdges() {
+    return this.targetEdges;
   }
 }
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java
index 1ed7ddf..11472a0 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java
@@ -26,6 +26,7 @@
 import org.apache.reef.wake.impl.PubSubEventHandler;
 
 import javax.inject.Inject;
+import java.util.Set;
 
 /**
  * Class for handling event to perform dynamic optimization.
@@ -51,9 +52,9 @@
   public void onNext(final DynamicOptimizationEvent dynamicOptimizationEvent) {
     final PhysicalPlan physicalPlan = dynamicOptimizationEvent.getPhysicalPlan();
     final Object dynOptData = dynamicOptimizationEvent.getDynOptData();
-    final StageEdge targetEdge = dynamicOptimizationEvent.getTargetEdge();
+    final Set<StageEdge> targetEdges = dynamicOptimizationEvent.getTargetEdges();
 
-    final PhysicalPlan newPlan = RunTimeOptimizer.dynamicOptimization(physicalPlan, dynOptData, targetEdge);
+    final PhysicalPlan newPlan = RunTimeOptimizer.dynamicOptimization(physicalPlan, dynOptData, targetEdges);
 
     pubSubEventHandler.onNext(new UpdatePhysicalPlanEvent(newPlan, dynamicOptimizationEvent.getTaskId(),
         dynamicOptimizationEvent.getExecutorId()));
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java
index f464e69..c7ea42f 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java
@@ -38,18 +38,20 @@
   /**
    * Dynamic optimization method to process the dag with an appropriate pass, decided by the stats.
    *
-   * @param originalPlan original physical execution plan.
+   * @param originalPlan the original physical execution plan.
+   * @param dynOptData   the data metric.
+   * @param targetEdges  the target edges.
    * @return the newly updated optimized physical plan.
    */
   public static synchronized PhysicalPlan dynamicOptimization(
           final PhysicalPlan originalPlan,
           final Object dynOptData,
-          final StageEdge targetEdge) {
+          final Set<StageEdge> targetEdges) {
     // Data for dynamic optimization used in DataSkewRuntimePass
     // is a map of <hash value, partition size>.
     final PhysicalPlan physicalPlan =
       new DataSkewRuntimePass()
-        .apply(originalPlan, Pair.of(targetEdge, (Map<Object, Long>) dynOptData));
+        .apply(originalPlan, Pair.of(targetEdges, (Map<Object, Long>) dynOptData));
     return physicalPlan;
   }
 }
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
index 2f6c998..d0a0be2 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
@@ -45,7 +45,7 @@
  * this RuntimePass identifies a number of keys with big partition sizes(skewed key)
  * and evenly redistributes data via overwriting incoming edges of destination tasks.
  */
-public final class DataSkewRuntimePass extends RuntimePass<Pair<StageEdge, Map<Object, Long>>> {
+public final class DataSkewRuntimePass extends RuntimePass<Pair<Set<StageEdge>, Map<Object, Long>>> {
   private static final Logger LOG = LoggerFactory.getLogger(DataSkewRuntimePass.class.getName());
   private static final int DEFAULT_NUM_SKEWED_KEYS = 1;
   /*
@@ -57,7 +57,7 @@
    * The reason why we do not divide the output into a fixed number is that the fixed number can be smaller than
    * the destination task parallelism.
    */
-  public static final int HASH_RANGE_MULTIPLIER = 10;
+  public static final int HASH_RANGE_MULTIPLIER = 5;
 
   private final Set<Class<? extends RuntimeEventHandler>> eventHandlers;
   // Skewed keys denote for top n keys in terms of partition size.
@@ -87,17 +87,20 @@
 
   @Override
   public PhysicalPlan apply(final PhysicalPlan originalPlan,
-                            final Pair<StageEdge, Map<Object, Long>> metricData) {
-    final StageEdge targetEdge = metricData.left();
+                            final Pair<Set<StageEdge>, Map<Object, Long>> metricData) {
+    final Set<StageEdge> targetEdges = metricData.left();
     // Get number of evaluators of the next stage (number of blocks).
-    final Integer dstParallelism = targetEdge.getDst().getPropertyValue(ParallelismProperty.class).
-        orElseThrow(() -> new RuntimeException("No parallelism on a vertex"));
+    final StageEdge firstEdge = targetEdges.stream().findFirst()
+      .orElseThrow(() -> new RuntimeException("Empty target edge set!"));
+    final Integer dstParallelism =  firstEdge
+      .getDst().getPropertyValue(ParallelismProperty.class)
+      .orElseThrow(() -> new RuntimeException("No parallelism on a vertex"));
     if (!PartitionerProperty.Value.DataSkewHashPartitioner
-      .equals(targetEdge.getPropertyValue(PartitionerProperty.class)
+      .equals(firstEdge.getPropertyValue(PartitionerProperty.class)
         .orElseThrow(() -> new RuntimeException("No partitioner property!")))) {
       throw new RuntimeException("Invalid partitioner is assigned to the target edge!");
     }
-    final DataSkewHashPartitioner partitioner = (DataSkewHashPartitioner) Partitioner.getPartitioner(targetEdge);
+    final DataSkewHashPartitioner partitioner = (DataSkewHashPartitioner) Partitioner.getPartitioner(firstEdge);
 
     // Calculate keyRanges.
     final List<KeyRange> keyRanges = calculateKeyRanges(metricData.right(), dstParallelism, partitioner);
@@ -108,10 +111,9 @@
 
     // Overwrite the previously assigned key range in the physical DAG with the new range.
     final DAG<Stage, StageEdge> stageDAG = originalPlan.getStageDAG();
-    for (Stage stage : stageDAG.getVertices()) {
-      List<StageEdge> stageEdges = stageDAG.getOutgoingEdgesOf(stage);
-      for (StageEdge edge : stageEdges) {
-        if (edge.equals(targetEdge)) {
+    for (final Stage stage : stageDAG.getVertices()) {
+      for (final StageEdge edge : stageDAG.getOutgoingEdgesOf(stage)) {
+        if (targetEdges.contains(edge)) {
           edge.setTaskIdxToKeyRange(taskIdxToKeyRange);
         }
       }
@@ -163,12 +165,11 @@
                                            final Integer dstParallelism,
                                            final Partitioner<Integer> partitioner) {
     final Map<Integer, Long> partitionKeyToPartitionCount = new HashMap<>();
-    int lastKey = 0;
+    int lastKey = dstParallelism * HASH_RANGE_MULTIPLIER - 1;
     // Aggregate the counts per each "partition key" assigned by Partitioner.
 
     for (final Map.Entry<Object, Long> entry : keyToCountMap.entrySet()) {
       final int partitionKey = partitioner.partition(entry.getKey());
-      lastKey = Math.max(lastKey, partitionKey);
       partitionKeyToPartitionCount.compute(partitionKey,
         (existPartitionKey, prevCount) -> (prevCount == null) ? entry.getValue() : prevCount + entry.getValue());
     }
@@ -231,6 +232,7 @@
             currentAccumulatedSize - prevAccumulatedSize);
       }
     }
+
     return keyRanges;
   }
 }
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java
index 910a4b3..2f0d7a1 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java
@@ -23,8 +23,6 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.math.BigInteger;
-
 /**
  * An implementation of {@link Partitioner} which hashes output data from a source task appropriate to detect data skew.
  * It hashes data finer than {@link HashPartitioner}.
@@ -38,7 +36,6 @@
 public final class DataSkewHashPartitioner implements Partitioner<Integer> {
   private static final Logger LOG = LoggerFactory.getLogger(DataSkewHashPartitioner.class.getName());
   private final KeyExtractor keyExtractor;
-  private final BigInteger hashRangeBase;
   private final int hashRange;
 
   /**
@@ -51,10 +48,7 @@
                                  final KeyExtractor keyExtractor) {
     this.keyExtractor = keyExtractor;
     // For this hash range, please check the description of HashRangeMultiplier in JobConf.
-    // For actual hash range to use, we calculate a prime number right next to the desired hash range.
-    this.hashRangeBase = new BigInteger(String.valueOf(dstParallelism * DataSkewRuntimePass.HASH_RANGE_MULTIPLIER));
-    this.hashRange = hashRangeBase.nextProbablePrime().intValue();
-    LOG.info("hashRangeBase {} resulting hashRange {}", hashRangeBase, hashRange);
+    this.hashRange = dstParallelism * DataSkewRuntimePass.HASH_RANGE_MULTIPLIER;
   }
 
   @Override
diff --git a/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java b/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
index 605c5c3..ce4103d 100644
--- a/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
+++ b/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
@@ -21,7 +21,7 @@
 import org.apache.nemo.common.HashRange;
 import org.apache.nemo.common.KeyExtractor;
 import org.apache.nemo.common.KeyRange;
-import org.apache.nemo.runtime.common.partitioner.HashPartitioner;
+import org.apache.nemo.runtime.common.partitioner.DataSkewHashPartitioner;
 import org.apache.nemo.runtime.common.partitioner.Partitioner;
 import org.junit.Before;
 import org.junit.Test;
@@ -39,9 +39,9 @@
   @Before
   public void setUp() {
     // Skewed partition size lists
-    buildPartitionSizeList(Arrays.asList(5L, 5L, 10L, 50L, 100L));
-    buildPartitionSizeList(Arrays.asList(5L, 10L, 5L, 0L, 0L));
-    buildPartitionSizeList(Arrays.asList(10L, 5L, 5L, 0L, 0L));
+    buildPartitionSizeList(Arrays.asList(5L, 5L, 10L, 50L, 110L, 5L, 5L, 10L, 50L, 100L));
+    buildPartitionSizeList(Arrays.asList(5L, 10L, 5L, 0L, 0L, 5L, 10L, 5L, 0L, 0L));
+    buildPartitionSizeList(Arrays.asList(10L, 5L, 5L, 0L, 0L, 10L, 5L, 5L, 0L, 0L));
   }
 
   /**
@@ -50,31 +50,22 @@
    */
   @Test
   public void testDataSkewDynamicOptimizationPass() {
-    final Integer taskNum = 5;
+    final Integer taskNum = 2;
     final KeyExtractor asIsExtractor = new AsIsKeyExtractor();
-    final Partitioner partitioner = new HashPartitioner(taskNum, asIsExtractor);
+    final Partitioner partitioner = new DataSkewHashPartitioner(taskNum, asIsExtractor);
 
     final List<KeyRange> keyRanges =
-        new DataSkewRuntimePass(2).calculateKeyRanges(testMetricData, taskNum, partitioner);
+        new DataSkewRuntimePass(1).calculateKeyRanges(testMetricData, taskNum, partitioner);
 
     // Test whether it correctly redistributed hash ranges.
     assertEquals(0, keyRanges.get(0).rangeBeginInclusive());
-    assertEquals(2, keyRanges.get(0).rangeEndExclusive());
-    assertEquals(2, keyRanges.get(1).rangeBeginInclusive());
-    assertEquals(3, keyRanges.get(1).rangeEndExclusive());
-    assertEquals(3, keyRanges.get(2).rangeBeginInclusive());
-    assertEquals(4, keyRanges.get(2).rangeEndExclusive());
-    assertEquals(4, keyRanges.get(3).rangeBeginInclusive());
-    assertEquals(5, keyRanges.get(3).rangeEndExclusive());
-    assertEquals(5, keyRanges.get(4).rangeBeginInclusive());
-    assertEquals(5, keyRanges.get(4).rangeEndExclusive());
+    assertEquals(5, keyRanges.get(0).rangeEndExclusive());
+    assertEquals(5, keyRanges.get(1).rangeBeginInclusive());
+    assertEquals(10, keyRanges.get(1).rangeEndExclusive());
 
     // Test whether it caught the provided skewness.
-    assertEquals(false, ((HashRange)keyRanges.get(0)).isSkewed());
+    assertEquals(true, ((HashRange)keyRanges.get(0)).isSkewed());
     assertEquals(false, ((HashRange)keyRanges.get(1)).isSkewed());
-    assertEquals(true, ((HashRange)keyRanges.get(2)).isSkewed());
-    assertEquals(true, ((HashRange)keyRanges.get(3)).isSkewed());
-    assertEquals(false, ((HashRange)keyRanges.get(4)).isSkewed());
   }
 
   /**
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
index bb76bc5..5605658 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -422,42 +422,52 @@
   }
 
   /**
-   * Get the target edge of dynamic optimization.
-   * The edge is annotated with {@link MetricCollectionProperty}, which is an outgoing edge of
-   * a parent of the stage put on hold.
+   * Get the target edges of dynamic optimization.
+   * The edges are annotated with {@link MetricCollectionProperty}, which are outgoing edges of
+   * parents of the stage put on hold.
    *
    * See {@link org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SkewReshapingPass}
-   * for setting the target edge of dynamic optimization.
+   * for setting the target edges of dynamic optimization.
    *
    * @param taskId the task ID that sent stage-level aggregated metric for dynamic optimization.
-   * @return the edge to optimize.
+   * @return the edges to optimize.
    */
-  private StageEdge getEdgeToOptimize(final String taskId) {
+  private Set<StageEdge> getEdgesToOptimize(final String taskId) {
+    final DAG<Stage, StageEdge> stageDag = planStateManager.getPhysicalPlan().getStageDAG();
+
     // Get a stage including the given task
-    final Stage stagePutOnHold = planStateManager.getPhysicalPlan().getStageDAG().getVertices().stream()
+    final Stage stagePutOnHold = stageDag.getVertices().stream()
       .filter(stage -> stage.getId().equals(RuntimeIdManager.getStageIdFromTaskId(taskId)))
       .findFirst()
       .orElseThrow(() -> new RuntimeException());
 
     // Stage put on hold, i.e. stage with vertex containing AggregateMetricTransform
     // should have a parent stage whose outgoing edges contain the target edge of dynamic optimization.
-    final List<Stage> parentStages = planStateManager.getPhysicalPlan().getStageDAG()
-      .getParents(stagePutOnHold.getId());
+    final List<StageEdge> edgesToStagePutOnHold = stageDag.getIncomingEdgesOf(stagePutOnHold);
+    if (edgesToStagePutOnHold.isEmpty()) {
+      throw new RuntimeException("No edges toward specified put on hold stage");
+    }
+    final int metricCollectionId = edgesToStagePutOnHold.get(0).getPropertyValue(MetricCollectionProperty.class)
+      .orElseThrow(() -> new RuntimeException("No metric collection property value for this put on hold stage"));
 
-    if (parentStages.size() > 1) {
-      throw new RuntimeException("Error in setting target edge of dynamic optimization!");
+    final Set<StageEdge> targetEdges = new HashSet<>();
+
+    // Get edges with identical MetricCollectionProperty (except the put on hold stage)
+    for (final Stage stage : stageDag.getVertices()) {
+      final Set<StageEdge> targetEdgesFound = stageDag.getOutgoingEdgesOf(stage).stream()
+        .filter(candidateEdge -> {
+          final Optional<Integer> candidateMCId =
+            candidateEdge.getPropertyValue(MetricCollectionProperty.class);
+          return candidateMCId.isPresent() && candidateMCId.get().equals(metricCollectionId)
+            && !edgesToStagePutOnHold.contains(candidateEdge);
+        })
+        .collect(Collectors.toSet());
+      targetEdges.addAll(targetEdgesFound);
     }
 
-    // Get outgoing edges of that stage with MetricCollectionProperty
-    final List<StageEdge> stageEdges = planStateManager.getPhysicalPlan().getStageDAG()
-      .getOutgoingEdgesOf(parentStages.get(0));
-    for (StageEdge edge : stageEdges) {
-      if (edge.getExecutionProperties().containsKey(MetricCollectionProperty.class)) {
-        return edge;
-      }
-    }
-
-    return null;
+    LOG.info("Target edges to optimize: {}",
+      targetEdges.stream().map(edge -> edge.getId()).collect(Collectors.toSet()));
+    return targetEdges;
   }
 
   /**
@@ -478,8 +488,8 @@
     final boolean stageComplete =
       planStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE);
 
-    final StageEdge targetEdge = getEdgeToOptimize(taskId);
-    if (targetEdge == null) {
+    final Set<StageEdge> targetEdges = getEdgesToOptimize(taskId);
+    if (targetEdges.isEmpty()) {
       throw new RuntimeException("No edges specified for data skew optimization");
     }
 
@@ -489,7 +499,7 @@
         .findFirst().orElseThrow(() -> new RuntimeException("DataSkewDynOptDataHandler is not registered!"));
       pubSubEventHandlerWrapper.getPubSubEventHandler()
         .onNext(new DynamicOptimizationEvent(planStateManager.getPhysicalPlan(), dynOptDataHandler.getDynOptData(),
-          taskId, executorId, targetEdge));
+          taskId, executorId, targetEdges));
     }
   }