[#1472][part-5] Use UnpooledByteBufAllocator to fix inaccurate usedMemory issue causing OOM (#1534)
### What changes were proposed in this pull request?
When we use `UnpooledByteBufAllocator` to allocate off-heap `ByteBuf`, Netty directly requests off-heap memory from the operating system instead of allocating it according to `pageSize` and `chunkSize`. This way, we can obtain the exact `ByteBuf` size during the pre-allocation of memory, avoiding distortion of metrics such as `usedMemory`.
Moreover, we have restored the code submission of the PR [#1521](https://github.com/apache/incubator-uniffle/pull/1521). We ensure that there is sufficient direct memory for the Netty server during decoding `sendShuffleDataRequest` by taking into account the `encodedLength` of `ByteBuf` in advance during the pre-allocation of memory, thus avoiding OOM during decoding `sendShuffleDataRequest`.
Since we are not using `PooledByteBufAllocator`, the PR [#1524](https://github.com/apache/incubator-uniffle/pull/1524) is no longer needed.
### Why are the changes needed?
A sub PR for: https://github.com/apache/incubator-uniffle/pull/1519
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs.
diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
index 5492478..b8c687c 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java
@@ -28,6 +28,7 @@
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.ByteBufUtils;
+import org.apache.uniffle.common.util.NettyUtils;
public class Decoders {
public static ShuffleServerInfo decodeShuffleServerInfo(ByteBuf byteBuf) {
@@ -46,7 +47,8 @@
long crc = byteBuf.readLong();
long taskAttemptId = byteBuf.readLong();
int dataLength = byteBuf.readInt();
- ByteBuf data = byteBuf.retain().readSlice(dataLength);
+ ByteBuf data = NettyUtils.getNettyBufferAllocator().directBuffer(dataLength);
+ data.writeBytes(byteBuf, dataLength);
int lengthOfShuffleServers = byteBuf.readInt();
List<ShuffleServerInfo> serverInfos = Lists.newArrayList();
for (int k = 0; k < lengthOfShuffleServers; k++) {
diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
index 317f4e7..cab4769 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java
@@ -130,6 +130,10 @@
return requireId;
}
+ public void setRequireId(long requireId) {
+ this.requireId = requireId;
+ }
+
public Map<Integer, List<ShuffleBlockInfo>> getPartitionToBlocks() {
return partitionToBlocks;
}
diff --git a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
index 5f1c87c..468d2a7 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
@@ -19,8 +19,10 @@
import java.util.concurrent.ThreadFactory;
+import io.netty.buffer.AbstractByteBufAllocator;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
@@ -39,6 +41,8 @@
public class NettyUtils {
private static final Logger logger = LoggerFactory.getLogger(NettyUtils.class);
+ private static final long MAX_DIRECT_MEMORY_IN_BYTES = PlatformDependent.maxDirectMemory();
+
/** Creates a Netty EventLoopGroup based on the IOMode. */
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
ThreadFactory threadFactory = ThreadUtils.getNettyThreadFactory(threadPrefix);
@@ -114,22 +118,18 @@
}
private static class AllocatorHolder {
- private static final PooledByteBufAllocator INSTANCE = createAllocator();
+ private static final AbstractByteBufAllocator INSTANCE = createUnpooledByteBufAllocator(true);
}
- public static PooledByteBufAllocator getNettyBufferAllocator() {
+ public static AbstractByteBufAllocator getNettyBufferAllocator() {
return AllocatorHolder.INSTANCE;
}
- private static PooledByteBufAllocator createAllocator() {
- return new PooledByteBufAllocator(
- true,
- PooledByteBufAllocator.defaultNumHeapArena(),
- PooledByteBufAllocator.defaultNumDirectArena(),
- PooledByteBufAllocator.defaultPageSize(),
- PooledByteBufAllocator.defaultMaxOrder(),
- 0,
- 0,
- PooledByteBufAllocator.defaultUseCacheForAllThreads());
+ public static UnpooledByteBufAllocator createUnpooledByteBufAllocator(boolean preferDirect) {
+ return new UnpooledByteBufAllocator(preferDirect);
+ }
+
+ public static long getMaxDirectMemory() {
+ return MAX_DIRECT_MEMORY_IN_BYTES;
}
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index a8ee154..7949ca9 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -103,7 +103,15 @@
}
}
- int allocateSize = size;
+ SendShuffleDataRequest sendShuffleDataRequest =
+ new SendShuffleDataRequest(
+ requestId(),
+ request.getAppId(),
+ shuffleId,
+ 0L,
+ stb.getValue(),
+ System.currentTimeMillis());
+ int allocateSize = size + sendShuffleDataRequest.encodedLength();
int finalBlockNum = blockNum;
try {
RetryUtils.retryWithCondition(
@@ -122,14 +130,7 @@
allocateSize, host, port));
}
- SendShuffleDataRequest sendShuffleDataRequest =
- new SendShuffleDataRequest(
- requestId(),
- request.getAppId(),
- shuffleId,
- requireId,
- stb.getValue(),
- System.currentTimeMillis());
+ sendShuffleDataRequest.setRequireId(requireId);
long start = System.currentTimeMillis();
RpcResponse rpcResponse =
transportClient.sendRpcSync(sendShuffleDataRequest, rpcTimeout);
diff --git a/server/src/main/java/org/apache/uniffle/server/NettyDirectMemoryTracker.java b/server/src/main/java/org/apache/uniffle/server/NettyDirectMemoryTracker.java
index c3a31fd..96206cc 100644
--- a/server/src/main/java/org/apache/uniffle/server/NettyDirectMemoryTracker.java
+++ b/server/src/main/java/org/apache/uniffle/server/NettyDirectMemoryTracker.java
@@ -25,7 +25,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.uniffle.common.util.NettyUtils;
import org.apache.uniffle.common.util.ThreadUtils;
public class NettyDirectMemoryTracker {
@@ -55,19 +54,10 @@
() -> {
try {
long usedDirectMemory = PlatformDependent.usedDirectMemory();
- long allocatedDirectMemory =
- NettyUtils.getNettyBufferAllocator().metric().usedDirectMemory();
- long pinnedDirectMemory = NettyUtils.getNettyBufferAllocator().pinnedDirectMemory();
if (LOG.isDebugEnabled()) {
- LOG.debug(
- "Current usedDirectMemory:{}, allocatedDirectMemory:{}, pinnedDirectMemory:{}",
- usedDirectMemory,
- allocatedDirectMemory,
- pinnedDirectMemory);
+ LOG.debug("Current usedDirectMemory:{}", usedDirectMemory);
}
ShuffleServerMetrics.gaugeUsedDirectMemorySize.set(usedDirectMemory);
- ShuffleServerMetrics.gaugeAllocatedDirectMemorySize.set(allocatedDirectMemory);
- ShuffleServerMetrics.gaugePinnedDirectMemorySize.set(pinnedDirectMemory);
} catch (Throwable t) {
LOG.error("Failed to report direct memory.", t);
}
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index ac9b95c..ce7d2d6 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -256,6 +256,7 @@
final long start = System.currentTimeMillis();
List<ShufflePartitionedData> shufflePartitionedData = toPartitionedData(req);
long alreadyReleasedSize = 0;
+ boolean hasFailureOccurred = false;
for (ShufflePartitionedData spd : shufflePartitionedData) {
String shuffleDataInfo =
"appId["
@@ -275,6 +276,7 @@
+ ret;
LOG.error(errorMsg);
responseMessage = errorMsg;
+ hasFailureOccurred = true;
break;
} else {
long toReleasedSize = spd.getTotalBlockSize();
@@ -293,9 +295,13 @@
ret = StatusCode.INTERNAL_ERROR;
responseMessage = errorMsg;
LOG.error(errorMsg);
+ hasFailureOccurred = true;
break;
}
}
+ if (hasFailureOccurred) {
+ shuffleServer.getShuffleBufferManager().releaseMemory(info.getRequireSize(), false, false);
+ }
// since the required buffer id is only used once, the shuffle client would try to require
// another buffer whether
// current connection succeeded or not. Therefore, the preAllocatedBuffer is first get and
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
index 649e504..274cde0 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerMetrics.java
@@ -186,8 +186,6 @@
public static Gauge.Child gaugeUsedBufferSize;
public static Gauge.Child gaugeReadBufferUsedSize;
public static Gauge.Child gaugeUsedDirectMemorySize;
- public static Gauge.Child gaugeAllocatedDirectMemorySize;
- public static Gauge.Child gaugePinnedDirectMemorySize;
public static Gauge.Child gaugeWriteHandler;
public static Gauge.Child gaugeEventQueueSize;
public static Gauge.Child gaugeHadoopFlushThreadPoolQueueSize;
@@ -384,8 +382,6 @@
gaugeUsedBufferSize = metricsManager.addLabeledGauge(USED_BUFFER_SIZE);
gaugeReadBufferUsedSize = metricsManager.addLabeledGauge(READ_USED_BUFFER_SIZE);
gaugeUsedDirectMemorySize = metricsManager.addLabeledGauge(USED_DIRECT_MEMORY_SIZE);
- gaugeAllocatedDirectMemorySize = metricsManager.addLabeledGauge(ALLOCATED_DIRECT_MEMORY_SIZE);
- gaugePinnedDirectMemorySize = metricsManager.addLabeledGauge(PINNED_DIRECT_MEMORY_SIZE);
gaugeWriteHandler = metricsManager.addLabeledGauge(TOTAL_WRITE_HANDLER);
gaugeEventQueueSize = metricsManager.addLabeledGauge(EVENT_QUEUE_SIZE);
gaugeHadoopFlushThreadPoolQueueSize =
diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index 9f23f0f..8636706 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -32,15 +32,18 @@
import com.google.common.collect.RangeMap;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeRangeMap;
+import io.netty.util.internal.PlatformDependent;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.rpc.ServerType;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.NettyUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.server.ShuffleDataFlushEvent;
import org.apache.uniffle.server.ShuffleFlushManager;
@@ -68,6 +71,7 @@
// Huge partition vars
private long hugePartitionSizeThreshold;
private long hugePartitionMemoryLimitSize;
+ private boolean nettyServerEnabled;
protected long bufferSize = 0;
protected AtomicLong preAllocatedSize = new AtomicLong(0L);
@@ -80,11 +84,16 @@
protected Map<String, Map<Integer, AtomicLong>> shuffleSizeMap = JavaUtils.newConcurrentMap();
public ShuffleBufferManager(ShuffleServerConf conf, ShuffleFlushManager shuffleFlushManager) {
+ this.nettyServerEnabled = conf.get(ShuffleServerConf.RPC_SERVER_TYPE) == ServerType.GRPC_NETTY;
long heapSize = Runtime.getRuntime().maxMemory();
this.capacity = conf.getSizeAsBytes(ShuffleServerConf.SERVER_BUFFER_CAPACITY);
if (this.capacity < 0) {
this.capacity =
- (long) (heapSize * conf.getDouble(ShuffleServerConf.SERVER_BUFFER_CAPACITY_RATIO));
+ nettyServerEnabled
+ ? (long)
+ (NettyUtils.getMaxDirectMemory()
+ * conf.getDouble(ShuffleServerConf.SERVER_BUFFER_CAPACITY_RATIO))
+ : (long) (heapSize * conf.getDouble(ShuffleServerConf.SERVER_BUFFER_CAPACITY_RATIO));
}
this.readCapacity = conf.getSizeAsBytes(ShuffleServerConf.SERVER_READ_BUFFER_CAPACITY);
if (this.readCapacity < 0) {
@@ -321,6 +330,25 @@
if (isPreAllocated) {
requirePreAllocatedSize(size);
}
+ if (LOG.isDebugEnabled()) {
+ long usedDirectMemory = PlatformDependent.usedDirectMemory();
+ long usedHeapMemory =
+ Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
+ LOG.debug(
+ "Require memory succeeded with "
+ + size
+ + " bytes, usedMemory["
+ + usedMemory.get()
+ + "] include preAllocation["
+ + preAllocatedSize.get()
+ + "], inFlushSize["
+ + inFlushSize.get()
+ + "], usedDirectMemory["
+ + usedDirectMemory
+ + "], usedHeapMemory["
+ + usedHeapMemory
+ + "]");
+ }
return true;
}
if (LOG.isDebugEnabled()) {
@@ -372,7 +400,7 @@
+ inFlushSize.get()
+ "] is less than released["
+ size
- + "], set allocated memory to 0");
+ + "], set in flush memory to 0");
inFlushSize.set(0L);
}
ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
@@ -465,7 +493,17 @@
}
public void releasePreAllocatedSize(long delta) {
- preAllocatedSize.addAndGet(-delta);
+ if (preAllocatedSize.get() >= delta) {
+ preAllocatedSize.addAndGet(-delta);
+ } else {
+ LOG.warn(
+ "Current pre-allocated memory["
+ + preAllocatedSize.get()
+ + "] is less than released["
+ + delta
+ + "], set pre-allocated memory to 0");
+ preAllocatedSize.set(0L);
+ }
ShuffleServerMetrics.gaugeAllocatedBufferSize.set(preAllocatedSize.get());
}
diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index 184cde0..ac8973e 100644
--- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -56,6 +56,7 @@
import org.apache.uniffle.server.ShuffleServerMetrics;
import org.apache.uniffle.server.ShuffleTaskManager;
import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
+import org.apache.uniffle.server.buffer.ShuffleBufferManager;
import org.apache.uniffle.storage.common.Storage;
import org.apache.uniffle.storage.common.StorageReadMetrics;
import org.apache.uniffle.storage.util.ShuffleStorageUtils;
@@ -114,11 +115,13 @@
}
}
int requireSize = shuffleServer.getShuffleTaskManager().getRequireBufferSize(requireBufferId);
+ int requireBlocksSize =
+ requireSize - req.encodedLength() < 0 ? 0 : requireSize - req.encodedLength();
StatusCode ret = StatusCode.SUCCESS;
String responseMessage = "OK";
if (req.getPartitionToBlocks().size() > 0) {
- ShuffleServerMetrics.counterTotalReceivedDataSize.inc(requireSize);
+ ShuffleServerMetrics.counterTotalReceivedDataSize.inc(requireBlocksSize);
ShuffleTaskManager manager = shuffleServer.getShuffleTaskManager();
PreAllocatedBufferInfo info = manager.getAndRemovePreAllocatedBuffer(requireBufferId);
boolean isPreAllocated = info != null;
@@ -134,18 +137,21 @@
+ appId
+ "], shuffleId["
+ shuffleId
- + "]";
+ + "], probably because the pre-allocated buffer has expired. "
+ + "Please increase the expiration time using "
+ + ShuffleServerConf.SERVER_PRE_ALLOCATION_EXPIRED.key()
+ + " in ShuffleServer's configuration";
LOG.warn(errorMsg);
- responseMessage = errorMsg;
- rpcResponse =
- new RpcResponse(req.getRequestId(), StatusCode.INTERNAL_ERROR, responseMessage);
+ rpcResponse = new RpcResponse(req.getRequestId(), StatusCode.INTERNAL_ERROR, errorMsg);
client.getChannel().writeAndFlush(rpcResponse);
return;
}
final long start = System.currentTimeMillis();
+ ShuffleBufferManager shuffleBufferManager = shuffleServer.getShuffleBufferManager();
+ shuffleBufferManager.releaseMemory(req.encodedLength(), false, true);
List<ShufflePartitionedData> shufflePartitionedData = toPartitionedData(req);
long alreadyReleasedSize = 0;
- boolean isFailureOccurs = false;
+ boolean hasFailureOccurred = false;
for (ShufflePartitionedData spd : shufflePartitionedData) {
String shuffleDataInfo =
"appId["
@@ -156,7 +162,7 @@
+ spd.getPartitionId()
+ "]";
try {
- if (isFailureOccurs) {
+ if (hasFailureOccurred) {
continue;
}
ret = manager.cacheShuffleData(appId, shuffleId, isPreAllocated, spd);
@@ -168,7 +174,7 @@
+ ret;
LOG.error(errorMsg);
responseMessage = errorMsg;
- isFailureOccurs = true;
+ hasFailureOccurred = true;
} else {
long toReleasedSize = spd.getTotalBlockSize();
// after each cacheShuffleData call, the `preAllocatedSize` is updated timely.
@@ -186,11 +192,12 @@
ret = StatusCode.INTERNAL_ERROR;
responseMessage = errorMsg;
LOG.error(errorMsg);
- isFailureOccurs = true;
+ hasFailureOccurred = true;
} finally {
// Once the cache failure occurs, we should explicitly release data held by byteBuf
- if (isFailureOccurs) {
+ if (hasFailureOccurred) {
Arrays.stream(spd.getBlockList()).forEach(block -> block.getData().release());
+ shuffleBufferManager.releaseMemory(spd.getTotalBlockSize(), false, false);
}
}
}
@@ -199,8 +206,8 @@
// current connection succeeded or not. Therefore, the preAllocatedBuffer is first get and
// removed, then after
// cacheShuffleData finishes, the preAllocatedSize should be updated accordingly.
- if (info.getRequireSize() > alreadyReleasedSize) {
- manager.releasePreAllocatedSize(info.getRequireSize() - alreadyReleasedSize);
+ if (requireBlocksSize > alreadyReleasedSize) {
+ manager.releasePreAllocatedSize(requireBlocksSize - alreadyReleasedSize);
}
rpcResponse = new RpcResponse(req.getRequestId(), ret, responseMessage);
long costTime = System.currentTimeMillis() - start;
@@ -218,7 +225,7 @@
+ " ms with "
+ shufflePartitionedData.size()
+ " blocks and "
- + requireSize
+ + requireBlocksSize
+ " bytes");
}
} else {