[#134] improvement(spark3): Use taskId and attemptNo as taskAttemptId (#1529)

### What changes were proposed in this pull request?
Use map index and task attempt number as the task attempt id in Spark3.

This requires to rework the bits of the blockId to maximize bit utilization for Spark3:
https://github.com/apache/incubator-uniffle/blob/b924acacb0c555370a593f3a069187cf8b8081d7/common/src/main/java/org/apache/uniffle/common/util/Constants.java#L30-L35

Ideally, the `TASK_ATTEMPT_ID_MAX_LENGTH` is set equal to `PARTITION_ID_MAX_LENGTH` + the number of bits required to store the largest task attempt number. The largest task attempt number is `maxFailures - 1`, or `maxFailures` if speculative execution is enabled (configured via `spark.speculation` and disabled by default). The `maxFailures` is configured via `spark.task.maxFailures` and defaults to 4. So by default, two bits are required to store the largest attempt number and `TASK_ATTEMPT_ID_MAX_LENGTH` should be set to `PARTITION_ID_MAX_LENGTH + 2`.

Example:

- with `PARTITION_ID_MAX_LENGTH = 20`, Uniffle supports 1,048,576 partitions
- requiring `TASK_ATTEMPT_ID_MAX_LENGTH = 22`
- allowing for `ATOMIC_INT_MAX_LENGTH = 21`.

### Why are the changes needed?
The map index (map partition id) is limited to the number of partitions of a shuffle. The task attempt number is limited by the max number of failures configured by `spark.task.maxFailures`, which defaults to 4. This provides us an id that is unique per shuffe while not growing arbitrarily large as `context.taskAttemptId` does.

Fix: #134

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Unit and integration tests.
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6250891..cb35ce3 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -78,6 +78,7 @@
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
+import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
@@ -112,6 +113,8 @@
   private boolean dynamicConfEnabled = false;
   private final ShuffleDataDistributionType dataDistributionType;
   private final int maxConcurrencyPerPartitionToWrite;
+  private final int maxFailures;
+  private final boolean speculation;
   private String user;
   private String uuid;
   private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
@@ -182,6 +185,8 @@
     this.dataDistributionType = getDataDistributionType(sparkConf);
     this.maxConcurrencyPerPartitionToWrite =
         RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
     long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
     int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
     this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
@@ -307,6 +312,8 @@
         RssSparkConfig.toRssConf(sparkConf).get(RssClientConf.DATA_DISTRIBUTION_TYPE);
     this.maxConcurrencyPerPartitionToWrite =
         RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
     this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
     this.heartbeatTimeout =
         sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2);
@@ -503,11 +510,18 @@
     }
     String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId());
+    long taskAttemptId =
+        getTaskAttemptId(
+            context.partitionId(),
+            context.attemptNumber(),
+            maxFailures,
+            speculation,
+            Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
     return new RssShuffleWriter<>(
         rssHandle.getAppId(),
         shuffleId,
         taskId,
-        context.taskAttemptId(),
+        taskAttemptId,
         writeMetrics,
         this,
         sparkConf,
@@ -518,6 +532,63 @@
         shuffleHandleInfo);
   }
 
+  /**
+   * Provides a task attempt id that is unique for a shuffle stage.
+   *
+   * <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
+   * that is unique across the entire Spark app which can reach very large numbers, which can
+   * practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
+   *
+   * <p>Here we use the map index or task id, appended by the attempt number per task. The map index
+   * is limited by the number of partitions of a stage. The attempt number per task is limited /
+   * configured by spark.task.maxFailures (default: 4).
+   *
+   * @return a task attempt id unique for a shuffle stage
+   */
+  @VisibleForTesting
+  protected static long getTaskAttemptId(
+      int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
+    // attempt number is zero based: 0, 1, …, maxFailures-1
+    // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
+    int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+    // with speculative execution enabled we could observe +1 attempts
+    if (speculation) {
+      maxAttemptNo++;
+    }
+
+    if (attemptNo > maxAttemptNo) {
+      // this should never happen, if it does, our assumptions are wrong,
+      // and we risk overflowing the attempt number bits
+      throw new RssException(
+          "Observing attempt number "
+              + attemptNo
+              + " while maxFailures is set to "
+              + maxFailures
+              + (speculation ? " with speculation enabled" : "")
+              + ".");
+    }
+
+    int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
+    int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+    if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
+      throw new RssException(
+          "Observing mapIndex["
+              + mapIndex
+              + "] that would produce a taskAttemptId with "
+              + (mapIndexBits + attemptBits)
+              + " bits which is larger than the allowed "
+              + maxTaskAttemptIdBits
+              + " bits (maxFailures["
+              + maxFailures
+              + "], speculation["
+              + speculation
+              + "]). Please consider providing more bits for taskAttemptIds.");
+    }
+
+    return (long) mapIndex << attemptBits | attemptNo;
+  }
+
   public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
     // todo: this implement is tricky, we should refactor it
     if (id.get() == null) {
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 64bd6f9..9150d6d 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle;
 
+import java.util.Arrays;
+
 import org.apache.spark.SparkConf;
 import org.apache.spark.sql.internal.SQLConf;
 import org.junit.jupiter.api.Test;
@@ -24,12 +26,14 @@
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssShuffleManagerTest extends RssShuffleManagerTestBase {
@@ -84,6 +88,252 @@
     }
   }
 
+  private long bits(String string) {
+    return Long.parseLong(string.replaceAll("[|]", ""), 2);
+  }
+
+  @Test
+  public void testGetTaskAttemptIdWithoutSpeculation() {
+    // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
+    // separate map index from attempt number, so merely for visualization purposes
+
+    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+      assertEquals(
+          bits("0000|"),
+          RssShuffleManager.getTaskAttemptId(0, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertEquals(
+          bits("0001|"),
+          RssShuffleManager.getTaskAttemptId(1, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertEquals(
+          bits("0010|"),
+          RssShuffleManager.getTaskAttemptId(2, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+    }
+
+    // maxFailures of 2
+    assertEquals(bits("000|0"), RssShuffleManager.getTaskAttemptId(0, 0, 2, false, 10));
+    assertEquals(bits("000|1"), RssShuffleManager.getTaskAttemptId(0, 1, 2, false, 10));
+    assertEquals(bits("001|0"), RssShuffleManager.getTaskAttemptId(1, 0, 2, false, 10));
+    assertEquals(bits("001|1"), RssShuffleManager.getTaskAttemptId(1, 1, 2, false, 10));
+    assertEquals(bits("010|0"), RssShuffleManager.getTaskAttemptId(2, 0, 2, false, 10));
+    assertEquals(bits("010|1"), RssShuffleManager.getTaskAttemptId(2, 1, 2, false, 10));
+    assertEquals(bits("011|0"), RssShuffleManager.getTaskAttemptId(3, 0, 2, false, 10));
+    assertEquals(bits("011|1"), RssShuffleManager.getTaskAttemptId(3, 1, 2, false, 10));
+
+    // maxFailures of 3
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, false, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, false, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, false, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, false, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, false, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, false, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, false, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, false, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, false, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, false, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, false, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, false, 10));
+
+    // maxFailures of 4
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 4, false, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 4, false, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 4, false, 10));
+    assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 4, false, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 4, false, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 4, false, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 4, false, 10));
+    assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 4, false, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 4, false, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 4, false, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 4, false, 10));
+    assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 4, false, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 4, false, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 4, false, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 4, false, 10));
+    assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 4, false, 10));
+
+    // maxFailures of 5
+    assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 5, false, 10));
+    assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 5, false, 10));
+
+    // test with ints that overflow into signed int and long
+    assertEquals(
+        Integer.MAX_VALUE, RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1 | 1,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 2 | 3,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 33));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 3 | 7,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 34));
+
+    // test with attemptNo >= maxFailures
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, -1, false, 10));
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, 0, false, 10));
+    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      Exception e =
+          assertThrowsExactly(
+              RssException.class,
+              () ->
+                  RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, maxFailures, false, 10),
+              String.valueOf(maxFailures));
+      assertEquals(
+          "Observing attempt number "
+              + (maxFailures + 128)
+              + " while maxFailures is set to "
+              + maxFailures
+              + ".",
+          e.getMessage());
+    }
+
+    // test with mapIndex that would require more than maxTaskAttemptBits
+    Exception e =
+        assertThrowsExactly(
+            RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 0, 3, true, 10));
+    assertEquals(
+        "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+            + "which is larger than the allowed 10 bits (maxFailures[3], speculation[true]). "
+            + "Please consider providing more bits for taskAttemptIds.",
+        e.getMessage());
+    // check that a lower mapIndex works as expected
+    assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 0, 3, true, 10));
+  }
+
+  @Test
+  public void testGetTaskAttemptIdWithSpeculation() {
+    // with speculation, we expect maxFailures+1 attempts
+
+    // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
+    // separate map index from attempt number, so merely for visualization purposes
+
+    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+      for (int attemptNo : Arrays.asList(0, 1)) {
+        assertEquals(
+            bits("0000|" + attemptNo),
+            RssShuffleManager.getTaskAttemptId(0, attemptNo, maxFailures, true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+        assertEquals(
+            bits("0001|" + attemptNo),
+            RssShuffleManager.getTaskAttemptId(1, attemptNo, maxFailures, true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+        assertEquals(
+            bits("0010|" + attemptNo),
+            RssShuffleManager.getTaskAttemptId(2, attemptNo, maxFailures, true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+      }
+    }
+
+    // maxFailures of 2
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 2, true, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 2, true, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 2, true, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 2, true, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 2, true, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 2, true, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 2, true, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 2, true, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 2, true, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 2, true, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 2, true, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 2, true, 10));
+
+    // maxFailures of 3
+    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, true, 10));
+    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, true, 10));
+    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, true, 10));
+    assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 3, true, 10));
+    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, true, 10));
+    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, true, 10));
+    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, true, 10));
+    assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 3, true, 10));
+    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, true, 10));
+    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, true, 10));
+    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, true, 10));
+    assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 3, true, 10));
+    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, true, 10));
+    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, true, 10));
+    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, true, 10));
+    assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 3, true, 10));
+
+    // maxFailures of 4
+    assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 4, true, 10));
+    assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 4, true, 10));
+
+    // test with ints that overflow into signed int and long
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1 | 1,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 2 | 3,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 3 | 7,
+        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
+
+    // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for speculation enabled)
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, -1, true, 10));
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, 0, true, 10));
+    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, maxFailures, true, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, maxFailures, true, 10),
+          String.valueOf(maxFailures));
+      Exception e =
+          assertThrowsExactly(
+              RssException.class,
+              () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, maxFailures, true, 10),
+              String.valueOf(maxFailures));
+      assertEquals(
+          "Observing attempt number "
+              + (maxFailures + 128)
+              + " while maxFailures is set to "
+              + maxFailures
+              + " with speculation enabled.",
+          e.getMessage());
+    }
+
+    // test with mapIndex that would require more than maxTaskAttemptBits
+    Exception e =
+        assertThrowsExactly(
+            RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 0, 4, false, 10));
+    assertEquals(
+        "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+            + "which is larger than the allowed 10 bits (maxFailures[4], speculation[false]). "
+            + "Please consider providing more bits for taskAttemptIds.",
+        e.getMessage());
+    // check that a lower mapIndex works as expected
+    assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 0, 4, false, 10));
+  }
+
   @Test
   public void testCreateShuffleManagerServer() {
     setupMockedRssShuffleUtils(StatusCode.SUCCESS);
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
new file mode 100644
index 0000000..e9a0818
--- /dev/null
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/FailingTasksTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.TaskContext;
+import org.apache.spark.api.java.function.MapPartitionsFunction;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+// This test has all tasks fail twice, the third attempt succeeds.
+// The failing attempts all provide zeros to the shuffle step, while the succeeding attempts
+// provide the actual non-zero integers (actually only one zero). If blocks from the failing
+// attempts leak into the read shuffle data, we would see those zeros and fail when comparing
+// to without RSS.
+public class FailingTasksTest extends SparkTaskFailureIntegrationTestBase {
+
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    shutdownServers();
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    Map<String, String> dynamicConf = Maps.newHashMap();
+    dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test");
+    dynamicConf.put(
+        RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    createCoordinatorServer(coordinatorConf);
+    ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+    createShuffleServer(shuffleServerConf);
+    startServers();
+  }
+
+  @Override
+  Map runTest(SparkSession spark, String fileName) throws Exception {
+    int n = 1000000;
+    return spark.range(0, n, 1, 4)
+        .mapPartitions(
+            (MapPartitionsFunction<Long, Long>)
+                it ->
+                    new Iterator<Long>() {
+                      final TaskContext context = TaskContext.get();
+
+                      @Override
+                      public boolean hasNext() {
+                        // the first two attempts fail in the end
+                        return context.attemptNumber() < 2 || it.hasNext();
+                      }
+
+                      @Override
+                      public Long next() {
+                        if (it.hasNext()) {
+                          Long next = it.next();
+                          // the failing attempt returns only zeros
+                          if (context.attemptNumber() < 2) {
+                            return 0L;
+                          } else {
+                            return next;
+                          }
+                        } else {
+                          throw new RuntimeException("let this task fail");
+                        }
+                      }
+                    },
+            Encoders.LONG())
+        .repartition(3, new Column("value"))
+        .mapPartitions((MapPartitionsFunction<Long, Long>) it -> it, Encoders.LONG())
+        .collectAsList().stream()
+        .collect(Collectors.toMap(v -> v, v -> v));
+  }
+
+  @Test
+  public void testFailedTasks() throws Exception {
+    run();
+  }
+}
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
index 3e73969..e7fed90 100644
--- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageDynamicServerReWriteTest.java
@@ -39,12 +39,10 @@
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.storage.util.StorageType;
 
-public class RSSStageDynamicServerReWriteTest extends SparkIntegrationTestBase {
+public class RSSStageDynamicServerReWriteTest extends SparkTaskFailureIntegrationTestBase {
 
   private static final Logger LOG = LoggerFactory.getLogger(RSSStageDynamicServerReWriteTest.class);
 
-  private static int maxTaskFailures = 3;
-
   @BeforeAll
   public static void setupServers(@TempDir File tmpDir) throws Exception {
     CoordinatorConf coordinatorConf = getCoordinatorConf();
@@ -97,17 +95,10 @@
   }
 
   @Override
-  protected SparkConf createSparkConf() {
-    return new SparkConf()
-        .setAppName(this.getClass().getSimpleName())
-        .setMaster(String.format("local[4,%d]", maxTaskFailures));
-  }
-
-  @Override
   public void updateSparkConfCustomer(SparkConf sparkConf) {
+    super.updateSparkConfCustomer(sparkConf);
     sparkConf.set(
         RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_RESUBMIT_STAGE, "true");
-    sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
   }
 
   @Test
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
index 5e95bc0..497e232 100644
--- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java
@@ -35,9 +35,7 @@
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.storage.util.StorageType;
 
-public class RSSStageResubmitTest extends SparkIntegrationTestBase {
-
-  private static int maxTaskFailures = 3;
+public class RSSStageResubmitTest extends SparkTaskFailureIntegrationTestBase {
 
   @BeforeAll
   public static void setupServers() throws Exception {
@@ -74,17 +72,10 @@
   }
 
   @Override
-  protected SparkConf createSparkConf() {
-    return new SparkConf()
-        .setAppName(this.getClass().getSimpleName())
-        .setMaster(String.format("local[4,%d]", maxTaskFailures));
-  }
-
-  @Override
   public void updateSparkConfCustomer(SparkConf sparkConf) {
+    super.updateSparkConfCustomer(sparkConf);
     sparkConf.set(
         RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_RESUBMIT_STAGE, "true");
-    sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
   }
 
   @Test
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
new file mode 100644
index 0000000..a1e0c37
--- /dev/null
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkTaskFailureIntegrationTestBase.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import org.apache.spark.SparkConf;
+
+public abstract class SparkTaskFailureIntegrationTestBase extends SparkIntegrationTestBase {
+
+  protected static final int maxTaskFailures = 3;
+
+  @Override
+  protected SparkConf createSparkConf() {
+    return new SparkConf()
+        .setAppName(this.getClass().getSimpleName())
+        .setMaster(String.format("local[4,%d]", maxTaskFailures));
+  }
+
+  @Override
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures));
+  }
+}