[#134][FOLLOWUP] improvement(spark2): Use taskId and attemptNo as taskAttemptId (#1544)

### What changes were proposed in this pull request?
Use map index and task attempt number as the task attempt id in Spark2.

### Why are the changes needed?

This aligns Spark2 taskAttemptId of the blockId with Spark3.

See  #1529

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing integration tests.
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index b70ab69..fa7aa05 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -23,6 +23,7 @@
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.MapOutputTracker;
@@ -50,6 +51,63 @@
   private Method unregisterAllMapOutputMethod;
   private Method registerShuffleMethod;
 
+  /**
+   * Provides a task attempt id that is unique for a shuffle stage.
+   *
+   * <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
+   * that is unique across the entire Spark app which can reach very large numbers, which can
+   * practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
+   *
+   * <p>Here we use the map index or task id, appended by the attempt number per task. The map index
+   * is limited by the number of partitions of a stage. The attempt number per task is limited /
+   * configured by spark.task.maxFailures (default: 4).
+   *
+   * @return a task attempt id unique for a shuffle stage
+   */
+  @VisibleForTesting
+  protected static long getTaskAttemptId(
+      int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
+    // attempt number is zero based: 0, 1, …, maxFailures-1
+    // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
+    int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+    // with speculative execution enabled we could observe +1 attempts
+    if (speculation) {
+      maxAttemptNo++;
+    }
+
+    if (attemptNo > maxAttemptNo) {
+      // this should never happen, if it does, our assumptions are wrong,
+      // and we risk overflowing the attempt number bits
+      throw new RssException(
+          "Observing attempt number "
+              + attemptNo
+              + " while maxFailures is set to "
+              + maxFailures
+              + (speculation ? " with speculation enabled" : "")
+              + ".");
+    }
+
+    int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
+    int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+    if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
+      throw new RssException(
+          "Observing mapIndex["
+              + mapIndex
+              + "] that would produce a taskAttemptId with "
+              + (mapIndexBits + attemptBits)
+              + " bits which is larger than the allowed "
+              + maxTaskAttemptIdBits
+              + " bits (maxFailures["
+              + maxFailures
+              + "], speculation["
+              + speculation
+              + "]). Please consider providing more bits for taskAttemptIds.");
+    }
+
+    return (long) mapIndex << attemptBits | attemptNo;
+  }
+
   @Override
   public void unregisterAllMapOutput(int shuffleId) throws SparkException {
     if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
index 15cc7fa..3d8ea05 100644
--- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
+++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
@@ -17,12 +17,16 @@
 
 package org.apache.uniffle.shuffle.manager;
 
+import java.util.Arrays;
+
 import org.apache.spark.SparkConf;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.exception.RssException;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssShuffleManagerBaseTest {
@@ -39,4 +43,254 @@
     assertEquals(remoteStorageInfo.getConfItems().size(), 1);
     assertEquals(remoteStorageInfo.getConfItems().get("fs.defaultFs"), "hdfs://rbf-xxx/foo");
   }
+
+  private long bits(String string) {
+    return Long.parseLong(string.replaceAll("[|]", ""), 2);
+  }
+
+  @Test
+  public void testGetTaskAttemptIdWithoutSpeculation() {
+    // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
+    // separate map index from attempt number, so merely for visualization purposes
+
+    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+      assertEquals(
+          bits("0000|"),
+          RssShuffleManagerBase.getTaskAttemptId(0, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertEquals(
+          bits("0001|"),
+          RssShuffleManagerBase.getTaskAttemptId(1, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertEquals(
+          bits("0010|"),
+          RssShuffleManagerBase.getTaskAttemptId(2, 0, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+    }
+
+    // maxFailures of 2
+    assertEquals(bits("000|0"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 2, false, 10));
+    assertEquals(bits("000|1"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 2, false, 10));
+    assertEquals(bits("001|0"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 2, false, 10));
+    assertEquals(bits("001|1"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 2, false, 10));
+    assertEquals(bits("010|0"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 2, false, 10));
+    assertEquals(bits("010|1"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 2, false, 10));
+    assertEquals(bits("011|0"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 2, false, 10));
+    assertEquals(bits("011|1"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 2, false, 10));
+
+    // maxFailures of 3
+    assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 3, false, 10));
+    assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 3, false, 10));
+    assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 3, false, 10));
+    assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 3, false, 10));
+    assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 3, false, 10));
+    assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 3, false, 10));
+    assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 3, false, 10));
+    assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 3, false, 10));
+    assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 3, false, 10));
+    assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 3, false, 10));
+    assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 3, false, 10));
+    assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 3, false, 10));
+
+    // maxFailures of 4
+    assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 4, false, 10));
+    assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 4, false, 10));
+    assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 4, false, 10));
+    assertEquals(bits("00|11"), RssShuffleManagerBase.getTaskAttemptId(0, 3, 4, false, 10));
+    assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 4, false, 10));
+    assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 4, false, 10));
+    assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 4, false, 10));
+    assertEquals(bits("01|11"), RssShuffleManagerBase.getTaskAttemptId(1, 3, 4, false, 10));
+    assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 4, false, 10));
+    assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 4, false, 10));
+    assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 4, false, 10));
+    assertEquals(bits("10|11"), RssShuffleManagerBase.getTaskAttemptId(2, 3, 4, false, 10));
+    assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 4, false, 10));
+    assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 4, false, 10));
+    assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 4, false, 10));
+    assertEquals(bits("11|11"), RssShuffleManagerBase.getTaskAttemptId(3, 3, 4, false, 10));
+
+    // maxFailures of 5
+    assertEquals(bits("0|000"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 5, false, 10));
+    assertEquals(bits("1|100"), RssShuffleManagerBase.getTaskAttemptId(1, 4, 5, false, 10));
+
+    // test with ints that overflow into signed int and long
+    assertEquals(
+        Integer.MAX_VALUE,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1 | 1,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 2 | 3,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 33));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 3 | 7,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 34));
+
+    // test with attemptNo >= maxFailures
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 1, -1, false, 10));
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 1, 0, false, 10));
+    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 1, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 2, maxFailures, false, 10),
+          String.valueOf(maxFailures));
+      Exception e =
+          assertThrowsExactly(
+              RssException.class,
+              () ->
+                  RssShuffleManagerBase.getTaskAttemptId(
+                      0, maxFailures + 128, maxFailures, false, 10),
+              String.valueOf(maxFailures));
+      assertEquals(
+          "Observing attempt number "
+              + (maxFailures + 128)
+              + " while maxFailures is set to "
+              + maxFailures
+              + ".",
+          e.getMessage());
+    }
+
+    // test with mapIndex that would require more than maxTaskAttemptBits
+    Exception e =
+        assertThrowsExactly(
+            RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(256, 0, 3, true, 10));
+    assertEquals(
+        "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+            + "which is larger than the allowed 10 bits (maxFailures[3], speculation[true]). "
+            + "Please consider providing more bits for taskAttemptIds.",
+        e.getMessage());
+    // check that a lower mapIndex works as expected
+    assertEquals(bits("11111111|00"), RssShuffleManagerBase.getTaskAttemptId(255, 0, 3, true, 10));
+  }
+
+  @Test
+  public void testGetTaskAttemptIdWithSpeculation() {
+    // with speculation, we expect maxFailures+1 attempts
+
+    // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
+    // separate map index from attempt number, so merely for visualization purposes
+
+    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
+    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
+      for (int attemptNo : Arrays.asList(0, 1)) {
+        assertEquals(
+            bits("0000|" + attemptNo),
+            RssShuffleManagerBase.getTaskAttemptId(0, attemptNo, maxFailures, true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+        assertEquals(
+            bits("0001|" + attemptNo),
+            RssShuffleManagerBase.getTaskAttemptId(1, attemptNo, maxFailures, true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+        assertEquals(
+            bits("0010|" + attemptNo),
+            RssShuffleManagerBase.getTaskAttemptId(2, attemptNo, maxFailures, true, 10),
+            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
+      }
+    }
+
+    // maxFailures of 2
+    assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 2, true, 10));
+    assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 2, true, 10));
+    assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 2, true, 10));
+    assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 2, true, 10));
+    assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 2, true, 10));
+    assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 2, true, 10));
+    assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 2, true, 10));
+    assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 2, true, 10));
+    assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 2, true, 10));
+    assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 2, true, 10));
+    assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 2, true, 10));
+    assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 2, true, 10));
+
+    // maxFailures of 3
+    assertEquals(bits("00|00"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 3, true, 10));
+    assertEquals(bits("00|01"), RssShuffleManagerBase.getTaskAttemptId(0, 1, 3, true, 10));
+    assertEquals(bits("00|10"), RssShuffleManagerBase.getTaskAttemptId(0, 2, 3, true, 10));
+    assertEquals(bits("00|11"), RssShuffleManagerBase.getTaskAttemptId(0, 3, 3, true, 10));
+    assertEquals(bits("01|00"), RssShuffleManagerBase.getTaskAttemptId(1, 0, 3, true, 10));
+    assertEquals(bits("01|01"), RssShuffleManagerBase.getTaskAttemptId(1, 1, 3, true, 10));
+    assertEquals(bits("01|10"), RssShuffleManagerBase.getTaskAttemptId(1, 2, 3, true, 10));
+    assertEquals(bits("01|11"), RssShuffleManagerBase.getTaskAttemptId(1, 3, 3, true, 10));
+    assertEquals(bits("10|00"), RssShuffleManagerBase.getTaskAttemptId(2, 0, 3, true, 10));
+    assertEquals(bits("10|01"), RssShuffleManagerBase.getTaskAttemptId(2, 1, 3, true, 10));
+    assertEquals(bits("10|10"), RssShuffleManagerBase.getTaskAttemptId(2, 2, 3, true, 10));
+    assertEquals(bits("10|11"), RssShuffleManagerBase.getTaskAttemptId(2, 3, 3, true, 10));
+    assertEquals(bits("11|00"), RssShuffleManagerBase.getTaskAttemptId(3, 0, 3, true, 10));
+    assertEquals(bits("11|01"), RssShuffleManagerBase.getTaskAttemptId(3, 1, 3, true, 10));
+    assertEquals(bits("11|10"), RssShuffleManagerBase.getTaskAttemptId(3, 2, 3, true, 10));
+    assertEquals(bits("11|11"), RssShuffleManagerBase.getTaskAttemptId(3, 3, 3, true, 10));
+
+    // maxFailures of 4
+    assertEquals(bits("0|000"), RssShuffleManagerBase.getTaskAttemptId(0, 0, 4, true, 10));
+    assertEquals(bits("1|100"), RssShuffleManagerBase.getTaskAttemptId(1, 4, 4, true, 10));
+
+    // test with ints that overflow into signed int and long
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 1 | 1,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 2 | 3,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
+    assertEquals(
+        (long) Integer.MAX_VALUE << 3 | 7,
+        RssShuffleManagerBase.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
+
+    // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for speculation enabled)
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 2, -1, true, 10));
+    assertThrowsExactly(
+        RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(0, 2, 0, true, 10));
+    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 1, maxFailures, true, 10),
+          String.valueOf(maxFailures));
+      assertThrowsExactly(
+          RssException.class,
+          () -> RssShuffleManagerBase.getTaskAttemptId(0, maxFailures + 2, maxFailures, true, 10),
+          String.valueOf(maxFailures));
+      Exception e =
+          assertThrowsExactly(
+              RssException.class,
+              () ->
+                  RssShuffleManagerBase.getTaskAttemptId(
+                      0, maxFailures + 128, maxFailures, true, 10),
+              String.valueOf(maxFailures));
+      assertEquals(
+          "Observing attempt number "
+              + (maxFailures + 128)
+              + " while maxFailures is set to "
+              + maxFailures
+              + " with speculation enabled.",
+          e.getMessage());
+    }
+
+    // test with mapIndex that would require more than maxTaskAttemptBits
+    Exception e =
+        assertThrowsExactly(
+            RssException.class, () -> RssShuffleManagerBase.getTaskAttemptId(256, 0, 4, false, 10));
+    assertEquals(
+        "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
+            + "which is larger than the allowed 10 bits (maxFailures[4], speculation[false]). "
+            + "Please consider providing more bits for taskAttemptIds.",
+        e.getMessage());
+    // check that a lower mapIndex works as expected
+    assertEquals(bits("11111111|00"), RssShuffleManagerBase.getTaskAttemptId(255, 0, 4, false, 10));
+  }
 }
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index ed3f340..4483cdd 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -73,6 +73,7 @@
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
+import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.RssUtils;
@@ -106,6 +107,8 @@
   private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
   private boolean heartbeatStarted = false;
   private boolean dynamicConfEnabled = false;
+  private final int maxFailures;
+  private final boolean speculation;
   private final String user;
   private final String uuid;
   private DataPusher dataPusher;
@@ -140,6 +143,8 @@
           "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false.");
     }
     this.sparkConf = sparkConf;
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
     this.user = sparkConf.get("spark.rss.quota.user", "user");
     this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
     // set & check replica config
@@ -461,11 +466,18 @@
                 shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
       }
       ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
+      long taskAttemptId =
+          getTaskAttemptId(
+              context.partitionId(),
+              context.attemptNumber(),
+              maxFailures,
+              speculation,
+              Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
       return new RssShuffleWriter<>(
           rssHandle.getAppId(),
           shuffleId,
           taskId,
-          context.taskAttemptId(),
+          taskAttemptId,
           writeMetrics,
           this,
           sparkConf,
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index b7710e7..b223cea 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -529,63 +529,6 @@
         shuffleHandleInfo);
   }
 
-  /**
-   * Provides a task attempt id that is unique for a shuffle stage.
-   *
-   * <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
-   * that is unique across the entire Spark app which can reach very large numbers, which can
-   * practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
-   *
-   * <p>Here we use the map index or task id, appended by the attempt number per task. The map index
-   * is limited by the number of partitions of a stage. The attempt number per task is limited /
-   * configured by spark.task.maxFailures (default: 4).
-   *
-   * @return a task attempt id unique for a shuffle stage
-   */
-  @VisibleForTesting
-  protected static long getTaskAttemptId(
-      int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
-    // attempt number is zero based: 0, 1, …, maxFailures-1
-    // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
-    int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
-
-    // with speculative execution enabled we could observe +1 attempts
-    if (speculation) {
-      maxAttemptNo++;
-    }
-
-    if (attemptNo > maxAttemptNo) {
-      // this should never happen, if it does, our assumptions are wrong,
-      // and we risk overflowing the attempt number bits
-      throw new RssException(
-          "Observing attempt number "
-              + attemptNo
-              + " while maxFailures is set to "
-              + maxFailures
-              + (speculation ? " with speculation enabled" : "")
-              + ".");
-    }
-
-    int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
-    int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
-    if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
-      throw new RssException(
-          "Observing mapIndex["
-              + mapIndex
-              + "] that would produce a taskAttemptId with "
-              + (mapIndexBits + attemptBits)
-              + " bits which is larger than the allowed "
-              + maxTaskAttemptIdBits
-              + " bits (maxFailures["
-              + maxFailures
-              + "], speculation["
-              + speculation
-              + "]). Please consider providing more bits for taskAttemptIds.");
-    }
-
-    return (long) mapIndex << attemptBits | attemptNo;
-  }
-
   public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
     // todo: this implement is tricky, we should refactor it
     if (id.get() == null) {
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 9150d6d..64bd6f9 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,8 +17,6 @@
 
 package org.apache.spark.shuffle;
 
-import java.util.Arrays;
-
 import org.apache.spark.SparkConf;
 import org.apache.spark.sql.internal.SQLConf;
 import org.junit.jupiter.api.Test;
@@ -26,14 +24,12 @@
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class RssShuffleManagerTest extends RssShuffleManagerTestBase {
@@ -88,252 +84,6 @@
     }
   }
 
-  private long bits(String string) {
-    return Long.parseLong(string.replaceAll("[|]", ""), 2);
-  }
-
-  @Test
-  public void testGetTaskAttemptIdWithoutSpeculation() {
-    // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
-    // separate map index from attempt number, so merely for visualization purposes
-
-    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
-    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
-      assertEquals(
-          bits("0000|"),
-          RssShuffleManager.getTaskAttemptId(0, 0, maxFailures, false, 10),
-          String.valueOf(maxFailures));
-      assertEquals(
-          bits("0001|"),
-          RssShuffleManager.getTaskAttemptId(1, 0, maxFailures, false, 10),
-          String.valueOf(maxFailures));
-      assertEquals(
-          bits("0010|"),
-          RssShuffleManager.getTaskAttemptId(2, 0, maxFailures, false, 10),
-          String.valueOf(maxFailures));
-    }
-
-    // maxFailures of 2
-    assertEquals(bits("000|0"), RssShuffleManager.getTaskAttemptId(0, 0, 2, false, 10));
-    assertEquals(bits("000|1"), RssShuffleManager.getTaskAttemptId(0, 1, 2, false, 10));
-    assertEquals(bits("001|0"), RssShuffleManager.getTaskAttemptId(1, 0, 2, false, 10));
-    assertEquals(bits("001|1"), RssShuffleManager.getTaskAttemptId(1, 1, 2, false, 10));
-    assertEquals(bits("010|0"), RssShuffleManager.getTaskAttemptId(2, 0, 2, false, 10));
-    assertEquals(bits("010|1"), RssShuffleManager.getTaskAttemptId(2, 1, 2, false, 10));
-    assertEquals(bits("011|0"), RssShuffleManager.getTaskAttemptId(3, 0, 2, false, 10));
-    assertEquals(bits("011|1"), RssShuffleManager.getTaskAttemptId(3, 1, 2, false, 10));
-
-    // maxFailures of 3
-    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, false, 10));
-    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, false, 10));
-    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, false, 10));
-    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, false, 10));
-    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, false, 10));
-    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, false, 10));
-    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, false, 10));
-    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, false, 10));
-    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, false, 10));
-    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, false, 10));
-    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, false, 10));
-    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, false, 10));
-
-    // maxFailures of 4
-    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 4, false, 10));
-    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 4, false, 10));
-    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 4, false, 10));
-    assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 4, false, 10));
-    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 4, false, 10));
-    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 4, false, 10));
-    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 4, false, 10));
-    assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 4, false, 10));
-    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 4, false, 10));
-    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 4, false, 10));
-    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 4, false, 10));
-    assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 4, false, 10));
-    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 4, false, 10));
-    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 4, false, 10));
-    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 4, false, 10));
-    assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 4, false, 10));
-
-    // maxFailures of 5
-    assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 5, false, 10));
-    assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 5, false, 10));
-
-    // test with ints that overflow into signed int and long
-    assertEquals(
-        Integer.MAX_VALUE, RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, false, 31));
-    assertEquals(
-        (long) Integer.MAX_VALUE << 1 | 1,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 2, false, 32));
-    assertEquals(
-        (long) Integer.MAX_VALUE << 2 | 3,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 4, false, 33));
-    assertEquals(
-        (long) Integer.MAX_VALUE << 3 | 7,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 8, false, 34));
-
-    // test with attemptNo >= maxFailures
-    assertThrowsExactly(
-        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, -1, false, 10));
-    assertThrowsExactly(
-        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 1, 0, false, 10));
-    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
-      assertThrowsExactly(
-          RssException.class,
-          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures, maxFailures, false, 10),
-          String.valueOf(maxFailures));
-      assertThrowsExactly(
-          RssException.class,
-          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, maxFailures, false, 10),
-          String.valueOf(maxFailures));
-      assertThrowsExactly(
-          RssException.class,
-          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, maxFailures, false, 10),
-          String.valueOf(maxFailures));
-      Exception e =
-          assertThrowsExactly(
-              RssException.class,
-              () ->
-                  RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, maxFailures, false, 10),
-              String.valueOf(maxFailures));
-      assertEquals(
-          "Observing attempt number "
-              + (maxFailures + 128)
-              + " while maxFailures is set to "
-              + maxFailures
-              + ".",
-          e.getMessage());
-    }
-
-    // test with mapIndex that would require more than maxTaskAttemptBits
-    Exception e =
-        assertThrowsExactly(
-            RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 0, 3, true, 10));
-    assertEquals(
-        "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
-            + "which is larger than the allowed 10 bits (maxFailures[3], speculation[true]). "
-            + "Please consider providing more bits for taskAttemptIds.",
-        e.getMessage());
-    // check that a lower mapIndex works as expected
-    assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 0, 3, true, 10));
-  }
-
-  @Test
-  public void testGetTaskAttemptIdWithSpeculation() {
-    // with speculation, we expect maxFailures+1 attempts
-
-    // the expected bits("xy|z") represents the expected Long in bit notation where | is used to
-    // separate map index from attempt number, so merely for visualization purposes
-
-    // maxFailures < 1 not allowed, we fall back to maxFailures=1 to be robust
-    for (int maxFailures : Arrays.asList(-1, 0, 1)) {
-      for (int attemptNo : Arrays.asList(0, 1)) {
-        assertEquals(
-            bits("0000|" + attemptNo),
-            RssShuffleManager.getTaskAttemptId(0, attemptNo, maxFailures, true, 10),
-            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
-        assertEquals(
-            bits("0001|" + attemptNo),
-            RssShuffleManager.getTaskAttemptId(1, attemptNo, maxFailures, true, 10),
-            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
-        assertEquals(
-            bits("0010|" + attemptNo),
-            RssShuffleManager.getTaskAttemptId(2, attemptNo, maxFailures, true, 10),
-            "maxFailures=" + maxFailures + ", attemptNo=" + attemptNo);
-      }
-    }
-
-    // maxFailures of 2
-    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 2, true, 10));
-    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 2, true, 10));
-    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 2, true, 10));
-    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 2, true, 10));
-    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 2, true, 10));
-    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 2, true, 10));
-    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 2, true, 10));
-    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 2, true, 10));
-    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 2, true, 10));
-    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 2, true, 10));
-    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 2, true, 10));
-    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 2, true, 10));
-
-    // maxFailures of 3
-    assertEquals(bits("00|00"), RssShuffleManager.getTaskAttemptId(0, 0, 3, true, 10));
-    assertEquals(bits("00|01"), RssShuffleManager.getTaskAttemptId(0, 1, 3, true, 10));
-    assertEquals(bits("00|10"), RssShuffleManager.getTaskAttemptId(0, 2, 3, true, 10));
-    assertEquals(bits("00|11"), RssShuffleManager.getTaskAttemptId(0, 3, 3, true, 10));
-    assertEquals(bits("01|00"), RssShuffleManager.getTaskAttemptId(1, 0, 3, true, 10));
-    assertEquals(bits("01|01"), RssShuffleManager.getTaskAttemptId(1, 1, 3, true, 10));
-    assertEquals(bits("01|10"), RssShuffleManager.getTaskAttemptId(1, 2, 3, true, 10));
-    assertEquals(bits("01|11"), RssShuffleManager.getTaskAttemptId(1, 3, 3, true, 10));
-    assertEquals(bits("10|00"), RssShuffleManager.getTaskAttemptId(2, 0, 3, true, 10));
-    assertEquals(bits("10|01"), RssShuffleManager.getTaskAttemptId(2, 1, 3, true, 10));
-    assertEquals(bits("10|10"), RssShuffleManager.getTaskAttemptId(2, 2, 3, true, 10));
-    assertEquals(bits("10|11"), RssShuffleManager.getTaskAttemptId(2, 3, 3, true, 10));
-    assertEquals(bits("11|00"), RssShuffleManager.getTaskAttemptId(3, 0, 3, true, 10));
-    assertEquals(bits("11|01"), RssShuffleManager.getTaskAttemptId(3, 1, 3, true, 10));
-    assertEquals(bits("11|10"), RssShuffleManager.getTaskAttemptId(3, 2, 3, true, 10));
-    assertEquals(bits("11|11"), RssShuffleManager.getTaskAttemptId(3, 3, 3, true, 10));
-
-    // maxFailures of 4
-    assertEquals(bits("0|000"), RssShuffleManager.getTaskAttemptId(0, 0, 4, true, 10));
-    assertEquals(bits("1|100"), RssShuffleManager.getTaskAttemptId(1, 4, 4, true, 10));
-
-    // test with ints that overflow into signed int and long
-    assertEquals(
-        (long) Integer.MAX_VALUE << 1,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 0, 1, true, 32));
-    assertEquals(
-        (long) Integer.MAX_VALUE << 1 | 1,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 1, 1, true, 32));
-    assertEquals(
-        (long) Integer.MAX_VALUE << 2 | 3,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 3, 3, true, 33));
-    assertEquals(
-        (long) Integer.MAX_VALUE << 3 | 7,
-        RssShuffleManager.getTaskAttemptId(Integer.MAX_VALUE, 7, 7, true, 34));
-
-    // test with attemptNo > maxFailures (attemptNo == maxFailures allowed for speculation enabled)
-    assertThrowsExactly(
-        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, -1, true, 10));
-    assertThrowsExactly(
-        RssException.class, () -> RssShuffleManager.getTaskAttemptId(0, 2, 0, true, 10));
-    for (int maxFailures : Arrays.asList(1, 2, 3, 4, 8, 128)) {
-      assertThrowsExactly(
-          RssException.class,
-          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 1, maxFailures, true, 10),
-          String.valueOf(maxFailures));
-      assertThrowsExactly(
-          RssException.class,
-          () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 2, maxFailures, true, 10),
-          String.valueOf(maxFailures));
-      Exception e =
-          assertThrowsExactly(
-              RssException.class,
-              () -> RssShuffleManager.getTaskAttemptId(0, maxFailures + 128, maxFailures, true, 10),
-              String.valueOf(maxFailures));
-      assertEquals(
-          "Observing attempt number "
-              + (maxFailures + 128)
-              + " while maxFailures is set to "
-              + maxFailures
-              + " with speculation enabled.",
-          e.getMessage());
-    }
-
-    // test with mapIndex that would require more than maxTaskAttemptBits
-    Exception e =
-        assertThrowsExactly(
-            RssException.class, () -> RssShuffleManager.getTaskAttemptId(256, 0, 4, false, 10));
-    assertEquals(
-        "Observing mapIndex[256] that would produce a taskAttemptId with 11 bits "
-            + "which is larger than the allowed 10 bits (maxFailures[4], speculation[false]). "
-            + "Please consider providing more bits for taskAttemptIds.",
-        e.getMessage());
-    // check that a lower mapIndex works as expected
-    assertEquals(bits("11111111|00"), RssShuffleManager.getTaskAttemptId(255, 0, 4, false, 10));
-  }
-
   @Test
   public void testCreateShuffleManagerServer() {
     setupMockedRssShuffleUtils(StatusCode.SUCCESS);