[#134][FOLLOWUP] improvement(spark2): Use taskId and attemptNo as taskAttemptId (#1544)
### What changes were proposed in this pull request?
Use map index and task attempt number as the task attempt id in Spark2.
### Why are the changes needed?
This aligns Spark2 taskAttemptId of the blockId with Spark3.
See #1529
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing integration tests.
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index b70ab69..fa7aa05 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -23,6 +23,7 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
@@ -50,6 +51,63 @@
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;
+ /**
+ * Provides a task attempt id that is unique for a shuffle stage.
+ *
+ * <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
+ * that is unique across the entire Spark app which can reach very large numbers, which can
+ * practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
+ *
+ * <p>Here we use the map index or task id, appended by the attempt number per task. The map index
+ * is limited by the number of partitions of a stage. The attempt number per task is limited /
+ * configured by spark.task.maxFailures (default: 4).
+ *
+ * @return a task attempt id unique for a shuffle stage
+ */
+ @VisibleForTesting
+ protected static long getTaskAttemptId(
+ int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
+ // attempt number is zero based: 0, 1, …, maxFailures-1
+ // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
+ int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+ // with speculative execution enabled we could observe +1 attempts
+ if (speculation) {
+ maxAttemptNo++;
+ }
+
+ if (attemptNo > maxAttemptNo) {
+ // this should never happen, if it does, our assumptions are wrong,
+ // and we risk overflowing the attempt number bits
+ throw new RssException(
+ "Observing attempt number "
+ + attemptNo
+ + " while maxFailures is set to "
+ + maxFailures
+ + (speculation ? " with speculation enabled" : "")
+ + ".");
+ }
+
+ int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
+ int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+ if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
+ throw new RssException(
+ "Observing mapIndex["
+ + mapIndex
+ + "] that would produce a taskAttemptId with "
+ + (mapIndexBits + attemptBits)
+ + " bits which is larger than the allowed "
+ + maxTaskAttemptIdBits
+ + " bits (maxFailures["
+ + maxFailures
+ + "], speculation["
+ + speculation
+ + "]). Please consider providing more bits for taskAttemptIds.");
+ }
+
+ return (long) mapIndex << attemptBits | attemptNo;
+ }
+
@Override
public void unregisterAllMapOutput(int shuffleId) throws SparkException {
if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
index 15cc7fa..3d8ea05 100644
--- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
+++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
@@ -17,12 +17,16 @@
package org.apache.uniffle.shuffle.manager;
+import java.util.Arrays;
+
import org.apache.spark.SparkConf;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.exception.RssException;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RssShuffleManagerBaseTest {
@@ -39,4 +43,254 @@
assertEquals(remoteStorageInfo.getConfItems().size(), 1);
assertEquals(remoteStorageInfo.getConfItems().get("fs.defaultFs"), "hdfs://rbf-xxx/foo");
}
+
+ private long bits(String string) {
+ return Long.parseLong(string.replaceAll("[|]", ""), 2);
+ }
+
+ @Test
+ public void testGetTaskAttemptIdWithoutSpeculation() {
+ // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
+ // separate map index from attempt number, so merely for visualization purposes
+
+ // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+ for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+ assertEquals(
+ bits("0000|"),
+ RssShuffleManagerBase.getTaskAttemptId(0, 0, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ bits("0001|"),
+ RssShuffleManagerBase.getTaskAttemptId(1, 0, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ bits("0010|"),
+ RssShuffleManagerBase.getTaskAttemptId(2, 0, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ }
+
+ // maxFailures of 2
+ assertEquals(bits("000|0"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 2, false, 10));
+ assertEquals(bits("000|1"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 2, false, 10));
+ assertEquals(bits("001|0"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 2, false, 10));
+ assertEquals(bits("001|1"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 2, false, 10));
+ assertEquals(bits("010|0"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 2, false, 10));
+ assertEquals(bits("010|1"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 2, false, 10));
+ assertEquals(bits("011|0"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 2, false, 10));
+ assertEquals(bits("011|1"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 2, false, 10));
+
+ // maxFailures of 3
+ assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 3, false, 10));
+ assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 3, false, 10));
+ assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 3, false, 10));
+ assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 3, false, 10));
+ assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 3, false, 10));
+ assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 3, false, 10));
+ assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 3, false, 10));
+ assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 3, false, 10));
+ assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 3, false, 10));
+ assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 3, false, 10));
+ assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 3, false, 10));
+ assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 3, false, 10));
+
+ // maxFailures of 4
+ assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 4, false, 10));
+ assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 4, false, 10));
+ assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 4, false, 10));
+ assertEquals(bits("00|11"), RssShuffleManagerBase.getTaskAttemptId(0, 3, 4, false, 10));
+ assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 4, false, 10));
+ assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 4, false, 10));
+ assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 4, false, 10));
+ assertEquals(bits("01|11"), RssShuffleManagerBase.getTaskAttemptId(1, 3, 4, false, 10));
+ assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 4, false, 10));
+ assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 4, false, 10));
+ assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 4, false, 10));
+ assertEquals(bits("10|11"), RssShuffleManagerBase.getTaskAttemptId(2, 3, 4, false, 10));
+ assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 4, false, 10));
+ assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 4, false, 10));
+ assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 4, false, 10));
+ assertEquals(bits("11|11"), RssShuffleManagerBase.getTaskAttemptId(3, 3, 4, false, 10));
+
+ // maxFailures of 5
+ assertEquals(bits("0|000"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 5, false, 10));
+ assertEquals(bits("1|100"), RssShuffleManagerBase.getTaskAttemptId(1, 4, 5, false, 10));
+
+ // test with ints that overflow into signed int and long
+ assertEquals(
+ Integer.MAX_VALUE,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 1 | 1,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 32));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 2 | 3,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 33));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 3 | 7,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 34));
+
+ // test with attemptNo >= maxFailures
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 1, -1, false, 10));
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 1, 0, false, 10));
+ for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 1, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 2, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ Exception e =
+ assertThrowsExactly(
+ RssException.class,
+ () ->
+ RssShuffleManagerBase.getTaskAttemptId(
+ 0, maxFailures + 128, maxFailures, false, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ "Observing attempt number "
+ + (maxFailures + 128)
+ + " while maxFailures is set to "
+ + maxFailures
+ + ".",
+ e.getMessage());
+ }
+
+ // test with mapIndex that would require more than maxTaskAttemptBits
+ Exception e =
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(256, 0, 3, true, 10));
+ assertEquals(
+ "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+ + "which is larger than the allowed 10 bits (maxFailures[3], speculation[true]). "
+ + "Please consider providing more bits for taskAttemptIds.",
+ e.getMessage());
+ // check that a lower mapIndex works as expected
+ assertEquals(bits("11111111|00"), RssShuffleManagerBase.getTaskAttemptId(255, 0, 3, true, 10));
+ }
+
+ @Test
+ public void testGetTaskAttemptIdWithSpeculation() {
+ // with speculation, we expect maxFailures+1 attempts
+
+ // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
+ // separate map index from attempt number, so merely for visualization purposes
+
+ // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+ for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+ for (int attemptNo : Arrays.asList(0, 1)) {
+ assertEquals(
+ bits("0000|" + attemptNo),
+ RssShuffleManagerBase.getTaskAttemptId(0, attemptNo, maxFailures, true, 10),
+ "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+ assertEquals(
+ bits("0001|" + attemptNo),
+ RssShuffleManagerBase.getTaskAttemptId(1, attemptNo, maxFailures, true, 10),
+ "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+ assertEquals(
+ bits("0010|" + attemptNo),
+ RssShuffleManagerBase.getTaskAttemptId(2, attemptNo, maxFailures, true, 10),
+ "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+ }
+ }
+
+ // maxFailures of 2
+ assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 2, true, 10));
+ assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 2, true, 10));
+ assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 2, true, 10));
+ assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 2, true, 10));
+ assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 2, true, 10));
+ assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 2, true, 10));
+ assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 2, true, 10));
+ assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 2, true, 10));
+ assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 2, true, 10));
+ assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 2, true, 10));
+ assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 2, true, 10));
+ assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 2, true, 10));
+
+ // maxFailures of 3
+ assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 3, true, 10));
+ assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 3, true, 10));
+ assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 3, true, 10));
+ assertEquals(bits("00|11"), RssShuffleManagerBase.getTaskAttemptId(0, 3, 3, true, 10));
+ assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 3, true, 10));
+ assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 3, true, 10));
+ assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 3, true, 10));
+ assertEquals(bits("01|11"), RssShuffleManagerBase.getTaskAttemptId(1, 3, 3, true, 10));
+ assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 3, true, 10));
+ assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 3, true, 10));
+ assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 3, true, 10));
+ assertEquals(bits("10|11"), RssShuffleManagerBase.getTaskAttemptId(2, 3, 3, true, 10));
+ assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 3, true, 10));
+ assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 3, true, 10));
+ assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 3, true, 10));
+ assertEquals(bits("11|11"), RssShuffleManagerBase.getTaskAttemptId(3, 3, 3, true, 10));
+
+ // maxFailures of 4
+ assertEquals(bits("0|000"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 4, true, 10));
+ assertEquals(bits("1|100"), RssShuffleManagerBase.getTaskAttemptId(1, 4, 4, true, 10));
+
+ // test with ints that overflow into signed int and long
+ assertEquals(
+ (long) Integer.MAX_VALUE << 1,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 1 | 1,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 2 | 3,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
+ assertEquals(
+ (long) Integer.MAX_VALUE << 3 | 7,
+ RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
+
+ // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for speculation enabled)
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 2, -1, true, 10));
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 2, 0, true, 10));
+ for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 1, maxFailures, true, 10),
+ String.valueOf(maxFailures));
+ assertThrowsExactly(
+ RssException.class,
+ () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 2, maxFailures, true, 10),
+ String.valueOf(maxFailures));
+ Exception e =
+ assertThrowsExactly(
+ RssException.class,
+ () ->
+ RssShuffleManagerBase.getTaskAttemptId(
+ 0, maxFailures + 128, maxFailures, true, 10),
+ String.valueOf(maxFailures));
+ assertEquals(
+ "Observing attempt number "
+ + (maxFailures + 128)
+ + " while maxFailures is set to "
+ + maxFailures
+ + " with speculation enabled.",
+ e.getMessage());
+ }
+
+ // test with mapIndex that would require more than maxTaskAttemptBits
+ Exception e =
+ assertThrowsExactly(
+ RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(256, 0, 4, false, 10));
+ assertEquals(
+ "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+ + "which is larger than the allowed 10 bits (maxFailures[4], speculation[false]). "
+ + "Please consider providing more bits for taskAttemptIds.",
+ e.getMessage());
+ // check that a lower mapIndex works as expected
+ assertEquals(bits("11111111|00"), RssShuffleManagerBase.getTaskAttemptId(255, 0, 4, false, 10));
+ }
}
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index ed3f340..4483cdd 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -73,6 +73,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.rpc.GrpcServer;
+import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
@@ -106,6 +107,8 @@
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
+ private final int maxFailures;
+ private final boolean speculation;
private final String user;
private final String uuid;
private DataPusher dataPusher;
@@ -140,6 +143,8 @@
"Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false.");
}
this.sparkConf = sparkConf;
+ this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+ this.speculation = sparkConf.getBoolean("spark.speculation", false);
this.user = sparkConf.get("spark.rss.quota.user", "user");
this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
// set & check replica config
@@ -461,11 +466,18 @@
shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
}
ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
+ long taskAttemptId =
+ getTaskAttemptId(
+ context.partitionId(),
+ context.attemptNumber(),
+ maxFailures,
+ speculation,
+ Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
return new RssShuffleWriter<>(
rssHandle.getAppId(),
shuffleId,
taskId,
- context.taskAttemptId(),
+ taskAttemptId,
writeMetrics,
this,
sparkConf,
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index b7710e7..b223cea 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -529,63 +529,6 @@
shuffleHandleInfo);
}
- /**
- * Provides a task attempt id that is unique for a shuffle stage.
- *
- * <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
- * that is unique across the entire Spark app which can reach very large numbers, which can
- * practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
- *
- * <p>Here we use the map index or task id, appended by the attempt number per task. The map index
- * is limited by the number of partitions of a stage. The attempt number per task is limited /
- * configured by spark.task.maxFailures (default: 4).
- *
- * @return a task attempt id unique for a shuffle stage
- */
- @VisibleForTesting
- protected static long getTaskAttemptId(
- int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
- // attempt number is zero based: 0, 1, …, maxFailures-1
- // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
- int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
-
- // with speculative execution enabled we could observe +1 attempts
- if (speculation) {
- maxAttemptNo++;
- }
-
- if (attemptNo > maxAttemptNo) {
- // this should never happen, if it does, our assumptions are wrong,
- // and we risk overflowing the attempt number bits
- throw new RssException(
- "Observing attempt number "
- + attemptNo
- + " while maxFailures is set to "
- + maxFailures
- + (speculation ? " with speculation enabled" : "")
- + ".");
- }
-
- int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
- int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
- if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
- throw new RssException(
- "Observing mapIndex["
- + mapIndex
- + "] that would produce a taskAttemptId with "
- + (mapIndexBits + attemptBits)
- + " bits which is larger than the allowed "
- + maxTaskAttemptIdBits
- + " bits (maxFailures["
- + maxFailures
- + "], speculation["
- + speculation
- + "]). Please consider providing more bits for taskAttemptIds.");
- }
-
- return (long) mapIndex << attemptBits | attemptNo;
- }
-
public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
// todo: this implement is tricky, we should refactor it
if (id.get() == null) {
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 9150d6d..64bd6f9 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,8 +17,6 @@
package org.apache.spark.shuffle;
-import java.util.Arrays;
-
import org.apache.spark.SparkConf;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.Test;
@@ -26,14 +24,12 @@
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.util.StorageType;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RssShuffleManagerTest extends RssShuffleManagerTestBase {
@@ -88,252 +84,6 @@
}
}
- private long bits(String string) {
- return Long.parseLong(string.replaceAll("[|]", ""), 2);
- }
-
- @Test
- public void testGetTaskAttemptIdWithoutSpeculation() {
- // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
- // separate map index from attempt number, so merely for visualization purposes
-
- // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
- for (int maxFailures : Arrays.asList(-1, 0, 1)) {
- assertEquals(
- bits("0000|"),
- RssShuffleManager.getTaskAttemptId(0, 0, maxFailures, false, 10),
- String.valueOf(maxFailures));
- assertEquals(
- bits("0001|"),
- RssShuffleManager.getTaskAttemptId(1, 0, maxFailures, false, 10),
- String.valueOf(maxFailures));
- assertEquals(
- bits("0010|"),
- RssShuffleManager.getTaskAttemptId(2, 0, maxFailures, false, 10),
- String.valueOf(maxFailures));
- }
-
- // maxFailures of 2
- assertEquals(bits("000|0"), RssShuffleManager.getTaskAttemptId(0, 0, 2, false, 10));
- assertEquals(bits("000|1"), RssShuffleManager.getTaskAttemptId(0, 1, 2, false, 10));
- assertEquals(bits("001|0"), RssShuffleManager.getTaskAttemptId(1, 0, 2, false, 10));
- assertEquals(bits("001|1"), RssShuffleManager.getTaskAttemptId(1, 1, 2, false, 10));
- assertEquals(bits("010|0"), RssShuffleManager.getTaskAttemptId(2, 0, 2, false, 10));
- assertEquals(bits("010|1"), RssShuffleManager.getTaskAttemptId(2, 1, 2, false, 10));
- assertEquals(bits("011|0"), RssShuffleManager.getTaskAttemptId(3, 0, 2, false, 10));
- assertEquals(bits("011|1"), RssShuffleManager.getTaskAttemptId(3, 1, 2, false, 10));
-
- // maxFailures of 3
- assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, false, 10));
- assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, false, 10));
- assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, false, 10));
- assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, false, 10));
- assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, false, 10));
- assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, false, 10));
- assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, false, 10));
- assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, false, 10));
- assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, false, 10));
- assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, false, 10));
- assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, false, 10));
- assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, false, 10));
-
- // maxFailures of 4
- assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 4, false, 10));
- assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 4, false, 10));
- assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 4, false, 10));
- assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 4, false, 10));
- assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 4, false, 10));
- assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 4, false, 10));
- assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 4, false, 10));
- assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 4, false, 10));
- assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 4, false, 10));
- assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 4, false, 10));
- assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 4, false, 10));
- assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 4, false, 10));
- assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 4, false, 10));
- assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 4, false, 10));
- assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 4, false, 10));
- assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 4, false, 10));
-
- // maxFailures of 5
- assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 5, false, 10));
- assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 5, false, 10));
-
- // test with ints that overflow into signed int and long
- assertEquals(
- Integer.MAX_VALUE, RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
- assertEquals(
- (long) Integer.MAX_VALUE << 1 | 1,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 32));
- assertEquals(
- (long) Integer.MAX_VALUE << 2 | 3,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 33));
- assertEquals(
- (long) Integer.MAX_VALUE << 3 | 7,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 34));
-
- // test with attemptNo >= maxFailures
- assertThrowsExactly(
- RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, -1, false, 10));
- assertThrowsExactly(
- RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, 0, false, 10));
- for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
- assertThrowsExactly(
- RssException.class,
- () -> RssShuffleManager.getTaskAttemptId(0, maxFailures, maxFailures, false, 10),
- String.valueOf(maxFailures));
- assertThrowsExactly(
- RssException.class,
- () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, maxFailures, false, 10),
- String.valueOf(maxFailures));
- assertThrowsExactly(
- RssException.class,
- () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, maxFailures, false, 10),
- String.valueOf(maxFailures));
- Exception e =
- assertThrowsExactly(
- RssException.class,
- () ->
- RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, maxFailures, false, 10),
- String.valueOf(maxFailures));
- assertEquals(
- "Observing attempt number "
- + (maxFailures + 128)
- + " while maxFailures is set to "
- + maxFailures
- + ".",
- e.getMessage());
- }
-
- // test with mapIndex that would require more than maxTaskAttemptBits
- Exception e =
- assertThrowsExactly(
- RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 0, 3, true, 10));
- assertEquals(
- "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
- + "which is larger than the allowed 10 bits (maxFailures[3], speculation[true]). "
- + "Please consider providing more bits for taskAttemptIds.",
- e.getMessage());
- // check that a lower mapIndex works as expected
- assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 0, 3, true, 10));
- }
-
- @Test
- public void testGetTaskAttemptIdWithSpeculation() {
- // with speculation, we expect maxFailures+1 attempts
-
- // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
- // separate map index from attempt number, so merely for visualization purposes
-
- // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
- for (int maxFailures : Arrays.asList(-1, 0, 1)) {
- for (int attemptNo : Arrays.asList(0, 1)) {
- assertEquals(
- bits("0000|" + attemptNo),
- RssShuffleManager.getTaskAttemptId(0, attemptNo, maxFailures, true, 10),
- "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
- assertEquals(
- bits("0001|" + attemptNo),
- RssShuffleManager.getTaskAttemptId(1, attemptNo, maxFailures, true, 10),
- "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
- assertEquals(
- bits("0010|" + attemptNo),
- RssShuffleManager.getTaskAttemptId(2, attemptNo, maxFailures, true, 10),
- "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
- }
- }
-
- // maxFailures of 2
- assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 2, true, 10));
- assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 2, true, 10));
- assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 2, true, 10));
- assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 2, true, 10));
- assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 2, true, 10));
- assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 2, true, 10));
- assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 2, true, 10));
- assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 2, true, 10));
- assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 2, true, 10));
- assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 2, true, 10));
- assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 2, true, 10));
- assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 2, true, 10));
-
- // maxFailures of 3
- assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, true, 10));
- assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, true, 10));
- assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, true, 10));
- assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 3, true, 10));
- assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, true, 10));
- assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, true, 10));
- assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, true, 10));
- assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 3, true, 10));
- assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, true, 10));
- assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, true, 10));
- assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, true, 10));
- assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 3, true, 10));
- assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, true, 10));
- assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, true, 10));
- assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, true, 10));
- assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 3, true, 10));
-
- // maxFailures of 4
- assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 4, true, 10));
- assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 4, true, 10));
-
- // test with ints that overflow into signed int and long
- assertEquals(
- (long) Integer.MAX_VALUE << 1,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
- assertEquals(
- (long) Integer.MAX_VALUE << 1 | 1,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
- assertEquals(
- (long) Integer.MAX_VALUE << 2 | 3,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
- assertEquals(
- (long) Integer.MAX_VALUE << 3 | 7,
- RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
-
- // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for speculation enabled)
- assertThrowsExactly(
- RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, -1, true, 10));
- assertThrowsExactly(
- RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, 0, true, 10));
- for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
- assertThrowsExactly(
- RssException.class,
- () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, maxFailures, true, 10),
- String.valueOf(maxFailures));
- assertThrowsExactly(
- RssException.class,
- () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, maxFailures, true, 10),
- String.valueOf(maxFailures));
- Exception e =
- assertThrowsExactly(
- RssException.class,
- () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, maxFailures, true, 10),
- String.valueOf(maxFailures));
- assertEquals(
- "Observing attempt number "
- + (maxFailures + 128)
- + " while maxFailures is set to "
- + maxFailures
- + " with speculation enabled.",
- e.getMessage());
- }
-
- // test with mapIndex that would require more than maxTaskAttemptBits
- Exception e =
- assertThrowsExactly(
- RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 0, 4, false, 10));
- assertEquals(
- "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
- + "which is larger than the allowed 10 bits (maxFailures[4], speculation[false]). "
- + "Please consider providing more bits for taskAttemptIds.",
- e.getMessage());
- // check that a lower mapIndex works as expected
- assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 0, 4, false, 10));
- }
-
@Test
public void testCreateShuffleManagerServer() {
setupMockedRssShuffleUtils(StatusCode.SUCCESS);