[#1608][part-2] fix(spark): avoid releasing block in advance when enable block resend (#1610)

### What changes were proposed in this pull request?

1. avoid releasing block previously when enable block resend
2. introduce the block max retry times

### Why are the changes needed?

For: #1608

In the current codebase for partition reassignment, it has some bugs as follows
1. data has been released when resending.
2. if the blocks fail to resend, it may fast fail without retry again

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

`RssShuffleWriterTest#blockFailureResendTest` is to test the resending block mechanism.
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
index 5a93c2b..9751ba0 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
@@ -34,14 +34,6 @@
     this.processedCallbackChain = new ArrayList<>();
   }
 
-  public AddBlockEvent(
-      String taskId, List<ShuffleBlockInfo> shuffleBlockInfoList, Runnable callback) {
-    this.taskId = taskId;
-    this.shuffleDataInfoList = shuffleBlockInfoList;
-    this.processedCallbackChain = new ArrayList<>();
-    addCallback(callback);
-  }
-
   /** @param callback, should not throw any exception and execute fast. */
   public void addCallback(Runnable callback) {
     processedCallbackChain.add(callback);
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java
new file mode 100644
index 0000000..116d194
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+public interface BlockFailureCallback {
+  void onBlockFailure(ShuffleBlockInfo block);
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java
new file mode 100644
index 0000000..2b5dc0d
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+public interface BlockSuccessCallback {
+  void onBlockSuccess(ShuffleBlockInfo block);
+}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index 30f649f..1517b71 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -88,14 +88,23 @@
         () -> {
           String taskId = event.getTaskId();
           List<ShuffleBlockInfo> shuffleBlockInfoList = event.getShuffleDataInfoList();
+          SendShuffleDataResult result = null;
           try {
-            SendShuffleDataResult result =
+            result =
                 shuffleWriteClient.sendShuffleData(
                     rssAppId, shuffleBlockInfoList, () -> !isValidTask(taskId));
             putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
             putFailedBlockSendTracker(
                 taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker());
           } finally {
+            Set<Long> succeedBlockIds =
+                result.getSuccessBlockIds() == null
+                    ? Collections.emptySet()
+                    : result.getSuccessBlockIds();
+            for (ShuffleBlockInfo block : shuffleBlockInfoList) {
+              block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
+            }
+
             List<Runnable> callbackChain =
                 Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
             for (Runnable runnable : callbackChain) {
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index d826104..efe376a 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -408,14 +408,18 @@
     }
   }
 
+  public void releaseBlockResource(ShuffleBlockInfo block) {
+    this.freeAllocatedMemory(block.getFreeMemory());
+    block.getData().release();
+  }
+
   public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockInfoList) {
     long totalSize = 0;
-    long memoryUsed = 0;
     List<AddBlockEvent> events = new ArrayList<>();
     List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
     for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
+      sbi.withCompletionCallback((block, isSuccessful) -> this.releaseBlockResource(block));
       totalSize += sbi.getSize();
-      memoryUsed += sbi.getFreeMemory();
       shuffleBlockInfosPerEvent.add(sbi);
       // split shuffle data according to the size
       if (totalSize > sendSizeLimit) {
@@ -427,20 +431,9 @@
                   + totalSize
                   + " bytes");
         }
-        // Use final temporary variables for closures
-        final long memoryUsedTemp = memoryUsed;
-        final List<ShuffleBlockInfo> shuffleBlocksTemp = shuffleBlockInfosPerEvent;
-        events.add(
-            new AddBlockEvent(
-                taskId,
-                shuffleBlockInfosPerEvent,
-                () -> {
-                  freeAllocatedMemory(memoryUsedTemp);
-                  shuffleBlocksTemp.stream().forEach(x -> x.getData().release());
-                }));
+        events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
         shuffleBlockInfosPerEvent = Lists.newArrayList();
         totalSize = 0;
-        memoryUsed = 0;
       }
     }
     if (!shuffleBlockInfosPerEvent.isEmpty()) {
@@ -453,16 +446,7 @@
                 + " bytes");
       }
       // Use final temporary variables for closures
-      final long memoryUsedTemp = memoryUsed;
-      final List<ShuffleBlockInfo> shuffleBlocksTemp = shuffleBlockInfosPerEvent;
-      events.add(
-          new AddBlockEvent(
-              taskId,
-              shuffleBlockInfosPerEvent,
-              () -> {
-                freeAllocatedMemory(memoryUsedTemp);
-                shuffleBlocksTemp.stream().forEach(x -> x.getData().release());
-              }));
+      events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
     }
     return events;
   }
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 38ebbbd..22143bc 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -371,6 +371,9 @@
           long sum = 0L;
           List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
           for (AddBlockEvent event : events) {
+            for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+              block.executeCompletionCallback(true);
+            }
             event.getProcessedCallbackChain().stream().forEach(x -> x.run());
             sum += event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum();
           }
@@ -413,6 +416,9 @@
                             // ignore.
                           }
                         }
+                        for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+                          block.executeCompletionCallback(true);
+                        }
                         event.getProcessedCallbackChain().stream().forEach(x -> x.run());
                         sum +=
                             event.getShuffleDataInfoList().stream()
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0b4faef..1b4df17 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -1264,4 +1264,9 @@
   public boolean isRssResubmitStage() {
     return rssResubmitStage;
   }
+
+  @VisibleForTesting
+  public void setDataPusher(DataPusher dataPusher) {
+    this.dataPusher = dataPusher;
+  }
 }
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 635b359..8a22b73 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -21,10 +21,10 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
@@ -46,6 +46,7 @@
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
@@ -83,6 +84,7 @@
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
 public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
@@ -94,7 +96,7 @@
   private final String appId;
   private final int shuffleId;
   private WriteBufferManager bufferManager;
-  private final String taskId;
+  private String taskId;
   private final int numMaps;
   private final ShuffleDependency<K, V, C> shuffleDependency;
   private final Partitioner partitioner;
@@ -113,7 +115,8 @@
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
   private SparkConf sparkConf;
-  private boolean blockSendFailureRetryEnabled;
+  private boolean blockFailSentRetryEnabled;
+  private int blockFailSentRetryMaxTimes = 1;
 
   /** used by columnar rss shuffle writer implementation */
   protected final long taskAttemptId;
@@ -122,7 +125,9 @@
 
   private final BlockingQueue<Object> finishEventQueue = new LinkedBlockingQueue<>();
 
-  private final Map<String, ShuffleServerInfo> faultyServers = new HashMap<>();
+  // shuffleServerId -> failoverShuffleServer
+  private final Map<String, ShuffleServerInfo> replacementShuffleServers =
+      JavaUtils.newConcurrentMap();
 
   // Only for tests
   @VisibleForTesting
@@ -192,7 +197,7 @@
     this.taskFailureCallback = taskFailureCallback;
     this.taskContext = context;
     this.sparkConf = sparkConf;
-    this.blockSendFailureRetryEnabled =
+    this.blockFailSentRetryEnabled =
         sparkConf.getBoolean(
             RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
                 + RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(),
@@ -269,8 +274,8 @@
     long recordCount = 0;
     while (records.hasNext()) {
       recordCount++;
-      // Task should fast fail when sending data failed
-      checkIfBlocksFailed();
+
+      checkDataIfAnyFailure();
 
       Product2<K, V> record = records.next();
       K key = record._1();
@@ -363,6 +368,17 @@
       List<ShuffleBlockInfo> shuffleBlockInfoList) {
     List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+      if (blockFailSentRetryEnabled) {
+        // do nothing if failed.
+        for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+          block.withCompletionCallback(
+              (completionBlock, isSuccessful) -> {
+                if (isSuccessful) {
+                  bufferManager.releaseBlockResource(completionBlock);
+                }
+              });
+        }
+      }
       event.addCallback(
           () -> {
             boolean ret = finishEventQueue.add(new Object());
@@ -386,7 +402,7 @@
       while (true) {
         try {
           finishEventQueue.clear();
-          checkIfBlocksFailed();
+          checkDataIfAnyFailure();
           Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
           blockIds.removeAll(successBlockIds);
           if (blockIds.isEmpty()) {
@@ -422,105 +438,128 @@
     }
   }
 
-  private void checkIfBlocksFailed() {
-    Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
-    if (blockSendFailureRetryEnabled && !failedBlockIds.isEmpty()) {
-      Set<TrackingBlockStatus> shouldResendBlockSet = shouldResendBlockStatusSet(failedBlockIds);
-      try {
-        reSendFailedBlockIds(shouldResendBlockSet);
-      } catch (Exception e) {
-        LOG.error("resend failed blocks failed.", e);
+  private void checkDataIfAnyFailure() {
+    if (blockFailSentRetryEnabled) {
+      collectFailedBlocksToResend();
+    } else {
+      if (hasAnyBlockFailure()) {
+        throw new RssSendFailedException("Fail to send the block");
       }
-      failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
     }
+  }
+
+  private boolean hasAnyBlockFailure() {
+    Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
     if (!failedBlockIds.isEmpty()) {
-      String errorMsg =
-          "Send failed: Task["
-              + taskId
-              + "]"
-              + " failed because "
-              + failedBlockIds.size()
-              + " blocks can't be sent to shuffle server: "
-              + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers();
-      LOG.error(errorMsg);
-      throw new RssSendFailedException(errorMsg);
+      LOG.error(
+          "Errors on sending blocks for task[{}]. {} blocks can't be sent to remote servers: {}",
+          taskId,
+          failedBlockIds.size(),
+          shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers());
+      return true;
     }
+    return false;
   }
 
-  private Set<TrackingBlockStatus> shouldResendBlockStatusSet(Set<Long> failedBlockIds) {
-    FailedBlockSendTracker failedBlockTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId);
-    Set<TrackingBlockStatus> resendBlockStatusSet = Sets.newHashSet();
-    for (Long failedBlockId : failedBlockIds) {
-      failedBlockTracker.getFailedBlockStatus(failedBlockId).stream()
-          // todo: more status need reassign
-          .filter(
-              trackingBlockStatus -> trackingBlockStatus.getStatusCode() == StatusCode.NO_BUFFER)
-          .forEach(trackingBlockStatus -> resendBlockStatusSet.add(trackingBlockStatus));
+  private void collectFailedBlocksToResend() {
+    if (!blockFailSentRetryEnabled) {
+      return;
     }
-    return resendBlockStatusSet;
+
+    FailedBlockSendTracker failedTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId);
+    Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
+    if (CollectionUtils.isEmpty(failedBlockIds)) {
+      return;
+    }
+
+    boolean isFastFail = false;
+    Set<TrackingBlockStatus> resendCandidates = new HashSet<>();
+    // to check whether the blocks resent exceed the max resend count.
+    for (Long blockId : failedBlockIds) {
+      List<TrackingBlockStatus> failedBlockStatus = failedTracker.getFailedBlockStatus(blockId);
+      int retryIndex = failedBlockStatus.get(0).getShuffleBlockInfo().getRetryCnt();
+      // todo: support retry times by config
+      if (retryIndex >= blockFailSentRetryMaxTimes) {
+        LOG.error(
+            "Partial blocks for taskId: [{}] retry exceeding the max retry times: [{}]. Fast fail! faulty server list: {}",
+            taskId,
+            blockFailSentRetryMaxTimes,
+            failedBlockStatus.stream()
+                .map(x -> x.getShuffleServerInfo())
+                .collect(Collectors.toSet()));
+        isFastFail = true;
+        break;
+      }
+
+      // todo: if setting multi replica and another replica is succeed to send, no need to resend
+      resendCandidates.addAll(failedBlockStatus);
+    }
+
+    if (isFastFail) {
+      // release data and allocated memory
+      for (Long blockId : failedBlockIds) {
+        List<TrackingBlockStatus> failedBlockStatus = failedTracker.getFailedBlockStatus(blockId);
+        Optional<TrackingBlockStatus> blockStatus = failedBlockStatus.stream().findFirst();
+        if (blockStatus.isPresent()) {
+          blockStatus.get().getShuffleBlockInfo().executeCompletionCallback(true);
+        }
+      }
+
+      throw new RssSendFailedException(
+          "Errors on resending the blocks data to the remote shuffle-server.");
+    }
+
+    resendFailedBlocks(resendCandidates);
   }
 
-  private void reSendFailedBlockIds(Set<TrackingBlockStatus> failedBlockStatusSet) {
-    List<ShuffleBlockInfo> reAssignSeverBlockInfoList = Lists.newArrayList();
-    List<ShuffleBlockInfo> failedBlockInfoList = Lists.newArrayList();
+  private void resendFailedBlocks(Set<TrackingBlockStatus> failedBlockStatusSet) {
+    List<ShuffleBlockInfo> reassignBlocks = Lists.newArrayList();
     Map<ShuffleServerInfo, List<TrackingBlockStatus>> faultyServerToPartitions =
         failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()));
-    faultyServerToPartitions.entrySet().stream()
+
+    for (Map.Entry<ShuffleServerInfo, List<TrackingBlockStatus>> entry :
+        faultyServerToPartitions.entrySet()) {
+      Set<Integer> partitionIds =
+          entry.getValue().stream()
+              .map(x -> x.getShuffleBlockInfo().getPartitionId())
+              .collect(Collectors.toSet());
+      ShuffleServerInfo replacement = replacementShuffleServers.get(entry.getKey().getId());
+      if (replacement == null) {
+        // todo: merge multiple requests into one.
+        replacement = reassignFaultyShuffleServer(partitionIds, entry.getKey().getId());
+        replacementShuffleServers.put(entry.getKey().getId(), replacement);
+      }
+
+      for (TrackingBlockStatus blockStatus : failedBlockStatusSet) {
+        // clear the previous retry state of block
+        ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
+        clearFailedBlockState(block);
+
+        final ShuffleBlockInfo newBlock = block;
+        newBlock.incrRetryCnt();
+        newBlock.reassignShuffleServers(Arrays.asList(replacement));
+
+        reassignBlocks.add(newBlock);
+      }
+    }
+
+    processShuffleBlockInfos(reassignBlocks);
+  }
+
+  private void clearFailedBlockState(ShuffleBlockInfo block) {
+    shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId());
+    block.getShuffleServerInfos().stream()
+        .filter(s -> replacementShuffleServers.containsKey(s.getId()))
         .forEach(
-            t -> {
-              Set<Integer> partitionIds =
-                  t.getValue().stream()
-                      .map(x -> x.getShuffleBlockInfo().getPartitionId())
-                      .collect(Collectors.toSet());
-              ShuffleServerInfo dynamicShuffleServer = faultyServers.get(t.getKey().getId());
-              if (dynamicShuffleServer == null) {
-                dynamicShuffleServer =
-                    reAssignFaultyShuffleServer(partitionIds, t.getKey().getId());
-                faultyServers.put(t.getKey().getId(), dynamicShuffleServer);
-              }
-
-              ShuffleServerInfo finalDynamicShuffleServer = dynamicShuffleServer;
-              failedBlockStatusSet.forEach(
-                  trackingBlockStatus -> {
-                    ShuffleBlockInfo failedBlockInfo = trackingBlockStatus.getShuffleBlockInfo();
-                    failedBlockInfoList.add(failedBlockInfo);
-                    reAssignSeverBlockInfoList.add(
-                        new ShuffleBlockInfo(
-                            failedBlockInfo.getShuffleId(),
-                            failedBlockInfo.getPartitionId(),
-                            failedBlockInfo.getBlockId(),
-                            failedBlockInfo.getLength(),
-                            failedBlockInfo.getCrc(),
-                            failedBlockInfo.getData(),
-                            Lists.newArrayList(finalDynamicShuffleServer),
-                            failedBlockInfo.getUncompressLength(),
-                            failedBlockInfo.getFreeMemory(),
-                            taskAttemptId));
-                  });
-            });
-    clearFailedBlockIdsStates(failedBlockInfoList, faultyServers);
-    processShuffleBlockInfos(reAssignSeverBlockInfoList);
-    checkIfBlocksFailed();
+            s ->
+                serverToPartitionToBlockIds
+                    .get(s)
+                    .get(block.getPartitionId())
+                    .remove(block.getBlockId()));
+    partitionLengths[block.getPartitionId()] -= block.getLength();
   }
 
-  private void clearFailedBlockIdsStates(
-      List<ShuffleBlockInfo> failedBlockInfoList, Map<String, ShuffleServerInfo> faultyServers) {
-    failedBlockInfoList.forEach(
-        shuffleBlockInfo -> {
-          shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(shuffleBlockInfo.getBlockId());
-          shuffleBlockInfo.getShuffleServerInfos().stream()
-              .filter(s -> faultyServers.containsKey(s.getId()))
-              .forEach(
-                  s ->
-                      serverToPartitionToBlockIds
-                          .get(s)
-                          .get(shuffleBlockInfo.getPartitionId())
-                          .remove(shuffleBlockInfo.getBlockId()));
-          partitionLengths[shuffleBlockInfo.getPartitionId()] -= shuffleBlockInfo.getLength();
-        });
-  }
-
-  private ShuffleServerInfo reAssignFaultyShuffleServer(
+  private ShuffleServerInfo reassignFaultyShuffleServer(
       Set<Integer> partitionIds, String faultyServerId) {
     RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
     String driver = rssConf.getString("driver.host", "");
@@ -611,6 +650,17 @@
         return Option.empty();
       }
     } finally {
+      if (blockFailSentRetryEnabled) {
+        if (success) {
+          if (CollectionUtils.isNotEmpty(shuffleManager.getFailedBlockIds(taskId))) {
+            LOG.error(
+                "Errors on stopping writer due to the remaining failed blockIds. This should not happen.");
+            return Option.empty();
+          }
+        } else {
+          shuffleManager.getBlockIdsFailedSendTracker(taskId).clearAndReleaseBlockResources();
+        }
+      }
       // free all memory & metadata, or memory leak happen in executor
       if (bufferManager != null) {
         bufferManager.freeAllMemory();
@@ -694,4 +744,29 @@
     }
     throw new RssException(e);
   }
+
+  @VisibleForTesting
+  protected void enableBlockFailSentRetry() {
+    this.blockFailSentRetryEnabled = true;
+  }
+
+  @VisibleForTesting
+  protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes) {
+    this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
+  }
+
+  @VisibleForTesting
+  protected void addReassignmentShuffleServer(String shuffleId, ShuffleServerInfo replacement) {
+    replacementShuffleServers.put(shuffleId, replacement);
+  }
+
+  @VisibleForTesting
+  protected void setTaskId(String taskId) {
+    this.taskId = taskId;
+  }
+
+  @VisibleForTesting
+  protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>> getServerToPartitionToBlockIds() {
+    return serverToPartitionToBlockIds;
+  }
 }
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index b68d4b7..5ca85ec 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -26,6 +26,7 @@
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -64,6 +65,7 @@
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
 import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
@@ -73,6 +75,198 @@
 
 public class RssShuffleWriterTest {
 
+  private MutableList<Product2<String, String>> createMockRecords() {
+    MutableList<Product2<String, String>> data = new MutableList<>();
+    data.appendElem(new Tuple2<>("testKey2", "testValue2"));
+    data.appendElem(new Tuple2<>("testKey3", "testValue3"));
+    data.appendElem(new Tuple2<>("testKey4", "testValue4"));
+    data.appendElem(new Tuple2<>("testKey6", "testValue6"));
+    data.appendElem(new Tuple2<>("testKey1", "testValue1"));
+    data.appendElem(new Tuple2<>("testKey5", "testValue5"));
+    return data;
+  }
+
+  @Test
+  public void blockFailureResendTest() throws Exception {
+    SparkConf conf = new SparkConf();
+    conf.setAppName("testApp")
+        .setMaster("local[2]")
+        .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+        .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
+
+    List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
+    Map<String, Set<Long>> successBlockIds = JavaUtils.newConcurrentMap();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
+    taskToFailedBlockSendTracker.put("taskId", new FailedBlockSendTracker());
+
+    AtomicInteger sentFailureCnt = new AtomicInteger();
+    FakedDataPusher dataPusher =
+        new FakedDataPusher(
+            event -> {
+              assertEquals("taskId", event.getTaskId());
+              FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(event.getTaskId());
+              for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+                boolean isSuccessful = true;
+                ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0);
+                if (shuffleServer.getId().equals("id1") && block.getRetryCnt() == 0) {
+                  tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
+                  sentFailureCnt.addAndGet(1);
+                  isSuccessful = false;
+                } else {
+                  successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet());
+                  successBlockIds.get(event.getTaskId()).add(block.getBlockId());
+                  shuffleBlockInfos.add(block);
+                }
+                block.executeCompletionCallback(isSuccessful);
+              }
+              return new CompletableFuture<>();
+            });
+
+    final RssShuffleManager manager =
+        TestUtils.createShuffleManager(
+            conf, false, dataPusher, successBlockIds, taskToFailedBlockSendTracker);
+    Serializer kryoSerializer = new KryoSerializer(conf);
+    Partitioner mockPartitioner = mock(Partitioner.class);
+    final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
+    ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
+    RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
+    when(mockHandle.getDependency()).thenReturn(mockDependency);
+    when(mockDependency.serializer()).thenReturn(kryoSerializer);
+    when(mockDependency.partitioner()).thenReturn(mockPartitioner);
+    when(mockPartitioner.numPartitions()).thenReturn(3);
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
+    List<ShuffleServerInfo> ssi12 =
+        Arrays.asList(
+            new ShuffleServerInfo("id1", "0.0.0.1", 100),
+            new ShuffleServerInfo("id2", "0.0.0.2", 100));
+    partitionToServers.put(0, ssi12);
+    List<ShuffleServerInfo> ssi34 =
+        Arrays.asList(
+            new ShuffleServerInfo("id3", "0.0.0.3", 100),
+            new ShuffleServerInfo("id4", "0.0.0.4", 100));
+    partitionToServers.put(1, ssi34);
+    List<ShuffleServerInfo> ssi56 =
+        Arrays.asList(
+            new ShuffleServerInfo("id5", "0.0.0.5", 100),
+            new ShuffleServerInfo("id6", "0.0.0.6", 100));
+    partitionToServers.put(2, ssi56);
+    when(mockPartitioner.getPartition("testKey1")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey2")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey4")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey5")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey3")).thenReturn(2);
+    when(mockPartitioner.getPartition("testKey7")).thenReturn(0);
+    when(mockPartitioner.getPartition("testKey8")).thenReturn(1);
+    when(mockPartitioner.getPartition("testKey9")).thenReturn(2);
+    when(mockPartitioner.getPartition("testKey6")).thenReturn(2);
+
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
+    WriteBufferManager bufferManager =
+        new WriteBufferManager(
+            0,
+            0,
+            bufferOptions,
+            kryoSerializer,
+            partitionToServers,
+            mockTaskMemoryManager,
+            shuffleWriteMetrics,
+            RssSparkConfig.toRssConf(conf));
+    bufferManager.setTaskId("taskId");
+
+    WriteBufferManager bufferManagerSpy = spy(bufferManager);
+    TaskContext contextMock = mock(TaskContext.class);
+    ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
+    RssShuffleWriter<String, String, String> rssShuffleWriter =
+        new RssShuffleWriter<>(
+            "appId",
+            0,
+            "taskId",
+            1L,
+            bufferManagerSpy,
+            shuffleWriteMetrics,
+            manager,
+            conf,
+            mockShuffleWriteClient,
+            mockHandle,
+            mockShuffleHandleInfo,
+            contextMock);
+    rssShuffleWriter.enableBlockFailSentRetry();
+    doReturn(100000L).when(bufferManagerSpy).acquireMemory(anyLong());
+    ShuffleServerInfo replacement = new ShuffleServerInfo("id10", "0.0.0.10", 100);
+    rssShuffleWriter.addReassignmentShuffleServer("id1", replacement);
+
+    RssShuffleWriter<String, String, String> rssShuffleWriterSpy = spy(rssShuffleWriter);
+    doNothing().when(rssShuffleWriterSpy).sendCommit();
+
+    // case 1. failed blocks will be resent
+    MutableList<Product2<String, String>> data = createMockRecords();
+    rssShuffleWriterSpy.write(data.iterator());
+
+    Awaitility.await()
+        .timeout(Duration.ofSeconds(5))
+        .until(() -> successBlockIds.get("taskId").size() == data.size());
+    assertEquals(2, sentFailureCnt.get());
+    assertEquals(0, taskToFailedBlockSendTracker.get("taskId").getFailedBlockIds().size());
+    assertEquals(6, shuffleWriteMetrics.recordsWritten());
+    assertEquals(
+        shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
+        shuffleWriteMetrics.bytesWritten());
+    assertEquals(6, shuffleBlockInfos.size());
+
+    assertEquals(0, bufferManagerSpy.getUsedBytes());
+    assertEquals(0, bufferManagerSpy.getInSendListBytes());
+
+    // check the blockId -> servers mapping.
+    // server -> partitionId -> blockIds
+    Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds =
+        rssShuffleWriterSpy.getServerToPartitionToBlockIds();
+    assertEquals(2, serverToPartitionToBlockIds.get(replacement).get(0).size());
+
+    // case2. If exceeding the max retry times, it will fast fail.
+    rssShuffleWriterSpy.setBlockFailSentRetryMaxTimes(1);
+    rssShuffleWriterSpy.setTaskId("taskId2");
+    FakedDataPusher alwaysFailedDataPusher =
+        new FakedDataPusher(
+            event -> {
+              assertEquals("taskId2", event.getTaskId());
+              FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(event.getTaskId());
+              for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+                boolean isSuccessful = true;
+                ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0);
+                if (shuffleServer.getId().equals("id1")) {
+                  tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
+                  isSuccessful = false;
+                } else {
+                  successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet());
+                  successBlockIds.get(event.getTaskId()).add(block.getBlockId());
+                }
+                block.executeCompletionCallback(isSuccessful);
+              }
+              return new CompletableFuture<>();
+            });
+    manager.setDataPusher(alwaysFailedDataPusher);
+
+    MutableList<Product2<String, String>> mockedData = createMockRecords();
+    try {
+      rssShuffleWriterSpy.write(mockedData.iterator());
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+    assertEquals(0, bufferManagerSpy.getUsedBytes());
+    assertEquals(0, bufferManagerSpy.getInSendListBytes());
+  }
+
   @Test
   public void checkBlockSendResultTest() {
     SparkConf conf = new SparkConf();
@@ -161,8 +355,7 @@
         assertThrows(
             RuntimeException.class,
             () -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L)));
-    System.out.println(e2.getMessage());
-    assertTrue(e3.getMessage().startsWith("Send failed:"));
+    assertTrue(e3.getMessage().startsWith("Fail to send the block"));
     successBlocks.clear();
     taskToFailedBlockSendTracker.clear();
   }
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
index 0c239c7..93e20dd 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
@@ -32,6 +32,12 @@
 
 public class FailedBlockSendTracker {
 
+  /**
+   * blockId -> list(trackingStatus)
+   *
+   * <p>This indicates the blockId latest sending status, and it will not store the resending
+   * history. The list data structure is to describe the multiple servers for the multiple replica
+   */
   private Map<Long, List<TrackingBlockStatus>> trackingBlockStatusMap;
 
   public FailedBlockSendTracker() {
@@ -55,7 +61,10 @@
     trackingBlockStatusMap.remove(blockId);
   }
 
-  public void clear() {
+  public void clearAndReleaseBlockResources() {
+    trackingBlockStatusMap.values().stream()
+        .flatMap(x -> x.stream())
+        .forEach(x -> x.getShuffleBlockInfo().executeCompletionCallback(true));
     trackingBlockStatusMap.clear();
   }
 
diff --git a/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java b/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java
new file mode 100644
index 0000000..01ba694
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+public interface BlockCompletionCallback {
+  void onBlockCompletion(ShuffleBlockInfo block, boolean isSuccessful);
+}
diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index 8de75d9..36dec5e 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -36,6 +36,9 @@
   private List<ShuffleServerInfo> shuffleServerInfos;
   private int uncompressLength;
   private long freeMemory;
+  private int retryCnt = 0;
+
+  private transient BlockCompletionCallback completionCallback;
 
   public ShuffleBlockInfo(
       int shuffleId,
@@ -153,7 +156,30 @@
     return sb.toString();
   }
 
+  public void incrRetryCnt() {
+    this.retryCnt += 1;
+  }
+
+  public int getRetryCnt() {
+    return retryCnt;
+  }
+
+  public void reassignShuffleServers(List<ShuffleServerInfo> replacements) {
+    this.shuffleServerInfos = replacements;
+  }
+
   public synchronized void copyDataTo(ByteBuf to) {
     ByteBufUtils.copyByteBuf(data, to);
   }
+
+  public void withCompletionCallback(BlockCompletionCallback callback) {
+    this.completionCallback = callback;
+  }
+
+  public void executeCompletionCallback(boolean isSuccessful) {
+    if (completionCallback == null) {
+      return;
+    }
+    completionCallback.onBlockCompletion(this, isSuccessful);
+  }
 }
diff --git a/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java
new file mode 100644
index 0000000..2a46387
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.function;
+
+@FunctionalInterface
+public interface TupleConsumer<T, F> {
+  void accept(T t, F f);
+}