[#1571] fix(server): Memory may leak when `EventInvalidException` occurs (#1574)
### What changes were proposed in this pull request?
In the implementation of the methods `flushBuffer`, `handleEventAndUpdateMetrics`, and `removeBufferByShuffleId`, read-write locks have been added to manage concurrency. This ensures that a `ShuffleBuffer` successfully converted into a `flushEvent` won't be cleaned up again by `removeBufferByShuffleId`, and a `ShuffleBuffer` already cleaned up by `removeBufferByShuffleId` won't be transformed back into a `flushEvent`. This effectively resolves the concurrency issue.
### Why are the changes needed?
Fix https://github.com/apache/incubator-uniffle/issues/1571 & https://github.com/apache/incubator-uniffle/issues/1560 & https://github.com/apache/incubator-uniffle/issues/1542
The key logic of the PR is as follows:
Before this PR:
1. A `ShuffleBuffer` is turned into a `FlushEvent`, and **_its blocks and size are cleared_**
→
2. The `FlushEvent` is added to the flushing queue
→
3. The method `removeBufferByShuffleId` is executed, which causes the following things to happen:
3.1. Running the following code snippet, but please note that in the code below, `buffer.getBlocks()` **_is empty and size is 0_**, because of the step 1 above:
```
for (ShuffleBuffer buffer : buffers) {
buffer.getBlocks().forEach(spb -> spb.getData().release());
ShuffleServerMetrics.gaugeTotalPartitionNum.dec();
size += buffer.getSize();
}
```
3.2. `appId` is removed from the `bufferPool`
→
4. The `FlushEvent` is taken out from the queue and encounters an `EventInvalidException` because the `appId` was removed before
→
5. When handling the `EventInvalidException`, nothing is done and the `event.doCleanup()` method **_is not called, causing a memory leak_**.
Of course, this is just one scenario of concurrency exceptions. In the previous code, without locking, in the `processFlushEvent` method, it is possible that the event may become invalid at any time when continuing executing in `processFlushEvent` method, which is why there is https://github.com/apache/incubator-uniffle/issues/1542. Also, there is https://github.com/apache/incubator-uniffle/issues/1560.
---
After this PR:
We will set a read lock for steps 1 and 2 above, a write lock for step 3, a read lock for step 4, and when encountering an `EventInvalidException` in step 5, we will call the `event.doCleanup()` method to release the memory.
In this way, we can ensure the following things when resources are being cleaned up:
1. `ShuffleBuffers` that have not yet been converted to `FlushEvents` will not be converted in the future, but will be directly cleaned up.
2. `FlushEvents` that have been converted from `ShuffleBuffers` will definitely encounter an `EventInvalidException`, and we will eventually handle this exception correctly, releasing memory.
3. If there is already a `FlushEvent` being processed and it is about to be flushed to disk, the resource cleanup task will wait for all `FlushEvents` related to the `appId` to be completed before starting the cleanup task, ensuring that the cleanup and flushing tasks are completely independent and do not interfere with each other.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs.
---------
Co-authored-by: leslizhang <leslizhang@tencent.com>
diff --git a/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java b/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
index c5b3200..2ff85ba 100644
--- a/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
+++ b/server/src/main/java/org/apache/uniffle/server/DefaultFlushEventHandler.java
@@ -21,6 +21,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Queues;
@@ -50,17 +51,20 @@
private final StorageType storageType;
protected final BlockingQueue<ShuffleDataFlushEvent> flushQueue = Queues.newLinkedBlockingQueue();
private ConsumerWithException<ShuffleDataFlushEvent> eventConsumer;
+ private final ShuffleServer shuffleServer;
private volatile boolean stopped = false;
public DefaultFlushEventHandler(
ShuffleServerConf conf,
StorageManager storageManager,
+ ShuffleServer shuffleServer,
ConsumerWithException<ShuffleDataFlushEvent> eventConsumer) {
this.shuffleServerConf = conf;
this.storageType =
StorageType.valueOf(shuffleServerConf.get(RssBaseConf.RSS_STORAGE_TYPE).name());
this.storageManager = storageManager;
+ this.shuffleServer = shuffleServer;
this.eventConsumer = eventConsumer;
initFlushEventExecutor();
}
@@ -83,8 +87,17 @@
*/
private void handleEventAndUpdateMetrics(ShuffleDataFlushEvent event, Storage storage) {
long start = System.currentTimeMillis();
+ String appId = event.getAppId();
+ ReentrantReadWriteLock.ReadLock readLock =
+ shuffleServer.getShuffleTaskManager().getAppReadLock(appId);
try {
- eventConsumer.accept(event);
+ readLock.lock();
+ try {
+ eventConsumer.accept(event);
+ } finally {
+ readLock.unlock();
+ }
+
if (storage != null) {
ShuffleServerMetrics.incStorageSuccessCounter(storage.getStorageHost());
}
@@ -124,8 +137,7 @@
}
if (e instanceof EventInvalidException) {
- // Invalid events have already been released / cleaned up
- // so no need to call event.doCleanup() here
+ event.doCleanup();
return;
}
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
index 41cb26b..15ea147 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java
@@ -81,7 +81,8 @@
storageBasePaths = RssUtils.getConfiguredLocalDirs(shuffleServerConf);
pendingEventTimeoutSec = shuffleServerConf.getLong(ShuffleServerConf.PENDING_EVENT_TIMEOUT_SEC);
eventHandler =
- new DefaultFlushEventHandler(shuffleServerConf, storageManager, this::processFlushEvent);
+ new DefaultFlushEventHandler(
+ shuffleServerConf, storageManager, shuffleServer, this::processFlushEvent);
}
public void addToFlushQueue(ShuffleDataFlushEvent event) {
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index f9ed125..b26167a 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -32,7 +32,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
-import java.util.concurrent.locks.ReentrantLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.Cache;
@@ -112,7 +112,7 @@
private Map<Long, PreAllocatedBufferInfo> requireBufferIds = JavaUtils.newConcurrentMap();
private Thread clearResourceThread;
private BlockingQueue<PurgeEvent> expiredAppIdQueue = Queues.newLinkedBlockingQueue();
- private final Cache<String, Lock> appLocks;
+ private final Cache<String, ReentrantReadWriteLock> appLocks;
public ShuffleTaskManager(
ShuffleServerConf conf,
@@ -222,9 +222,18 @@
topNShuffleDataSizeOfAppCalcTask.start();
}
- private Lock getAppLock(String appId) {
+ public ReentrantReadWriteLock.WriteLock getAppWriteLock(String appId) {
try {
- return appLocks.get(appId, ReentrantLock::new);
+ return appLocks.get(appId, ReentrantReadWriteLock::new).writeLock();
+ } catch (ExecutionException e) {
+ LOG.error("Failed to get App lock.", e);
+ throw new RssException(e);
+ }
+ }
+
+ public ReentrantReadWriteLock.ReadLock getAppReadLock(String appId) {
+ try {
+ return appLocks.get(appId, ReentrantReadWriteLock::new).readLock();
} catch (ExecutionException e) {
LOG.error("Failed to get App lock.", e);
throw new RssException(e);
@@ -257,7 +266,7 @@
String user,
ShuffleDataDistributionType dataDistType,
int maxConcurrencyPerPartitionToWrite) {
- Lock lock = getAppLock(appId);
+ ReentrantReadWriteLock.WriteLock lock = getAppWriteLock(appId);
try {
lock.lock();
refreshAppId(appId);
@@ -692,35 +701,42 @@
* @param shuffleIds
*/
public void removeResourcesByShuffleIds(String appId, List<Integer> shuffleIds) {
- if (CollectionUtils.isEmpty(shuffleIds)) {
- return;
- }
-
- LOG.info("Start remove resource for appId[{}], shuffleIds[{}]", appId, shuffleIds);
- final long start = System.currentTimeMillis();
- final ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId);
- if (taskInfo != null) {
- for (Integer shuffleId : shuffleIds) {
- taskInfo.getCachedBlockIds().remove(shuffleId);
- taskInfo.getCommitCounts().remove(shuffleId);
- taskInfo.getCommitLocks().remove(shuffleId);
+ Lock writeLock = getAppWriteLock(appId);
+ writeLock.lock();
+ try {
+ if (CollectionUtils.isEmpty(shuffleIds)) {
+ return;
}
+
+ LOG.info("Start remove resource for appId[{}], shuffleIds[{}]", appId, shuffleIds);
+ final long start = System.currentTimeMillis();
+ final ShuffleTaskInfo taskInfo = shuffleTaskInfos.get(appId);
+ if (taskInfo != null) {
+ for (Integer shuffleId : shuffleIds) {
+ taskInfo.getCachedBlockIds().remove(shuffleId);
+ taskInfo.getCommitCounts().remove(shuffleId);
+ taskInfo.getCommitLocks().remove(shuffleId);
+ }
+ }
+ Optional.ofNullable(partitionsToBlockIds.get(appId))
+ .ifPresent(
+ x -> {
+ for (Integer shuffleId : shuffleIds) {
+ x.remove(shuffleId);
+ }
+ });
+ shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds);
+ shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds);
+ storageManager.removeResources(
+ new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds));
+ LOG.info(
+ "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
+ appId,
+ shuffleIds,
+ System.currentTimeMillis() - start);
+ } finally {
+ writeLock.unlock();
}
- Optional.ofNullable(partitionsToBlockIds.get(appId))
- .ifPresent(
- x -> {
- for (Integer shuffleId : shuffleIds) {
- x.remove(shuffleId);
- }
- });
- shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds);
- shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds);
- storageManager.removeResources(new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds));
- LOG.info(
- "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]",
- appId,
- shuffleIds,
- System.currentTimeMillis() - start);
}
public void checkLeakShuffleData() {
@@ -736,7 +752,7 @@
@VisibleForTesting
public void removeResources(String appId, boolean checkAppExpired) {
- Lock lock = getAppLock(appId);
+ Lock lock = getAppWriteLock(appId);
try {
lock.lock();
LOG.info("Start remove resource for appId[" + appId + "]");
diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index b03aec5..ceca592 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -18,12 +18,14 @@
package org.apache.uniffle.server.buffer;
import java.util.Collection;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
@@ -291,23 +293,36 @@
int startPartition,
int endPartition,
boolean isHugePartition) {
- ShuffleDataFlushEvent event =
- buffer.toFlushEvent(
- appId,
- shuffleId,
- startPartition,
- endPartition,
- () -> bufferPool.containsKey(appId),
- shuffleFlushManager.getDataDistributionType(appId));
- if (event != null) {
- event.addCleanupCallback(() -> releaseMemory(event.getSize(), true, false));
- updateShuffleSize(appId, shuffleId, -event.getSize());
- inFlushSize.addAndGet(event.getSize());
- if (isHugePartition) {
- event.markOwnedByHugePartition();
+ ReentrantReadWriteLock.ReadLock readLock = shuffleTaskManager.getAppReadLock(appId);
+ readLock.lock();
+ if (!bufferPool.getOrDefault(appId, new HashMap<>()).containsKey(shuffleId)) {
+ LOG.info(
+ "Shuffle[{}] for app[{}] has already been removed, no need to flush the buffer",
+ shuffleId,
+ appId);
+ return;
+ }
+ try {
+ ShuffleDataFlushEvent event =
+ buffer.toFlushEvent(
+ appId,
+ shuffleId,
+ startPartition,
+ endPartition,
+ () -> bufferPool.getOrDefault(appId, new HashMap<>()).containsKey(shuffleId),
+ shuffleFlushManager.getDataDistributionType(appId));
+ if (event != null) {
+ event.addCleanupCallback(() -> releaseMemory(event.getSize(), true, false));
+ updateShuffleSize(appId, shuffleId, -event.getSize());
+ inFlushSize.addAndGet(event.getSize());
+ if (isHugePartition) {
+ event.markOwnedByHugePartition();
+ }
+ ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
+ shuffleFlushManager.addToFlushQueue(event);
}
- ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
- shuffleFlushManager.addToFlushQueue(event);
+ } finally {
+ readLock.unlock();
}
}
diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
index ab38414..76d06cb 100644
--- a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerOnKerberizedHadoopTest.java
@@ -22,6 +22,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.Path;
@@ -104,6 +105,12 @@
when(mockShuffleServer.getShuffleTaskManager().getUserByAppId(appId1)).thenReturn("alex");
when(mockShuffleServer.getShuffleTaskManager().getUserByAppId(appId2)).thenReturn("alex");
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+ .thenReturn(rsLock.readLock());
+ ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+ .thenReturn(rsLock2.readLock());
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
index c170870..274c2cb 100644
--- a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
@@ -26,6 +26,7 @@
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import java.util.stream.IntStream;
@@ -139,6 +140,9 @@
ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION, maxConcurrency);
String appId = "concurrentWrite2HdfsWriteOneByOne_appId";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -171,6 +175,9 @@
ShuffleServerConf.SERVER_MAX_CONCURRENCY_OF_ONE_PARTITION, maxConcurrency);
String appId = "concurrentWrite2HdfsWriteOfSinglePartition_appId";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -198,6 +205,9 @@
@Test
public void writeTest() throws Exception {
String appId = "writeTest_appId";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
storageManager.registerRemoteStorage(appId, remoteStorage);
@@ -263,6 +273,8 @@
// test case for process event whose related app was cleared already
assertEquals(0, ShuffleServerMetrics.gaugeWriteHandler.get(), 0.5);
ShuffleDataFlushEvent fakeEvent = createShuffleDataFlushEvent("fakeAppId", 1, 1, 1, null);
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock("fakeAppId"))
+ .thenReturn(rsLock.readLock());
manager.addToFlushQueue(fakeEvent);
waitForQueueClear(manager);
waitForMetrics(ShuffleServerMetrics.gaugeWriteHandler, 0, 0.5);
@@ -276,6 +288,9 @@
ShuffleServerConf.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name());
String appId = "localMetricsTest_appId";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
ShuffleFlushManager manager =
@@ -306,6 +321,9 @@
ShuffleServerConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
String appId = "localMetricsTest_appId";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
ShuffleFlushManager manager =
@@ -355,6 +373,9 @@
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
String appId = "complexWriteTest_appId";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
storageManager.registerRemoteStorage(appId, remoteStorage);
List<ShufflePartitionedBlock> expectedBlocks = Lists.newArrayList();
List<ShuffleDataFlushEvent> flushEvents1 = Lists.newArrayList();
@@ -399,6 +420,12 @@
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
String appId1 = "complexWriteTest_appId1";
String appId2 = "complexWriteTest_appId2";
+ ReentrantReadWriteLock rsLock1 = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+ .thenReturn(rsLock1.readLock());
+ ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+ .thenReturn(rsLock2.readLock());
storageManager.registerRemoteStorage(appId1, remoteStorage);
storageManager.registerRemoteStorage(appId2, remoteStorage);
ShuffleFlushManager manager =
@@ -456,6 +483,12 @@
public void clearLocalTest(@TempDir File tempDir) throws Exception {
final String appId1 = "clearLocalTest_appId1";
final String appId2 = "clearLocalTest_appId12";
+ ReentrantReadWriteLock rsLock1 = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId1))
+ .thenReturn(rsLock1.readLock());
+ ReentrantReadWriteLock rsLock2 = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId2))
+ .thenReturn(rsLock2.readLock());
ShuffleServerConf serverConf = new ShuffleServerConf();
serverConf.set(
ShuffleServerConf.RSS_STORAGE_BASE_PATH, Arrays.asList(tempDir.getAbsolutePath()));
@@ -690,6 +723,9 @@
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
String appId = "fallbackWrittenWhenMultiStorageManagerEnableTest";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
storageManager.registerRemoteStorage(appId, new RemoteStorageInfo(remoteStorage.getPath()));
ShuffleFlushManager flushManager =
@@ -740,6 +776,9 @@
StorageManager storageManager =
StorageManagerFactory.getInstance().createStorageManager(shuffleServerConf);
String appId = "fallbackWrittenWhenMultiStorageManagerEnableTest";
+ ReentrantReadWriteLock rsLock = new ReentrantReadWriteLock();
+ when(mockShuffleServer.getShuffleTaskManager().getAppReadLock(appId))
+ .thenReturn(rsLock.readLock());
storageManager.registerRemoteStorage(appId, new RemoteStorageInfo(remoteStorage.getPath()));
ShuffleFlushManager flushManager =
diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index 94e5d60..08428cc 100644
--- a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -24,6 +24,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import com.google.common.collect.RangeMap;
import com.google.common.util.concurrent.Uninterruptibles;
@@ -85,6 +86,7 @@
mockShuffleTaskManager = mock(ShuffleTaskManager.class);
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
shuffleBufferManager = new ShuffleBufferManager(conf, mockShuffleFlushManager, false);
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
}
@Test
@@ -115,6 +117,10 @@
@Test
public void getShuffleDataWithExpectedTaskIdsTest() {
String appId = "getShuffleDataWithExpectedTaskIdsTest";
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
ShufflePartitionedData spd1 = createData(0, 1, 68);
ShufflePartitionedData spd2 = createData(0, 2, 68);
@@ -146,6 +152,10 @@
@Test
public void getShuffleDataTest() {
String appId = "getShuffleDataTest";
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
shuffleBufferManager.registerBuffer(appId, 1, 0, 1);
shuffleBufferManager.registerBuffer(appId, 2, 0, 1);
shuffleBufferManager.registerBuffer(appId, 3, 0, 1);
@@ -209,6 +219,10 @@
public void shuffleIdToSizeTest() {
String appId1 = "shuffleIdToSizeTest1";
String appId2 = "shuffleIdToSizeTest2";
+ ReentrantReadWriteLock rwLock1 = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId1)).thenReturn(rwLock1.readLock());
+ ReentrantReadWriteLock rwLock2 = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId2)).thenReturn(rwLock2.readLock());
shuffleBufferManager.registerBuffer(appId1, 1, 0, 0);
shuffleBufferManager.registerBuffer(appId1, 2, 0, 0);
shuffleBufferManager.registerBuffer(appId2, 1, 0, 0);
@@ -254,9 +268,12 @@
@Test
public void cacheShuffleDataTest() {
String appId = "cacheShuffleDataTest";
- int shuffleId = 1;
-
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
int startPartitionNum = (int) ShuffleServerMetrics.gaugeTotalPartitionNum.get();
+ int shuffleId = 1;
StatusCode sc =
shuffleBufferManager.cacheShuffleData(appId, shuffleId, false, createData(0, 16));
assertEquals(StatusCode.NO_REGISTER, sc);
@@ -322,8 +339,11 @@
@Test
public void cacheShuffleDataWithPreAllocationTest() {
String appId = "cacheShuffleDataWithPreAllocationTest";
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
int shuffleId = 1;
-
shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
// pre allocate memory
shuffleBufferManager.requireMemory(48, true);
@@ -393,8 +413,12 @@
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
String appId = "bufferSizeTest";
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
int shuffleId = 1;
-
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
shuffleBufferManager.registerBuffer(appId, shuffleId, 4, 5);
@@ -523,8 +547,11 @@
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
String appId = "bufferSizeTest";
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
int shuffleId = 1;
-
shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
shuffleBufferManager.cacheShuffleData(appId, shuffleId, false, createData(0, 64));
@@ -555,10 +582,14 @@
shuffleBufferManager = new ShuffleBufferManager(serverConf, shuffleFlushManager, false);
String appId = "shuffleFlushTest";
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
+
int shuffleId = 0;
int smallShuffleId = 1;
int smallShuffleIdTwo = 2;
-
shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);
shuffleBufferManager.registerBuffer(appId, smallShuffleId, 0, 1);
@@ -676,6 +707,10 @@
when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mock(ShuffleTaskManager.class));
String appId = "bufferSizeTest";
+ shuffleBufferManager.setShuffleTaskManager(mockShuffleTaskManager);
+ ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
+ when(mockShuffleTaskManager.getAppReadLock(appId)).thenReturn(rwLock.readLock());
+ when(mockShuffleServer.getShuffleTaskManager()).thenReturn(mockShuffleTaskManager);
int shuffleId = 1;
shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 1);
shuffleBufferManager.registerBuffer(appId, shuffleId, 2, 3);