[#1571] fix(server): Memory may leak when `EventInvalidException` occurs (#1574)

### What changes were proposed in this pull request?

In the implementation of the methods `flushBuffer`, `handleEventAndUpdateMetrics`, and `removeBufferByShuffleId`, read-write locks have been added to manage concurrency. This ensures that a `ShuffleBuffer` successfully converted into a `flushEvent` won't be cleaned up again by `removeBufferByShuffleId`, and a `ShuffleBuffer` already cleaned up by `removeBufferByShuffleId` won't be transformed back into a `flushEvent`. This effectively resolves the concurrency issue.

### Why are the changes needed?

Fix https://github.com/apache/incubator-uniffle/issues/1571 & https://github.com/apache/incubator-uniffle/issues/1560 & https://github.com/apache/incubator-uniffle/issues/1542

The key logic of the PR is as follows:

Before this PR:
1. A `ShuffleBuffer` is turned into a `FlushEvent`, and **_its blocks and size are cleared_**
→
2. The `FlushEvent` is added to the flushing queue
→
3. The method `removeBufferByShuffleId` is executed, which causes the following things to happen:

3.1. Running the following code snippet, but please note that in the code below, `buffer.getBlocks()` **_is empty and size is 0_**, because of the step 1 above:
```
for (ShuffleBuffer buffer : buffers) {
  buffer.getBlocks().forEach(spb -> spb.getData().release());
  ShuffleServerMetrics.gaugeTotalPartitionNum.dec();
  size += buffer.getSize();
}
```

3.2. `appId` is removed from the `bufferPool`
→
4. The `FlushEvent` is taken out from the queue and encounters an `EventInvalidException` because the `appId` was removed before
→
5. When handling the `EventInvalidException`, nothing is done and the `event.doCleanup()` method **_is not called, causing a memory leak_**.
Of course, this is just one scenario of concurrency exceptions. In the previous code, without locking, in the `processFlushEvent` method, it is possible that the event may become invalid at any time when continuing executing in `processFlushEvent` method, which is why there is https://github.com/apache/incubator-uniffle/issues/1542. Also, there is https://github.com/apache/incubator-uniffle/issues/1560.

---

After this PR:
We will set a read lock for steps 1 and 2 above, a write lock for step 3, a read lock for step 4, and when encountering an `EventInvalidException` in step 5, we will call the `event.doCleanup()` method to release the memory.

In this way, we can ensure the following things when resources are being cleaned up:
1. `ShuffleBuffers` that have not yet been converted to `FlushEvents` will not be converted in the future, but will be directly cleaned up.
2. `FlushEvents` that have been converted from `ShuffleBuffers` will definitely encounter an `EventInvalidException`, and we will eventually handle this exception correctly, releasing memory.
3. If there is already a `FlushEvent` being processed and it is about to be flushed to disk, the resource cleanup task will wait for all `FlushEvents` related to the `appId` to be completed before starting the cleanup task, ensuring that the cleanup and flushing tasks are completely independent and do not interfere with each other.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing UTs.
---------

Co-authored-by: leslizhang <leslizhang@tencent.com>
diff --git a/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java b/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
index c5b3200..2ff85ba 100644
--- a/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
+++ b/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
@@ -21,6 +21,7 @@
 import java.util.concurrent.Executor;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Queues;
@@ -50,17 +51,20 @@
   private final StorageType storageType;
   protected final BlockingQueue<ShuffleDataFlushEvent> flushQueue = Queues.newLinkedBlockingQueue();
   private ConsumerWithException<ShuffleDataFlushEvent> eventConsumer;
+  private final ShuffleServer shuffleServer;
 
   private volatile boolean stopped = false;
 
   public DefaultFlushEventHandler(
       ShuffleServerConf conf,
       StorageManager storageManager,
+      ShuffleServer shuffleServer,
       ConsumerWithException<ShuffleDataFlushEvent> eventConsumer) {
     this.shuffleServerConf = conf;
     this.storageType =
         StorageType.valueOf(shuffleServerConf.get(RssBaseConf.RSS_STORAGE_TYPE).name());
     this.storageManager = storageManager;
+    this.shuffleServer = shuffleServer;
     this.eventConsumer = eventConsumer;
     initFlushEventExecutor();
   }
@@ -83,8 +87,17 @@
    */
   private void handleEventAndUpdateMetrics(ShuffleDataFlushEvent event, Storage storage) {
     long start = System.currentTimeMillis();
+    String appId = event.getAppId();
+    ReentrantReadWriteLock.ReadLock readLock =
+        shuffleServer.getShuffleTaskManager().getAppReadLock(appId);
     try {
-      eventConsumer.accept(event);
+      readLock.lock();
+      try {
+        eventConsumer.accept(event);
+      } finally {
+        readLock.unlock();
+      }
+
       if (storage != null) {
         ShuffleServerMetrics.incStorageSuccessCounter(storage.getStorageHost());
       }
@@ -124,8 +137,7 @@
       }
 
       if (e instanceof EventInvalidException) {
-        // Invalid events have already been released / cleaned up
-        // so no need to call event.doCleanup() here
+        event.doCleanup();
         return;
       }
 
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
index 41cb26b..15ea147 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
@@ -81,7 +81,8 @@
     storageBasePaths = RssUtils.getConfiguredLocalDirs(shuffleServerConf);
     pendingEventTimeoutSec = shuffleServerConf.getLong(ShuffleServerConf.PENDING_EVENT_TIMEOUT_SEC);
     eventHandler =
-        new DefaultFlushEventHandler(shuffleServerConf, storageManager, this::processFlushEvent);
+        new DefaultFlushEventHandler(
+            shuffleServerConf, storageManager, shuffleServer, this::processFlushEvent);
   }
 
   public void addToFlushQueue(ShuffleDataFlushEvent event) {
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index f9ed125..b26167a 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -32,7 +32,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.Lock;
-import java.util.concurrent.locks.ReentrantLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.cache.Cache;
@@ -112,7 +112,7 @@
   private Map<Long, PreAllocatedBufferInfo> requireBufferIds = JavaUtils.newConcurrentMap();
   private Thread clearResourceThread;
   private BlockingQueue<PurgeEvent> expiredAppIdQueue = Queues.newLinkedBlockingQueue();
-  private final Cache<String, Lock> appLocks;
+  private final Cache<String, ReentrantReadWriteLock> appLocks;
 
   public ShuffleTaskManager(
       ShuffleServerConf conf,
@@ -222,9 +222,18 @@
     topNShuffleDataSizeOfAppCalcTask.start();
   }
 
-  private Lock getAppLock(String appId) {
+  public ReentrantReadWriteLock.WriteLock getAppWriteLock(String appId) {
     try {
-      return appLocks.get(appId, ReentrantLock::new);
+      return appLocks.get(appId, ReentrantReadWriteLock::new).writeLock();
+    } catch (ExecutionException e) {
+      LOG.error("Failed to get App lock.", e);
+      throw new RssException(e);
+    }
+  }
+
+  public ReentrantReadWriteLock.ReadLock getAppReadLock(String appId) {
+    try {
+      return appLocks.get(appId, ReentrantReadWriteLock::new).readLock();
     } catch (ExecutionException e) {
       LOG.error("Failed to get App lock.", e);
       throw new RssException(e);
@@ -257,7 +266,7 @@
       String user,
       ShuffleDataDistributionType dataDistType,
       int maxConcurrencyPerPartitionToWrite) {
-    Lock lock = getAppLock(appId);
+    ReentrantReadWriteLock.WriteLock lock = getAppWriteLock(appId);
     try {
       lock.lock();
       refreshAppId(appId);
@@ -692,35 +701,42 @@
    * @param shuffleIds
    */
   public void removeResourcesByShuffleIds(String appId, List<Integer> shuffleIds) {
-    if (CollectionUtils.isEmpty(shuffleIds)) {
-      return;
-    }
-
-    LOG.info("Start remove resource for appId[{}], shuffleIds[{}]", appId, shuffleIds);
-    final long start = System.currentTimeMillis();
-    final ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId);
-    if (taskInfo != null) {
-      for (Integer shuffleId : shuffleIds) {
-        taskInfo.getCachedBlockIds().remove(shuffleId);
-        taskInfo.getCommitCounts().remove(shuffleId);
-        taskInfo.getCommitLocks().remove(shuffleId);
+    Lock writeLock = getAppWriteLock(appId);
+    writeLock.lock();
+    try {
+      if (CollectionUtils.isEmpty(shuffleIds)) {
+        return;
       }
+
+      LOG.info("Start remove resource for appId[{}], shuffleIds[{}]", appId, shuffleIds);
+      final long start = System.currentTimeMillis();
+      final ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId);
+      if (taskInfo != null) {
+        for (Integer shuffleId : shuffleIds) {
+          taskInfo.getCachedBlockIds().remove(shuffleId);
+          taskInfo.getCommitCounts().remove(shuffleId);
+          taskInfo.getCommitLocks().remove(shuffleId);
+        }
+      }
+      Optional.ofNullable(partitionsToBlockIds.get(appId))
+          .ifPresent(
+              x -> {
+                for (Integer shuffleId : shuffleIds) {
+                  x.remove(shuffleId);
+                }
+              });
+      shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds);
+      shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds);
+      storageManager.removeResources(
+          new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds));
+      LOG.info(
+          "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
+          appId,
+          shuffleIds,
+          System.currentTimeMillis() - start);
+    } finally {
+      writeLock.unlock();
     }
-    Optional.ofNullable(partitionsToBlockIds.get(appId))
-        .ifPresent(
-            x -> {
-              for (Integer shuffleId : shuffleIds) {
-                x.remove(shuffleId);
-              }
-            });
-    shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds);
-    shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds);
-    storageManager.removeResources(new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds));
-    LOG.info(
-        "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
-        appId,
-        shuffleIds,
-        System.currentTimeMillis() - start);
   }
 
   public void checkLeakShuffleData() {
@@ -736,7 +752,7 @@
 
   @VisibleForTesting
   public void removeResources(String appId, boolean checkAppExpired) {
-    Lock lock = getAppLock(appId);
+    Lock lock = getAppWriteLock(appId);
     try {
       lock.lock();
       LOG.info("Start remove resource for appId[" + appId + "]");
diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index b03aec5..ceca592 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -18,12 +18,14 @@
 package org.apache.uniffle.server.buffer;
 
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
@@ -291,23 +293,36 @@
       int startPartition,
       int endPartition,
       boolean isHugePartition) {
-    ShuffleDataFlushEvent event =
-        buffer.toFlushEvent(
-            appId,
-            shuffleId,
-            startPartition,
-            endPartition,
-            () -> bufferPool.containsKey(appId),
-            shuffleFlushManager.getDataDistributionType(appId));
-    if (event != null) {
-      event.addCleanupCallback(() -> releaseMemory(event.getSize(), true, false));
-      updateShuffleSize(appId, shuffleId, -event.getSize());
-      inFlushSize.addAndGet(event.getSize());
-      if (isHugePartition) {
-        event.markOwnedByHugePartition();
+    ReentrantReadWriteLock.ReadLock readLock = shuffleTaskManager.getAppReadLock(appId);
+    readLock.lock();
+    if (!bufferPool.getOrDefault(appId, new HashMap<>()).containsKey(shuffleId)) {
+      LOG.info(
+          "Shuffle[{}] for app[{}] has already been removed, no need to flush the buffer",
+          shuffleId,
+          appId);
+      return;
+    }
+    try {
+      ShuffleDataFlushEvent event =
+          buffer.toFlushEvent(
+              appId,
+              shuffleId,
+              startPartition,
+              endPartition,
+              () -> bufferPool.getOrDefault(appId, new HashMap<>()).containsKey(shuffleId),
+              shuffleFlushManager.getDataDistributionType(appId));
+      if (event != null) {
+        event.addCleanupCallback(() -> releaseMemory(event.getSize(), true, false));
+        updateShuffleSize(appId, shuffleId, -event.getSize());
+        inFlushSize.addAndGet(event.getSize());
+        if (isHugePartition) {
+          event.markOwnedByHugePartition();
+        }
+        ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
+        shuffleFlushManager.addToFlushQueue(event);
       }
-      ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
-      shuffleFlushManager.addToFlushQueue(event);
+    } finally {
+      readLock.unlock();
     }
   }
 
diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
index ab38414..76d06cb 100644
--- a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
@@ -22,6 +22,7 @@
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.Path;
@@ -104,6 +105,12 @@
 
     when(mockShuffleServer.getShuffleTaskManager().getUserByAppId(appId1)).thenReturn("alex");
     when(mockShuffleServer.getShuffleTaskManager().getUserByAppId(appId2)).thenReturn("alex");
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+        .thenReturn(rsLock.readLock());
+    ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+        .thenReturn(rsLock2.readLock());
 
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
index c170870..274c2cb 100644
--- a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
@@ -26,6 +26,7 @@
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.function.Supplier;
 import java.util.stream.IntStream;
 
@@ -139,6 +140,9 @@
         ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION, maxConcurrency);
 
     String appId = "concurrentWrite2HdfsWriteOneByOne_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -171,6 +175,9 @@
         ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION, maxConcurrency);
 
     String appId = "concurrentWrite2HdfsWriteOfSinglePartition_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -198,6 +205,9 @@
   @Test
   public void writeTest() throws Exception {
     String appId = "writeTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -263,6 +273,8 @@
     // test case for process event whose related app was cleared already
     assertEquals(0, ShuffleServerMetrics.gaugeWriteHandler.get(), 0.5);
     ShuffleDataFlushEvent fakeEvent = createShuffleDataFlushEvent("fakeAppId", 1, 1, 1, null);
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock("fakeAppId"))
+        .thenReturn(rsLock.readLock());
     manager.addToFlushQueue(fakeEvent);
     waitForQueueClear(manager);
     waitForMetrics(ShuffleServerMetrics.gaugeWriteHandler, 0, 0.5);
@@ -276,6 +288,9 @@
         ShuffleServerConf.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name());
 
     String appId = "localMetricsTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     ShuffleFlushManager manager =
@@ -306,6 +321,9 @@
         ShuffleServerConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
 
     String appId = "localMetricsTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     ShuffleFlushManager manager =
@@ -355,6 +373,9 @@
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId = "complexWriteTest_appId";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     storageManager.registerRemoteStorage(appId, remoteStorage);
     List<ShufflePartitionedBlock> expectedBlocks = Lists.newArrayList();
     List<ShuffleDataFlushEvent> flushEvents1 = Lists.newArrayList();
@@ -399,6 +420,12 @@
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId1 = "complexWriteTest_appId1";
     String appId2 = "complexWriteTest_appId2";
+    ReentrantReadWriteLock rsLock1 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+        .thenReturn(rsLock1.readLock());
+    ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+        .thenReturn(rsLock2.readLock());
     storageManager.registerRemoteStorage(appId1, remoteStorage);
     storageManager.registerRemoteStorage(appId2, remoteStorage);
     ShuffleFlushManager manager =
@@ -456,6 +483,12 @@
   public void clearLocalTest(@TempDir File tempDir) throws Exception {
     final String appId1 = "clearLocalTest_appId1";
     final String appId2 = "clearLocalTest_appId12";
+    ReentrantReadWriteLock rsLock1 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+        .thenReturn(rsLock1.readLock());
+    ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+        .thenReturn(rsLock2.readLock());
     ShuffleServerConf serverConf = new ShuffleServerConf();
     serverConf.set(
         ShuffleServerConf.RSS_STORAGE_BASE_PATH, Arrays.asList(tempDir.getAbsolutePath()));
@@ -690,6 +723,9 @@
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId = "fallbackWrittenWhenMultiStorageManagerEnableTest";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     storageManager.registerRemoteStorage(appId, new RemoteStorageInfo(remoteStorage.getPath()));
 
     ShuffleFlushManager flushManager =
@@ -740,6 +776,9 @@
     StorageManager storageManager =
         StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
     String appId = "fallbackWrittenWhenMultiStorageManagerEnableTest";
+    ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+    when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+        .thenReturn(rsLock.readLock());
     storageManager.registerRemoteStorage(appId, new RemoteStorageInfo(remoteStorage.getPath()));
 
     ShuffleFlushManager flushManager =
diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index 94e5d60..08428cc 100644
--- a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -24,6 +24,7 @@
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import com.google.common.collect.RangeMap;
 import com.google.common.util.concurrent.Uninterruptibles;
@@ -85,6 +86,7 @@
     mockShuffleTaskManager = mock(ShuffleTaskManager.class);
     when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager = new ShuffleBufferManager(conf, mockShuffleFlushManager, false);
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
   }
 
   @Test
@@ -115,6 +117,10 @@
   @Test
   public void getShuffleDataWithExpectedTaskIdsTest() {
     String appId = "getShuffleDataWithExpectedTaskIdsTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
     ShufflePartitionedData spd1 = createData(0, 1, 68);
     ShufflePartitionedData spd2 = createData(0, 2, 68);
@@ -146,6 +152,10 @@
   @Test
   public void getShuffleDataTest() {
     String appId = "getShuffleDataTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
     shuffleBufferManager.registerBuffer(appId, 2, 0, 1);
     shuffleBufferManager.registerBuffer(appId, 3, 0, 1);
@@ -209,6 +219,10 @@
   public void shuffleIdToSizeTest() {
     String appId1 = "shuffleIdToSizeTest1";
     String appId2 = "shuffleIdToSizeTest2";
+    ReentrantReadWriteLock rwLock1 = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId1)).thenReturn(rwLock1.readLock());
+    ReentrantReadWriteLock rwLock2 = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId2)).thenReturn(rwLock2.readLock());
     shuffleBufferManager.registerBuffer(appId1, 1, 0, 0);
     shuffleBufferManager.registerBuffer(appId1, 2, 0, 0);
     shuffleBufferManager.registerBuffer(appId2, 1, 0, 0);
@@ -254,9 +268,12 @@
   @Test
   public void cacheShuffleDataTest() {
     String appId = "cacheShuffleDataTest";
-    int shuffleId = 1;
-
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int startPartitionNum = (int) ShuffleServerMetrics.gaugeTotalPartitionNum.get();
+    int shuffleId = 1;
     StatusCode sc =
         shuffleBufferManager.cacheShuffleData(appId, shuffleId, false, createData(0, 16));
     assertEquals(StatusCode.NO_REGISTER, sc);
@@ -322,8 +339,11 @@
   @Test
   public void cacheShuffleDataWithPreAllocationTest() {
     String appId = "cacheShuffleDataWithPreAllocationTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int shuffleId = 1;
-
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     // pre allocate memory
     shuffleBufferManager.requireMemory(48, true);
@@ -393,8 +413,12 @@
     when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
 
     String appId = "bufferSizeTest";
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
     int shuffleId = 1;
-
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 4, 5);
@@ -523,8 +547,11 @@
     when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
 
     String appId = "bufferSizeTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int shuffleId = 1;
-
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
     shuffleBufferManager.cacheShuffleData(appId, shuffleId, false, createData(0, 64));
@@ -555,10 +582,14 @@
     shuffleBufferManager = new ShuffleBufferManager(serverConf, shuffleFlushManager, false);
 
     String appId = "shuffleFlushTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
+
     int shuffleId = 0;
     int smallShuffleId = 1;
     int smallShuffleIdTwo = 2;
-
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
     shuffleBufferManager.registerBuffer(appId, smallShuffleId, 0, 1);
@@ -676,6 +707,10 @@
     when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
 
     String appId = "bufferSizeTest";
+    shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+    ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+    when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
     int shuffleId = 1;
     shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
     shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);