[NEMO-327] Fix skew handling for multi shuffle edge receiver (#189)
JIRA: [NEMO-327: Fix skew handling for multi shuffle edge receive](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-327)
**Major changes:**
- For the case that a vertex receives multiple shuffle edges, makes `DataSkewPolicy` collect metric data for the shuffle edges in a single metric aggregation vertex and optimize the edges at once.
**Minor changes to note:**
- Makes runtime pass receive multiple target edges.
**Tests for the changes:**
- Adds an integration test with the data skew policy for `NetworkTraceAnalysis` that has a join.
**Other comments:**
- N/A.
Closes #189
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
index 6a6ca4d..30826a8 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
@@ -24,6 +24,7 @@
import org.apache.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
import org.apache.nemo.common.ir.vertex.*;
import org.apache.nemo.common.exception.IllegalVertexOperationException;
+import org.apache.nemo.common.ir.vertex.transform.AggregateMetricTransform;
import java.io.Serializable;
import java.util.*;
@@ -259,8 +260,9 @@
private void executionPropertyCheck() {
// DataSizeMetricCollection is not compatible with Push (All data have to be stored before the data collection)
vertices.forEach(v -> incomingEdges.get(v).stream().filter(e -> e instanceof IREdge).map(e -> (IREdge) e)
- .filter(e -> Optional.of(MetricCollectionProperty.Value.DataSkewRuntimePass)
- .equals(e.getPropertyValue(MetricCollectionProperty.class)))
+ .filter(e -> e.getPropertyValue(MetricCollectionProperty.class).isPresent())
+ .filter(e -> !(e.getDst() instanceof OperatorVertex
+ && ((OperatorVertex) e.getDst()).getTransform() instanceof AggregateMetricTransform))
.filter(e -> DataFlowProperty.Value.Push.equals(e.getPropertyValue(DataFlowProperty.class).get()))
.forEach(e -> {
throw new CompileTimeOptimizationException("DAG execution property check: "
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java
index c9749a8..fd7dfb6 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MetricCollectionProperty.java
@@ -23,12 +23,12 @@
/**
* MetricCollection ExecutionProperty that indicates the edge of which data metric will be collected.
*/
-public final class MetricCollectionProperty extends EdgeExecutionProperty<MetricCollectionProperty.Value> {
+public final class MetricCollectionProperty extends EdgeExecutionProperty<Integer> {
/**
* Constructor.
* @param value value of the execution property.
*/
- private MetricCollectionProperty(final Value value) {
+ private MetricCollectionProperty(final Integer value) {
super(value);
}
@@ -37,14 +37,7 @@
* @param value value of the new execution property.
* @return the newly created execution property.
*/
- public static MetricCollectionProperty of(final Value value) {
+ public static MetricCollectionProperty of(final Integer value) {
return new MetricCollectionProperty(value);
}
-
- /**
- * Possible values of MetricCollection ExecutionProperty.
- */
- public enum Value {
- DataSkewRuntimePass
- }
}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index da436d9..57a1647 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -31,6 +31,7 @@
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
import org.apache.nemo.common.ir.vertex.transform.AggregateMetricTransform;
import org.apache.nemo.common.ir.vertex.transform.MetricCollectTransform;
import org.apache.nemo.compiler.optimizer.PairKeyExtractor;
@@ -40,9 +41,7 @@
import org.slf4j.LoggerFactory;
import java.io.Serializable;
-import java.util.ArrayList;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
@@ -68,10 +67,12 @@
@Override
public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
+ int mcCount = 0;
+ // destination vertex ID to metric aggregation vertex - ID pair map
+ final Map<String, Pair<OperatorVertex, Integer>> dstVtxIdToABV = new HashMap<>();
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
- final List<OperatorVertex> metricCollectVertices = new ArrayList<>();
- dag.topologicalDo(v -> {
+ for (final IRVertex v : dag.getTopologicalSort()) {
// We care about OperatorVertices that have shuffle incoming edges with main output.
// TODO #210: Data-aware dynamic optimization at run-time
if (v instanceof OperatorVertex && dag.getIncomingEdgesOf(v).stream().anyMatch(irEdge ->
@@ -80,15 +81,32 @@
&& dag.getIncomingEdgesOf(v).stream().noneMatch(irEdge ->
irEdge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())) {
- dag.getIncomingEdgesOf(v).forEach(edge -> {
+ for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
if (CommunicationPatternProperty.Value.Shuffle
- .equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
- final OperatorVertex abv = generateMetricAggregationVertex();
+ .equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ final String dstId = edge.getDst().getId();
+
+ // Get or generate a metric collection vertex.
+ final int metricCollectionId;
+ final OperatorVertex abv;
+ if (!dstVtxIdToABV.containsKey(dstId)) {
+ // There is no metric aggregation vertex for this destination vertex.
+ metricCollectionId = mcCount++;
+ abv = generateMetricAggregationVertex();
+ builder.addVertex(abv);
+
+ abv.setPropertyPermanently(ResourceSlotProperty.of(false));
+ dstVtxIdToABV.put(dstId, Pair.of(abv, metricCollectionId));
+ } else {
+ // There is a metric aggregation vertex for this destination vertex already.
+ final Pair<OperatorVertex, Integer> aggrPair = dstVtxIdToABV.get(dstId);
+ metricCollectionId = aggrPair.right();
+ abv = aggrPair.left();
+ }
+
final OperatorVertex mcv = generateMetricCollectVertex(edge);
- metricCollectVertices.add(mcv);
builder.addVertex(v);
builder.addVertex(mcv);
- builder.addVertex(abv);
// We then insert the vertex with MetricCollectTransform and vertex with AggregateMetricTransform
// between the vertex and incoming vertices.
@@ -97,8 +115,8 @@
final IREdge edgeToOriginalDstV =
new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), edge.getSrc(), v);
edge.copyExecutionPropertiesTo(edgeToOriginalDstV);
- edgeToOriginalDstV.setPropertyPermanently(
- MetricCollectionProperty.of(MetricCollectionProperty.Value.DataSkewRuntimePass));
+ edgeToOriginalDstV.setPropertyPermanently(MetricCollectionProperty.of(metricCollectionId));
+ edgeToABV.setPropertyPermanently(MetricCollectionProperty.of(metricCollectionId));
builder.connectVertices(edgeToMCV);
builder.connectVertices(edgeToABV);
@@ -111,12 +129,12 @@
} else {
builder.connectVertices(edge);
}
- });
+ }
} else { // Others are simply added to the builder, unless it comes from an updated vertex
builder.addVertex(v);
dag.getIncomingEdgesOf(v).forEach(builder::connectVertices);
}
- });
+ }
final DAG<IRVertex, IREdge> newDAG = builder.build();
return newDAG;
}
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java
index 4005731..3c031b0 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/NetworkTraceAnalysisITCase.java
@@ -22,6 +22,7 @@
import org.apache.nemo.common.test.ArgBuilder;
import org.apache.nemo.common.test.ExampleTestArgs;
import org.apache.nemo.common.test.ExampleTestUtil;
+import org.apache.nemo.examples.beam.policy.DataSkewPolicyParallelismFive;
import org.apache.nemo.examples.beam.policy.DefaultPolicyParallelismFive;
import org.apache.nemo.examples.beam.policy.TransientResourcePolicyParallelismFive;
import org.apache.nemo.examples.beam.policy.LargeShufflePolicyParallelismFive;
@@ -86,4 +87,16 @@
.addOptimizationPolicy(TransientResourcePolicyParallelismFive.class.getCanonicalName())
.build());
}
+
+ /**
+ * Testing data skew dynamic optimization.
+ * @throws Exception exception on the way.
+ */
+ @Test (timeout = ExampleTestArgs.TIMEOUT)
+ public void testDataSkew() throws Exception {
+ JobLauncher.main(builder
+ .addJobId(NetworkTraceAnalysisITCase.class.getSimpleName() + "_skew")
+ .addOptimizationPolicy(DataSkewPolicyParallelismFive.class.getCanonicalName())
+ .build());
+ }
}
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java
index 58d6745..3623d8e 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEvent.java
@@ -22,6 +22,8 @@
import org.apache.nemo.runtime.common.plan.PhysicalPlan;
import org.apache.nemo.runtime.common.plan.StageEdge;
+import java.util.Set;
+
/**
* An event for triggering dynamic optimization.
*/
@@ -30,24 +32,27 @@
private final Object dynOptData;
private final String taskId;
private final String executorId;
- private final StageEdge targetEdge;
+ private final Set<StageEdge> targetEdges;
/**
* Default constructor.
- * @param physicalPlan physical plan to be optimized.
- * @param taskId id of the task which triggered the dynamic optimization.
- * @param executorId the id of executor which executes {@code taskId}
+ *
+ * @param physicalPlan the physical plan to be optimized.
+ * @param dynOptData the metric data.
+ * @param taskId the ID of the task which triggered the dynamic optimization.
+ * @param executorId the ID of executor which executes {@code taskId}
+ * @param targetEdges the target edges.
*/
public DynamicOptimizationEvent(final PhysicalPlan physicalPlan,
final Object dynOptData,
final String taskId,
final String executorId,
- final StageEdge targetEdge) {
+ final Set<StageEdge> targetEdges) {
this.physicalPlan = physicalPlan;
this.taskId = taskId;
this.dynOptData = dynOptData;
this.executorId = executorId;
- this.targetEdge = targetEdge;
+ this.targetEdges = targetEdges;
}
/**
@@ -75,7 +80,7 @@
return this.dynOptData;
}
- public StageEdge getTargetEdge() {
- return this.targetEdge;
+ public Set<StageEdge> getTargetEdges() {
+ return this.targetEdges;
}
}
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java
index 1ed7ddf..11472a0 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java
@@ -26,6 +26,7 @@
import org.apache.reef.wake.impl.PubSubEventHandler;
import javax.inject.Inject;
+import java.util.Set;
/**
* Class for handling event to perform dynamic optimization.
@@ -51,9 +52,9 @@
public void onNext(final DynamicOptimizationEvent dynamicOptimizationEvent) {
final PhysicalPlan physicalPlan = dynamicOptimizationEvent.getPhysicalPlan();
final Object dynOptData = dynamicOptimizationEvent.getDynOptData();
- final StageEdge targetEdge = dynamicOptimizationEvent.getTargetEdge();
+ final Set<StageEdge> targetEdges = dynamicOptimizationEvent.getTargetEdges();
- final PhysicalPlan newPlan = RunTimeOptimizer.dynamicOptimization(physicalPlan, dynOptData, targetEdge);
+ final PhysicalPlan newPlan = RunTimeOptimizer.dynamicOptimization(physicalPlan, dynOptData, targetEdges);
pubSubEventHandler.onNext(new UpdatePhysicalPlanEvent(newPlan, dynamicOptimizationEvent.getTaskId(),
dynamicOptimizationEvent.getExecutorId()));
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java
index f464e69..c7ea42f 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/RunTimeOptimizer.java
@@ -38,18 +38,20 @@
/**
* Dynamic optimization method to process the dag with an appropriate pass, decided by the stats.
*
- * @param originalPlan original physical execution plan.
+ * @param originalPlan the original physical execution plan.
+ * @param dynOptData the data metric.
+ * @param targetEdges the target edges.
* @return the newly updated optimized physical plan.
*/
public static synchronized PhysicalPlan dynamicOptimization(
final PhysicalPlan originalPlan,
final Object dynOptData,
- final StageEdge targetEdge) {
+ final Set<StageEdge> targetEdges) {
// Data for dynamic optimization used in DataSkewRuntimePass
// is a map of <hash value, partition size>.
final PhysicalPlan physicalPlan =
new DataSkewRuntimePass()
- .apply(originalPlan, Pair.of(targetEdge, (Map<Object, Long>) dynOptData));
+ .apply(originalPlan, Pair.of(targetEdges, (Map<Object, Long>) dynOptData));
return physicalPlan;
}
}
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
index 2f6c998..d0a0be2 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
@@ -45,7 +45,7 @@
* this RuntimePass identifies a number of keys with big partition sizes(skewed key)
* and evenly redistributes data via overwriting incoming edges of destination tasks.
*/
-public final class DataSkewRuntimePass extends RuntimePass<Pair<StageEdge, Map<Object, Long>>> {
+public final class DataSkewRuntimePass extends RuntimePass<Pair<Set<StageEdge>, Map<Object, Long>>> {
private static final Logger LOG = LoggerFactory.getLogger(DataSkewRuntimePass.class.getName());
private static final int DEFAULT_NUM_SKEWED_KEYS = 1;
/*
@@ -57,7 +57,7 @@
* The reason why we do not divide the output into a fixed number is that the fixed number can be smaller than
* the destination task parallelism.
*/
- public static final int HASH_RANGE_MULTIPLIER = 10;
+ public static final int HASH_RANGE_MULTIPLIER = 5;
private final Set<Class<? extends RuntimeEventHandler>> eventHandlers;
// Skewed keys denote for top n keys in terms of partition size.
@@ -87,17 +87,20 @@
@Override
public PhysicalPlan apply(final PhysicalPlan originalPlan,
- final Pair<StageEdge, Map<Object, Long>> metricData) {
- final StageEdge targetEdge = metricData.left();
+ final Pair<Set<StageEdge>, Map<Object, Long>> metricData) {
+ final Set<StageEdge> targetEdges = metricData.left();
// Get number of evaluators of the next stage (number of blocks).
- final Integer dstParallelism = targetEdge.getDst().getPropertyValue(ParallelismProperty.class).
- orElseThrow(() -> new RuntimeException("No parallelism on a vertex"));
+ final StageEdge firstEdge = targetEdges.stream().findFirst()
+ .orElseThrow(() -> new RuntimeException("Empty target edge set!"));
+ final Integer dstParallelism = firstEdge
+ .getDst().getPropertyValue(ParallelismProperty.class)
+ .orElseThrow(() -> new RuntimeException("No parallelism on a vertex"));
if (!PartitionerProperty.Value.DataSkewHashPartitioner
- .equals(targetEdge.getPropertyValue(PartitionerProperty.class)
+ .equals(firstEdge.getPropertyValue(PartitionerProperty.class)
.orElseThrow(() -> new RuntimeException("No partitioner property!")))) {
throw new RuntimeException("Invalid partitioner is assigned to the target edge!");
}
- final DataSkewHashPartitioner partitioner = (DataSkewHashPartitioner) Partitioner.getPartitioner(targetEdge);
+ final DataSkewHashPartitioner partitioner = (DataSkewHashPartitioner) Partitioner.getPartitioner(firstEdge);
// Calculate keyRanges.
final List<KeyRange> keyRanges = calculateKeyRanges(metricData.right(), dstParallelism, partitioner);
@@ -108,10 +111,9 @@
// Overwrite the previously assigned key range in the physical DAG with the new range.
final DAG<Stage, StageEdge> stageDAG = originalPlan.getStageDAG();
- for (Stage stage : stageDAG.getVertices()) {
- List<StageEdge> stageEdges = stageDAG.getOutgoingEdgesOf(stage);
- for (StageEdge edge : stageEdges) {
- if (edge.equals(targetEdge)) {
+ for (final Stage stage : stageDAG.getVertices()) {
+ for (final StageEdge edge : stageDAG.getOutgoingEdgesOf(stage)) {
+ if (targetEdges.contains(edge)) {
edge.setTaskIdxToKeyRange(taskIdxToKeyRange);
}
}
@@ -163,12 +165,11 @@
final Integer dstParallelism,
final Partitioner<Integer> partitioner) {
final Map<Integer, Long> partitionKeyToPartitionCount = new HashMap<>();
- int lastKey = 0;
+ int lastKey = dstParallelism * HASH_RANGE_MULTIPLIER - 1;
// Aggregate the counts per each "partition key" assigned by Partitioner.
for (final Map.Entry<Object, Long> entry : keyToCountMap.entrySet()) {
final int partitionKey = partitioner.partition(entry.getKey());
- lastKey = Math.max(lastKey, partitionKey);
partitionKeyToPartitionCount.compute(partitionKey,
(existPartitionKey, prevCount) -> (prevCount == null) ? entry.getValue() : prevCount + entry.getValue());
}
@@ -231,6 +232,7 @@
currentAccumulatedSize - prevAccumulatedSize);
}
}
+
return keyRanges;
}
}
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java
index 910a4b3..2f0d7a1 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/partitioner/DataSkewHashPartitioner.java
@@ -23,8 +23,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.math.BigInteger;
-
/**
* An implementation of {@link Partitioner} which hashes output data from a source task appropriate to detect data skew.
* It hashes data finer than {@link HashPartitioner}.
@@ -38,7 +36,6 @@
public final class DataSkewHashPartitioner implements Partitioner<Integer> {
private static final Logger LOG = LoggerFactory.getLogger(DataSkewHashPartitioner.class.getName());
private final KeyExtractor keyExtractor;
- private final BigInteger hashRangeBase;
private final int hashRange;
/**
@@ -51,10 +48,7 @@
final KeyExtractor keyExtractor) {
this.keyExtractor = keyExtractor;
// For this hash range, please check the description of HashRangeMultiplier in JobConf.
- // For actual hash range to use, we calculate a prime number right next to the desired hash range.
- this.hashRangeBase = new BigInteger(String.valueOf(dstParallelism * DataSkewRuntimePass.HASH_RANGE_MULTIPLIER));
- this.hashRange = hashRangeBase.nextProbablePrime().intValue();
- LOG.info("hashRangeBase {} resulting hashRange {}", hashRangeBase, hashRange);
+ this.hashRange = dstParallelism * DataSkewRuntimePass.HASH_RANGE_MULTIPLIER;
}
@Override
diff --git a/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java b/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
index 605c5c3..ce4103d 100644
--- a/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
+++ b/runtime/common/src/test/java/org/apache/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
@@ -21,7 +21,7 @@
import org.apache.nemo.common.HashRange;
import org.apache.nemo.common.KeyExtractor;
import org.apache.nemo.common.KeyRange;
-import org.apache.nemo.runtime.common.partitioner.HashPartitioner;
+import org.apache.nemo.runtime.common.partitioner.DataSkewHashPartitioner;
import org.apache.nemo.runtime.common.partitioner.Partitioner;
import org.junit.Before;
import org.junit.Test;
@@ -39,9 +39,9 @@
@Before
public void setUp() {
// Skewed partition size lists
- buildPartitionSizeList(Arrays.asList(5L, 5L, 10L, 50L, 100L));
- buildPartitionSizeList(Arrays.asList(5L, 10L, 5L, 0L, 0L));
- buildPartitionSizeList(Arrays.asList(10L, 5L, 5L, 0L, 0L));
+ buildPartitionSizeList(Arrays.asList(5L, 5L, 10L, 50L, 110L, 5L, 5L, 10L, 50L, 100L));
+ buildPartitionSizeList(Arrays.asList(5L, 10L, 5L, 0L, 0L, 5L, 10L, 5L, 0L, 0L));
+ buildPartitionSizeList(Arrays.asList(10L, 5L, 5L, 0L, 0L, 10L, 5L, 5L, 0L, 0L));
}
/**
@@ -50,31 +50,22 @@
*/
@Test
public void testDataSkewDynamicOptimizationPass() {
- final Integer taskNum = 5;
+ final Integer taskNum = 2;
final KeyExtractor asIsExtractor = new AsIsKeyExtractor();
- final Partitioner partitioner = new HashPartitioner(taskNum, asIsExtractor);
+ final Partitioner partitioner = new DataSkewHashPartitioner(taskNum, asIsExtractor);
final List<KeyRange> keyRanges =
- new DataSkewRuntimePass(2).calculateKeyRanges(testMetricData, taskNum, partitioner);
+ new DataSkewRuntimePass(1).calculateKeyRanges(testMetricData, taskNum, partitioner);
// Test whether it correctly redistributed hash ranges.
assertEquals(0, keyRanges.get(0).rangeBeginInclusive());
- assertEquals(2, keyRanges.get(0).rangeEndExclusive());
- assertEquals(2, keyRanges.get(1).rangeBeginInclusive());
- assertEquals(3, keyRanges.get(1).rangeEndExclusive());
- assertEquals(3, keyRanges.get(2).rangeBeginInclusive());
- assertEquals(4, keyRanges.get(2).rangeEndExclusive());
- assertEquals(4, keyRanges.get(3).rangeBeginInclusive());
- assertEquals(5, keyRanges.get(3).rangeEndExclusive());
- assertEquals(5, keyRanges.get(4).rangeBeginInclusive());
- assertEquals(5, keyRanges.get(4).rangeEndExclusive());
+ assertEquals(5, keyRanges.get(0).rangeEndExclusive());
+ assertEquals(5, keyRanges.get(1).rangeBeginInclusive());
+ assertEquals(10, keyRanges.get(1).rangeEndExclusive());
// Test whether it caught the provided skewness.
- assertEquals(false, ((HashRange)keyRanges.get(0)).isSkewed());
+ assertEquals(true, ((HashRange)keyRanges.get(0)).isSkewed());
assertEquals(false, ((HashRange)keyRanges.get(1)).isSkewed());
- assertEquals(true, ((HashRange)keyRanges.get(2)).isSkewed());
- assertEquals(true, ((HashRange)keyRanges.get(3)).isSkewed());
- assertEquals(false, ((HashRange)keyRanges.get(4)).isSkewed());
}
/**
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
index bb76bc5..5605658 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -422,42 +422,52 @@
}
/**
- * Get the target edge of dynamic optimization.
- * The edge is annotated with {@link MetricCollectionProperty}, which is an outgoing edge of
- * a parent of the stage put on hold.
+ * Get the target edges of dynamic optimization.
+ * The edges are annotated with {@link MetricCollectionProperty}, which are outgoing edges of
+ * parents of the stage put on hold.
*
* See {@link org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SkewReshapingPass}
- * for setting the target edge of dynamic optimization.
+ * for setting the target edges of dynamic optimization.
*
* @param taskId the task ID that sent stage-level aggregated metric for dynamic optimization.
- * @return the edge to optimize.
+ * @return the edges to optimize.
*/
- private StageEdge getEdgeToOptimize(final String taskId) {
+ private Set<StageEdge> getEdgesToOptimize(final String taskId) {
+ final DAG<Stage, StageEdge> stageDag = planStateManager.getPhysicalPlan().getStageDAG();
+
// Get a stage including the given task
- final Stage stagePutOnHold = planStateManager.getPhysicalPlan().getStageDAG().getVertices().stream()
+ final Stage stagePutOnHold = stageDag.getVertices().stream()
.filter(stage -> stage.getId().equals(RuntimeIdManager.getStageIdFromTaskId(taskId)))
.findFirst()
.orElseThrow(() -> new RuntimeException());
// Stage put on hold, i.e. stage with vertex containing AggregateMetricTransform
// should have a parent stage whose outgoing edges contain the target edge of dynamic optimization.
- final List<Stage> parentStages = planStateManager.getPhysicalPlan().getStageDAG()
- .getParents(stagePutOnHold.getId());
+ final List<StageEdge> edgesToStagePutOnHold = stageDag.getIncomingEdgesOf(stagePutOnHold);
+ if (edgesToStagePutOnHold.isEmpty()) {
+ throw new RuntimeException("No edges toward specified put on hold stage");
+ }
+ final int metricCollectionId = edgesToStagePutOnHold.get(0).getPropertyValue(MetricCollectionProperty.class)
+ .orElseThrow(() -> new RuntimeException("No metric collection property value for this put on hold stage"));
- if (parentStages.size() > 1) {
- throw new RuntimeException("Error in setting target edge of dynamic optimization!");
+ final Set<StageEdge> targetEdges = new HashSet<>();
+
+ // Get edges with identical MetricCollectionProperty (except the put on hold stage)
+ for (final Stage stage : stageDag.getVertices()) {
+ final Set<StageEdge> targetEdgesFound = stageDag.getOutgoingEdgesOf(stage).stream()
+ .filter(candidateEdge -> {
+ final Optional<Integer> candidateMCId =
+ candidateEdge.getPropertyValue(MetricCollectionProperty.class);
+ return candidateMCId.isPresent() && candidateMCId.get().equals(metricCollectionId)
+ && !edgesToStagePutOnHold.contains(candidateEdge);
+ })
+ .collect(Collectors.toSet());
+ targetEdges.addAll(targetEdgesFound);
}
- // Get outgoing edges of that stage with MetricCollectionProperty
- final List<StageEdge> stageEdges = planStateManager.getPhysicalPlan().getStageDAG()
- .getOutgoingEdgesOf(parentStages.get(0));
- for (StageEdge edge : stageEdges) {
- if (edge.getExecutionProperties().containsKey(MetricCollectionProperty.class)) {
- return edge;
- }
- }
-
- return null;
+ LOG.info("Target edges to optimize: {}",
+ targetEdges.stream().map(edge -> edge.getId()).collect(Collectors.toSet()));
+ return targetEdges;
}
/**
@@ -478,8 +488,8 @@
final boolean stageComplete =
planStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE);
- final StageEdge targetEdge = getEdgeToOptimize(taskId);
- if (targetEdge == null) {
+ final Set<StageEdge> targetEdges = getEdgesToOptimize(taskId);
+ if (targetEdges.isEmpty()) {
throw new RuntimeException("No edges specified for data skew optimization");
}
@@ -489,7 +499,7 @@
.findFirst().orElseThrow(() -> new RuntimeException("DataSkewDynOptDataHandler is not registered!"));
pubSubEventHandlerWrapper.getPubSubEventHandler()
.onNext(new DynamicOptimizationEvent(planStateManager.getPhysicalPlan(), dynOptDataHandler.getDynOptData(),
- taskId, executorId, targetEdge));
+ taskId, executorId, targetEdges));
}
}