[#1608][part-5] feat(spark3): always use the available assignment (#1652)

### What changes were proposed in this pull request?

1. make the write client always use the latest available assignment for the following writing when the block reassign happens.
2. support multi time retry for partition reassign
3. limit the max reassign server num of one partition
4. refactor the reassign rpc
5. rename the faultyServer -> receivingFailureServer. 

#### Reassign whole process
![image](https://github.com/apache/incubator-uniffle/assets/8609142/8afa5386-be39-4ccb-9c10-95ffb3154939)

#### Always using the latest assignment

To acheive always using the latest assignment, I introduce the `TaskAttemptAssignment` to get the latest assignment for current task. The creating process of AddBlockEvent also will apply the latest assignment by `TaskAttemptAssignment` 

And it will be updated by the `reassignOnBlockSendFailure` rpc. 
That means the original reassign rpc response will be refactored and replaced by the whole latest `shuffleHandleInfo`.

### Why are the changes needed?

This PR is the subtask for #1608.

Leverging the #1615 / #1610 / #1609, we have implemented the reassign servers mechansim when write client encounters the server failure or unhealthy. But this is not good enough that will not share the faulty server state to the unstarted tasks and latter `AddBlockEvent` .

### Does this PR introduce _any_ user-facing change?

Yes. 

### How was this patch tested?

Unit and integration tests.

Integration tests as follows:
1. `PartitionBlockDataReassignBasicTest` to validate the reassign mechanism valid
2. `PartitionBlockDataReassignMultiTimesTest` is to test the partition reassign mechanism of multiple retries.

---------

Co-authored-by: Enrico Minack <github@enrico.minack.dev>
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
index 60c9597..acf6815 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
@@ -22,6 +22,7 @@
 
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -31,14 +32,14 @@
   private String appId;
   private int numMaps;
   private ShuffleDependency<K, V, C> dependency;
-  private Broadcast<ShuffleHandleInfo> handlerInfoBd;
+  private Broadcast<SimpleShuffleHandleInfo> handlerInfoBd;
 
   public RssShuffleHandle(
       int shuffleId,
       String appId,
       int numMaps,
       ShuffleDependency<K, V, C> dependency,
-      Broadcast<ShuffleHandleInfo> handlerInfoBd) {
+      Broadcast<SimpleShuffleHandleInfo> handlerInfoBd) {
     super(shuffleId);
     this.appId = appId;
     this.numMaps = numMaps;
@@ -67,6 +68,6 @@
   }
 
   public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
-    return handlerInfoBd.value().getPartitionToServers();
+    return handlerInfoBd.value().getAvailablePartitionServersForWriter();
   }
 }
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index ee1278c..f118c85 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -64,6 +64,19 @@
           .withDescription(
               "The memory spill switch triggered by Spark TaskMemoryManager, default value is false.");
 
+  public static final ConfigOption<Integer> RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM =
+      ConfigOptions.key("rss.client.reassign.maxReassignServerNum")
+          .intType()
+          .defaultValue(10)
+          .withDescription(
+              "The max reassign server num for one partition when using partition reassign mechanism.");
+
+  public static final ConfigOption<Integer> RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES =
+      ConfigOptions.key("rss.client.reassign.blockRetryMaxTimes")
+          .intType()
+          .defaultValue(1)
+          .withDescription("The block retry max times when partition reassign is enabled.");
+
   public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";
 
   public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index cf49d3e..51384f1 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -35,6 +35,7 @@
 import org.apache.spark.SparkContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.deploy.SparkHadoopUtil;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.storage.BlockManagerId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -62,8 +63,8 @@
 
   private static final Logger LOG = LoggerFactory.getLogger(RssSparkShuffleUtils.class);
 
-  public static final ClassTag<ShuffleHandleInfo> SHUFFLE_HANDLER_INFO_CLASS_TAG =
-      scala.reflect.ClassTag$.MODULE$.apply(ShuffleHandleInfo.class);
+  public static final ClassTag<SimpleShuffleHandleInfo> DEFAULT_SHUFFLE_HANDLER_INFO_CLASS_TAG =
+      scala.reflect.ClassTag$.MODULE$.apply(SimpleShuffleHandleInfo.class);
   public static final ClassTag<byte[]> BYTE_ARRAY_CLASS_TAG =
       scala.reflect.ClassTag$.MODULE$.apply(byte[].class);
 
@@ -256,7 +257,7 @@
   }
 
   /**
-   * create broadcast variable of {@link ShuffleHandleInfo}
+   * create broadcast variable of {@link SimpleShuffleHandleInfo}
    *
    * @param sc expose for easy unit-test
    * @param shuffleId
@@ -264,14 +265,14 @@
    * @param storageInfo
    * @return Broadcast variable registered for auto cleanup
    */
-  public static Broadcast<ShuffleHandleInfo> broadcastShuffleHdlInfo(
+  public static Broadcast<SimpleShuffleHandleInfo> broadcastShuffleHdlInfo(
       SparkContext sc,
       int shuffleId,
       Map<Integer, List<ShuffleServerInfo>> partitionToServers,
       RemoteStorageInfo storageInfo) {
-    ShuffleHandleInfo handleInfo =
-        new ShuffleHandleInfo(shuffleId, partitionToServers, storageInfo);
-    return sc.broadcast(handleInfo, SHUFFLE_HANDLER_INFO_CLASS_TAG);
+    SimpleShuffleHandleInfo handleInfo =
+        new SimpleShuffleHandleInfo(shuffleId, partitionToServers, storageInfo);
+    return sc.broadcast(handleInfo, DEFAULT_SHUFFLE_HANDLER_INFO_CLASS_TAG);
   }
 
   private static <T> T instantiateFetchFailedException(
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
deleted file mode 100644
index 0e4dd49..0000000
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
+++ /dev/null
@@ -1,250 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-import com.google.common.annotations.VisibleForTesting;
-
-import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
-import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.proto.RssProtos;
-
-/**
- * Class for holding, 1. partition ID -> shuffle servers mapping. 2. remote storage info
- *
- * <p>It's to be broadcast to executors and referenced by shuffle tasks.
- */
-public class ShuffleHandleInfo implements Serializable {
-
-  private int shuffleId;
-  private RemoteStorageInfo remoteStorage;
-
-  /**
-   * partitionId -> replica -> assigned servers.
-   *
-   * <p>The first index of list<ShuffleServerInfo> is the initial static assignment server.
-   *
-   * <p>The remaining indexes are the replacement servers if exists.
-   */
-  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers;
-  // faulty servers replacement mapping
-  private Map<String, Set<ShuffleServerInfo>> faultyServerToReplacements;
-
-  public static final ShuffleHandleInfo EMPTY_HANDLE_INFO =
-      new ShuffleHandleInfo(-1, Collections.EMPTY_MAP, RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
-
-  public ShuffleHandleInfo(
-      int shuffleId,
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
-      RemoteStorageInfo storageInfo) {
-    this.shuffleId = shuffleId;
-    this.remoteStorage = storageInfo;
-    this.faultyServerToReplacements = new HashMap<>();
-    this.partitionReplicaAssignedServers = toPartitionReplicaMapping(partitionToServers);
-  }
-
-  public ShuffleHandleInfo() {
-    // ignore
-  }
-
-  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping(
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
-    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers =
-        new HashMap<>();
-    for (Map.Entry<Integer, List<ShuffleServerInfo>> partitionEntry :
-        partitionToServers.entrySet()) {
-      int partitionId = partitionEntry.getKey();
-      Map<Integer, List<ShuffleServerInfo>> replicaMapping =
-          partitionReplicaAssignedServers.computeIfAbsent(partitionId, x -> new HashMap<>());
-
-      List<ShuffleServerInfo> replicaServers = partitionEntry.getValue();
-      for (int i = 0; i < replicaServers.size(); i++) {
-        int replicaIdx = i;
-        replicaMapping
-            .computeIfAbsent(replicaIdx, x -> new ArrayList<>())
-            .add(replicaServers.get(i));
-      }
-    }
-    return partitionReplicaAssignedServers;
-  }
-
-  /**
-   * This composes the partition's replica servers + replacement servers, this will be used by the
-   * shuffleReader to get the blockIds
-   */
-  public Map<Integer, List<ShuffleServerInfo>> listPartitionAssignedServers() {
-    Map<Integer, List<ShuffleServerInfo>> partitionServers = new HashMap<>();
-    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
-        partitionReplicaAssignedServers.entrySet()) {
-      int partitionId = entry.getKey();
-      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
-      List<ShuffleServerInfo> servers =
-          replicaServers.values().stream().flatMap(x -> x.stream()).collect(Collectors.toList());
-      partitionServers.computeIfAbsent(partitionId, x -> new ArrayList<>()).addAll(servers);
-    }
-    return partitionServers;
-  }
-
-  /** Return all the assigned servers for the writer to commit */
-  public Set<ShuffleServerInfo> listAssignedServers() {
-    return partitionReplicaAssignedServers.values().stream()
-        .flatMap(x -> x.values().stream())
-        .flatMap(x -> x.stream())
-        .collect(Collectors.toSet());
-  }
-
-  public RemoteStorageInfo getRemoteStorage() {
-    return remoteStorage;
-  }
-
-  public int getShuffleId() {
-    return shuffleId;
-  }
-
-  @VisibleForTesting
-  protected boolean isMarkedAsFaultyServer(String serverId) {
-    return faultyServerToReplacements.containsKey(serverId);
-  }
-
-  public Set<ShuffleServerInfo> getExistingReplacements(String faultyServerId) {
-    return faultyServerToReplacements.get(faultyServerId);
-  }
-
-  public void updateReassignment(
-      Set<Integer> partitionIds, String faultyServerId, Set<ShuffleServerInfo> replacements) {
-    if (replacements == null) {
-      return;
-    }
-
-    faultyServerToReplacements.put(faultyServerId, replacements);
-    // todo: optimize the multiple for performance
-    for (Integer partitionId : partitionIds) {
-      Map<Integer, List<ShuffleServerInfo>> replicaServers =
-          partitionReplicaAssignedServers.get(partitionId);
-      for (Map.Entry<Integer, List<ShuffleServerInfo>> serverEntry : replicaServers.entrySet()) {
-        List<ShuffleServerInfo> servers = serverEntry.getValue();
-        if (servers.stream()
-            .map(x -> x.getId())
-            .collect(Collectors.toSet())
-            .contains(faultyServerId)) {
-          Set<ShuffleServerInfo> tempSet = new HashSet<>();
-          tempSet.addAll(replacements);
-          tempSet.removeAll(servers);
-          servers.addAll(tempSet);
-        }
-      }
-    }
-  }
-
-  // partitionId -> replica -> failover servers
-  // always return the last server.
-  @VisibleForTesting
-  public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
-    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
-        partitionReplicaAssignedServers.entrySet()) {
-      int partitionId = entry.getKey();
-      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
-      for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
-          replicaServers.entrySet()) {
-        ShuffleServerInfo lastServer =
-            replicaServerEntry.getValue().get(replicaServerEntry.getValue().size() - 1);
-        partitionToServers.computeIfAbsent(partitionId, x -> new ArrayList<>()).add(lastServer);
-      }
-    }
-    return partitionToServers;
-  }
-
-  public PartitionDataReplicaRequirementTracking createPartitionReplicaTracking() {
-    PartitionDataReplicaRequirementTracking replicaRequirement =
-        new PartitionDataReplicaRequirementTracking(shuffleId, partitionReplicaAssignedServers);
-    return replicaRequirement;
-  }
-
-  public static RssProtos.ShuffleHandleInfo toProto(ShuffleHandleInfo handleInfo) {
-    Map<Integer, RssProtos.PartitionReplicaServers> partitionToServers = new HashMap<>();
-    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
-        handleInfo.partitionReplicaAssignedServers.entrySet()) {
-      int partitionId = entry.getKey();
-
-      Map<Integer, RssProtos.ReplicaServersItem> replicaServersProto = new HashMap<>();
-      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
-      for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
-          replicaServers.entrySet()) {
-        RssProtos.ReplicaServersItem item =
-            RssProtos.ReplicaServersItem.newBuilder()
-                .addAllServerId(ShuffleServerInfo.toProto(replicaServerEntry.getValue()))
-                .build();
-        replicaServersProto.put(replicaServerEntry.getKey(), item);
-      }
-
-      RssProtos.PartitionReplicaServers partitionReplicaServerProto =
-          RssProtos.PartitionReplicaServers.newBuilder()
-              .putAllReplicaServers(replicaServersProto)
-              .build();
-      partitionToServers.put(partitionId, partitionReplicaServerProto);
-    }
-
-    RssProtos.ShuffleHandleInfo handleProto =
-        RssProtos.ShuffleHandleInfo.newBuilder()
-            .setShuffleId(handleInfo.shuffleId)
-            .setRemoteStorageInfo(
-                RssProtos.RemoteStorageInfo.newBuilder()
-                    .setPath(handleInfo.remoteStorage.getPath())
-                    .putAllConfItems(handleInfo.remoteStorage.getConfItems())
-                    .build())
-            .putAllPartitionToServers(partitionToServers)
-            .build();
-    return handleProto;
-  }
-
-  public static ShuffleHandleInfo fromProto(RssProtos.ShuffleHandleInfo handleProto) {
-    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionToServers = new HashMap<>();
-    for (Map.Entry<Integer, RssProtos.PartitionReplicaServers> entry :
-        handleProto.getPartitionToServersMap().entrySet()) {
-      Map<Integer, List<ShuffleServerInfo>> replicaServers =
-          partitionToServers.computeIfAbsent(entry.getKey(), x -> new HashMap<>());
-      for (Map.Entry<Integer, RssProtos.ReplicaServersItem> serverEntry :
-          entry.getValue().getReplicaServersMap().entrySet()) {
-        int replicaIdx = serverEntry.getKey();
-        List<ShuffleServerInfo> shuffleServerInfos =
-            ShuffleServerInfo.fromProto(serverEntry.getValue().getServerIdList());
-        replicaServers.put(replicaIdx, shuffleServerInfos);
-      }
-    }
-    RemoteStorageInfo remoteStorageInfo =
-        new RemoteStorageInfo(
-            handleProto.getRemoteStorageInfo().getPath(),
-            handleProto.getRemoteStorageInfo().getConfItemsMap());
-    ShuffleHandleInfo handle = new ShuffleHandleInfo();
-    handle.shuffleId = handle.getShuffleId();
-    handle.partitionReplicaAssignedServers = partitionToServers;
-    handle.remoteStorage = remoteStorageInfo;
-    return handle;
-  }
-}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfoManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfoManager.java
new file mode 100644
index 0000000..cc3d3b4
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfoManager.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+
+import org.apache.uniffle.common.util.JavaUtils;
+
+public class ShuffleHandleInfoManager implements Closeable {
+  private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo;
+
+  public ShuffleHandleInfoManager() {
+    this.shuffleIdToShuffleHandleInfo = JavaUtils.newConcurrentMap();
+  }
+
+  public ShuffleHandleInfo get(int shuffleId) {
+    return shuffleIdToShuffleHandleInfo.get(shuffleId);
+  }
+
+  public void remove(int shuffleId) {
+    shuffleIdToShuffleHandleInfo.remove(shuffleId);
+  }
+
+  public void register(int shuffleId, ShuffleHandleInfo handle) {
+    shuffleIdToShuffleHandleInfo.put(shuffleId, handle);
+  }
+
+  @Override
+  public void close() throws IOException {
+    if (shuffleIdToShuffleHandleInfo == null) {
+      return;
+    }
+    shuffleIdToShuffleHandleInfo.clear();
+  }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
new file mode 100644
index 0000000..4a05239
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.proto.RssProtos;
+
+/** This class holds the dynamic partition assignment for partition reassign mechanism. */
+public class MutableShuffleHandleInfo extends ShuffleHandleInfoBase {
+  private static final Logger LOGGER = LoggerFactory.getLogger(MutableShuffleHandleInfo.class);
+
+  /**
+   * partitionId -> replica -> assigned servers.
+   *
+   * <p>The first index of list<ShuffleServerInfo> is the initial static assignment server.
+   *
+   * <p>The remaining indexes are the replacement servers if exists.
+   */
+  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers;
+
+  private Map<String, Set<ShuffleServerInfo>> excludedServerToReplacements;
+
+  public MutableShuffleHandleInfo(
+      int shuffleId,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      RemoteStorageInfo storageInfo) {
+    super(shuffleId, storageInfo);
+    this.excludedServerToReplacements = new HashMap<>();
+    this.partitionReplicaAssignedServers = toPartitionReplicaMapping(partitionToServers);
+  }
+
+  @VisibleForTesting
+  protected MutableShuffleHandleInfo(
+      int shuffleId,
+      RemoteStorageInfo storageInfo,
+      Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers) {
+    super(shuffleId, storageInfo);
+    this.excludedServerToReplacements = new HashMap<>();
+    this.partitionReplicaAssignedServers = partitionReplicaAssignedServers;
+  }
+
+  public MutableShuffleHandleInfo(int shuffleId, RemoteStorageInfo storageInfo) {
+    super(shuffleId, storageInfo);
+  }
+
+  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping(
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers =
+        new HashMap<>();
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> partitionEntry :
+        partitionToServers.entrySet()) {
+      int partitionId = partitionEntry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicaMapping =
+          partitionReplicaAssignedServers.computeIfAbsent(partitionId, x -> new HashMap<>());
+
+      List<ShuffleServerInfo> replicaServers = partitionEntry.getValue();
+      for (int i = 0; i < replicaServers.size(); i++) {
+        int replicaIdx = i;
+        replicaMapping
+            .computeIfAbsent(replicaIdx, x -> new ArrayList<>())
+            .add(replicaServers.get(i));
+      }
+    }
+    return partitionReplicaAssignedServers;
+  }
+
+  public Set<ShuffleServerInfo> getReplacements(String faultyServerId) {
+    return excludedServerToReplacements.get(faultyServerId);
+  }
+
+  public Set<ShuffleServerInfo> updateAssignment(
+      int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) {
+    if (replacements == null || StringUtils.isEmpty(receivingFailureServerId)) {
+      return Collections.emptySet();
+    }
+    excludedServerToReplacements.put(receivingFailureServerId, replacements);
+
+    Set<ShuffleServerInfo> updatedServers = new HashSet<>();
+    Map<Integer, List<ShuffleServerInfo>> replicaServers =
+        partitionReplicaAssignedServers.get(partitionId);
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> serverEntry : replicaServers.entrySet()) {
+      List<ShuffleServerInfo> servers = serverEntry.getValue();
+      if (servers.stream()
+          .map(x -> x.getId())
+          .collect(Collectors.toSet())
+          .contains(receivingFailureServerId)) {
+        Set<ShuffleServerInfo> tempSet = new HashSet<>();
+        tempSet.addAll(replacements);
+        tempSet.removeAll(servers);
+
+        if (CollectionUtils.isNotEmpty(tempSet)) {
+          updatedServers.addAll(tempSet);
+          servers.addAll(tempSet);
+        }
+      }
+    }
+    return updatedServers;
+  }
+
+  @Override
+  public Set<ShuffleServerInfo> getServers() {
+    return partitionReplicaAssignedServers.values().stream()
+        .flatMap(x -> x.values().stream().flatMap(k -> k.stream()))
+        .collect(Collectors.toSet());
+  }
+
+  @Override
+  public Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter() {
+    Map<Integer, List<ShuffleServerInfo>> assignment = new HashMap<>();
+    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
+        partitionReplicaAssignedServers.entrySet()) {
+      int partitionId = entry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
+          replicaServers.entrySet()) {
+        ShuffleServerInfo candidate;
+        int candidateSize = replicaServerEntry.getValue().size();
+        candidate = replicaServerEntry.getValue().get(candidateSize - 1);
+        assignment.computeIfAbsent(partitionId, x -> new ArrayList<>()).add(candidate);
+      }
+    }
+    return assignment;
+  }
+
+  @Override
+  public Map<Integer, List<ShuffleServerInfo>> getAllPartitionServersForReader() {
+    Map<Integer, List<ShuffleServerInfo>> assignment = new HashMap<>();
+    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
+        partitionReplicaAssignedServers.entrySet()) {
+      int partitionId = entry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
+          replicaServers.entrySet()) {
+        assignment
+            .computeIfAbsent(partitionId, x -> new ArrayList<>())
+            .addAll(replicaServerEntry.getValue());
+      }
+    }
+    return assignment;
+  }
+
+  @Override
+  public PartitionDataReplicaRequirementTracking createPartitionReplicaTracking() {
+    PartitionDataReplicaRequirementTracking replicaRequirement =
+        new PartitionDataReplicaRequirementTracking(shuffleId, partitionReplicaAssignedServers);
+    return replicaRequirement;
+  }
+
+  public Set<String> listExcludedServers() {
+    return excludedServerToReplacements.keySet();
+  }
+
+  public void checkPartitionReassignServerNum(
+      Set<Integer> partitionIds, int legalReassignServerNum) {
+    for (int partitionId : partitionIds) {
+      Map<Integer, List<ShuffleServerInfo>> replicas =
+          partitionReplicaAssignedServers.get(partitionId);
+      for (List<ShuffleServerInfo> servers : replicas.values()) {
+        if (servers.size() - 1 > legalReassignServerNum) {
+          throw new RssException(
+              "Illegal reassignment servers for partitionId: "
+                  + partitionId
+                  + " that exceeding the max legal reassign server num: "
+                  + legalReassignServerNum);
+        }
+      }
+    }
+  }
+
+  public static RssProtos.MutableShuffleHandleInfo toProto(MutableShuffleHandleInfo handleInfo) {
+    synchronized (handleInfo) {
+      Map<Integer, RssProtos.PartitionReplicaServers> partitionToServers = new HashMap<>();
+      for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
+          handleInfo.partitionReplicaAssignedServers.entrySet()) {
+        int partitionId = entry.getKey();
+
+        Map<Integer, RssProtos.ReplicaServersItem> replicaServersProto = new HashMap<>();
+        Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
+        for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
+            replicaServers.entrySet()) {
+          RssProtos.ReplicaServersItem item =
+              RssProtos.ReplicaServersItem.newBuilder()
+                  .addAllServerId(ShuffleServerInfo.toProto(replicaServerEntry.getValue()))
+                  .build();
+          replicaServersProto.put(replicaServerEntry.getKey(), item);
+        }
+
+        RssProtos.PartitionReplicaServers partitionReplicaServerProto =
+            RssProtos.PartitionReplicaServers.newBuilder()
+                .putAllReplicaServers(replicaServersProto)
+                .build();
+        partitionToServers.put(partitionId, partitionReplicaServerProto);
+      }
+
+      RssProtos.MutableShuffleHandleInfo handleProto =
+          RssProtos.MutableShuffleHandleInfo.newBuilder()
+              .setShuffleId(handleInfo.shuffleId)
+              .setRemoteStorageInfo(
+                  RssProtos.RemoteStorageInfo.newBuilder()
+                      .setPath(handleInfo.remoteStorage.getPath())
+                      .putAllConfItems(handleInfo.remoteStorage.getConfItems())
+                      .build())
+              .putAllPartitionToServers(partitionToServers)
+              .build();
+      return handleProto;
+    }
+  }
+
+  public static MutableShuffleHandleInfo fromProto(RssProtos.MutableShuffleHandleInfo handleProto) {
+    if (handleProto == null) {
+      return null;
+    }
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionToServers = new HashMap<>();
+    for (Map.Entry<Integer, RssProtos.PartitionReplicaServers> entry :
+        handleProto.getPartitionToServersMap().entrySet()) {
+      Map<Integer, List<ShuffleServerInfo>> replicaServers =
+          partitionToServers.computeIfAbsent(entry.getKey(), x -> new HashMap<>());
+      for (Map.Entry<Integer, RssProtos.ReplicaServersItem> serverEntry :
+          entry.getValue().getReplicaServersMap().entrySet()) {
+        int replicaIdx = serverEntry.getKey();
+        List<ShuffleServerInfo> shuffleServerInfos =
+            ShuffleServerInfo.fromProto(serverEntry.getValue().getServerIdList());
+        replicaServers.put(replicaIdx, shuffleServerInfos);
+      }
+    }
+    RemoteStorageInfo remoteStorageInfo =
+        new RemoteStorageInfo(
+            handleProto.getRemoteStorageInfo().getPath(),
+            handleProto.getRemoteStorageInfo().getConfItemsMap());
+    MutableShuffleHandleInfo handle =
+        new MutableShuffleHandleInfo(handleProto.getShuffleId(), remoteStorageInfo);
+    handle.partitionReplicaAssignedServers = partitionToServers;
+    return handle;
+  }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
new file mode 100644
index 0000000..99f7a74
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+public interface ShuffleHandleInfo {
+  /** Get all the assigned servers including the excluded servers. */
+  Set<ShuffleServerInfo> getServers();
+
+  /**
+   * Get the assignment of available servers for writer to write partitioned blocks to corresponding
+   * shuffleServers. Implementations might return dynamic, up-to-date information here.
+   */
+  Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter();
+
+  /**
+   * Get all servers ever assigned to writers group by partitionId for reader to get the data
+   * written to these servers
+   */
+  Map<Integer, List<ShuffleServerInfo>> getAllPartitionServersForReader();
+
+  /** Create the partition replicas tracker for the writer to check data replica requirements */
+  PartitionDataReplicaRequirementTracking createPartitionReplicaTracking();
+
+  int getShuffleId();
+
+  RemoteStorageInfo getRemoteStorage();
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfoBase.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfoBase.java
new file mode 100644
index 0000000..f24bd0f
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfoBase.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.io.Serializable;
+
+import org.apache.uniffle.common.RemoteStorageInfo;
+
+public abstract class ShuffleHandleInfoBase implements ShuffleHandleInfo, Serializable {
+  protected int shuffleId;
+  protected RemoteStorageInfo remoteStorage;
+
+  public ShuffleHandleInfoBase(int shuffleId, RemoteStorageInfo remoteStorage) {
+    this.shuffleId = shuffleId;
+    this.remoteStorage = remoteStorage;
+  }
+
+  public int getShuffleId() {
+    return shuffleId;
+  }
+
+  public RemoteStorageInfo getRemoteStorage() {
+    return remoteStorage;
+  }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
new file mode 100644
index 0000000..60cb6f2
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+/**
+ * Class for holding, 1. partition ID -> shuffle servers mapping. 2. remote storage info
+ *
+ * <p>It's to be broadcast to executors and referenced by shuffle tasks.
+ */
+public class SimpleShuffleHandleInfo extends ShuffleHandleInfoBase implements Serializable {
+  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+
+  public SimpleShuffleHandleInfo(
+      int shuffleId,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      RemoteStorageInfo storageInfo) {
+    super(shuffleId, storageInfo);
+    this.partitionToServers = partitionToServers;
+  }
+
+  @Override
+  public Set<ShuffleServerInfo> getServers() {
+    return partitionToServers.values().stream()
+        .flatMap(x -> x.stream())
+        .collect(Collectors.toSet());
+  }
+
+  @Override
+  public Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter() {
+    return partitionToServers;
+  }
+
+  @Override
+  public Map<Integer, List<ShuffleServerInfo>> getAllPartitionServersForReader() {
+    return partitionToServers;
+  }
+
+  @Override
+  public PartitionDataReplicaRequirementTracking createPartitionReplicaTracking() {
+    return new PartitionDataReplicaRequirementTracking(partitionToServers, shuffleId);
+  }
+
+  public RemoteStorageInfo getRemoteStorage() {
+    return remoteStorage;
+  }
+
+  public int getShuffleId() {
+    return shuffleId;
+  }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
new file mode 100644
index 0000000..0044ba2
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.util.List;
+import java.util.Map;
+
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+
+/** This class is to get the partition assignment for ShuffleWriter. */
+public class TaskAttemptAssignment {
+  private Map<Integer, List<ShuffleServerInfo>> assignment;
+
+  public TaskAttemptAssignment(long taskAttemptId, ShuffleHandleInfo shuffleHandleInfo) {
+    this.update(shuffleHandleInfo);
+  }
+
+  /**
+   * Retrieving the partition's current available shuffleServers.
+   *
+   * @param partitionId
+   * @return
+   */
+  public List<ShuffleServerInfo> retrieve(int partitionId) {
+    return assignment.get(partitionId);
+  }
+
+  public void update(ShuffleHandleInfo handle) {
+    if (handle == null) {
+      throw new RssException("Errors on updating shuffle handle by the empty handleInfo.");
+    }
+    this.assignment = handle.getAvailablePartitionServersForWriter();
+  }
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index f5fa497..6717428 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -77,7 +77,6 @@
   private ShuffleWriteMetrics shuffleWriteMetrics;
   // cache partition -> records
   private Map<Integer, WriterBuffer> buffers;
-  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
   private int serializerBufferSize;
   private int bufferSegmentSize;
   private long copyTime = 0;
@@ -98,6 +97,7 @@
   private int memorySpillTimeoutSec;
   private boolean isRowBased;
   private BlockIdLayout blockIdLayout;
+  private Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc;
 
   public WriteBufferManager(
       int shuffleId,
@@ -127,11 +127,11 @@
       long taskAttemptId,
       BufferManagerOptions bufferManagerOptions,
       Serializer serializer,
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
       TaskMemoryManager taskMemoryManager,
       ShuffleWriteMetrics shuffleWriteMetrics,
       RssConf rssConf,
-      Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
+      Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
+      Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc) {
     super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
     this.bufferSize = bufferManagerOptions.getBufferSize();
     this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
@@ -139,7 +139,6 @@
     this.shuffleId = shuffleId;
     this.taskId = taskId;
     this.taskAttemptId = taskAttemptId;
-    this.partitionToServers = partitionToServers;
     this.shuffleWriteMetrics = shuffleWriteMetrics;
     this.serializerBufferSize = bufferManagerOptions.getSerializerBufferSize();
     this.bufferSegmentSize = bufferManagerOptions.getBufferSegmentSize();
@@ -164,6 +163,31 @@
     this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
     this.memorySpillEnabled = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
     this.blockIdLayout = BlockIdLayout.from(rssConf);
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+  }
+
+  public WriteBufferManager(
+      int shuffleId,
+      String taskId,
+      long taskAttemptId,
+      BufferManagerOptions bufferManagerOptions,
+      Serializer serializer,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssConf rssConf,
+      Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
+    this(
+        shuffleId,
+        taskId,
+        taskAttemptId,
+        bufferManagerOptions,
+        serializer,
+        taskMemoryManager,
+        shuffleWriteMetrics,
+        rssConf,
+        spillFunc,
+        partitionId -> partitionToServers.get(partitionId));
   }
 
   /** add serialized columnar data directly when integrate with gluten */
@@ -353,7 +377,7 @@
         compressed.length,
         crc32,
         compressed,
-        partitionToServers.get(partitionId),
+        partitionAssignmentRetrieveFunc.apply(partitionId),
         uncompressLength,
         wb.getMemoryUsed(),
         taskAttemptId);
@@ -582,4 +606,9 @@
   public void setSendSizeLimit(long sendSizeLimit) {
     this.sendSizeLimit = sendSizeLimit;
   }
+
+  public void setPartitionAssignmentRetrieveFunc(
+      Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc) {
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+  }
 }
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index 51d191a..4f16917 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -17,12 +17,14 @@
 
 package org.apache.uniffle.shuffle.manager;
 
-import java.util.Set;
+import java.util.List;
+import java.util.Map;
 
 import org.apache.spark.SparkException;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 
-import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.ReceivingFailureServer;
 
 /**
  * This is a proxy interface that mainly delegates the un-registration of shuffles to the
@@ -78,6 +80,6 @@
   boolean reassignAllShuffleServersForWholeStage(
       int stageId, int stageAttemptNumber, int shuffleId, int numMaps);
 
-  ShuffleServerInfo reassignFaultyShuffleServerForTasks(
-      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId);
+  MutableShuffleHandleInfo reassignOnBlockSendFailure(
+      int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers);
 }
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index b713bd7..11f613f 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -25,13 +25,14 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 
-import com.google.common.collect.Sets;
 import io.grpc.stub.StreamObserver;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.proto.RssProtos;
@@ -189,22 +190,18 @@
     RssProtos.PartitionToShuffleServerResponse reply;
     RssProtos.StatusCode code;
     int shuffleId = request.getShuffleId();
-    ShuffleHandleInfo shuffleHandleInfoByShuffleId =
-        shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
-    if (shuffleHandleInfoByShuffleId != null) {
+    MutableShuffleHandleInfo shuffleHandle =
+        (MutableShuffleHandleInfo) shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
+    if (shuffleHandle != null) {
       code = RssProtos.StatusCode.SUCCESS;
       reply =
           RssProtos.PartitionToShuffleServerResponse.newBuilder()
               .setStatus(code)
-              .setShuffleHandleInfo(ShuffleHandleInfo.toProto(shuffleHandleInfoByShuffleId))
+              .setShuffleHandleInfo(MutableShuffleHandleInfo.toProto(shuffleHandle))
               .build();
     } else {
       code = RssProtos.StatusCode.INVALID_REQUEST;
-      reply =
-          RssProtos.PartitionToShuffleServerResponse.newBuilder()
-              .setStatus(code)
-              .setShuffleHandleInfo(ShuffleHandleInfo.toProto(ShuffleHandleInfo.EMPTY_HANDLE_INFO))
-              .build();
+      reply = RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).build();
     }
     responseObserver.onNext(reply);
     responseObserver.onCompleted();
@@ -232,20 +229,35 @@
   }
 
   @Override
-  public void reassignFaultyShuffleServer(
-      RssProtos.RssReassignFaultyShuffleServerRequest request,
-      StreamObserver<RssProtos.RssReassignFaultyShuffleServerResponse> responseObserver) {
-    ShuffleServerInfo shuffleServerInfo =
-        shuffleManager.reassignFaultyShuffleServerForTasks(
-            request.getShuffleId(),
-            Sets.newHashSet(request.getPartitionIdsList()),
-            request.getFaultyShuffleServerId());
-    RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
-    RssProtos.RssReassignFaultyShuffleServerResponse reply =
-        RssProtos.RssReassignFaultyShuffleServerResponse.newBuilder()
-            .setStatus(code)
-            .setServer(ShuffleServerInfo.convertToShuffleServerId(shuffleServerInfo))
-            .build();
+  public void reassignOnBlockSendFailure(
+      org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureRequest request,
+      io.grpc.stub.StreamObserver<
+              org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureResponse>
+          responseObserver) {
+    RssProtos.StatusCode code = RssProtos.StatusCode.INTERNAL_ERROR;
+    RssProtos.RssReassignOnBlockSendFailureResponse reply;
+    try {
+      MutableShuffleHandleInfo handle =
+          shuffleManager.reassignOnBlockSendFailure(
+              request.getShuffleId(),
+              request.getFailurePartitionToServerIdsMap().entrySet().stream()
+                  .collect(
+                      Collectors.toMap(
+                          Map.Entry::getKey, x -> ReceivingFailureServer.fromProto(x.getValue()))));
+      code = RssProtos.StatusCode.SUCCESS;
+      reply =
+          RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
+              .setStatus(code)
+              .setHandle(MutableShuffleHandleInfo.toProto(handle))
+              .build();
+    } catch (Exception e) {
+      LOG.error("Errors on reassigning when block send failure.", e);
+      reply =
+          RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
+              .setStatus(code)
+              .setMsg(e.getMessage())
+              .build();
+    }
     responseObserver.onNext(reply);
     responseObserver.onCompleted();
   }
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java
deleted file mode 100644
index a2fe771..0000000
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java
+++ /dev/null
@@ -1,117 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import com.google.common.collect.Sets;
-import org.junit.jupiter.api.Test;
-
-import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
-import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleServerInfo;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class ShuffleHandleInfoTest {
-
-  private ShuffleServerInfo createFakeServerInfo(String id) {
-    return new ShuffleServerInfo(id, id, 1);
-  }
-
-  @Test
-  public void testReassignment() {
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
-    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b")));
-    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
-
-    ShuffleHandleInfo handleInfo =
-        new ShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
-
-    assertFalse(handleInfo.isMarkedAsFaultyServer("a"));
-    Set<Integer> partitions = Sets.newHashSet(1);
-    handleInfo.updateReassignment(partitions, "a", Sets.newHashSet(createFakeServerInfo("d")));
-    assertTrue(handleInfo.isMarkedAsFaultyServer("a"));
-  }
-
-  @Test
-  public void testListAllPartitionAssignmentServers() {
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
-    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b")));
-    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
-
-    ShuffleHandleInfo handleInfo =
-        new ShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
-
-    // case1
-    Set<Integer> partitions = Sets.newHashSet(2);
-    handleInfo.updateReassignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
-
-    Map<Integer, List<ShuffleServerInfo>> partitionAssignment =
-        handleInfo.listPartitionAssignedServers();
-    assertEquals(2, partitionAssignment.size());
-    assertEquals(
-        Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
-        partitionAssignment.get(2));
-
-    // case2: reassign multiple times for one partition, it will not append the same replacement
-    // servers
-    handleInfo.updateReassignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
-    partitionAssignment = handleInfo.listPartitionAssignedServers();
-    assertEquals(
-        Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
-        partitionAssignment.get(2));
-
-    // case3: reassign multiple times for one partition, it will append the non-existing replacement
-    // servers
-    handleInfo.updateReassignment(
-        partitions, "c", Sets.newHashSet(createFakeServerInfo("d"), createFakeServerInfo("e")));
-    partitionAssignment = handleInfo.listPartitionAssignedServers();
-    assertEquals(
-        Arrays.asList(
-            createFakeServerInfo("c"), createFakeServerInfo("d"), createFakeServerInfo("e")),
-        partitionAssignment.get(2));
-  }
-
-  @Test
-  public void testCreatePartitionReplicaTracking() {
-    ShuffleServerInfo a = createFakeServerInfo("a");
-    ShuffleServerInfo b = createFakeServerInfo("b");
-    ShuffleServerInfo c = createFakeServerInfo("c");
-
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
-    partitionToServers.put(1, Arrays.asList(a, b));
-    partitionToServers.put(2, Arrays.asList(c));
-
-    ShuffleHandleInfo handleInfo =
-        new ShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
-
-    // not any replacements
-    PartitionDataReplicaRequirementTracking tracking = handleInfo.createPartitionReplicaTracking();
-    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = tracking.getInventory();
-    assertEquals(a, inventory.get(1).get(0).get(0));
-    assertEquals(b, inventory.get(1).get(1).get(0));
-    assertEquals(c, inventory.get(2).get(0).get(0));
-  }
-}
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
new file mode 100644
index 0000000..e861923
--- /dev/null
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.handle;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class MutableShuffleHandleInfoTest {
+
+  private ShuffleServerInfo createFakeServerInfo(String id) {
+    return new ShuffleServerInfo(id, id, 1);
+  }
+
+  @Test
+  public void testUpdateAssignment() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b")));
+    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
+
+    MutableShuffleHandleInfo handleInfo =
+        new MutableShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
+
+    // case1: update the replacement servers but has existing servers
+    Set<ShuffleServerInfo> updated =
+        handleInfo.updateAssignment(
+            1, "a", Sets.newHashSet(createFakeServerInfo("a"), createFakeServerInfo("d")));
+    assertTrue(updated.stream().findFirst().get().getId().equals("d"));
+
+    // case2: update when having multiple servers
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers =
+        new HashMap<>();
+    List<ShuffleServerInfo> servers =
+        new ArrayList<>(
+            Arrays.asList(
+                createFakeServerInfo("a"),
+                createFakeServerInfo("b"),
+                createFakeServerInfo("c"),
+                createFakeServerInfo("d")));
+    partitionReplicaAssignedServers
+        .computeIfAbsent(1, x -> new HashMap<>())
+        .computeIfAbsent(0, x -> servers);
+    handleInfo =
+        new MutableShuffleHandleInfo(1, new RemoteStorageInfo(""), partitionReplicaAssignedServers);
+    int partitionId = 1;
+    updated =
+        handleInfo.updateAssignment(
+            partitionId,
+            "a",
+            Sets.newHashSet(
+                createFakeServerInfo("b"),
+                createFakeServerInfo("d"),
+                createFakeServerInfo("e"),
+                createFakeServerInfo("f")));
+    assertEquals(updated, Sets.newHashSet(createFakeServerInfo("e"), createFakeServerInfo("f")));
+  }
+
+  @Test
+  public void testListAllPartitionAssignmentServers() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b")));
+    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
+
+    MutableShuffleHandleInfo handleInfo =
+        new MutableShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
+
+    // case1
+    int partitionId = 2;
+    handleInfo.updateAssignment(partitionId, "c", Sets.newHashSet(createFakeServerInfo("d")));
+
+    Map<Integer, List<ShuffleServerInfo>> partitionAssignment =
+        handleInfo.getAllPartitionServersForReader();
+    assertEquals(2, partitionAssignment.size());
+    assertEquals(
+        Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
+        partitionAssignment.get(2));
+
+    // case2: reassign multiple times for one partition, it will not append the same replacement
+    // servers
+    handleInfo.updateAssignment(partitionId, "c", Sets.newHashSet(createFakeServerInfo("d")));
+    partitionAssignment = handleInfo.getAllPartitionServersForReader();
+    assertEquals(
+        Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
+        partitionAssignment.get(2));
+
+    // case3: reassign multiple times for one partition, it will append the non-existing replacement
+    // servers
+    handleInfo.updateAssignment(
+        partitionId, "c", Sets.newHashSet(createFakeServerInfo("d"), createFakeServerInfo("e")));
+    partitionAssignment = handleInfo.getAllPartitionServersForReader();
+    assertEquals(
+        Arrays.asList(
+            createFakeServerInfo("c"), createFakeServerInfo("d"), createFakeServerInfo("e")),
+        partitionAssignment.get(2));
+  }
+
+  @Test
+  public void testCreatePartitionReplicaTracking() {
+    ShuffleServerInfo a = createFakeServerInfo("a");
+    ShuffleServerInfo b = createFakeServerInfo("b");
+    ShuffleServerInfo c = createFakeServerInfo("c");
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(a, b));
+    partitionToServers.put(2, Arrays.asList(c));
+
+    MutableShuffleHandleInfo handleInfo =
+        new MutableShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
+
+    // not any replacements
+    PartitionDataReplicaRequirementTracking tracking = handleInfo.createPartitionReplicaTracking();
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = tracking.getInventory();
+    assertEquals(a, inventory.get(1).get(0).get(0));
+    assertEquals(b, inventory.get(1).get(1).get(0));
+    assertEquals(c, inventory.get(2).get(0).get(0));
+  }
+}
diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 37d2b8e..e7acaaf 100644
--- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -18,11 +18,14 @@
 package org.apache.uniffle.shuffle.manager;
 
 import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
 import java.util.Set;
 
-import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfoBase;
 
-import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.ReceivingFailureServer;
 
 import static org.mockito.Mockito.mock;
 
@@ -55,7 +58,7 @@
   }
 
   @Override
-  public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
+  public ShuffleHandleInfoBase getShuffleHandleInfoByShuffleId(int shuffleId) {
     return null;
   }
 
@@ -69,8 +72,8 @@
   }
 
   @Override
-  public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
-      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
-    return mock(ShuffleServerInfo.class);
+  public MutableShuffleHandleInfo reassignOnBlockSendFailure(
+      int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) {
+    return mock(MutableShuffleHandleInfo.class);
   }
 }
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index cfd5ae3..e7bc631 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -41,6 +41,9 @@
 import org.apache.spark.TaskContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.DataPusher;
@@ -62,6 +65,7 @@
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -320,7 +324,7 @@
               + shuffleId
               + "], partitionNum is 0, "
               + "return the empty RssShuffleHandle directly");
-      Broadcast<ShuffleHandleInfo> hdlInfoBd =
+      Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
           RssSparkShuffleUtils.broadcastShuffleHdlInfo(
               RssSparkShuffleUtils.getActiveSparkContext(),
               shuffleId,
@@ -380,11 +384,11 @@
     shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions());
     shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length);
     if (shuffleManagerRpcServiceEnabled) {
-      ShuffleHandleInfo handleInfo =
-          new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+      MutableShuffleHandleInfo handleInfo =
+          new MutableShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
       shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
     }
-    Broadcast<ShuffleHandleInfo> hdlInfoBd =
+    Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
         RssSparkShuffleUtils.broadcastShuffleHdlInfo(
             RssSparkShuffleUtils.getActiveSparkContext(),
             shuffleId,
@@ -481,7 +485,7 @@
         shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
       } else {
         shuffleHandleInfo =
-            new ShuffleHandleInfo(
+            new SimpleShuffleHandleInfo(
                 shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
       }
       ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
@@ -551,13 +555,13 @@
         shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
       } else {
         shuffleHandleInfo =
-            new ShuffleHandleInfo(
+            new SimpleShuffleHandleInfo(
                 shuffleId,
                 rssShuffleHandle.getPartitionToServers(),
                 rssShuffleHandle.getRemoteStorage());
       }
       Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-          shuffleHandleInfo.getPartitionToServers();
+          shuffleHandleInfo.getAllPartitionServersForReader();
       Roaring64NavigableMap blockIdBitmap =
           getShuffleResult(
               clientType,
@@ -812,8 +816,8 @@
    * @param shuffleId shuffleId
    * @return ShuffleHandleInfo
    */
-  private synchronized ShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
-    ShuffleHandleInfo shuffleHandleInfo;
+  private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
+    MutableShuffleHandleInfo shuffleHandleInfo;
     RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
     String driver = rssConf.getString("driver.host", "");
     int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
@@ -824,7 +828,8 @@
         new RssPartitionToShuffleServerRequest(shuffleId);
     RssPartitionToShuffleServerResponse handleInfoResponse =
         shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
-    shuffleHandleInfo = ShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto());
+    shuffleHandleInfo =
+        MutableShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto());
     return shuffleHandleInfo;
   }
 
@@ -873,8 +878,8 @@
         LOG.error("Clear MapoutTracker Meta failed!");
         throw new RssException("Clear MapoutTracker Meta failed!", e);
       }
-      ShuffleHandleInfo handleInfo =
-          new ShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
+      MutableShuffleHandleInfo handleInfo =
+          new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
       shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
       serverAssignedInfos.put(stageIdAndAttempt, true);
       return true;
@@ -887,22 +892,10 @@
     }
   }
 
-  // this is only valid on driver side that exposed to being invoked by grpc server
   @Override
-  public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
-      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
-    ShuffleHandleInfo handleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
-    synchronized (handleInfo) {
-      // find out whether this server has been marked faulty in this shuffle
-      // if it has been reassigned, directly return the replacement server.
-      Set<ShuffleServerInfo> replacements =
-          handleInfo.getExistingReplacements(faultyShuffleServerId);
-      if (replacements == null) {
-        replacements = Sets.newHashSet(assignShuffleServer(shuffleId, faultyShuffleServerId));
-      }
-      handleInfo.updateReassignment(partitionIds, faultyShuffleServerId, replacements);
-      return replacements.stream().findFirst().get();
-    }
+  public MutableShuffleHandleInfo reassignOnBlockSendFailure(
+      int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) {
+    throw new RssException("Illegal access for reassignOnBlockSendFailure that is not supported.");
   }
 
   private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 5a5d8bd..c38f159 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -55,8 +55,9 @@
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.RssSparkShuffleUtils;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.storage.BlockManagerId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -122,7 +123,7 @@
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
-      ShuffleHandleInfo shuffleHandleInfo,
+      SimpleShuffleHandleInfo shuffleHandleInfo,
       TaskContext context) {
     this(
         appId,
@@ -168,8 +169,8 @@
     this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.serverToPartitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
-    this.shuffleServersForData = shuffleHandleInfo.listAssignedServers();
-    this.partitionToServers = shuffleHandleInfo.getPartitionToServers();
+    this.shuffleServersForData = shuffleHandleInfo.getServers();
+    this.partitionToServers = shuffleHandleInfo.getAvailablePartitionServersForWriter();
     this.isMemoryShuffleEnabled =
         isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
@@ -211,7 +212,7 @@
             taskAttemptId,
             bufferOptions,
             rssHandle.getDependency().serializer(),
-            shuffleHandleInfo.getPartitionToServers(),
+            shuffleHandleInfo.getAvailablePartitionServersForWriter(),
             context.taskMemoryManager(),
             shuffleWriteMetrics,
             RssSparkConfig.toRssConf(sparkConf),
diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 8711c48..f35e655 100644
--- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -45,7 +45,7 @@
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
@@ -97,7 +97,7 @@
     when(mockHandle.getPartitionToServers()).thenReturn(Maps.newHashMap());
     TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
 
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager =
@@ -286,7 +286,7 @@
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
 
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
@@ -398,7 +398,7 @@
     when(mockHandle.getDependency()).thenReturn(mockDependency);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
 
     RssShuffleWriter<String, String, String> writer =
         new RssShuffleWriter<>(
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1f61876..e629b23 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -19,9 +19,9 @@
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -41,6 +41,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.MapOutputTracker;
 import org.apache.spark.ShuffleDependency;
@@ -51,6 +52,9 @@
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.DataPusher;
@@ -74,6 +78,7 @@
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -83,6 +88,7 @@
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.rpc.GrpcServer;
+import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
@@ -92,6 +98,7 @@
 import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
 import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
 
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
 import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
 import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
 
@@ -133,11 +140,6 @@
   protected ShuffleWriteClient shuffleWriteClient;
 
   private ShuffleManagerClient shuffleManagerClient;
-  /**
-   * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is dynamically allocated.
-   * ShuffleServer is not obtained from RssShuffleHandle, but from this mapping.
-   */
-  private Map<Integer, ShuffleHandleInfo> shuffleIdToShuffleHandleInfo;
   /** Whether to enable the dynamic shuffleServer function rewrite and reread functions */
   private boolean rssResubmitStage;
 
@@ -152,6 +154,10 @@
    */
   private Map<String, Boolean> serverAssignedInfos;
 
+  private final int partitionReassignMaxServerNum;
+
+  private final ShuffleHandleInfoManager shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+
   public RssShuffleManager(SparkConf conf, boolean isDriver) {
     this.sparkConf = conf;
     boolean supportsRelocation =
@@ -246,6 +252,15 @@
             && RssSparkShuffleUtils.isStageResubmitSupported();
     this.taskBlockSendFailureRetryEnabled =
         rssConf.getBoolean(RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED);
+
+    // The feature of partition reassign is exclusive with multiple replicas and stage retry.
+    if (taskBlockSendFailureRetryEnabled) {
+      if (rssResubmitStage || dataReplica > 1) {
+        throw new RssException(
+            "The feature of partition reassign is incompatible with multiple replicas and stage retry.");
+      }
+    }
+
     this.shuffleManagerRpcServiceEnabled = taskBlockSendFailureRetryEnabled || rssResubmitStage;
     if (isDriver) {
       heartBeatScheduledExecutorService =
@@ -280,9 +295,10 @@
             failedTaskIds,
             poolSize,
             keepAliveTime);
-    this.shuffleIdToShuffleHandleInfo = JavaUtils.newConcurrentMap();
     this.failuresShuffleServerIds = Sets.newHashSet();
     this.serverAssignedInfos = JavaUtils.newConcurrentMap();
+    this.partitionReassignMaxServerNum =
+        rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
   }
 
   public CompletableFuture<Long> sendData(AddBlockEvent event) {
@@ -370,6 +386,8 @@
     this.heartBeatScheduledExecutorService = null;
     this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
     this.dataPusher = dataPusher;
+    this.partitionReassignMaxServerNum =
+        rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
   }
 
   // This method is called in Spark driver side,
@@ -423,7 +441,7 @@
               + shuffleId
               + "], partitionNum is 0, "
               + "return the empty RssShuffleHandle directly");
-      Broadcast<ShuffleHandleInfo> hdlInfoBd =
+      Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
           RssSparkShuffleUtils.broadcastShuffleHdlInfo(
               RssSparkShuffleUtils.getActiveSparkContext(),
               shuffleId,
@@ -479,11 +497,11 @@
     shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions());
     shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length);
     if (shuffleManagerRpcServiceEnabled) {
-      ShuffleHandleInfo handleInfo =
-          new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
-      shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+      MutableShuffleHandleInfo handleInfo =
+          new MutableShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+      shuffleHandleInfoManager.register(shuffleId, handleInfo);
     }
-    Broadcast<ShuffleHandleInfo> hdlInfoBd =
+    Broadcast<SimpleShuffleHandleInfo> hdlInfoBd =
         RssSparkShuffleUtils.broadcastShuffleHdlInfo(
             RssSparkShuffleUtils.getActiveSparkContext(),
             shuffleId,
@@ -521,7 +539,7 @@
       shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
     } else {
       shuffleHandleInfo =
-          new ShuffleHandleInfo(
+          new SimpleShuffleHandleInfo(
               shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
     }
     String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
@@ -663,7 +681,7 @@
       shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
     } else {
       shuffleHandleInfo =
-          new ShuffleHandleInfo(
+          new SimpleShuffleHandleInfo(
               shuffleId,
               rssShuffleHandle.getPartitionToServers(),
               rssShuffleHandle.getRemoteStorage());
@@ -721,13 +739,13 @@
         readMetrics,
         RssSparkConfig.toRssConf(sparkConf),
         dataDistributionType,
-        shuffleHandleInfo.listPartitionAssignedServers());
+        shuffleHandleInfo.getAllPartitionServersForReader());
   }
 
   private Map<ShuffleServerInfo, Set<Integer>> getPartitionDataServers(
       ShuffleHandleInfo shuffleHandleInfo, int startPartition, int endPartition) {
     Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
-        shuffleHandleInfo.listPartitionAssignedServers();
+        shuffleHandleInfo.getAllPartitionServersForReader();
     Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
         allPartitionToServers.entrySet().stream()
             .filter(x -> x.getKey() >= startPartition && x.getKey() < endPartition)
@@ -1106,7 +1124,7 @@
 
   @Override
   public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
-    return shuffleIdToShuffleHandleInfo.get(shuffleId);
+    return shuffleHandleInfoManager.get(shuffleId);
   }
 
   private ShuffleManagerClient createShuffleManagerClient(String host, int port) {
@@ -1122,8 +1140,8 @@
    * @param shuffleId shuffleId
    * @return ShuffleHandleInfo
    */
-  private synchronized ShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
-    ShuffleHandleInfo shuffleHandleInfo;
+  private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
+    MutableShuffleHandleInfo shuffleHandleInfo;
     RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
     String driver = rssConf.getString("driver.host", "");
     int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
@@ -1135,7 +1153,8 @@
     RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
         shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
     shuffleHandleInfo =
-        ShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
+        MutableShuffleHandleInfo.fromProto(
+            rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
     return shuffleHandleInfo;
   }
 
@@ -1185,9 +1204,9 @@
         LOG.error("Clear MapoutTracker Meta failed!");
         throw new RssException("Clear MapoutTracker Meta failed!", e);
       }
-      ShuffleHandleInfo handleInfo =
-          new ShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
-      shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
+      MutableShuffleHandleInfo handleInfo =
+          new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
+      shuffleHandleInfoManager.register(shuffleId, handleInfo);
       serverAssignedInfos.put(stageIdAndAttempt, true);
       return true;
     } else {
@@ -1199,63 +1218,126 @@
     }
   }
 
-  // this is only valid on driver side that exposed to being invoked by grpc server
+  /** this is only valid on driver side that exposed to being invoked by grpc server */
   @Override
-  public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
-      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
-    ShuffleHandleInfo handleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
+  public MutableShuffleHandleInfo reassignOnBlockSendFailure(
+      int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) {
+    long startTime = System.currentTimeMillis();
+    MutableShuffleHandleInfo handleInfo =
+        (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId);
     synchronized (handleInfo) {
-      // find out whether this server has been marked faulty in this shuffle
-      // if it has been reassigned, directly return the replacement server.
-      // otherwise, it should request new servers to reassign
-      Set<ShuffleServerInfo> replacements =
-          handleInfo.getExistingReplacements(faultyShuffleServerId);
-      if (replacements == null) {
-        replacements = requestServersForTask(shuffleId, partitionIds, faultyShuffleServerId);
+      // If the reassignment servers for one partition exceeds the max reassign server num,
+      // it should fast fail.
+      handleInfo.checkPartitionReassignServerNum(
+          partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
+
+      Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new HashMap<>();
+      // receivingFailureServer -> partitionId -> replacementServerIds. For logging
+      Map<String, Map<Integer, Set<String>>> reassignResult = new HashMap<>();
+
+      for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
+          partitionToFailureServers.entrySet()) {
+        int partitionId = entry.getKey();
+        for (ReceivingFailureServer receivingFailureServer : entry.getValue()) {
+          StatusCode code = receivingFailureServer.getStatusCode();
+          String serverId = receivingFailureServer.getServerId();
+
+          boolean serverHasReplaced = false;
+          Set<ShuffleServerInfo> replacements = handleInfo.getReplacements(serverId);
+          if (CollectionUtils.isEmpty(replacements)) {
+            final int requiredServerNum = 1;
+            Set<String> excludedServers = new HashSet<>(handleInfo.listExcludedServers());
+            excludedServers.add(serverId);
+            replacements =
+                reassignServerForTask(
+                    shuffleId, Sets.newHashSet(partitionId), excludedServers, requiredServerNum);
+          } else {
+            serverHasReplaced = true;
+          }
+
+          Set<ShuffleServerInfo> updatedReassignServers =
+              handleInfo.updateAssignment(partitionId, serverId, replacements);
+
+          reassignResult
+              .computeIfAbsent(serverId, x -> new HashMap<>())
+              .computeIfAbsent(partitionId, x -> new HashSet<>())
+              .addAll(
+                  updatedReassignServers.stream().map(x -> x.getId()).collect(Collectors.toSet()));
+
+          if (serverHasReplaced) {
+            for (ShuffleServerInfo serverInfo : updatedReassignServers) {
+              newServerToPartitions
+                  .computeIfAbsent(serverInfo, x -> new ArrayList<>())
+                  .add(new PartitionRange(partitionId, partitionId));
+            }
+          }
+        }
       }
-      handleInfo.updateReassignment(partitionIds, faultyShuffleServerId, replacements);
+      if (!newServerToPartitions.isEmpty()) {
+        LOG.info(
+            "Register the new partition->servers assignment on reassign. {}",
+            newServerToPartitions);
+        registerShuffleServers(id.get(), shuffleId, newServerToPartitions, getRemoteStorageInfo());
+      }
+
       LOG.info(
-          "Reassign shuffle-server from {} -> {} for shuffleId: {}, partitionIds: {}",
-          faultyShuffleServerId,
-          replacements,
-          shuffleId,
-          partitionIds);
-      return replacements.stream().findFirst().get();
+          "Finished reassignOnBlockSendFailure request and cost {}(ms). Reassign result: {}",
+          System.currentTimeMillis() - startTime,
+          reassignResult);
+
+      return handleInfo;
     }
   }
 
-  private Set<ShuffleServerInfo> requestServersForTask(
-      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
-    Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
-    faultyServerIds.addAll(failuresShuffleServerIds);
-    AtomicReference<ShuffleServerInfo> replacementRef = new AtomicReference<>();
+  /**
+   * Creating the shuffleAssignmentInfo from the servers and partitionIds
+   *
+   * @param servers
+   * @param partitionIds
+   * @return
+   */
+  private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(
+      Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
+    Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new HashMap<>();
+    List<PartitionRange> partitionRanges = new ArrayList<>();
+    for (Integer partitionId : partitionIds) {
+      newPartitionToServers.put(partitionId, new ArrayList<>(servers));
+      partitionRanges.add(new PartitionRange(partitionId, partitionId));
+    }
+    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new HashMap<>();
+    for (ShuffleServerInfo server : servers) {
+      serverToPartitionRanges.put(server, partitionRanges);
+    }
+    return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
+  }
+
+  /** Request the new shuffle-servers to replace faulty server. */
+  private Set<ShuffleServerInfo> reassignServerForTask(
+      int shuffleId,
+      Set<Integer> partitionIds,
+      Set<String> excludedServers,
+      int requiredServerNum) {
+    AtomicReference<Set<ShuffleServerInfo>> replacementsRef =
+        new AtomicReference<>(new HashSet<>());
     requestShuffleAssignment(
         shuffleId,
+        requiredServerNum,
         1,
+        requiredServerNum,
         1,
-        1,
-        1,
-        faultyServerIds,
+        excludedServers,
         shuffleAssignmentsInfo -> {
           if (shuffleAssignmentsInfo == null) {
             return null;
           }
-          Optional<List<ShuffleServerInfo>> replacementOpt =
-              shuffleAssignmentsInfo.getPartitionToServers().values().stream().findFirst();
-          ShuffleServerInfo replacement = replacementOpt.get().get(0);
-          replacementRef.set(replacement);
-
-          Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new HashMap<>();
-          List<PartitionRange> partitionRanges = new ArrayList<>();
-          for (Integer partitionId : partitionIds) {
-            newPartitionToServers.put(partitionId, Arrays.asList(replacement));
-            partitionRanges.add(new PartitionRange(partitionId, partitionId));
-          }
-          Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new HashMap<>();
-          serverToPartitionRanges.put(replacement, partitionRanges);
-          return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
+          Set<ShuffleServerInfo> replacements =
+              shuffleAssignmentsInfo.getPartitionToServers().values().stream()
+                  .flatMap(x -> x.stream())
+                  .collect(Collectors.toSet());
+          replacementsRef.set(replacements);
+          return createShuffleAssignmentsInfo(replacements, partitionIds);
         });
-    return Sets.newHashSet(replacementRef.get());
+    return replacementsRef.get();
   }
 
   private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
@@ -1285,6 +1367,7 @@
                     assignmentShuffleServerNumber,
                     estimateTaskConcurrency,
                     faultyServerIds);
+            LOG.info("Finished reassign");
             if (reassignmentHandler != null) {
               response = reassignmentHandler.apply(response);
             }
@@ -1315,4 +1398,14 @@
   public void setDataPusher(DataPusher dataPusher) {
     this.dataPusher = dataPusher;
   }
+
+  @VisibleForTesting
+  public Map<String, Set<Long>> getTaskToSuccessBlockIds() {
+    return taskToSuccessBlockIds;
+  }
+
+  @VisibleForTesting
+  public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker() {
+    return taskToFailedBlockSendTracker;
+  }
 }
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 0283b84..06c8772 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -21,6 +21,8 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -58,8 +60,9 @@
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.RssSparkShuffleUtils;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.storage.BlockManagerId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -69,13 +72,14 @@
 import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.impl.TrackingBlockStatus;
-import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest;
+import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
 import org.apache.uniffle.client.request.RssReassignServersRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
-import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse;
+import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReassignServersReponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
 import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
@@ -84,9 +88,10 @@
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
-import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
+
 public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
 
   private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class);
@@ -125,9 +130,11 @@
 
   private final BlockingQueue<Object> finishEventQueue = new LinkedBlockingQueue<>();
 
-  // shuffleServerId -> failoverShuffleServer
-  private final Map<String, ShuffleServerInfo> replacementShuffleServers =
-      JavaUtils.newConcurrentMap();
+  // Will be updated when the reassignment is triggered.
+  private TaskAttemptAssignment taskAttemptAssignment;
+
+  private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND =
+      Sets.newHashSet(StatusCode.NO_REGISTER);
 
   // Only for tests
   @VisibleForTesting
@@ -158,6 +165,7 @@
         shuffleHandleInfo,
         context);
     this.bufferManager = bufferManager;
+    this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, shuffleHandleInfo);
   }
 
   private RssShuffleWriter(
@@ -189,7 +197,7 @@
     this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.serverToPartitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
-    this.shuffleServersForData = shuffleHandleInfo.listAssignedServers();
+    this.shuffleServersForData = shuffleHandleInfo.getServers();
     this.partitionLengths = new long[partitioner.numPartitions()];
     Arrays.fill(partitionLengths, 0);
     this.isMemoryShuffleEnabled =
@@ -202,6 +210,8 @@
             RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
                 + RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(),
             RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
+    this.blockFailSentRetryMaxTimes =
+        RssSparkConfig.toRssConf(sparkConf).get(RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
   }
 
   public RssShuffleWriter(
@@ -238,12 +248,18 @@
             taskAttemptId,
             bufferOptions,
             rssHandle.getDependency().serializer(),
-            shuffleHandleInfo.getPartitionToServers(),
             context.taskMemoryManager(),
             shuffleWriteMetrics,
             RssSparkConfig.toRssConf(sparkConf),
-            this::processShuffleBlockInfos);
+            this::processShuffleBlockInfos,
+            this::getPartitionAssignedServers);
     this.bufferManager = bufferManager;
+    this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, shuffleHandleInfo);
+  }
+
+  @VisibleForTesting
+  protected List<ShuffleServerInfo> getPartitionAssignedServers(int partitionId) {
+    return this.taskAttemptAssignment.retrieve(partitionId);
   }
 
   private boolean isMemoryShuffleEnabled(String storageType) {
@@ -481,8 +497,11 @@
     // to check whether the blocks resent exceed the max resend count.
     for (Long blockId : failedBlockIds) {
       List<TrackingBlockStatus> failedBlockStatus = failedTracker.getFailedBlockStatus(blockId);
-      int retryIndex = failedBlockStatus.get(0).getShuffleBlockInfo().getRetryCnt();
-      // todo: support retry times by config
+      int retryIndex =
+          failedBlockStatus.stream()
+              .map(x -> x.getShuffleBlockInfo().getRetryCnt())
+              .max(Comparator.comparing(Integer::valueOf))
+              .get();
       if (retryIndex >= blockFailSentRetryMaxTimes) {
         LOG.error(
             "Partial blocks for taskId: [{}] retry exceeding the max retry times: [{}]. Fast fail! faulty server list: {}",
@@ -495,6 +514,19 @@
         break;
       }
 
+      for (TrackingBlockStatus status : failedBlockStatus) {
+        StatusCode code = status.getStatusCode();
+        if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
+          LOG.error(
+              "Partial blocks for taskId: [{}] failed on the illegal status code: [{}] without resend on server: {}",
+              taskId,
+              code,
+              status.getShuffleServerInfo());
+          isFastFail = true;
+          break;
+        }
+      }
+
       // todo: if setting multi replica and another replica is succeed to send, no need to resend
       resendCandidates.addAll(failedBlockStatus);
     }
@@ -513,47 +545,98 @@
           "Errors on resending the blocks data to the remote shuffle-server.");
     }
 
-    resendFailedBlocks(resendCandidates);
+    reassignAndResendBlocks(resendCandidates);
   }
 
-  private void resendFailedBlocks(Set<TrackingBlockStatus> failedBlockStatusSet) {
-    List<ShuffleBlockInfo> reassignBlocks = Lists.newArrayList();
-    Map<ShuffleServerInfo, List<TrackingBlockStatus>> faultyServerToPartitions =
-        failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()));
-
-    for (Map.Entry<ShuffleServerInfo, List<TrackingBlockStatus>> entry :
-        faultyServerToPartitions.entrySet()) {
-      Set<Integer> partitionIds =
-          entry.getValue().stream()
-              .map(x -> x.getShuffleBlockInfo().getPartitionId())
-              .collect(Collectors.toSet());
-      ShuffleServerInfo replacement = replacementShuffleServers.get(entry.getKey().getId());
-      if (replacement == null) {
-        // todo: merge multiple requests into one.
-        replacement = reassignFaultyShuffleServer(partitionIds, entry.getKey().getId());
-        replacementShuffleServers.put(entry.getKey().getId(), replacement);
+  private void doReassignOnBlockSendFailure(
+      Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers) {
+    LOG.info(
+        "Initiate reassignOnBlockSendFailure. failure partition servers: {}",
+        failurePartitionToServers);
+    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+    String driver = rssConf.getString("driver.host", "");
+    int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+    try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) {
+      RssReassignOnBlockSendFailureRequest request =
+          new RssReassignOnBlockSendFailureRequest(shuffleId, failurePartitionToServers);
+      RssReassignOnBlockSendFailureResponse response =
+          shuffleManagerClient.reassignOnBlockSendFailure(request);
+      if (response.getStatusCode() != StatusCode.SUCCESS) {
+        String msg =
+            String.format(
+                "Reassign failed. statusCode: %s, msg: %s",
+                response.getStatusCode(), response.getMessage());
+        throw new RssException(msg);
       }
+      MutableShuffleHandleInfo handle = MutableShuffleHandleInfo.fromProto(response.getHandle());
+      taskAttemptAssignment.update(handle);
+    } catch (Exception e) {
+      throw new RssException(
+          "Errors on reassign on block send failure. failure partition->servers : "
+              + failurePartitionToServers,
+          e);
+    }
+  }
 
-      for (TrackingBlockStatus blockStatus : failedBlockStatusSet) {
-        // clear the previous retry state of block
-        ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
-        clearFailedBlockState(block);
+  private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
+    List<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
+    Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks =
+        blocks.stream()
+            .collect(Collectors.groupingBy(d -> d.getShuffleBlockInfo().getPartitionId()));
 
-        final ShuffleBlockInfo newBlock = block;
-        newBlock.incrRetryCnt();
-        newBlock.reassignShuffleServers(Arrays.asList(replacement));
-
-        reassignBlocks.add(newBlock);
+    Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new HashMap<>();
+    for (Map.Entry<Integer, List<TrackingBlockStatus>> entry : partitionedFailedBlocks.entrySet()) {
+      int partitionId = entry.getKey();
+      List<TrackingBlockStatus> partitionBlocks = entry.getValue();
+      Map<ShuffleServerInfo, TrackingBlockStatus> serverBlocks =
+          partitionBlocks.stream()
+              .collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()))
+              .entrySet()
+              .stream()
+              .collect(
+                  Collectors.toMap(
+                      Map.Entry::getKey, x -> x.getValue().stream().findFirst().get()));
+      for (Map.Entry<ShuffleServerInfo, TrackingBlockStatus> blockStatusEntry :
+          serverBlocks.entrySet()) {
+        String serverId = blockStatusEntry.getKey().getId();
+        // avoid duplicate reassign for the same failure server.
+        String latestServerId = getPartitionAssignedServers(partitionId).get(0).getId();
+        if (!serverId.equals(latestServerId)) {
+          continue;
+        }
+        StatusCode code = blockStatusEntry.getValue().getStatusCode();
+        failurePartitionToServers
+            .computeIfAbsent(partitionId, x -> new ArrayList<>())
+            .add(new ReceivingFailureServer(serverId, code));
       }
     }
 
-    processShuffleBlockInfos(reassignBlocks);
+    if (!failurePartitionToServers.isEmpty()) {
+      doReassignOnBlockSendFailure(failurePartitionToServers);
+    }
+
+    for (TrackingBlockStatus blockStatus : blocks) {
+      ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+      // todo: getting the replacement should support multi replica.
+      ShuffleServerInfo replacement = getPartitionAssignedServers(block.getPartitionId()).get(0);
+      if (blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
+        throw new RssException(
+            "No available replacement server for: " + blockStatus.getShuffleServerInfo().getId());
+      }
+      // clear the previous retry state of block
+      clearFailedBlockState(block);
+      final ShuffleBlockInfo newBlock = block;
+      newBlock.incrRetryCnt();
+      newBlock.reassignShuffleServers(Arrays.asList(replacement));
+      resendCandidates.add(newBlock);
+    }
+
+    processShuffleBlockInfos(resendCandidates);
   }
 
   private void clearFailedBlockState(ShuffleBlockInfo block) {
     shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
     block.getShuffleServerInfos().stream()
-        .filter(s -> replacementShuffleServers.containsKey(s.getId()))
         .forEach(
             s ->
                 serverToPartitionToBlockIds
@@ -561,30 +644,7 @@
                     .get(block.getPartitionId())
                     .remove(block.getBlockId()));
     partitionLengths[block.getPartitionId()] -= block.getLength();
-  }
-
-  private ShuffleServerInfo reassignFaultyShuffleServer(
-      Set<Integer> partitionIds, String faultyServerId) {
-    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-    String driver = rssConf.getString("driver.host", "");
-    int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-    try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) {
-      RssReassignFaultyShuffleServerRequest request =
-          new RssReassignFaultyShuffleServerRequest(shuffleId, partitionIds, faultyServerId);
-      RssReassignFaultyShuffleServerResponse response =
-          shuffleManagerClient.reassignFaultyShuffleServer(request);
-      if (response.getStatusCode() != StatusCode.SUCCESS) {
-        throw new RssException(
-            "reassign server response with statusCode[" + response.getStatusCode() + "]");
-      }
-      if (response.getShuffleServer() == null) {
-        throw new RssException("empty newer reassignment server!");
-      }
-      return response.getShuffleServer();
-    } catch (Exception e) {
-      throw new RssException(
-          "Failed to reassign a new server for faultyServerId server[" + faultyServerId + "]", e);
-    }
+    blockIds.remove(block.getBlockId());
   }
 
   @VisibleForTesting
@@ -760,11 +820,6 @@
   }
 
   @VisibleForTesting
-  protected void addReassignmentShuffleServer(String shuffleId, ShuffleServerInfo replacement) {
-    replacementShuffleServers.put(shuffleId, replacement);
-  }
-
-  @VisibleForTesting
   protected void setTaskId(String taskId) {
     this.taskId = taskId;
   }
@@ -773,4 +828,13 @@
   protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>> getServerToPartitionToBlockIds() {
     return serverToPartitionToBlockIds;
   }
+
+  @VisibleForTesting
+  protected RssShuffleManager getShuffleManager() {
+    return shuffleManager;
+  }
+
+  public TaskAttemptAssignment getTaskAttemptAssignment() {
+    return taskAttemptAssignment;
+  }
 }
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 2f930a8..6d81e2e 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -49,13 +49,15 @@
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.TestUtils;
+import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.rpc.StatusCode;
@@ -86,8 +88,259 @@
     return data;
   }
 
+  private MutableShuffleHandleInfo createMutableShuffleHandle() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
+    List<ShuffleServerInfo> ssi12 =
+        Arrays.asList(
+            new ShuffleServerInfo("id1", "0.0.0.1", 100),
+            new ShuffleServerInfo("id2", "0.0.0.2", 100));
+    partitionToServers.put(0, ssi12);
+    List<ShuffleServerInfo> ssi34 =
+        Arrays.asList(
+            new ShuffleServerInfo("id3", "0.0.0.3", 100),
+            new ShuffleServerInfo("id4", "0.0.0.4", 100));
+    partitionToServers.put(1, ssi34);
+    List<ShuffleServerInfo> ssi56 =
+        Arrays.asList(
+            new ShuffleServerInfo("id5", "0.0.0.5", 100),
+            new ShuffleServerInfo("id6", "0.0.0.6", 100));
+    partitionToServers.put(2, ssi56);
+
+    MutableShuffleHandleInfo shuffleHandleInfo =
+        new MutableShuffleHandleInfo(0, partitionToServers, RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+    return shuffleHandleInfo;
+  }
+
+  private RssShuffleWriter createMockWriter(MutableShuffleHandleInfo shuffleHandle, String taskId) {
+    SparkConf conf = new SparkConf();
+    conf.setAppName("testApp")
+        .setMaster("local[2]")
+        .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+        .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
+
+    Map<String, Set<Long>> successBlockIds = JavaUtils.newConcurrentMap();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
+    taskToFailedBlockSendTracker.put(taskId, new FailedBlockSendTracker());
+
+    FakedDataPusher dataPusher = null;
+    final RssShuffleManager manager =
+        TestUtils.createShuffleManager(
+            conf, false, dataPusher, successBlockIds, taskToFailedBlockSendTracker);
+    Serializer kryoSerializer = new KryoSerializer(conf);
+    Partitioner mockPartitioner = mock(Partitioner.class);
+    final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
+    ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
+    RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
+    when(mockHandle.getDependency()).thenReturn(mockDependency);
+    when(mockDependency.serializer()).thenReturn(kryoSerializer);
+    when(mockDependency.partitioner()).thenReturn(mockPartitioner);
+    when(mockPartitioner.numPartitions()).thenReturn(3);
+
+    when(mockPartitioner.getPartition("testKey1")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey2")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey4")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey5")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey3")).thenReturn(2);
+    when(mockPartitioner.getPartition("testKey7")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey8")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey9")).thenReturn(2);
+    when(mockPartitioner.getPartition("testKey6")).thenReturn(2);
+
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
+    WriteBufferManager bufferManager =
+        new WriteBufferManager(
+            0,
+            0,
+            bufferOptions,
+            kryoSerializer,
+            shuffleHandle.getAvailablePartitionServersForWriter(),
+            mockTaskMemoryManager,
+            shuffleWriteMetrics,
+            RssSparkConfig.toRssConf(conf));
+    bufferManager.setTaskId(taskId);
+
+    WriteBufferManager bufferManagerSpy = spy(bufferManager);
+    TaskContext contextMock = mock(TaskContext.class);
+    RssShuffleWriter<String, String, String> rssShuffleWriter =
+        new RssShuffleWriter<>(
+            "appId",
+            0,
+            taskId,
+            1L,
+            bufferManagerSpy,
+            shuffleWriteMetrics,
+            manager,
+            conf,
+            mockShuffleWriteClient,
+            mockHandle,
+            shuffleHandle,
+            contextMock);
+    rssShuffleWriter.enableBlockFailSentRetry();
+    doReturn(100000L).when(bufferManagerSpy).acquireMemory(anyLong());
+
+    RssShuffleWriter<String, String, String> rssShuffleWriterSpy = spy(rssShuffleWriter);
+    doNothing().when(rssShuffleWriterSpy).sendCommit();
+
+    return rssShuffleWriterSpy;
+  }
+
+  private void updateShuffleHandleAssignment(
+      MutableShuffleHandleInfo handle,
+      Set<Integer> partitionIds,
+      String receivingFailureServerId,
+      Set<ShuffleServerInfo> replacements) {
+    for (int partitionId : partitionIds) {
+      handle.updateAssignment(partitionId, receivingFailureServerId, replacements);
+    }
+  }
+
+  /** Test the reassign multi times for one partitionId. */
   @Test
-  public void blockFailureResendTest() throws Exception {
+  public void reassignMultiTimesForOnePartitionIdTest() {
+    String taskId = "taskId";
+    MutableShuffleHandleInfo shuffleHandle = createMutableShuffleHandle();
+    RssShuffleWriter writer = createMockWriter(shuffleHandle, taskId);
+    writer.setBlockFailSentRetryMaxTimes(10);
+
+    // Make the id1 + id10 + id11 broken, and then finally, it will use the id12 successfully
+    AtomicInteger failureCnt = new AtomicInteger();
+    RssShuffleManager shuffleManager = writer.getShuffleManager();
+    Map<String, Set<Long>> taskToSuccessBlockIds = shuffleManager.getTaskToSuccessBlockIds();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
+        shuffleManager.getTaskToFailedBlockSendTracker();
+    TaskAttemptAssignment taskAssignment = writer.getTaskAttemptAssignment();
+    FakedDataPusher pusher =
+        new FakedDataPusher(
+            addBlockEvent -> {
+              List<ShuffleBlockInfo> blocks = addBlockEvent.getShuffleDataInfoList();
+              for (ShuffleBlockInfo block : blocks) {
+                ShuffleServerInfo server = block.getShuffleServerInfos().get(0);
+                String serverId = server.getId();
+                if (Arrays.asList("id1", "id10", "id11").contains(serverId)) {
+                  taskToFailedBlockSendTracker
+                      .computeIfAbsent(taskId, x -> new FailedBlockSendTracker())
+                      .add(block, server, StatusCode.NO_BUFFER);
+                  failureCnt.incrementAndGet();
+
+                  // refresh the assignment to simulate the reassign rpc.
+                  if (serverId.equals("id1")) {
+                    ShuffleServerInfo replacement1 = new ShuffleServerInfo("id10", "0.0.0.10", 100);
+                    updateShuffleHandleAssignment(
+                        shuffleHandle,
+                        Sets.newHashSet(0, 1, 2),
+                        "id1",
+                        Sets.newHashSet(replacement1));
+                    taskAssignment.update(shuffleHandle);
+                  } else if (serverId.equals("id10")) {
+                    ShuffleServerInfo replacement2 = new ShuffleServerInfo("id11", "0.0.0.10", 100);
+                    updateShuffleHandleAssignment(
+                        shuffleHandle,
+                        Sets.newHashSet(0, 1, 2),
+                        "id10",
+                        Sets.newHashSet(replacement2));
+                    taskAssignment.update(shuffleHandle);
+                  } else if (serverId.equals("id11")) {
+                    ShuffleServerInfo replacement3 = new ShuffleServerInfo("id12", "0.0.0.10", 100);
+                    updateShuffleHandleAssignment(
+                        shuffleHandle,
+                        Sets.newHashSet(0, 1, 2),
+                        "id11",
+                        Sets.newHashSet(replacement3));
+                    taskAssignment.update(shuffleHandle);
+                  }
+
+                } else {
+                  taskToSuccessBlockIds
+                      .computeIfAbsent(taskId, x -> new HashSet<>())
+                      .add(block.getBlockId());
+                }
+              }
+              return new CompletableFuture<>();
+            });
+    shuffleManager.setDataPusher(pusher);
+
+    writer
+        .getBufferManager()
+        .setPartitionAssignmentRetrieveFunc(
+            partitionId -> writer.getPartitionAssignedServers(partitionId));
+
+    // case1: the reassignment will refresh the following plan. So the failure will only occur one
+    // time.
+    MutableList<Product2<String, String>> mockedData = createMockRecords();
+    writer.write(mockedData.iterator());
+
+    Awaitility.await()
+        .timeout(Duration.ofSeconds(5))
+        .until(() -> taskToSuccessBlockIds.get(taskId).size() == mockedData.size());
+    assertEquals(3, failureCnt.get());
+  }
+
+  /** Once the reassignment occurs, the following AddBlockEvents will use the latest assignment. */
+  @Test
+  public void refreshAssignmentTest() {
+    String taskId = "taskId";
+    MutableShuffleHandleInfo shuffleHandle = createMutableShuffleHandle();
+    RssShuffleWriter writer = createMockWriter(shuffleHandle, taskId);
+
+    ShuffleServerInfo replacement = new ShuffleServerInfo("id10", "0.0.0.10", 100);
+    updateShuffleHandleAssignment(
+        shuffleHandle, Sets.newHashSet(0, 1, 2), "id1", Sets.newHashSet(replacement));
+
+    AtomicInteger failureCnt = new AtomicInteger();
+    RssShuffleManager shuffleManager = writer.getShuffleManager();
+    Map<String, Set<Long>> taskToSuccessBlockIds = shuffleManager.getTaskToSuccessBlockIds();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
+        shuffleManager.getTaskToFailedBlockSendTracker();
+    FakedDataPusher pusher =
+        new FakedDataPusher(
+            addBlockEvent -> {
+              List<ShuffleBlockInfo> blocks = addBlockEvent.getShuffleDataInfoList();
+              for (ShuffleBlockInfo block : blocks) {
+                ShuffleServerInfo server = block.getShuffleServerInfos().get(0);
+                if (server.getId().equals("id1")) {
+                  taskToFailedBlockSendTracker
+                      .computeIfAbsent(taskId, x -> new FailedBlockSendTracker())
+                      .add(block, server, StatusCode.NO_BUFFER);
+                  failureCnt.incrementAndGet();
+                  // refresh the assignment to simulate the reassign rpc.
+                  writer.getTaskAttemptAssignment().update(shuffleHandle);
+                } else {
+                  taskToSuccessBlockIds
+                      .computeIfAbsent(taskId, x -> new HashSet<>())
+                      .add(block.getBlockId());
+                }
+              }
+              return new CompletableFuture<>();
+            });
+    shuffleManager.setDataPusher(pusher);
+
+    writer
+        .getBufferManager()
+        .setPartitionAssignmentRetrieveFunc(
+            partitionId -> writer.getPartitionAssignedServers(partitionId));
+
+    // case1: the reassignment will refresh the following plan. So the failure will only occur one
+    // time.
+    MutableList<Product2<String, String>> mockedData = createMockRecords();
+    writer.write(mockedData.iterator());
+
+    Awaitility.await()
+        .timeout(Duration.ofSeconds(5))
+        .until(() -> taskToSuccessBlockIds.get(taskId).size() == mockedData.size());
+    assertEquals(1, failureCnt.get());
+  }
+
+  @Test
+  public void blockFailureResendTest() {
     SparkConf conf = new SparkConf();
     conf.setAppName("testApp")
         .setMaster("local[2]")
@@ -157,6 +410,7 @@
             new ShuffleServerInfo("id5", "0.0.0.5", 100),
             new ShuffleServerInfo("id6", "0.0.0.6", 100));
     partitionToServers.put(2, ssi56);
+
     when(mockPartitioner.getPartition("testKey1")).thenReturn(0);
     when(mockPartitioner.getPartition("testKey2")).thenReturn(1);
     when(mockPartitioner.getPartition("testKey4")).thenReturn(0);
@@ -185,7 +439,8 @@
 
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    MutableShuffleHandleInfo shuffleHandleInfo =
+        new MutableShuffleHandleInfo(0, partitionToServers, RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
             "appId",
@@ -198,12 +453,17 @@
             conf,
             mockShuffleWriteClient,
             mockHandle,
-            mockShuffleHandleInfo,
+            shuffleHandleInfo,
             contextMock);
     rssShuffleWriter.enableBlockFailSentRetry();
     doReturn(100000L).when(bufferManagerSpy).acquireMemory(anyLong());
+
     ShuffleServerInfo replacement = new ShuffleServerInfo("id10", "0.0.0.10", 100);
-    rssShuffleWriter.addReassignmentShuffleServer("id1", replacement);
+    shuffleHandleInfo.updateAssignment(0, "id1", Sets.newHashSet(replacement));
+    shuffleHandleInfo.updateAssignment(1, "id1", Sets.newHashSet(replacement));
+    shuffleHandleInfo.updateAssignment(2, "id1", Sets.newHashSet(replacement));
+
+    rssShuffleWriter.getTaskAttemptAssignment().update(shuffleHandleInfo);
 
     RssShuffleWriter<String, String, String> rssShuffleWriterSpy = spy(rssShuffleWriter);
     doNothing().when(rssShuffleWriterSpy).sendCommit();
@@ -237,6 +497,7 @@
     rssShuffleWriter.setTaskId("taskId2");
     rssShuffleWriter.getBufferManager().setTaskId("taskId2");
     taskToFailedBlockSendTracker.put("taskId2", new FailedBlockSendTracker());
+    AtomicInteger rejectCnt = new AtomicInteger(0);
     FakedDataPusher alwaysFailedDataPusher =
         new FakedDataPusher(
             event -> {
@@ -245,9 +506,10 @@
               for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
                 boolean isSuccessful = true;
                 ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0);
-                if (shuffleServer.getId().equals("id1")) {
+                if (shuffleServer.getId().equals("id1") && rejectCnt.get() <= 3) {
                   tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
                   isSuccessful = false;
+                  rejectCnt.incrementAndGet();
                 } else {
                   successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet());
                   successBlockIds.get(event.getTaskId()).add(block.getBlockId());
@@ -315,7 +577,7 @@
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
 
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
             "appId",
@@ -461,7 +723,7 @@
     when(mockDependency.partitioner()).thenReturn(mockPartitioner);
     when(mockPartitioner.numPartitions()).thenReturn(1);
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
 
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
@@ -584,7 +846,7 @@
 
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
     RssShuffleWriter<String, String, String> rssShuffleWriter =
         new RssShuffleWriter<>(
             "appId",
@@ -695,7 +957,7 @@
     RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
     TaskContext contextMock = mock(TaskContext.class);
-    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
 
     List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 31);
diff --git a/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java b/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java
index 02d5b62..2b22f79 100644
--- a/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java
+++ b/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.client;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -42,6 +43,27 @@
     this.inventory = inventory;
   }
 
+  // for the DefaultShuffleHandleInfo
+  public PartitionDataReplicaRequirementTracking(
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers, int shuffleId) {
+    this.shuffleId = shuffleId;
+    this.inventory = toPartitionReplicaServers(partitionToServers);
+  }
+
+  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaServers(
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = new HashMap<>();
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> entry : partitionToServers.entrySet()) {
+      int partitionId = entry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicas =
+          inventory.computeIfAbsent(partitionId, x -> new HashMap<>());
+      for (int i = 0; i < entry.getValue().size(); i++) {
+        replicas.computeIfAbsent(i, x -> new ArrayList<>()).add(entry.getValue().get(i));
+      }
+    }
+    return inventory;
+  }
+
   public boolean isSatisfied(int partitionId, int minReplica) {
     // replica index -> successful count
     Map<Integer, Integer> succeedReplicas = succeedList.get(partitionId);
diff --git a/common/src/main/java/org/apache/uniffle/common/ReceivingFailureServer.java b/common/src/main/java/org/apache/uniffle/common/ReceivingFailureServer.java
new file mode 100644
index 0000000..4d3b0c9
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/ReceivingFailureServer.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.proto.RssProtos;
+
+public class ReceivingFailureServer {
+  private String serverId;
+  private StatusCode statusCode;
+
+  private ReceivingFailureServer() {
+    // ignore
+  }
+
+  public ReceivingFailureServer(String serverId, StatusCode statusCode) {
+    this.serverId = serverId;
+    this.statusCode = statusCode;
+  }
+
+  public String getServerId() {
+    return serverId;
+  }
+
+  public StatusCode getStatusCode() {
+    return statusCode;
+  }
+
+  @Override
+  public String toString() {
+    return "ReceivingFailureServer{"
+        + "serverId='"
+        + serverId
+        + '\''
+        + ", statusCode="
+        + statusCode
+        + '}';
+  }
+
+  public static List<ReceivingFailureServer> fromProto(RssProtos.ReceivingFailureServers proto) {
+    List<ReceivingFailureServer> servers = new ArrayList<>();
+    for (RssProtos.ReceivingFailureServer protoServer : proto.getServerList()) {
+      ReceivingFailureServer server = new ReceivingFailureServer();
+      server.serverId = protoServer.getServerId();
+      server.statusCode = StatusCode.fromProto(protoServer.getStatusCode());
+      servers.add(server);
+    }
+    return servers;
+  }
+
+  public static RssProtos.ReceivingFailureServers toProto(List<ReceivingFailureServer> servers) {
+    List<RssProtos.ReceivingFailureServer> protoServers = new ArrayList<>();
+    for (ReceivingFailureServer server : servers) {
+      protoServers.add(
+          RssProtos.ReceivingFailureServer.newBuilder()
+              .setServerId(server.serverId)
+              .setStatusCode(server.statusCode.toProto())
+              .build());
+    }
+    return RssProtos.ReceivingFailureServers.newBuilder().addAllServer(protoServers).build();
+  }
+
+  public static RssProtos.ReceivingFailureServer toProto(ReceivingFailureServer server) {
+    return RssProtos.ReceivingFailureServer.newBuilder()
+        .setServerId(server.serverId)
+        .setStatusCode(server.statusCode.toProto())
+        .build();
+  }
+}
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
index 3cf1736..0401b2c 100644
--- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RssShuffleManagerTest.java
@@ -35,7 +35,7 @@
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.RssSparkShuffleUtils;
-import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.sql.SparkSession;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeAll;
@@ -220,7 +220,7 @@
               .build();
       ShuffleHandleInfo handle = shuffleManager.getShuffleHandleInfoByShuffleId(0);
       Set<ShuffleServerInfo> servers =
-          handle.getPartitionToServers().values().stream()
+          handle.getAvailablePartitionServersForWriter().values().stream()
               .flatMap(Collection::stream)
               .collect(Collectors.toSet());
 
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
similarity index 72%
rename from integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignTest.java
rename to integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
index 562320d..0a1bd15 100644
--- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignTest.java
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
@@ -25,6 +25,8 @@
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.io.TempDir;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.rpc.ServerType;
 import org.apache.uniffle.coordinator.CoordinatorConf;
@@ -34,16 +36,20 @@
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
+import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_RETRY_MAX;
 import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED;
-import static org.junit.jupiter.api.Assertions.assertEquals;
 
-/** This class is to test the mechanism of partition block data reassignment. */
-public class PartitionBlockDataReassignTest extends SparkSQLTest {
+/** This class is to basic test the mechanism of partition block data reassignment */
+public class PartitionBlockDataReassignBasicTest extends SparkSQLTest {
+  private static final Logger LOGGER =
+      LoggerFactory.getLogger(PartitionBlockDataReassignBasicTest.class);
 
-  private static String basePath;
+  protected static String basePath;
 
   @BeforeAll
   public static void setupServers(@TempDir File tmpDir) throws Exception {
+    LOGGER.info("Setup servers");
+
     // for coordinator
     CoordinatorConf coordinatorConf = getCoordinatorConf();
     coordinatorConf.setLong("rss.coordinator.app.expired", 5000);
@@ -71,13 +77,19 @@
 
     startServers();
 
-    // simulate one server without enough buffer
-    ShuffleServer faultyShuffleServer = grpcShuffleServers.get(0);
-    ShuffleBufferManager bufferManager = faultyShuffleServer.getShuffleBufferManager();
+    // simulate one server without enough buffer for grpc
+    ShuffleServer grpcServer = grpcShuffleServers.get(0);
+    ShuffleBufferManager bufferManager = grpcServer.getShuffleBufferManager();
+    bufferManager.setUsedMemory(bufferManager.getCapacity() + 100);
+
+    // simulate one server without enough buffer for netty server
+    ShuffleServer nettyServer = nettyShuffleServers.get(0);
+    bufferManager = nettyServer.getShuffleBufferManager();
     bufferManager.setUsedMemory(bufferManager.getCapacity() + 100);
   }
 
-  private static ShuffleServerConf buildShuffleServerConf(ServerType serverType) throws Exception {
+  protected static ShuffleServerConf buildShuffleServerConf(ServerType serverType)
+      throws Exception {
     ShuffleServerConf shuffleServerConf = getShuffleServerConf(serverType);
     shuffleServerConf.setLong("rss.server.heartbeat.interval", 5000);
     shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 4000);
@@ -87,18 +99,22 @@
   }
 
   @Override
-  public void updateRssStorage(SparkConf sparkConf) {
-    sparkConf.set("spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER, "1");
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set("spark.sql.shuffle.partitions", "4");
+    sparkConf.set("spark." + RSS_CLIENT_RETRY_MAX, "2");
+    sparkConf.set(
+        "spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
+        String.valueOf(grpcShuffleServers.size()));
     sparkConf.set("spark." + RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(), "true");
   }
 
   @Override
+  public void updateRssStorage(SparkConf sparkConf) {
+    // ignore
+  }
+
+  @Override
   public void checkShuffleData() throws Exception {
-    Thread.sleep(12000);
-    String[] paths = basePath.split(",");
-    for (String path : paths) {
-      File f = new File(path);
-      assertEquals(0, f.list().length);
-    }
+    // ignore
   }
 }
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignMultiTimesTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignMultiTimesTest.java
new file mode 100644
index 0000000..5ee3d7b
--- /dev/null
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignMultiTimesTest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.File;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.io.TempDir;
+
+import org.apache.uniffle.common.rpc.ServerType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.coordinator.strategy.assignment.AssignmentStrategyFactory;
+import org.apache.uniffle.server.MockedGrpcServer;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.server.buffer.ShuffleBufferManager;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
+import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
+import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_RETRY_MAX;
+import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED;
+import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_ASSIGNMENT_STRATEGY;
+
+/** This class is to test the partition reassign mechanism of multiple retries. */
+public class PartitionBlockDataReassignMultiTimesTest extends PartitionBlockDataReassignBasicTest {
+  @BeforeAll
+  public static void setupServers(@TempDir File tmpDir) throws Exception {
+    // for coordinator
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    coordinatorConf.setLong("rss.coordinator.app.expired", 5000);
+    coordinatorConf.set(
+        COORDINATOR_ASSIGNMENT_STRATEGY, AssignmentStrategyFactory.StrategyName.BASIC);
+
+    Map<String, String> dynamicConf = Maps.newHashMap();
+    dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    createCoordinatorServer(coordinatorConf);
+
+    // for shuffle-server
+    File dataDir1 = new File(tmpDir, "data1");
+    File dataDir2 = new File(tmpDir, "data2");
+    basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
+
+    // grpc server.
+    ShuffleServerConf grpcShuffleServerConf1 = buildShuffleServerConf(ServerType.GRPC);
+    createMockedShuffleServer(grpcShuffleServerConf1);
+
+    ShuffleServerConf grpcShuffleServerConf2 = buildShuffleServerConf(ServerType.GRPC);
+    createMockedShuffleServer(grpcShuffleServerConf2);
+
+    ShuffleServerConf grpcShuffleServerConf3 = buildShuffleServerConf(ServerType.GRPC);
+    createMockedShuffleServer(grpcShuffleServerConf3);
+
+    // netty server.
+    ShuffleServerConf grpcShuffleServerConf4 = buildShuffleServerConf(ServerType.GRPC_NETTY);
+    createShuffleServer(grpcShuffleServerConf4);
+
+    ShuffleServerConf grpcShuffleServerConf5 = buildShuffleServerConf(ServerType.GRPC_NETTY);
+    createShuffleServer(grpcShuffleServerConf5);
+
+    startServers();
+  }
+
+  @Override
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set("spark.sql.shuffle.partitions", "4");
+    sparkConf.set("spark." + RSS_CLIENT_RETRY_MAX, "2");
+    sparkConf.set("spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER, "1");
+    sparkConf.set("spark." + RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(), "true");
+    sparkConf.set("spark." + RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES.key(), "10");
+
+    // simulate the grpc servers has different free memory
+    // and make the assign priority seq: g1 -> g2 -> g3
+    ShuffleServer g1 = grpcShuffleServers.get(0);
+    ShuffleBufferManager bufferManager = g1.getShuffleBufferManager();
+    bufferManager.setUsedMemory(bufferManager.getCapacity() - 3000000);
+    g1.sendHeartbeat();
+
+    ShuffleServer g2 = grpcShuffleServers.get(1);
+    bufferManager = g2.getShuffleBufferManager();
+    bufferManager.setUsedMemory(bufferManager.getCapacity() - 2000000);
+    g2.sendHeartbeat();
+
+    ShuffleServer g3 = grpcShuffleServers.get(2);
+    bufferManager = g3.getShuffleBufferManager();
+    bufferManager.setUsedMemory(bufferManager.getCapacity() - 1000000);
+    g3.sendHeartbeat();
+
+    // This will make the partition of g1 reassign to g2 servers.
+    ((MockedGrpcServer) g1.getServer()).getService().enableMockRequireBufferFailWithNoBuffer();
+
+    // And then reassign to g3
+    ((MockedGrpcServer) g2.getServer()).getService().enableMockRequireBufferFailWithNoBuffer();
+  }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index 77506a7..c74843c 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -20,12 +20,12 @@
 import java.io.Closeable;
 
 import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
-import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest;
+import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
 import org.apache.uniffle.client.request.RssReassignServersRequest;
 import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
-import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse;
+import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReassignServersReponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
@@ -48,6 +48,6 @@
 
   RssReassignServersReponse reassignShuffleServers(RssReassignServersRequest req);
 
-  RssReassignFaultyShuffleServerResponse reassignFaultyShuffleServer(
-      RssReassignFaultyShuffleServerRequest request);
+  RssReassignOnBlockSendFailureResponse reassignOnBlockSendFailure(
+      RssReassignOnBlockSendFailureRequest request);
 }
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 128f26d..61e24b5 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -24,12 +24,12 @@
 
 import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
-import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest;
+import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
 import org.apache.uniffle.client.request.RssReassignServersRequest;
 import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
-import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse;
+import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReassignServersReponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
@@ -119,12 +119,12 @@
   }
 
   @Override
-  public RssReassignFaultyShuffleServerResponse reassignFaultyShuffleServer(
-      RssReassignFaultyShuffleServerRequest request) {
-    RssProtos.RssReassignFaultyShuffleServerRequest rssReassignFaultyShuffleServerRequest =
-        request.toProto();
-    RssProtos.RssReassignFaultyShuffleServerResponse response =
-        getBlockingStub().reassignFaultyShuffleServer(rssReassignFaultyShuffleServerRequest);
-    return RssReassignFaultyShuffleServerResponse.fromProto(response);
+  public RssReassignOnBlockSendFailureResponse reassignOnBlockSendFailure(
+      RssReassignOnBlockSendFailureRequest request) {
+    RssProtos.RssReassignOnBlockSendFailureRequest protoReq =
+        RssReassignOnBlockSendFailureRequest.toProto(request);
+    RssProtos.RssReassignOnBlockSendFailureResponse response =
+        getBlockingStub().reassignOnBlockSendFailure(protoReq);
+    return RssReassignOnBlockSendFailureResponse.fromProto(response);
   }
 }
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 3ab81a0..f20cd85 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -234,6 +234,7 @@
         appId, 0, Collections.emptyList(), requireSize, retryMax, retryIntervalMax);
   }
 
+  @VisibleForTesting
   public long requirePreAllocation(
       String appId,
       int shuffleId,
@@ -241,6 +242,24 @@
       int requireSize,
       int retryMax,
       long retryIntervalMax) {
+    return requirePreAllocation(
+        appId,
+        shuffleId,
+        partitionIds,
+        requireSize,
+        retryMax,
+        retryIntervalMax,
+        new AtomicReference<>(StatusCode.INTERNAL_ERROR));
+  }
+
+  public long requirePreAllocation(
+      String appId,
+      int shuffleId,
+      List<Integer> partitionIds,
+      int requireSize,
+      int retryMax,
+      long retryIntervalMax,
+      AtomicReference<StatusCode> failedStatusCodeRef) {
     RequireBufferRequest rpcRequest =
         RequireBufferRequest.newBuilder()
             .setShuffleId(shuffleId)
@@ -275,6 +294,7 @@
           && rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER_FOR_HUGE_PARTITION) {
         break;
       }
+      failedStatusCodeRef.set(StatusCode.fromCode(rpcResponse.getStatus().getNumber()));
       if (retry >= retryMax) {
         LOG.warn(
             "ShuffleServer "
@@ -492,7 +512,8 @@
                       partitionIds,
                       allocateSize,
                       request.getRetryMax() / maxRetryAttempts,
-                      request.getRetryIntervalMax());
+                      request.getRetryIntervalMax(),
+                      failedStatusCode);
               if (requireId == FAILED_REQUIRE_ID) {
                 throw new RssException(
                     String.format(
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index 0c6860a..ca06d55 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -17,9 +17,11 @@
 
 package org.apache.uniffle.client.impl.grpc;
 
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
@@ -107,6 +109,7 @@
     Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
         request.getShuffleIdToBlocks();
     boolean isSuccessful = true;
+    AtomicReference<StatusCode> failedStatusCode = new AtomicReference<>(StatusCode.INTERNAL_ERROR);
 
     for (Map.Entry<Integer, Map<Integer, List<ShuffleBlockInfo>>> stb :
         shuffleIdToBlocks.entrySet()) {
@@ -137,9 +140,12 @@
               long requireId =
                   requirePreAllocation(
                       request.getAppId(),
+                      0,
+                      Collections.emptyList(),
                       allocateSize,
                       request.getRetryMax(),
-                      request.getRetryIntervalMax());
+                      request.getRetryIntervalMax(),
+                      failedStatusCode);
               if (requireId == FAILED_REQUIRE_ID) {
                 throw new RssException(
                     String.format(
@@ -164,6 +170,7 @@
                     port);
               }
               if (rpcResponse.getStatusCode() != StatusCode.SUCCESS) {
+                failedStatusCode.set(StatusCode.fromCode(rpcResponse.getStatusCode().statusCode()));
                 String msg =
                     "Can't send shuffle data with "
                         + finalBlockNum
@@ -198,7 +205,7 @@
     if (isSuccessful) {
       response = new RssSendShuffleDataResponse(StatusCode.SUCCESS);
     } else {
-      response = new RssSendShuffleDataResponse(StatusCode.INTERNAL_ERROR);
+      response = new RssSendShuffleDataResponse(failedStatusCode.get());
     }
     return response;
   }
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignFaultyShuffleServerRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignFaultyShuffleServerRequest.java
deleted file mode 100644
index c85666f..0000000
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignFaultyShuffleServerRequest.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.uniffle.client.request;
-
-import java.util.Set;
-
-import org.apache.uniffle.proto.RssProtos;
-
-public class RssReassignFaultyShuffleServerRequest {
-
-  private int shuffleId;
-  private Set<Integer> partitionIds;
-  private String faultyShuffleServerId;
-
-  public RssReassignFaultyShuffleServerRequest(
-      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
-    this.shuffleId = shuffleId;
-    this.partitionIds = partitionIds;
-    this.faultyShuffleServerId = faultyShuffleServerId;
-  }
-
-  public int getShuffleId() {
-    return shuffleId;
-  }
-
-  public Set<Integer> getPartitionIds() {
-    return partitionIds;
-  }
-
-  public String getFaultyShuffleServerId() {
-    return faultyShuffleServerId;
-  }
-
-  public RssProtos.RssReassignFaultyShuffleServerRequest toProto() {
-    RssProtos.RssReassignFaultyShuffleServerRequest.Builder builder =
-        RssProtos.RssReassignFaultyShuffleServerRequest.newBuilder()
-            .setShuffleId(this.shuffleId)
-            .setFaultyShuffleServerId(this.faultyShuffleServerId)
-            .addAllPartitionIds(this.partitionIds);
-    return builder.build();
-  }
-}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java
new file mode 100644
index 0000000..7a28493
--- /dev/null
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.request;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import org.apache.uniffle.common.ReceivingFailureServer;
+import org.apache.uniffle.proto.RssProtos;
+
+public class RssReassignOnBlockSendFailureRequest {
+  private int shuffleId;
+  private Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers;
+
+  public RssReassignOnBlockSendFailureRequest(
+      int shuffleId, Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers) {
+    this.shuffleId = shuffleId;
+    this.failurePartitionToServers = failurePartitionToServers;
+  }
+
+  public static RssProtos.RssReassignOnBlockSendFailureRequest toProto(
+      RssReassignOnBlockSendFailureRequest request) {
+    return RssProtos.RssReassignOnBlockSendFailureRequest.newBuilder()
+        .setShuffleId(request.shuffleId)
+        .putAllFailurePartitionToServerIds(
+            request.failurePartitionToServers.entrySet().stream()
+                .collect(
+                    Collectors.toMap(
+                        Map.Entry::getKey, x -> ReceivingFailureServer.toProto(x.getValue()))))
+        .build();
+  }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
index 66c3288..9daa002 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
@@ -21,15 +21,17 @@
 import org.apache.uniffle.proto.RssProtos;
 
 public class RssPartitionToShuffleServerResponse extends ClientResponse {
-  private RssProtos.ShuffleHandleInfo shuffleHandleInfoProto;
+  private RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto;
 
   public RssPartitionToShuffleServerResponse(
-      StatusCode statusCode, String message, RssProtos.ShuffleHandleInfo shuffleHandleInfoProto) {
+      StatusCode statusCode,
+      String message,
+      RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto) {
     super(statusCode, message);
     this.shuffleHandleInfoProto = shuffleHandleInfoProto;
   }
 
-  public RssProtos.ShuffleHandleInfo getShuffleHandleInfoProto() {
+  public RssProtos.MutableShuffleHandleInfo getShuffleHandleInfoProto() {
     return shuffleHandleInfoProto;
   }
 
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignFaultyShuffleServerResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignFaultyShuffleServerResponse.java
deleted file mode 100644
index 4c3b7c4..0000000
--- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignFaultyShuffleServerResponse.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.uniffle.client.response;
-
-import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.rpc.StatusCode;
-import org.apache.uniffle.proto.RssProtos;
-
-public class RssReassignFaultyShuffleServerResponse extends ClientResponse {
-
-  private ShuffleServerInfo shuffleServer;
-
-  public RssReassignFaultyShuffleServerResponse(
-      StatusCode statusCode, String message, ShuffleServerInfo shuffleServer) {
-    super(statusCode, message);
-    this.shuffleServer = shuffleServer;
-  }
-
-  public ShuffleServerInfo getShuffleServer() {
-    return shuffleServer;
-  }
-
-  public static RssReassignFaultyShuffleServerResponse fromProto(
-      RssProtos.RssReassignFaultyShuffleServerResponse response) {
-    return new RssReassignFaultyShuffleServerResponse(
-        StatusCode.valueOf(response.getStatus().name()),
-        response.getMsg(),
-        new ShuffleServerInfo(
-            response.getServer().getId(),
-            response.getServer().getIp(),
-            response.getServer().getPort(),
-            response.getServer().getNettyPort()));
-  }
-}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
new file mode 100644
index 0000000..81ca7d5
--- /dev/null
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignOnBlockSendFailureResponse.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.response;
+
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.proto.RssProtos;
+
+public class RssReassignOnBlockSendFailureResponse extends ClientResponse {
+  private RssProtos.MutableShuffleHandleInfo handle;
+
+  public RssReassignOnBlockSendFailureResponse(
+      StatusCode statusCode, String message, RssProtos.MutableShuffleHandleInfo handle) {
+    super(statusCode, message);
+    this.handle = handle;
+  }
+
+  public RssProtos.MutableShuffleHandleInfo getHandle() {
+    return handle;
+  }
+
+  public static RssReassignOnBlockSendFailureResponse fromProto(
+      RssProtos.RssReassignOnBlockSendFailureResponse response) {
+    return new RssReassignOnBlockSendFailureResponse(
+        StatusCode.valueOf(response.getStatus().name()), response.getMsg(), response.getHandle());
+  }
+}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 720002c..97470f4 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -523,8 +523,8 @@
   rpc reportShuffleWriteFailure (ReportShuffleWriteFailureRequest) returns (ReportShuffleWriteFailureResponse);
   // Reassign the RPC interface of the ShuffleServer list
   rpc reassignShuffleServers(ReassignServersRequest) returns (ReassignServersReponse);
-  // Reassign a new server instead a faulty server the RPC interface
-  rpc reassignFaultyShuffleServer(RssReassignFaultyShuffleServerRequest) returns (RssReassignFaultyShuffleServerResponse);
+  // Reassign on block send failure that occurs in writer
+  rpc reassignOnBlockSendFailure(RssReassignOnBlockSendFailureRequest) returns (RssReassignOnBlockSendFailureResponse);
 }
 
 message ReportShuffleFetchFailureRequest {
@@ -552,10 +552,10 @@
 message PartitionToShuffleServerResponse {
   StatusCode status = 1;
   string msg = 2;
-  ShuffleHandleInfo shuffleHandleInfo = 3;
+  MutableShuffleHandleInfo shuffleHandleInfo = 3;
 }
 
-message ShuffleHandleInfo {
+message MutableShuffleHandleInfo {
   int32 shuffleId = 1;
   map<int32, PartitionReplicaServers> partitionToServers = 2;
   RemoteStorageInfo remoteStorageInfo = 3;
@@ -601,16 +601,24 @@
   string msg = 3;
 }
 
-message RssReassignFaultyShuffleServerRequest{
+message RssReassignOnBlockSendFailureRequest{
   int32 shuffleId  = 1;
-  repeated int32 partitionIds = 2;
-  string faultyShuffleServerId = 3;
+  map<int32, ReceivingFailureServers> failurePartitionToServerIds = 2;
 }
 
-message RssReassignFaultyShuffleServerResponse{
+message ReceivingFailureServers {
+  repeated ReceivingFailureServer server = 1;
+}
+
+message ReceivingFailureServer {
+  string serverId = 1;
+  StatusCode statusCode = 2;
+}
+
+message RssReassignOnBlockSendFailureResponse {
   StatusCode status = 1;
-  ShuffleServerId server = 2;
-  string msg = 3;
+  string msg = 2;
+  MutableShuffleHandleInfo handle = 3;
 }
 
 
diff --git a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
index 9b7ca1d..8181ddc 100644
--- a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
+++ b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
@@ -95,7 +95,7 @@
   }
 
   @VisibleForTesting
-  boolean sendHeartBeat(
+  public boolean sendHeartBeat(
       String id,
       String ip,
       int grpcPort,
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
index 79fe35b..de613e2 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
@@ -527,4 +527,21 @@
   public String getEncodedTags() {
     return StringUtils.join(tags, ",");
   }
+
+  @VisibleForTesting
+  public void sendHeartbeat() {
+    ShuffleServer shuffleServer = this;
+    registerHeartBeat.sendHeartBeat(
+        shuffleServer.getId(),
+        shuffleServer.getIp(),
+        shuffleServer.getGrpcPort(),
+        shuffleServer.getUsedMemory(),
+        shuffleServer.getPreAllocatedMemory(),
+        shuffleServer.getAvailableMemory(),
+        shuffleServer.getEventNumInFlush(),
+        shuffleServer.getTags(),
+        shuffleServer.getServerStatus(),
+        shuffleServer.getStorageManager().getStorageInfo(),
+        shuffleServer.getNettyPort());
+  }
 }
diff --git a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
index 87b9abd..d0d77f6 100644
--- a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
+++ b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
@@ -45,6 +45,9 @@
 
   private boolean mockSendDataFailed = false;
 
+  private boolean mockRequireBufferFailedWithNoBuffer = false;
+  private boolean isMockRequireBufferFailedWithNoBufferForHugePartition = false;
+
   private boolean recordGetShuffleResult = false;
 
   private long numOfFailedReadRequest = 0;
@@ -62,6 +65,14 @@
     this.mockSendDataFailed = mockSendDataFailed;
   }
 
+  public void enableMockRequireBufferFailWithNoBuffer() {
+    this.mockRequireBufferFailedWithNoBuffer = true;
+  }
+
+  public void enableMockRequireBufferFailWithNoBufferForHugePartition() {
+    this.isMockRequireBufferFailedWithNoBufferForHugePartition = true;
+  }
+
   public void enableRecordGetShuffleResult() {
     recordGetShuffleResult = true;
   }
@@ -88,6 +99,30 @@
   }
 
   @Override
+  public void requireBuffer(
+      RssProtos.RequireBufferRequest request,
+      StreamObserver<RssProtos.RequireBufferResponse> responseObserver) {
+    if (mockRequireBufferFailedWithNoBuffer
+        || isMockRequireBufferFailedWithNoBufferForHugePartition) {
+      LOG.info("Make require buffer mocked failed.");
+      StatusCode code =
+          mockRequireBufferFailedWithNoBuffer
+              ? StatusCode.NO_BUFFER
+              : StatusCode.NO_BUFFER_FOR_HUGE_PARTITION;
+      RssProtos.RequireBufferResponse response =
+          RssProtos.RequireBufferResponse.newBuilder()
+              .setStatus(code.toProto())
+              .setRequireBufferId(-1)
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
+
+    super.requireBuffer(request, responseObserver);
+  }
+
+  @Override
   public void sendShuffleData(
       RssProtos.SendShuffleDataRequest request,
       StreamObserver<RssProtos.SendShuffleDataResponse> responseObserver) {