[#1373][FOLLOWUP] fix(spark): register with incorrect partitionRanges after reassign (#1612)

### What changes were proposed in this pull request?

fix partition id inconsistency when reassign new shuffle server

For example:
when writing data on node a1, the registered partition id is 1003.
a1 node fails,and reassign node b1 and register shuffle server b1,but partitionNumPerRange is 1.
when writing data to node b1, NO_REGISTER exception will be thrown

### Why are the changes needed?

Fix: (#1373)

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

---------

Co-authored-by: shun01.ding <shun01.ding@vipshop.com>
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1b4df17..6d9487c 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -18,7 +18,10 @@
 package org.apache.spark.shuffle;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -27,6 +30,7 @@
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
@@ -1157,7 +1161,8 @@
               1,
               requiredShuffleServerNumber,
               estimateTaskConcurrency,
-              failuresShuffleServerIds);
+              failuresShuffleServerIds,
+              null);
       /**
        * we need to clear the metadata of the completed task, otherwise some of the stage's data
        * will be lost
@@ -1196,24 +1201,54 @@
       }
 
       // get the newer server to replace faulty server.
-      ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, faultyShuffleServerId);
+      ShuffleServerInfo newAssignedServer =
+          reassignShuffleServerForTask(shuffleId, partitionIds, faultyShuffleServerId);
       if (newAssignedServer != null) {
         handleInfo.createNewReassignmentForMultiPartitions(
             partitionIds, faultyShuffleServerId, newAssignedServer);
       }
+      LOG.info(
+          "Reassign shuffle-server from {} -> {} for shuffleId: {}, partitionIds: {}",
+          faultyShuffleServerId,
+          newAssignedServer,
+          shuffleId,
+          partitionIds);
       return newAssignedServer;
     }
   }
 
-  private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
+  private ShuffleServerInfo reassignShuffleServerForTask(
+      int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
     Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
     faultyServerIds.addAll(failuresShuffleServerIds);
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers =
-        requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
-    if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) {
-      return partitionToServers.get(0).get(0);
-    }
-    return null;
+    AtomicReference<ShuffleServerInfo> replacementRef = new AtomicReference<>();
+    requestShuffleAssignment(
+        shuffleId,
+        1,
+        1,
+        1,
+        1,
+        faultyServerIds,
+        shuffleAssignmentsInfo -> {
+          if (shuffleAssignmentsInfo == null) {
+            return null;
+          }
+          Optional<List<ShuffleServerInfo>> replacementOpt =
+              shuffleAssignmentsInfo.getPartitionToServers().values().stream().findFirst();
+          ShuffleServerInfo replacement = replacementOpt.get().get(0);
+          replacementRef.set(replacement);
+
+          Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new HashMap<>();
+          List<PartitionRange> partitionRanges = new ArrayList<>();
+          for (Integer partitionId : partitionIds) {
+            newPartitionToServers.put(partitionId, Arrays.asList(replacement));
+            partitionRanges.add(new PartitionRange(partitionId, partitionId));
+          }
+          Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new HashMap<>();
+          serverToPartitionRanges.put(replacement, partitionRanges);
+          return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
+        });
+    return replacementRef.get();
   }
 
   private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
@@ -1222,7 +1257,8 @@
       int partitionNumPerRange,
       int assignmentShuffleServerNumber,
       int estimateTaskConcurrency,
-      Set<String> faultyServerIds) {
+      Set<String> faultyServerIds,
+      Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> reassignmentHandler) {
     Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
     ClientUtils.validateClientType(clientType);
     assignmentTags.add(clientType);
@@ -1242,6 +1278,9 @@
                     assignmentShuffleServerNumber,
                     estimateTaskConcurrency,
                     faultyServerIds);
+            if (reassignmentHandler != null) {
+              response = reassignmentHandler.apply(response);
+            }
             registerShuffleServers(
                 id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
             return response.getPartitionToServers();