[#1608][part-3] feat(spark3): support reading data from multiple reassigned servers (#1615)

### What changes were proposed in this pull request?

Support reading from partition block data reassignment servers.

### Why are the changes needed?

For: #1608

Writer has been writing data into reassignment servers, so it's necessary to read from reassignment servers.
And the blockId will be stored in their owned partition servers, so this PR can read blockIds from these servers and 
support min-replica requirements at the same time.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

`PartitionBlockDataReassignTest` integration test.
diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index f016bfc..e5a25b6 100644
--- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -39,6 +39,7 @@
 import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
@@ -593,7 +594,8 @@
         Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
         String appId,
         int shuffleId,
-        Set<Integer> failedPartitions) {
+        Set<Integer> failedPartitions,
+        PartitionDataReplicaRequirementTracking tracking) {
       return null;
     }
 
diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index eeae036..b858312 100644
--- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -59,6 +59,7 @@
 import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
@@ -560,7 +561,8 @@
         Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
         String appId,
         int shuffleId,
-        Set<Integer> failedPartitions) {
+        Set<Integer> failedPartitions,
+        PartitionDataReplicaRequirementTracking tracking) {
       return null;
     }
 
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
index 34081e4..60c9597 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
@@ -19,7 +19,6 @@
 
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.broadcast.Broadcast;
@@ -70,8 +69,4 @@
   public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
     return handlerInfoBd.value().getPartitionToServers();
   }
-
-  public Set<ShuffleServerInfo> getShuffleServersForData() {
-    return handlerInfoBd.value().getShuffleServersForData();
-  }
 }
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
index e54145c..0e4dd49 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
@@ -18,17 +18,21 @@
 package org.apache.spark.shuffle;
 
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
-import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
+import com.google.common.annotations.VisibleForTesting;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos;
 
 /**
  * Class for holding, 1. partition ID -> shuffle servers mapping. 2. remote storage info
@@ -38,19 +42,19 @@
 public class ShuffleHandleInfo implements Serializable {
 
   private int shuffleId;
-
-  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
-
-  // partitionId -> replica -> failover servers
-  private Map<Integer, Map<Integer, ShuffleServerInfo>> failoverPartitionServers;
-  // todo: support mores replacement servers for one faulty server.
-  private Map<String, ShuffleServerInfo> faultyServerReplacements;
-
-  // shuffle servers which is for store shuffle data
-  private Set<ShuffleServerInfo> shuffleServersForData;
-  // remoteStorage used for this job
   private RemoteStorageInfo remoteStorage;
 
+  /**
+   * partitionId -> replica -> assigned servers.
+   *
+   * <p>The first index of list<ShuffleServerInfo> is the initial static assignment server.
+   *
+   * <p>The remaining indexes are the replacement servers if exists.
+   */
+  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers;
+  // faulty servers replacement mapping
+  private Map<String, Set<ShuffleServerInfo>> faultyServerToReplacements;
+
   public static final ShuffleHandleInfo EMPTY_HANDLE_INFO =
       new ShuffleHandleInfo(-1, Collections.EMPTY_MAP, RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
 
@@ -59,22 +63,59 @@
       Map<Integer, List<ShuffleServerInfo>> partitionToServers,
       RemoteStorageInfo storageInfo) {
     this.shuffleId = shuffleId;
-    this.partitionToServers = partitionToServers;
-    this.shuffleServersForData = Sets.newHashSet();
-    this.failoverPartitionServers = Maps.newConcurrentMap();
-    for (List<ShuffleServerInfo> ssis : partitionToServers.values()) {
-      this.shuffleServersForData.addAll(ssis);
-    }
     this.remoteStorage = storageInfo;
-    this.faultyServerReplacements = JavaUtils.newConcurrentMap();
+    this.faultyServerToReplacements = new HashMap<>();
+    this.partitionReplicaAssignedServers = toPartitionReplicaMapping(partitionToServers);
   }
 
-  public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
-    return partitionToServers;
+  public ShuffleHandleInfo() {
+    // ignore
   }
 
-  public Set<ShuffleServerInfo> getShuffleServersForData() {
-    return shuffleServersForData;
+  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping(
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers =
+        new HashMap<>();
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> partitionEntry :
+        partitionToServers.entrySet()) {
+      int partitionId = partitionEntry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicaMapping =
+          partitionReplicaAssignedServers.computeIfAbsent(partitionId, x -> new HashMap<>());
+
+      List<ShuffleServerInfo> replicaServers = partitionEntry.getValue();
+      for (int i = 0; i < replicaServers.size(); i++) {
+        int replicaIdx = i;
+        replicaMapping
+            .computeIfAbsent(replicaIdx, x -> new ArrayList<>())
+            .add(replicaServers.get(i));
+      }
+    }
+    return partitionReplicaAssignedServers;
+  }
+
+  /**
+   * This composes the partition's replica servers + replacement servers, this will be used by the
+   * shuffleReader to get the blockIds
+   */
+  public Map<Integer, List<ShuffleServerInfo>> listPartitionAssignedServers() {
+    Map<Integer, List<ShuffleServerInfo>> partitionServers = new HashMap<>();
+    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
+        partitionReplicaAssignedServers.entrySet()) {
+      int partitionId = entry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
+      List<ShuffleServerInfo> servers =
+          replicaServers.values().stream().flatMap(x -> x.stream()).collect(Collectors.toList());
+      partitionServers.computeIfAbsent(partitionId, x -> new ArrayList<>()).addAll(servers);
+    }
+    return partitionServers;
+  }
+
+  /** Return all the assigned servers for the writer to commit */
+  public Set<ShuffleServerInfo> listAssignedServers() {
+    return partitionReplicaAssignedServers.values().stream()
+        .flatMap(x -> x.values().stream())
+        .flatMap(x -> x.stream())
+        .collect(Collectors.toSet());
   }
 
   public RemoteStorageInfo getRemoteStorage() {
@@ -85,33 +126,125 @@
     return shuffleId;
   }
 
-  public boolean isExistingFaultyServer(String serverId) {
-    return faultyServerReplacements.containsKey(serverId);
+  @VisibleForTesting
+  protected boolean isMarkedAsFaultyServer(String serverId) {
+    return faultyServerToReplacements.containsKey(serverId);
   }
 
-  public ShuffleServerInfo useExistingReassignmentForMultiPartitions(
-      Set<Integer> partitionIds, String faultyServerId) {
-    return createNewReassignmentForMultiPartitions(partitionIds, faultyServerId, null);
+  public Set<ShuffleServerInfo> getExistingReplacements(String faultyServerId) {
+    return faultyServerToReplacements.get(faultyServerId);
   }
 
-  public ShuffleServerInfo createNewReassignmentForMultiPartitions(
-      Set<Integer> partitionIds, String faultyServerId, ShuffleServerInfo replacement) {
-    if (replacement != null) {
-      faultyServerReplacements.put(faultyServerId, replacement);
+  public void updateReassignment(
+      Set<Integer> partitionIds, String faultyServerId, Set<ShuffleServerInfo> replacements) {
+    if (replacements == null) {
+      return;
     }
 
-    replacement = faultyServerReplacements.get(faultyServerId);
+    faultyServerToReplacements.put(faultyServerId, replacements);
+    // todo: optimize the multiple for performance
     for (Integer partitionId : partitionIds) {
-      List<ShuffleServerInfo> replicaServers = partitionToServers.get(partitionId);
-      for (int i = 0; i < replicaServers.size(); i++) {
-        if (replicaServers.get(i).getId().equals(faultyServerId)) {
-          Map<Integer, ShuffleServerInfo> replicaReplacements =
-              failoverPartitionServers.computeIfAbsent(
-                  partitionId, k -> JavaUtils.newConcurrentMap());
-          replicaReplacements.put(i, replacement);
+      Map<Integer, List<ShuffleServerInfo>> replicaServers =
+          partitionReplicaAssignedServers.get(partitionId);
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> serverEntry : replicaServers.entrySet()) {
+        List<ShuffleServerInfo> servers = serverEntry.getValue();
+        if (servers.stream()
+            .map(x -> x.getId())
+            .collect(Collectors.toSet())
+            .contains(faultyServerId)) {
+          Set<ShuffleServerInfo> tempSet = new HashSet<>();
+          tempSet.addAll(replacements);
+          tempSet.removeAll(servers);
+          servers.addAll(tempSet);
         }
       }
     }
-    return replacement;
+  }
+
+  // partitionId -> replica -> failover servers
+  // always return the last server.
+  @VisibleForTesting
+  public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
+        partitionReplicaAssignedServers.entrySet()) {
+      int partitionId = entry.getKey();
+      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
+          replicaServers.entrySet()) {
+        ShuffleServerInfo lastServer =
+            replicaServerEntry.getValue().get(replicaServerEntry.getValue().size() - 1);
+        partitionToServers.computeIfAbsent(partitionId, x -> new ArrayList<>()).add(lastServer);
+      }
+    }
+    return partitionToServers;
+  }
+
+  public PartitionDataReplicaRequirementTracking createPartitionReplicaTracking() {
+    PartitionDataReplicaRequirementTracking replicaRequirement =
+        new PartitionDataReplicaRequirementTracking(shuffleId, partitionReplicaAssignedServers);
+    return replicaRequirement;
+  }
+
+  public static RssProtos.ShuffleHandleInfo toProto(ShuffleHandleInfo handleInfo) {
+    Map<Integer, RssProtos.PartitionReplicaServers> partitionToServers = new HashMap<>();
+    for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
+        handleInfo.partitionReplicaAssignedServers.entrySet()) {
+      int partitionId = entry.getKey();
+
+      Map<Integer, RssProtos.ReplicaServersItem> replicaServersProto = new HashMap<>();
+      Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
+          replicaServers.entrySet()) {
+        RssProtos.ReplicaServersItem item =
+            RssProtos.ReplicaServersItem.newBuilder()
+                .addAllServerId(ShuffleServerInfo.toProto(replicaServerEntry.getValue()))
+                .build();
+        replicaServersProto.put(replicaServerEntry.getKey(), item);
+      }
+
+      RssProtos.PartitionReplicaServers partitionReplicaServerProto =
+          RssProtos.PartitionReplicaServers.newBuilder()
+              .putAllReplicaServers(replicaServersProto)
+              .build();
+      partitionToServers.put(partitionId, partitionReplicaServerProto);
+    }
+
+    RssProtos.ShuffleHandleInfo handleProto =
+        RssProtos.ShuffleHandleInfo.newBuilder()
+            .setShuffleId(handleInfo.shuffleId)
+            .setRemoteStorageInfo(
+                RssProtos.RemoteStorageInfo.newBuilder()
+                    .setPath(handleInfo.remoteStorage.getPath())
+                    .putAllConfItems(handleInfo.remoteStorage.getConfItems())
+                    .build())
+            .putAllPartitionToServers(partitionToServers)
+            .build();
+    return handleProto;
+  }
+
+  public static ShuffleHandleInfo fromProto(RssProtos.ShuffleHandleInfo handleProto) {
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionToServers = new HashMap<>();
+    for (Map.Entry<Integer, RssProtos.PartitionReplicaServers> entry :
+        handleProto.getPartitionToServersMap().entrySet()) {
+      Map<Integer, List<ShuffleServerInfo>> replicaServers =
+          partitionToServers.computeIfAbsent(entry.getKey(), x -> new HashMap<>());
+      for (Map.Entry<Integer, RssProtos.ReplicaServersItem> serverEntry :
+          entry.getValue().getReplicaServersMap().entrySet()) {
+        int replicaIdx = serverEntry.getKey();
+        List<ShuffleServerInfo> shuffleServerInfos =
+            ShuffleServerInfo.fromProto(serverEntry.getValue().getServerIdList());
+        replicaServers.put(replicaIdx, shuffleServerInfos);
+      }
+    }
+    RemoteStorageInfo remoteStorageInfo =
+        new RemoteStorageInfo(
+            handleProto.getRemoteStorageInfo().getPath(),
+            handleProto.getRemoteStorageInfo().getConfItemsMap());
+    ShuffleHandleInfo handle = new ShuffleHandleInfo();
+    handle.shuffleId = handle.getShuffleId();
+    handle.partitionReplicaAssignedServers = partitionToServers;
+    handle.remoteStorage = remoteStorageInfo;
+    return handle;
   }
 }
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index bcf1303..b713bd7 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -32,7 +32,6 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.proto.RssProtos;
@@ -194,37 +193,17 @@
         shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
     if (shuffleHandleInfoByShuffleId != null) {
       code = RssProtos.StatusCode.SUCCESS;
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-          shuffleHandleInfoByShuffleId.getPartitionToServers();
-      Map<Integer, RssProtos.GetShuffleServerListResponse> protopartitionToServers =
-          JavaUtils.newConcurrentMap();
-      for (Map.Entry<Integer, List<ShuffleServerInfo>> integerListEntry :
-          partitionToServers.entrySet()) {
-        List<RssProtos.ShuffleServerId> shuffleServerIds =
-            ShuffleServerInfo.toProto(integerListEntry.getValue());
-        RssProtos.GetShuffleServerListResponse getShuffleServerListResponse =
-            RssProtos.GetShuffleServerListResponse.newBuilder()
-                .addAllServers(shuffleServerIds)
-                .build();
-        protopartitionToServers.put(integerListEntry.getKey(), getShuffleServerListResponse);
-      }
-      RemoteStorageInfo remoteStorage = shuffleHandleInfoByShuffleId.getRemoteStorage();
-      RssProtos.RemoteStorageInfo.Builder protosRemoteStage =
-          RssProtos.RemoteStorageInfo.newBuilder()
-              .setPath(remoteStorage.getPath())
-              .putAllConfItems(remoteStorage.getConfItems());
       reply =
           RssProtos.PartitionToShuffleServerResponse.newBuilder()
               .setStatus(code)
-              .putAllPartitionToShuffleServer(protopartitionToServers)
-              .setRemoteStorageInfo(protosRemoteStage)
+              .setShuffleHandleInfo(ShuffleHandleInfo.toProto(shuffleHandleInfoByShuffleId))
               .build();
     } else {
       code = RssProtos.StatusCode.INVALID_REQUEST;
       reply =
           RssProtos.PartitionToShuffleServerResponse.newBuilder()
               .setStatus(code)
-              .putAllPartitionToShuffleServer(null)
+              .setShuffleHandleInfo(ShuffleHandleInfo.toProto(ShuffleHandleInfo.EMPTY_HANDLE_INFO))
               .build();
     }
     responseObserver.onNext(reply);
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java
index bb1b1ca..a2fe771 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/ShuffleHandleInfoTest.java
@@ -19,13 +19,14 @@
 
 import java.util.Arrays;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 
@@ -35,8 +36,8 @@
 
 public class ShuffleHandleInfoTest {
 
-  private ShuffleServerInfo createFakeServerInfo(String host) {
-    return new ShuffleServerInfo(host, 1);
+  private ShuffleServerInfo createFakeServerInfo(String id) {
+    return new ShuffleServerInfo(id, id, 1);
   }
 
   @Test
@@ -48,14 +49,69 @@
     ShuffleHandleInfo handleInfo =
         new ShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
 
-    // case1
-    assertFalse(handleInfo.isExistingFaultyServer("a"));
-    Set<Integer> partitions = new HashSet<>();
-    partitions.add(1);
-    ShuffleServerInfo newServer = createFakeServerInfo("d");
-    handleInfo.createNewReassignmentForMultiPartitions(partitions, "a", createFakeServerInfo("d"));
-    assertTrue(handleInfo.isExistingFaultyServer("a"));
+    assertFalse(handleInfo.isMarkedAsFaultyServer("a"));
+    Set<Integer> partitions = Sets.newHashSet(1);
+    handleInfo.updateReassignment(partitions, "a", Sets.newHashSet(createFakeServerInfo("d")));
+    assertTrue(handleInfo.isMarkedAsFaultyServer("a"));
+  }
 
-    assertEquals(newServer, handleInfo.useExistingReassignmentForMultiPartitions(partitions, "a"));
+  @Test
+  public void testListAllPartitionAssignmentServers() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b")));
+    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
+
+    ShuffleHandleInfo handleInfo =
+        new ShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
+
+    // case1
+    Set<Integer> partitions = Sets.newHashSet(2);
+    handleInfo.updateReassignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
+
+    Map<Integer, List<ShuffleServerInfo>> partitionAssignment =
+        handleInfo.listPartitionAssignedServers();
+    assertEquals(2, partitionAssignment.size());
+    assertEquals(
+        Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
+        partitionAssignment.get(2));
+
+    // case2: reassign multiple times for one partition, it will not append the same replacement
+    // servers
+    handleInfo.updateReassignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
+    partitionAssignment = handleInfo.listPartitionAssignedServers();
+    assertEquals(
+        Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
+        partitionAssignment.get(2));
+
+    // case3: reassign multiple times for one partition, it will append the non-existing replacement
+    // servers
+    handleInfo.updateReassignment(
+        partitions, "c", Sets.newHashSet(createFakeServerInfo("d"), createFakeServerInfo("e")));
+    partitionAssignment = handleInfo.listPartitionAssignedServers();
+    assertEquals(
+        Arrays.asList(
+            createFakeServerInfo("c"), createFakeServerInfo("d"), createFakeServerInfo("e")),
+        partitionAssignment.get(2));
+  }
+
+  @Test
+  public void testCreatePartitionReplicaTracking() {
+    ShuffleServerInfo a = createFakeServerInfo("a");
+    ShuffleServerInfo b = createFakeServerInfo("b");
+    ShuffleServerInfo c = createFakeServerInfo("c");
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(a, b));
+    partitionToServers.put(2, Arrays.asList(c));
+
+    ShuffleHandleInfo handleInfo =
+        new ShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
+
+    // not any replacements
+    PartitionDataReplicaRequirementTracking tracking = handleInfo.createPartitionReplicaTracking();
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = tracking.getInventory();
+    assertEquals(a, inventory.get(1).get(0).get(0));
+    assertEquals(b, inventory.get(1).get(1).get(0));
+    assertEquals(c, inventory.get(2).get(0).get(0));
   }
 }
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 78bcc2c..cfd5ae3 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -822,13 +822,9 @@
     }
     RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
         new RssPartitionToShuffleServerRequest(shuffleId);
-    RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
+    RssPartitionToShuffleServerResponse handleInfoResponse =
         shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
-    shuffleHandleInfo =
-        new ShuffleHandleInfo(
-            shuffleId,
-            rpcPartitionToShufflerServer.getPartitionToServers(),
-            rpcPartitionToShufflerServer.getRemoteStorageInfo());
+    shuffleHandleInfo = ShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto());
     return shuffleHandleInfo;
   }
 
@@ -899,18 +895,13 @@
     synchronized (handleInfo) {
       // find out whether this server has been marked faulty in this shuffle
       // if it has been reassigned, directly return the replacement server.
-      if (handleInfo.isExistingFaultyServer(faultyShuffleServerId)) {
-        return handleInfo.useExistingReassignmentForMultiPartitions(
-            partitionIds, faultyShuffleServerId);
+      Set<ShuffleServerInfo> replacements =
+          handleInfo.getExistingReplacements(faultyShuffleServerId);
+      if (replacements == null) {
+        replacements = Sets.newHashSet(assignShuffleServer(shuffleId, faultyShuffleServerId));
       }
-
-      // get the newer server to replace faulty server.
-      ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, faultyShuffleServerId);
-      if (newAssignedServer != null) {
-        handleInfo.createNewReassignmentForMultiPartitions(
-            partitionIds, faultyShuffleServerId, newAssignedServer);
-      }
-      return newAssignedServer;
+      handleInfo.updateReassignment(partitionIds, faultyShuffleServerId, replacements);
+      return replacements.stream().findFirst().get();
     }
   }
 
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 9e64b2f..5a5d8bd 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -168,7 +168,7 @@
     this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.serverToPartitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
-    this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
+    this.shuffleServersForData = shuffleHandleInfo.listAssignedServers();
     this.partitionToServers = shuffleHandleInfo.getPartitionToServers();
     this.isMemoryShuffleEnabled =
         isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6d9487c..1f61876 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -62,6 +62,7 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
@@ -667,14 +668,8 @@
               rssShuffleHandle.getPartitionToServers(),
               rssShuffleHandle.getRemoteStorage());
     }
-    Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
-        shuffleHandleInfo.getPartitionToServers();
-    Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
-        allPartitionToServers.entrySet().stream()
-            .filter(x -> x.getKey() >= startPartition && x.getKey() < endPartition)
-            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
     Map<ShuffleServerInfo, Set<Integer>> serverToPartitions =
-        RssUtils.generateServerToPartitions(requirePartitionToServers);
+        getPartitionDataServers(shuffleHandleInfo, startPartition, endPartition);
     long start = System.currentTimeMillis();
     Roaring64NavigableMap blockIdBitmap =
         getShuffleResultForMultiPart(
@@ -682,7 +677,8 @@
             serverToPartitions,
             rssShuffleHandle.getAppId(),
             shuffleId,
-            context.stageAttemptNumber());
+            context.stageAttemptNumber(),
+            shuffleHandleInfo.createPartitionReplicaTracking());
     LOG.info(
         "Get shuffle blockId cost "
             + (System.currentTimeMillis() - start)
@@ -725,7 +721,20 @@
         readMetrics,
         RssSparkConfig.toRssConf(sparkConf),
         dataDistributionType,
-        allPartitionToServers);
+        shuffleHandleInfo.listPartitionAssignedServers());
+  }
+
+  private Map<ShuffleServerInfo, Set<Integer>> getPartitionDataServers(
+      ShuffleHandleInfo shuffleHandleInfo, int startPartition, int endPartition) {
+    Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
+        shuffleHandleInfo.listPartitionAssignedServers();
+    Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
+        allPartitionToServers.entrySet().stream()
+            .filter(x -> x.getKey() >= startPartition && x.getKey() < endPartition)
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+    Map<ShuffleServerInfo, Set<Integer>> serverToPartitions =
+        RssUtils.generateServerToPartitions(requirePartitionToServers);
+    return serverToPartitions;
   }
 
   @SuppressFBWarnings("REC_CATCH_EXCEPTION")
@@ -1074,11 +1083,17 @@
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
       int shuffleId,
-      int stageAttemptId) {
+      int stageAttemptId,
+      PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
     Set<Integer> failedPartitions = Sets.newHashSet();
     try {
       return shuffleWriteClient.getShuffleResultForMultiPart(
-          clientType, serverToPartitions, appId, shuffleId, failedPartitions);
+          clientType,
+          serverToPartitions,
+          appId,
+          shuffleId,
+          failedPartitions,
+          replicaRequirementTracking);
     } catch (RssFetchFailedException e) {
       throw RssSparkShuffleUtils.reportRssFetchFailedException(
           e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions);
@@ -1120,10 +1135,7 @@
     RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
         shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
     shuffleHandleInfo =
-        new ShuffleHandleInfo(
-            shuffleId,
-            rpcPartitionToShufflerServer.getPartitionToServers(),
-            rpcPartitionToShufflerServer.getRemoteStorageInfo());
+        ShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
     return shuffleHandleInfo;
   }
 
@@ -1195,29 +1207,24 @@
     synchronized (handleInfo) {
       // find out whether this server has been marked faulty in this shuffle
       // if it has been reassigned, directly return the replacement server.
-      if (handleInfo.isExistingFaultyServer(faultyShuffleServerId)) {
-        return handleInfo.useExistingReassignmentForMultiPartitions(
-            partitionIds, faultyShuffleServerId);
+      // otherwise, it should request new servers to reassign
+      Set<ShuffleServerInfo> replacements =
+          handleInfo.getExistingReplacements(faultyShuffleServerId);
+      if (replacements == null) {
+        replacements = requestServersForTask(shuffleId, partitionIds, faultyShuffleServerId);
       }
-
-      // get the newer server to replace faulty server.
-      ShuffleServerInfo newAssignedServer =
-          reassignShuffleServerForTask(shuffleId, partitionIds, faultyShuffleServerId);
-      if (newAssignedServer != null) {
-        handleInfo.createNewReassignmentForMultiPartitions(
-            partitionIds, faultyShuffleServerId, newAssignedServer);
-      }
+      handleInfo.updateReassignment(partitionIds, faultyShuffleServerId, replacements);
       LOG.info(
           "Reassign shuffle-server from {} -> {} for shuffleId: {}, partitionIds: {}",
           faultyShuffleServerId,
-          newAssignedServer,
+          replacements,
           shuffleId,
           partitionIds);
-      return newAssignedServer;
+      return replacements.stream().findFirst().get();
     }
   }
 
-  private ShuffleServerInfo reassignShuffleServerForTask(
+  private Set<ShuffleServerInfo> requestServersForTask(
       int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
     Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
     faultyServerIds.addAll(failuresShuffleServerIds);
@@ -1248,7 +1255,7 @@
           serverToPartitionRanges.put(replacement, partitionRanges);
           return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
         });
-    return replacementRef.get();
+    return Sets.newHashSet(replacementRef.get());
   }
 
   private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 8a22b73..0283b84 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -189,7 +189,7 @@
     this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.serverToPartitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
-    this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
+    this.shuffleServersForData = shuffleHandleInfo.listAssignedServers();
     this.partitionLengths = new long[partitioner.numPartitions()];
     Arrays.fill(partitionLengths, 0);
     this.isMemoryShuffleEnabled =
@@ -467,6 +467,10 @@
     }
 
     FailedBlockSendTracker failedTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId);
+    if (failedTracker == null) {
+      return;
+    }
+
     Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
     if (CollectionUtils.isEmpty(failedBlockIds)) {
       return;
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 5ca85ec..2f930a8 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -233,8 +233,10 @@
     assertEquals(2, serverToPartitionToBlockIds.get(replacement).get(0).size());
 
     // case2. If exceeding the max retry times, it will fast fail.
-    rssShuffleWriterSpy.setBlockFailSentRetryMaxTimes(1);
-    rssShuffleWriterSpy.setTaskId("taskId2");
+    rssShuffleWriter.setBlockFailSentRetryMaxTimes(1);
+    rssShuffleWriter.setTaskId("taskId2");
+    rssShuffleWriter.getBufferManager().setTaskId("taskId2");
+    taskToFailedBlockSendTracker.put("taskId2", new FailedBlockSendTracker());
     FakedDataPusher alwaysFailedDataPusher =
         new FakedDataPusher(
             event -> {
@@ -257,8 +259,9 @@
     manager.setDataPusher(alwaysFailedDataPusher);
 
     MutableList<Product2<String, String>> mockedData = createMockRecords();
+
     try {
-      rssShuffleWriterSpy.write(mockedData.iterator());
+      rssShuffleWriter.write(mockedData.iterator());
       fail();
     } catch (Exception e) {
       // ignore
diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index d8d435b..0805dfe 100644
--- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -54,6 +54,7 @@
 import org.junit.jupiter.api.io.TempDir;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
@@ -663,7 +664,8 @@
         Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
         String appId,
         int shuffleId,
-        Set<Integer> failedPartitions) {
+        Set<Integer> failedPartitions,
+        PartitionDataReplicaRequirementTracking tracking) {
       return null;
     }
 
diff --git a/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java b/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java
new file mode 100644
index 0000000..02d5b62
--- /dev/null
+++ b/client/src/main/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTracking.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+/**
+ * This class is to track the partition data replica requirements, which is used for {@link
+ * org.apache.uniffle.client.impl.ShuffleWriteClientImpl} to check whether the reading blockIds from
+ * multi/single server(s) meet the min replica.
+ */
+public class PartitionDataReplicaRequirementTracking {
+  private int shuffleId;
+
+  // partitionId -> replicaIndex -> shuffleServerInfo
+  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory;
+
+  private Map<Integer, Map<Integer, Integer>> succeedList = new HashMap<>();
+
+  public PartitionDataReplicaRequirementTracking(
+      int shuffleId, Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory) {
+    this.shuffleId = shuffleId;
+    this.inventory = inventory;
+  }
+
+  public boolean isSatisfied(int partitionId, int minReplica) {
+    // replica index -> successful count
+    Map<Integer, Integer> succeedReplicas = succeedList.get(partitionId);
+    if (succeedReplicas == null) {
+      succeedReplicas = new HashMap<>();
+    }
+
+    Map<Integer, List<ShuffleServerInfo>> replicaList = inventory.get(partitionId);
+    int replicaSuccessfulCnt = 0;
+    for (Map.Entry<Integer, Integer> succeedReplica : succeedReplicas.entrySet()) {
+      int replicaIndex = succeedReplica.getKey();
+      int succeedCnt = succeedReplica.getValue();
+
+      int expected = replicaList.get(replicaIndex).size();
+      if (succeedCnt >= expected) {
+        replicaSuccessfulCnt += 1;
+      }
+    }
+    if (replicaSuccessfulCnt >= minReplica) {
+      return true;
+    }
+    return false;
+  }
+
+  public void markPartitionOfServerSuccessful(int partitionId, ShuffleServerInfo server) {
+    Map<Integer, Integer> partitionRequirements =
+        succeedList.computeIfAbsent(partitionId, l -> new HashMap<>());
+
+    Map<Integer, List<ShuffleServerInfo>> replicaServerChains = inventory.get(partitionId);
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> entry : replicaServerChains.entrySet()) {
+      int replicaIdx = entry.getKey();
+      if (entry.getValue().contains(server)) {
+        int old = partitionRequirements.computeIfAbsent(replicaIdx, x -> 0);
+        partitionRequirements.put(replicaIdx, old + 1);
+      }
+    }
+  }
+
+  public int getShuffleId() {
+    return shuffleId;
+  }
+
+  public Map<Integer, Map<Integer, List<ShuffleServerInfo>>> getInventory() {
+    return inventory;
+  }
+
+  @Override
+  public String toString() {
+    return "PartitionDataReplicaRequirementTracking{"
+        + "shuffleId="
+        + shuffleId
+        + ", inventory="
+        + inventory
+        + ", succeedList="
+        + succeedList
+        + '}';
+  }
+}
diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 88d97c3..7d8f533 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -24,6 +24,7 @@
 
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
@@ -103,7 +104,8 @@
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
       int shuffleId,
-      Set<Integer> failedPartitions);
+      Set<Integer> failedPartitions,
+      PartitionDataReplicaRequirementTracking replicaRequirementTracking);
 
   void close();
 
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 129dadc..42e60f3 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -21,6 +21,7 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -42,6 +43,7 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.client.api.CoordinatorClient;
 import org.apache.uniffle.client.api.ShuffleServerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
@@ -815,18 +817,19 @@
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
       String appId,
       int shuffleId,
-      Set<Integer> failedPartitions) {
-    Map<Integer, Integer> partitionReadSuccess = Maps.newHashMap();
+      Set<Integer> failedPartitions,
+      PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
     Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
+    Set<Integer> allRequestedPartitionIds = new HashSet<>();
     for (Map.Entry<ShuffleServerInfo, Set<Integer>> entry : serverToPartitions.entrySet()) {
       ShuffleServerInfo shuffleServerInfo = entry.getKey();
       Set<Integer> requestPartitions = Sets.newHashSet();
       for (Integer partitionId : entry.getValue()) {
-        partitionReadSuccess.putIfAbsent(partitionId, 0);
-        if (partitionReadSuccess.get(partitionId) < replicaRead) {
+        if (!replicaRequirementTracking.isSatisfied(partitionId, replicaRead)) {
           requestPartitions.add(partitionId);
         }
       }
+      allRequestedPartitionIds.addAll(requestPartitions);
       RssGetShuffleResultForMultiPartRequest request =
           new RssGetShuffleResultForMultiPartRequest(
               appId, shuffleId, requestPartitions, blockIdLayout);
@@ -838,8 +841,8 @@
           Roaring64NavigableMap blockIdBitmapOfServer = response.getBlockIdBitmap();
           blockIdBitmap.or(blockIdBitmapOfServer);
           for (Integer partitionId : requestPartitions) {
-            Integer oldVal = partitionReadSuccess.get(partitionId);
-            partitionReadSuccess.put(partitionId, oldVal + 1);
+            replicaRequirementTracking.markPartitionOfServerSuccessful(
+                partitionId, shuffleServerInfo);
           }
         }
       } catch (Exception e) {
@@ -852,12 +855,15 @@
                 + "], shuffleId["
                 + shuffleId
                 + "], requestPartitions"
-                + requestPartitions);
+                + requestPartitions,
+            e);
       }
     }
     boolean isSuccessful =
-        partitionReadSuccess.entrySet().stream().allMatch(x -> x.getValue() >= replicaRead);
+        allRequestedPartitionIds.stream()
+            .allMatch(x -> replicaRequirementTracking.isSatisfied(x, replicaRead));
     if (!isSuccessful) {
+      LOG.error("Failed to meet replica requirement: {}", replicaRequirementTracking);
       throw new RssFetchFailedException(
           "Get shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]");
     }
diff --git a/client/src/test/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTrackingTest.java b/client/src/test/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTrackingTest.java
new file mode 100644
index 0000000..39554e1
--- /dev/null
+++ b/client/src/test/java/org/apache/uniffle/client/PartitionDataReplicaRequirementTrackingTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class PartitionDataReplicaRequirementTrackingTest {
+
+  @Test
+  public void testSingleReplicaWithSingleShuffleServer() {
+    // partitionId -> replicaIndex -> shuffleServerInfo
+    ShuffleServerInfo s1 = new ShuffleServerInfo("s1", "1.1.1.1", 2);
+    ShuffleServerInfo s2 = new ShuffleServerInfo("s2", "1.1.1.1", 3);
+
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = new HashMap<>();
+
+    Map<Integer, List<ShuffleServerInfo>> partition0 =
+        inventory.computeIfAbsent(0, x -> new HashMap<>());
+    partition0.put(0, Arrays.asList(s1));
+
+    Map<Integer, List<ShuffleServerInfo>> partition1 =
+        inventory.computeIfAbsent(1, x -> new HashMap<>());
+    partition1.put(0, Arrays.asList(s2));
+
+    PartitionDataReplicaRequirementTracking tracking =
+        new PartitionDataReplicaRequirementTracking(1, inventory);
+    assertFalse(tracking.isSatisfied(0, 1));
+    assertFalse(tracking.isSatisfied(1, 1));
+
+    tracking.markPartitionOfServerSuccessful(0, s1);
+    assertTrue(tracking.isSatisfied(0, 1));
+    assertFalse(tracking.isSatisfied(1, 1));
+
+    tracking.markPartitionOfServerSuccessful(1, s2);
+    assertTrue(tracking.isSatisfied(0, 1));
+    assertTrue(tracking.isSatisfied(1, 1));
+  }
+
+  @Test
+  public void testSingleReplicaWithMultiServers() {
+    // partitionId -> replicaIndex -> shuffleServerInfo
+    ShuffleServerInfo s1 = new ShuffleServerInfo("s1", "1.1.1.1", 2);
+    ShuffleServerInfo s2 = new ShuffleServerInfo("s2", "1.1.1.1", 3);
+
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = new HashMap<>();
+
+    int partitionId = 0;
+    Map<Integer, List<ShuffleServerInfo>> partition0 =
+        inventory.computeIfAbsent(partitionId, x -> new HashMap<>());
+    partition0.put(partitionId, Arrays.asList(s1));
+    partition0.put(partitionId, Arrays.asList(s1, s2));
+
+    PartitionDataReplicaRequirementTracking tracking =
+        new PartitionDataReplicaRequirementTracking(1, inventory);
+    assertFalse(tracking.isSatisfied(partitionId, 1));
+
+    // mark the partition-0 with 1 server, it will fail.
+    tracking.markPartitionOfServerSuccessful(partitionId, s1);
+    assertFalse(tracking.isSatisfied(partitionId, 1));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s1);
+    assertTrue(tracking.isSatisfied(partitionId, 1));
+  }
+
+  @Test
+  public void testMultipleReplicaWithSingleServer() {
+    // partitionId -> replicaIndex -> shuffleServerInfo
+    ShuffleServerInfo s1 = new ShuffleServerInfo("s1", "1.1.1.1", 2);
+    ShuffleServerInfo s2 = new ShuffleServerInfo("s2", "1.1.1.1", 3);
+    ShuffleServerInfo s3 = new ShuffleServerInfo("s3", "1.1.1.1", 3);
+
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = new HashMap<>();
+    int partitionId = 1;
+
+    Map<Integer, List<ShuffleServerInfo>> partition1 =
+        inventory.computeIfAbsent(partitionId, x -> new HashMap<>());
+
+    // replicaIdx -> shuffle-servers
+    partition1.put(0, Arrays.asList(s1));
+    partition1.put(1, Arrays.asList(s2));
+    partition1.put(2, Arrays.asList(s3));
+
+    // partition1 has 3 replicas
+    PartitionDataReplicaRequirementTracking tracking =
+        new PartitionDataReplicaRequirementTracking(1, inventory);
+    assertFalse(tracking.isSatisfied(partitionId, 1));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s1);
+    assertTrue(tracking.isSatisfied(partitionId, 1));
+    assertFalse(tracking.isSatisfied(partitionId, 2));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s2);
+    assertTrue(tracking.isSatisfied(partitionId, 1));
+    assertTrue(tracking.isSatisfied(partitionId, 2));
+    assertFalse(tracking.isSatisfied(partitionId, 3));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s3);
+    assertTrue(tracking.isSatisfied(partitionId, 1));
+    assertTrue(tracking.isSatisfied(partitionId, 2));
+    assertTrue(tracking.isSatisfied(partitionId, 3));
+  }
+
+  @Test
+  public void testMultipleReplicaWithMultiServers() {
+    ShuffleServerInfo s1 = new ShuffleServerInfo("s1", "1.1.1.1", 2);
+    ShuffleServerInfo s2 = new ShuffleServerInfo("s2", "1.1.1.1", 3);
+    ShuffleServerInfo s3 = new ShuffleServerInfo("s3", "1.1.1.1", 3);
+    ShuffleServerInfo s4 = new ShuffleServerInfo("s4", "1.1.1.1", 3);
+
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> inventory = new HashMap<>();
+    int partitionId = 0;
+
+    Map<Integer, List<ShuffleServerInfo>> partition1 =
+        inventory.computeIfAbsent(partitionId, x -> new HashMap<>());
+
+    // replicaIdx -> shuffle-servers
+    partition1.put(0, Arrays.asList(s1, s2));
+    partition1.put(1, Arrays.asList(s3, s4));
+
+    PartitionDataReplicaRequirementTracking tracking =
+        new PartitionDataReplicaRequirementTracking(1, inventory);
+    assertFalse(tracking.isSatisfied(partitionId, 1));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s1);
+    tracking.markPartitionOfServerSuccessful(partitionId, s3);
+    assertFalse(tracking.isSatisfied(partitionId, 1));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s2);
+    assertTrue(tracking.isSatisfied(partitionId, 1));
+    assertFalse(tracking.isSatisfied(partitionId, 2));
+
+    tracking.markPartitionOfServerSuccessful(partitionId, s4);
+    assertTrue(tracking.isSatisfied(partitionId, 1));
+    assertTrue(tracking.isSatisfied(partitionId, 2));
+  }
+}
diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 259c81e..4f3459d 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -114,11 +114,14 @@
     }
   }
 
-  private static ShuffleServerInfo convertToShuffleServerId(
+  private static ShuffleServerInfo convertFromShuffleServerId(
       RssProtos.ShuffleServerId shuffleServerId) {
     ShuffleServerInfo shuffleServerInfo =
         new ShuffleServerInfo(
-            shuffleServerId.getId(), shuffleServerId.getIp(), shuffleServerId.getPort(), 0);
+            shuffleServerId.getId(),
+            shuffleServerId.getIp(),
+            shuffleServerId.getPort(),
+            shuffleServerId.getNettyPort());
     return shuffleServerInfo;
   }
 
@@ -136,7 +139,7 @@
 
   public static List<ShuffleServerInfo> fromProto(List<RssProtos.ShuffleServerId> servers) {
     return servers.stream()
-        .map(server -> convertToShuffleServerId(server))
+        .map(server -> convertFromShuffleServerId(server))
         .collect(Collectors.toList());
   }
 
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignTest.java
new file mode 100644
index 0000000..562320d
--- /dev/null
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.File;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.io.TempDir;
+
+import org.apache.uniffle.common.rpc.ServerType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.server.buffer.ShuffleBufferManager;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
+import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/** This class is to test the mechanism of partition block data reassignment. */
+public class PartitionBlockDataReassignTest extends SparkSQLTest {
+
+  private static String basePath;
+
+  @BeforeAll
+  public static void setupServers(@TempDir File tmpDir) throws Exception {
+    // for coordinator
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    coordinatorConf.setLong("rss.coordinator.app.expired", 5000);
+    Map<String, String> dynamicConf = Maps.newHashMap();
+    dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    createCoordinatorServer(coordinatorConf);
+
+    // for shuffle-server
+    File dataDir1 = new File(tmpDir, "data1");
+    File dataDir2 = new File(tmpDir, "data2");
+    basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
+
+    ShuffleServerConf grpcShuffleServerConf1 = buildShuffleServerConf(ServerType.GRPC);
+    createShuffleServer(grpcShuffleServerConf1);
+
+    ShuffleServerConf grpcShuffleServerConf2 = buildShuffleServerConf(ServerType.GRPC);
+    createShuffleServer(grpcShuffleServerConf2);
+
+    ShuffleServerConf grpcShuffleServerConf3 = buildShuffleServerConf(ServerType.GRPC_NETTY);
+    createShuffleServer(grpcShuffleServerConf3);
+
+    ShuffleServerConf grpcShuffleServerConf4 = buildShuffleServerConf(ServerType.GRPC_NETTY);
+    createShuffleServer(grpcShuffleServerConf4);
+
+    startServers();
+
+    // simulate one server without enough buffer
+    ShuffleServer faultyShuffleServer = grpcShuffleServers.get(0);
+    ShuffleBufferManager bufferManager = faultyShuffleServer.getShuffleBufferManager();
+    bufferManager.setUsedMemory(bufferManager.getCapacity() + 100);
+  }
+
+  private static ShuffleServerConf buildShuffleServerConf(ServerType serverType) throws Exception {
+    ShuffleServerConf shuffleServerConf = getShuffleServerConf(serverType);
+    shuffleServerConf.setLong("rss.server.heartbeat.interval", 5000);
+    shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 4000);
+    shuffleServerConf.setString("rss.storage.basePath", basePath);
+    shuffleServerConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE.name());
+    return shuffleServerConf;
+  }
+
+  @Override
+  public void updateRssStorage(SparkConf sparkConf) {
+    sparkConf.set("spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER, "1");
+    sparkConf.set("spark." + RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(), "true");
+  }
+
+  @Override
+  public void checkShuffleData() throws Exception {
+    Thread.sleep(12000);
+    String[] paths = basePath.split(",");
+    for (String path : paths) {
+      File f = new File(path);
+      assertEquals(0, f.list().length);
+    }
+  }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
index 74d508e..66c3288 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java
@@ -17,85 +17,27 @@
 
 package org.apache.uniffle.client.response;
 
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
-
-import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.proto.RssProtos;
 
 public class RssPartitionToShuffleServerResponse extends ClientResponse {
-
-  private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
-  private Set<ShuffleServerInfo> shuffleServersForData;
-  private RemoteStorageInfo remoteStorageInfo;
+  private RssProtos.ShuffleHandleInfo shuffleHandleInfoProto;
 
   public RssPartitionToShuffleServerResponse(
-      StatusCode statusCode,
-      String message,
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
-      Set<ShuffleServerInfo> shuffleServersForData,
-      RemoteStorageInfo remoteStorageInfo) {
+      StatusCode statusCode, String message, RssProtos.ShuffleHandleInfo shuffleHandleInfoProto) {
     super(statusCode, message);
-    this.partitionToServers = partitionToServers;
-    this.remoteStorageInfo = remoteStorageInfo;
-    this.shuffleServersForData = shuffleServersForData;
+    this.shuffleHandleInfoProto = shuffleHandleInfoProto;
   }
 
-  public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
-    return partitionToServers;
-  }
-
-  public Set<ShuffleServerInfo> getShuffleServersForData() {
-    return shuffleServersForData;
-  }
-
-  public RemoteStorageInfo getRemoteStorageInfo() {
-    return remoteStorageInfo;
+  public RssProtos.ShuffleHandleInfo getShuffleHandleInfoProto() {
+    return shuffleHandleInfoProto;
   }
 
   public static RssPartitionToShuffleServerResponse fromProto(
       RssProtos.PartitionToShuffleServerResponse response) {
-    Map<Integer, RssProtos.GetShuffleServerListResponse> partitionToShuffleServerMap =
-        response.getPartitionToShuffleServerMap();
-    Map<Integer, List<ShuffleServerInfo>> rpcPartitionToShuffleServerInfos = Maps.newHashMap();
-    Set<Map.Entry<Integer, RssProtos.GetShuffleServerListResponse>> entries =
-        partitionToShuffleServerMap.entrySet();
-    for (Map.Entry<Integer, RssProtos.GetShuffleServerListResponse> entry : entries) {
-      Integer partitionId = entry.getKey();
-      List<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList();
-      List<? extends RssProtos.ShuffleServerIdOrBuilder> serversOrBuilderList =
-          entry.getValue().getServersOrBuilderList();
-      for (RssProtos.ShuffleServerIdOrBuilder shuffleServerIdOrBuilder : serversOrBuilderList) {
-        shuffleServerInfos.add(
-            new ShuffleServerInfo(
-                shuffleServerIdOrBuilder.getId(),
-                shuffleServerIdOrBuilder.getIp(),
-                shuffleServerIdOrBuilder.getPort(),
-                shuffleServerIdOrBuilder.getNettyPort()));
-      }
-
-      rpcPartitionToShuffleServerInfos.put(partitionId, shuffleServerInfos);
-    }
-    Set<ShuffleServerInfo> rpcShuffleServersForData = Sets.newHashSet();
-    for (List<ShuffleServerInfo> ssis : rpcPartitionToShuffleServerInfos.values()) {
-      rpcShuffleServersForData.addAll(ssis);
-    }
-    RssProtos.RemoteStorageInfo protoRemoteStorageInfo = response.getRemoteStorageInfo();
-    RemoteStorageInfo rpcRemoteStorageInfo =
-        new RemoteStorageInfo(
-            protoRemoteStorageInfo.getPath(), protoRemoteStorageInfo.getConfItemsMap());
     return new RssPartitionToShuffleServerResponse(
         StatusCode.valueOf(response.getStatus().name()),
         response.getMsg(),
-        rpcPartitionToShuffleServerInfos,
-        rpcShuffleServersForData,
-        rpcRemoteStorageInfo);
+        response.getShuffleHandleInfo());
   }
 }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 9ac38b7..720002c 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -551,9 +551,22 @@
 
 message PartitionToShuffleServerResponse {
   StatusCode status = 1;
-  map<int32,GetShuffleServerListResponse> partitionToShuffleServer = 2;
-  RemoteStorageInfo remote_storage_info = 3;
-  string msg = 4;
+  string msg = 2;
+  ShuffleHandleInfo shuffleHandleInfo = 3;
+}
+
+message ShuffleHandleInfo {
+  int32 shuffleId = 1;
+  map<int32, PartitionReplicaServers> partitionToServers = 2;
+  RemoteStorageInfo remoteStorageInfo = 3;
+}
+
+message PartitionReplicaServers {
+  map<int32, ReplicaServersItem> replicaServers = 1;
+}
+
+message ReplicaServersItem {
+  repeated ShuffleServerId serverId = 1;
 }
 
 message RemoteStorageInfo{
diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index 35ef3fa..578b6c4 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -712,4 +712,8 @@
     }
     return false;
   }
+
+  public void setUsedMemory(long usedMemory) {
+    this.usedMemory.set(usedMemory);
+  }
 }