[#1606] feat(client): Add client retry mechanism for NO_BUFFER when reading data(memory/local/index) (#1616)
### What changes were proposed in this pull request?
1. Remove server-side retry when reading data, and switch to client-side backoff retry;
2. Optimize lock contention issues when reading data, using CAS.
### Why are the changes needed?
Fix https://github.com/apache/incubator-uniffle/issues/1606.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New UTs added. Tested in our env.
diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
index 1f7aae9..df3f667 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
@@ -216,6 +216,13 @@
LOG.info("In reduce: " + reduceId + ", Rss MR client starts to fetch blocks from RSS server");
JobConf readerJobConf = getRemoteConf();
boolean expectedTaskIdsBitmapFilterEnable = serverInfoList.size() > 1;
+ int retryMax =
+ rssJobConf.getInt(
+ RssMRConfig.RSS_CLIENT_RETRY_MAX, RssMRConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE);
+ long retryIntervalMax =
+ rssJobConf.getLong(
+ RssMRConfig.RSS_CLIENT_RETRY_INTERVAL_MAX,
+ RssMRConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance()
.createShuffleReadClient(
@@ -232,6 +239,8 @@
.hadoopConf(readerJobConf)
.idHelper(new MRIdHelper())
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
.rssConf(RssMRConfig.toRssConf(rssJobConf)));
RssFetcher fetcher =
new RssFetcher(
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 75855ba..76bfed6 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -108,6 +108,14 @@
@Override
public Iterator<Product2<K, C>> read() {
LOG.info("Shuffle read started:" + getReadInfo());
+ int retryMax =
+ rssConf.getInteger(
+ RssClientConfig.RSS_CLIENT_RETRY_MAX,
+ RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE);
+ long retryIntervalMax =
+ rssConf.getLong(
+ RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX,
+ RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance()
.createShuffleReadClient(
@@ -123,6 +131,8 @@
.shuffleServerInfoList(shuffleServerInfoList)
.hadoopConf(hadoopConf)
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
.rssConf(rssConf));
RssShuffleDataIterator rssShuffleDataIterator =
new RssShuffleDataIterator<K, C>(
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 3d7b58b..0c7f3be 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -246,6 +246,14 @@
boolean expectedTaskIdsBitmapFilterEnable =
!(mapStartIndex == 0 && mapEndIndex == Integer.MAX_VALUE)
|| shuffleServerInfoList.size() > 1;
+ int retryMax =
+ rssConf.getInteger(
+ RssClientConfig.RSS_CLIENT_RETRY_MAX,
+ RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE);
+ long retryIntervalMax =
+ rssConf.getLong(
+ RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX,
+ RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance()
.createShuffleReadClient(
@@ -262,6 +270,8 @@
.hadoopConf(hadoopConf)
.shuffleDataDistributionType(dataDistributionType)
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
.rssConf(rssConf));
RssShuffleDataIterator<K, C> iterator =
new RssShuffleDataIterator<>(
diff --git a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index ce7d900..8efdd44 100644
--- a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++ b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -212,6 +212,8 @@
private int indexReadLimit;
private long readBufferSize;
private ClientType clientType;
+ private int retryMax;
+ private long retryIntervalMax;
public ReadClientBuilder appId(String appId) {
this.appId = appId;
@@ -310,6 +312,16 @@
return this;
}
+ public ReadClientBuilder retryMax(int retryMax) {
+ this.retryMax = retryMax;
+ return this;
+ }
+
+ public ReadClientBuilder retryIntervalMax(long retryIntervalMax) {
+ this.retryIntervalMax = retryIntervalMax;
+ return this;
+ }
+
public ReadClientBuilder() {}
public String getAppId() {
@@ -388,6 +400,14 @@
return clientType;
}
+ public int getRetryMax() {
+ return retryMax;
+ }
+
+ public long getRetryIntervalMax() {
+ return retryIntervalMax;
+ }
+
public ShuffleReadClientImpl build() {
return new ShuffleReadClientImpl(this);
}
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 49bb2de..4a789bf 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -168,6 +168,8 @@
request.setExpectTaskIds(taskIdBitmap);
request.setClientConf(builder.getRssConf());
request.setClientType(builder.getClientType());
+ request.setRetryMax(builder.getRetryMax());
+ request.setRetryIntervalMax(builder.getRetryIntervalMax());
if (builder.isExpectedTaskIdsBitmapFilterEnable()) {
request.useExpectedTaskIdsBitmapFilter();
}
diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
new file mode 100644
index 0000000..abefb1b
--- /dev/null
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Stream;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.io.TempDir;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.ServerType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.coordinator.CoordinatorServer;
+import org.apache.uniffle.server.MockedGrpcServer;
+import org.apache.uniffle.server.MockedShuffleServer;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class RpcClientRetryTest extends ShuffleReadWriteBase {
+
+ private static ShuffleServerInfo shuffleServerInfo0;
+ private static ShuffleServerInfo shuffleServerInfo1;
+ private static ShuffleServerInfo shuffleServerInfo2;
+ private static MockedShuffleWriteClientImpl shuffleWriteClientImpl;
+
+ private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType storageType) {
+ return ShuffleClientFactory.newReadBuilder()
+ .storageType(storageType.name())
+ .shuffleId(0)
+ .partitionId(0)
+ .indexReadLimit(100)
+ .partitionNumPerRange(1)
+ .partitionNum(10)
+ .readBufferSize(1000);
+ }
+
+ public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir)
+ throws Exception {
+ ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC);
+ File dataDir1 = new File(tmpDir, id + "_1");
+ File dataDir2 = new File(tmpDir, id + "_2");
+ String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
+ shuffleServerConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE.name());
+ shuffleServerConf.setString("rss.storage.basePath", basePath);
+ shuffleServerConf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE, 5.0);
+ shuffleServerConf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, 15.0);
+ shuffleServerConf.set(ShuffleServerConf.SERVER_BUFFER_CAPACITY, 600L);
+ return new MockedShuffleServer(shuffleServerConf);
+ }
+
+ @BeforeAll
+ public static void initCluster(@TempDir File tmpDir) throws Exception {
+ CoordinatorConf coordinatorConf = getCoordinatorConf();
+ createCoordinatorServer(coordinatorConf);
+
+ grpcShuffleServers.add(createMockedShuffleServer(0, tmpDir));
+ grpcShuffleServers.add(createMockedShuffleServer(1, tmpDir));
+ grpcShuffleServers.add(createMockedShuffleServer(2, tmpDir));
+
+ shuffleServerInfo0 =
+ new ShuffleServerInfo(
+ String.format("127.0.0.1-%s", grpcShuffleServers.get(0).getGrpcPort()),
+ grpcShuffleServers.get(0).getIp(),
+ grpcShuffleServers.get(0).getGrpcPort());
+ shuffleServerInfo1 =
+ new ShuffleServerInfo(
+ String.format("127.0.0.1-%s", grpcShuffleServers.get(1).getGrpcPort()),
+ grpcShuffleServers.get(1).getIp(),
+ grpcShuffleServers.get(1).getGrpcPort());
+ shuffleServerInfo2 =
+ new ShuffleServerInfo(
+ String.format("127.0.0.1-%s", grpcShuffleServers.get(2).getGrpcPort()),
+ grpcShuffleServers.get(2).getIp(),
+ grpcShuffleServers.get(2).getGrpcPort());
+ for (CoordinatorServer coordinator : coordinators) {
+ coordinator.start();
+ }
+ for (ShuffleServer shuffleServer : grpcShuffleServers) {
+ shuffleServer.start();
+ }
+ }
+
+ public static void cleanCluster() throws Exception {
+ for (CoordinatorServer coordinator : coordinators) {
+ coordinator.stopServer();
+ }
+ for (ShuffleServer shuffleServer : grpcShuffleServers) {
+ shuffleServer.stopServer();
+ }
+ grpcShuffleServers = Lists.newArrayList();
+ coordinators = Lists.newArrayList();
+ }
+
+ @AfterAll
+ public static void cleanEnv() throws Exception {
+ if (shuffleWriteClientImpl != null) {
+ shuffleWriteClientImpl.close();
+ }
+ cleanCluster();
+ }
+
+ private static Stream<Arguments> testRpcRetryLogicProvider() {
+ return Stream.of(
+ Arguments.of(StorageType.MEMORY_LOCALFILE),
+ // According to SERVER_BUFFER_CAPACITY & SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE,
+ // data will be flushed to disk, so read from disk only
+ Arguments.of(StorageType.LOCALFILE));
+ }
+
+ @ParameterizedTest
+ @MethodSource("testRpcRetryLogicProvider")
+ public void testRpcRetryLogic(StorageType storageType) {
+ String testAppId = "testRpcRetryLogic";
+ registerShuffleServer(testAppId, 3, 2, 2, true);
+ Map<Long, byte[]> expectedData = Maps.newHashMap();
+ Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
+
+ List<ShuffleBlockInfo> blocks =
+ createShuffleBlockList(
+ 0,
+ 0,
+ 0,
+ 3,
+ 25,
+ blockIdBitmap,
+ expectedData,
+ Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
+
+ SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks);
+ Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
+ Roaring64NavigableMap successfulBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
+ for (Long blockId : result.getSuccessBlockIds()) {
+ successfulBlockIdBitmap.addLong(blockId);
+ }
+ for (Long blockId : result.getFailedBlockIds()) {
+ failedBlockIdBitmap.addLong(blockId);
+ }
+ assertEquals(0, failedBlockIdBitmap.getLongCardinality());
+ assertEquals(blockIdBitmap, successfulBlockIdBitmap);
+
+ Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
+
+ ShuffleReadClientImpl readClient1 =
+ baseReadBuilder(storageType)
+ .appId(testAppId)
+ .blockIdBitmap(blockIdBitmap)
+ .taskIdBitmap(taskIdBitmap)
+ .shuffleServerInfoList(
+ Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2))
+ .retryMax(3)
+ .retryIntervalMax(1)
+ .build();
+
+ // The data cannot be read because the maximum number of retries is 3
+ enableFirstNReadRequestsToFail(4);
+ try {
+ validateResult(readClient1, expectedData);
+ fail();
+ } catch (Exception e) {
+ // do nothing
+ }
+ disableFirstNReadRequestsToFail();
+
+ ShuffleReadClientImpl readClient2 =
+ baseReadBuilder(storageType)
+ .appId(testAppId)
+ .blockIdBitmap(blockIdBitmap)
+ .taskIdBitmap(taskIdBitmap)
+ .shuffleServerInfoList(
+ Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2))
+ .retryMax(3)
+ .retryIntervalMax(1)
+ .build();
+
+ // The data can be read because the reader will retry
+ enableFirstNReadRequestsToFail(1);
+ validateResult(readClient2, expectedData);
+ disableFirstNReadRequestsToFail();
+ }
+
+ private static void enableFirstNReadRequestsToFail(int failedCount) {
+ for (ShuffleServer server : grpcShuffleServers) {
+ ((MockedGrpcServer) server.getServer())
+ .getService()
+ .enableFirstNReadRequestToFail(failedCount);
+ }
+ }
+
+ private static void disableFirstNReadRequestsToFail() {
+ for (ShuffleServer server : grpcShuffleServers) {
+ ((MockedGrpcServer) server.getServer()).getService().resetFirstNReadRequestToFail();
+ }
+ }
+
+ static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl {
+ MockedShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder builder) {
+ super(builder);
+ }
+
+ public SendShuffleDataResult sendShuffleData(
+ String appId, List<ShuffleBlockInfo> shuffleBlockInfoList) {
+ return super.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
+ }
+ }
+
+ private void registerShuffleServer(
+ String testAppId, int replica, int replicaWrite, int replicaRead, boolean replicaSkip) {
+
+ shuffleWriteClientImpl =
+ new MockedShuffleWriteClientImpl(
+ ShuffleClientFactory.newWriteBuilder()
+ .clientType(ClientType.GRPC.name())
+ .retryMax(3)
+ .retryIntervalMax(1000)
+ .heartBeatThreadNum(1)
+ .replica(replica)
+ .replicaWrite(replicaWrite)
+ .replicaRead(replicaRead)
+ .replicaSkipEnabled(replicaSkip)
+ .dataTransferPoolSize(1)
+ .dataCommitPoolSize(1)
+ .unregisterThreadPoolSize(10)
+ .unregisterRequestTimeSec(10));
+
+ List<ShuffleServerInfo> allServers =
+ Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2);
+
+ for (int i = 0; i < replica; i++) {
+ shuffleWriteClientImpl.registerShuffle(
+ allServers.get(i),
+ testAppId,
+ 0,
+ Lists.newArrayList(new PartitionRange(0, 0)),
+ new RemoteStorageInfo(""),
+ ShuffleDataDistributionType.NORMAL,
+ 1);
+ }
+ }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 5a6919e..3ab81a0 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -35,6 +35,7 @@
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.request.RetryableRequest;
import org.apache.uniffle.client.request.RssAppHeartBeatRequest;
import org.apache.uniffle.client.request.RssFinishShuffleRequest;
import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest;
@@ -109,12 +110,26 @@
import org.apache.uniffle.proto.ShuffleServerGrpc;
import org.apache.uniffle.proto.ShuffleServerGrpc.ShuffleServerBlockingStub;
+import static org.apache.uniffle.proto.RssProtos.StatusCode.NO_BUFFER;
+
public class ShuffleServerGrpcClient extends GrpcClient implements ShuffleServerClient {
private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerGrpcClient.class);
protected static final long FAILED_REQUIRE_ID = -1;
protected long rpcTimeout;
private ShuffleServerBlockingStub blockingStub;
+ /**
+ * A single instance of the Random class is created as a member variable to be reused throughout
+ * `ShuffleServerGrpcClient`. This approach has the following benefits: 1. Performance
+ * optimization: It avoids the overhead of creating and destroying objects frequently, reducing
+ * memory allocation and garbage collection costs. 2. Randomness: Reusing the same Random object
+ * helps maintain the randomness of the generated numbers. If multiple Random objects are created
+ * in a short period of time, their seeds may be the same or very close, leading to less random
+ * numbers.
+ */
+ protected Random random = new Random();
+
+ protected static final int BACK_OFF_BASE = 2000;
@VisibleForTesting
public ShuffleServerGrpcClient(String host, int port) {
@@ -237,8 +252,6 @@
long start = System.currentTimeMillis();
int retry = 0;
long result = FAILED_REQUIRE_ID;
- Random random = new Random();
- final int backOffBase = 2000;
if (LOG.isDebugEnabled()) {
LOG.debug(
"Requiring buffer for appId: {}, shuffleId: {}, partitionIds: {} with {} bytes from {}:{}",
@@ -258,7 +271,7 @@
"Exception happened when requiring pre-allocated buffer from {}:{}", host, port, e);
return result;
}
- if (rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER
+ if (rpcResponse.getStatus() != NO_BUFFER
&& rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER_FOR_HUGE_PARTITION) {
break;
}
@@ -291,7 +304,7 @@
long backoffTime =
Math.min(
retryIntervalMax,
- backOffBase * (1L << Math.min(retry, 16)) + random.nextInt(backOffBase));
+ BACK_OFF_BASE * (1L << Math.min(retry, 16)) + random.nextInt(BACK_OFF_BASE));
Thread.sleep(backoffTime);
} catch (Exception e) {
LOG.warn(
@@ -822,7 +835,6 @@
.setLength(request.getLength())
.setTimestamp(start)
.build();
- GetLocalShuffleDataResponse rpcResponse = getBlockingStub().getLocalShuffleData(rpcRequest);
String requestInfo =
"appId["
+ request.getAppId()
@@ -831,22 +843,29 @@
+ "], partitionId["
+ request.getPartitionId()
+ "]";
- LOG.info(
- "GetShuffleData from {}:{} for {} cost {} ms",
- host,
- port,
- requestInfo,
- System.currentTimeMillis() - start);
-
- RssProtos.StatusCode statusCode = rpcResponse.getStatus();
-
+ int retry = 0;
+ GetLocalShuffleDataResponse rpcResponse;
+ while (true) {
+ rpcResponse = getBlockingStub().getLocalShuffleData(rpcRequest);
+ if (rpcResponse.getStatus() != NO_BUFFER) {
+ break;
+ }
+ waitOrThrow(
+ request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start);
+ retry++;
+ }
RssGetShuffleDataResponse response;
- switch (statusCode) {
+ switch (rpcResponse.getStatus()) {
case SUCCESS:
+ LOG.info(
+ "GetShuffleData from {}:{} for {} cost {} ms",
+ host,
+ port,
+ requestInfo,
+ System.currentTimeMillis() - start);
response =
new RssGetShuffleDataResponse(
StatusCode.SUCCESS, ByteBuffer.wrap(rpcResponse.getData().toByteArray()));
-
break;
default:
String msg =
@@ -874,8 +893,6 @@
.setPartitionNumPerRange(request.getPartitionNumPerRange())
.setPartitionNum(request.getPartitionNum())
.build();
- long start = System.currentTimeMillis();
- GetLocalShuffleIndexResponse rpcResponse = getBlockingStub().getLocalShuffleIndex(rpcRequest);
String requestInfo =
"appId["
+ request.getAppId()
@@ -884,18 +901,27 @@
+ "], partitionId["
+ request.getPartitionId()
+ "]";
- LOG.info(
- "GetShuffleIndex from {}:{} for {} cost {} ms",
- host,
- port,
- requestInfo,
- System.currentTimeMillis() - start);
-
- RssProtos.StatusCode statusCode = rpcResponse.getStatus();
-
+ long start = System.currentTimeMillis();
+ int retry = 0;
+ GetLocalShuffleIndexResponse rpcResponse;
+ while (true) {
+ rpcResponse = getBlockingStub().getLocalShuffleIndex(rpcRequest);
+ if (rpcResponse.getStatus() != NO_BUFFER) {
+ break;
+ }
+ waitOrThrow(
+ request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start);
+ retry++;
+ }
RssGetShuffleIndexResponse response;
- switch (statusCode) {
+ switch (rpcResponse.getStatus()) {
case SUCCESS:
+ LOG.info(
+ "GetShuffleIndex from {}:{} for {} cost {} ms",
+ host,
+ port,
+ requestInfo,
+ System.currentTimeMillis() - start);
response =
new RssGetShuffleIndexResponse(
StatusCode.SUCCESS,
@@ -944,8 +970,6 @@
.setSerializedExpectedTaskIdsBitmap(serializedTaskIdsBytes)
.setTimestamp(start)
.build();
-
- GetMemoryShuffleDataResponse rpcResponse = getBlockingStub().getMemoryShuffleData(rpcRequest);
String requestInfo =
"appId["
+ request.getAppId()
@@ -954,20 +978,28 @@
+ "], partitionId["
+ request.getPartitionId()
+ "]";
- LOG.info(
- "GetInMemoryShuffleData from {}:{} for "
- + requestInfo
- + " cost "
- + (System.currentTimeMillis() - start)
- + " ms",
- host,
- port);
-
- RssProtos.StatusCode statusCode = rpcResponse.getStatus();
-
+ int retry = 0;
+ GetMemoryShuffleDataResponse rpcResponse;
+ while (true) {
+ rpcResponse = getBlockingStub().getMemoryShuffleData(rpcRequest);
+ if (rpcResponse.getStatus() != NO_BUFFER) {
+ break;
+ }
+ waitOrThrow(
+ request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start);
+ retry++;
+ }
RssGetInMemoryShuffleDataResponse response;
- switch (statusCode) {
+ switch (rpcResponse.getStatus()) {
case SUCCESS:
+ LOG.info(
+ "GetInMemoryShuffleData from {}:{} for "
+ + requestInfo
+ + " cost "
+ + (System.currentTimeMillis() - start)
+ + " ms",
+ host,
+ port);
response =
new RssGetInMemoryShuffleDataResponse(
StatusCode.SUCCESS,
@@ -995,6 +1027,47 @@
return "ShuffleServerGrpcClient for host[" + host + "], port[" + port + "]";
}
+ protected void waitOrThrow(
+ RetryableRequest request, int retry, String requestInfo, StatusCode statusCode, long start) {
+ if (retry >= request.getRetryMax()) {
+ String msg =
+ String.format(
+ "ShuffleServer %s:%s is full when %s due to %s, after %d retries, cost %d ms",
+ host,
+ port,
+ request.operationType(),
+ statusCode,
+ request.getRetryMax(),
+ System.currentTimeMillis() - start);
+ LOG.error(msg);
+ throw new RssFetchFailedException(msg);
+ }
+ try {
+ long backoffTime =
+ Math.min(
+ request.getRetryIntervalMax(),
+ BACK_OFF_BASE * (1L << Math.min(retry, 16)) + random.nextInt(BACK_OFF_BASE));
+ LOG.warn(
+ "Can't acquire buffer for {} from {}:{} when executing {}, due to {}. "
+ + "Will retry {} more time(s) after waiting {} milliseconds.",
+ requestInfo,
+ host,
+ port,
+ request.operationType(),
+ statusCode,
+ request.getRetryMax() - retry,
+ backoffTime);
+ Thread.sleep(backoffTime);
+ } catch (InterruptedException e) {
+ LOG.warn(
+ "Exception happened when executing {} from {}:{}",
+ request.operationType(),
+ host,
+ port,
+ e);
+ }
+ }
+
private List<ShufflePartitionRange> toShufflePartitionRanges(
List<PartitionRange> partitionRanges) {
List<ShufflePartitionRange> ret = Lists.newArrayList();
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index fc8aa02..f677b63 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -221,12 +221,27 @@
+ "], lastBlockId["
+ request.getLastBlockId()
+ "]";
- RpcResponse rpcResponse = transportClient.sendRpcSync(getMemoryShuffleDataRequest, rpcTimeout);
- GetMemoryShuffleDataResponse getMemoryShuffleDataResponse =
- (GetMemoryShuffleDataResponse) rpcResponse;
- StatusCode statusCode = rpcResponse.getStatusCode();
- switch (statusCode) {
+ long start = System.currentTimeMillis();
+ int retry = 0;
+ RpcResponse rpcResponse;
+ GetMemoryShuffleDataResponse getMemoryShuffleDataResponse;
+ while (true) {
+ rpcResponse = transportClient.sendRpcSync(getMemoryShuffleDataRequest, rpcTimeout);
+ getMemoryShuffleDataResponse = (GetMemoryShuffleDataResponse) rpcResponse;
+ if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) {
+ break;
+ }
+ waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
+ retry++;
+ }
+ switch (rpcResponse.getStatusCode()) {
case SUCCESS:
+ LOG.info(
+ "GetInMemoryShuffleData from {}:{} for {} cost {} ms",
+ host,
+ nettyPort,
+ requestInfo,
+ System.currentTimeMillis() - start);
return new RssGetInMemoryShuffleDataResponse(
StatusCode.SUCCESS,
getMemoryShuffleDataResponse.body(),
@@ -236,7 +251,7 @@
"Can't get shuffle in memory data from "
+ host
+ ":"
- + port
+ + nettyPort
+ " for "
+ requestInfo
+ ", errorMsg:"
@@ -257,8 +272,6 @@
request.getPartitionId(),
request.getPartitionNumPerRange(),
request.getPartitionNum());
- long start = System.currentTimeMillis();
- RpcResponse rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout);
String requestInfo =
"appId["
+ request.getAppId()
@@ -266,17 +279,27 @@
+ request.getShuffleId()
+ "], partitionId["
+ request.getPartitionId();
- LOG.info(
- "GetShuffleIndex from {}:{} for {} cost {} ms",
- host,
- port,
- requestInfo,
- System.currentTimeMillis() - start);
- GetLocalShuffleIndexResponse getLocalShuffleIndexResponse =
- (GetLocalShuffleIndexResponse) rpcResponse;
- StatusCode statusCode = rpcResponse.getStatusCode();
- switch (statusCode) {
+ long start = System.currentTimeMillis();
+ int retry = 0;
+ RpcResponse rpcResponse;
+ GetLocalShuffleIndexResponse getLocalShuffleIndexResponse;
+ while (true) {
+ rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout);
+ getLocalShuffleIndexResponse = (GetLocalShuffleIndexResponse) rpcResponse;
+ if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) {
+ break;
+ }
+ waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
+ retry++;
+ }
+ switch (rpcResponse.getStatusCode()) {
case SUCCESS:
+ LOG.info(
+ "GetShuffleIndex from {}:{} for {} cost {} ms",
+ host,
+ nettyPort,
+ requestInfo,
+ System.currentTimeMillis() - start);
return new RssGetShuffleIndexResponse(
StatusCode.SUCCESS,
getLocalShuffleIndexResponse.body(),
@@ -286,7 +309,7 @@
"Can't get shuffle index from "
+ host
+ ":"
- + port
+ + nettyPort
+ " for "
+ requestInfo
+ ", errorMsg:"
@@ -310,8 +333,6 @@
request.getOffset(),
request.getLength(),
System.currentTimeMillis());
- long start = System.currentTimeMillis();
- RpcResponse rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout);
String requestInfo =
"appId["
+ request.getAppId()
@@ -320,17 +341,27 @@
+ "], partitionId["
+ request.getPartitionId()
+ "]";
- LOG.info(
- "GetShuffleData from {}:{} for {} cost {} ms",
- host,
- port,
- requestInfo,
- System.currentTimeMillis() - start);
- GetLocalShuffleDataResponse getLocalShuffleDataResponse =
- (GetLocalShuffleDataResponse) rpcResponse;
- StatusCode statusCode = rpcResponse.getStatusCode();
- switch (statusCode) {
+ long start = System.currentTimeMillis();
+ int retry = 0;
+ RpcResponse rpcResponse;
+ GetLocalShuffleDataResponse getLocalShuffleDataResponse;
+ while (true) {
+ rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout);
+ getLocalShuffleDataResponse = (GetLocalShuffleDataResponse) rpcResponse;
+ if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) {
+ break;
+ }
+ waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
+ retry++;
+ }
+ switch (rpcResponse.getStatusCode()) {
case SUCCESS:
+ LOG.info(
+ "GetShuffleData from {}:{} for {} cost {} ms",
+ host,
+ nettyPort,
+ requestInfo,
+ System.currentTimeMillis() - start);
return new RssGetShuffleDataResponse(
StatusCode.SUCCESS, getLocalShuffleDataResponse.body());
default:
@@ -338,7 +369,7 @@
"Can't get shuffle data from "
+ host
+ ":"
- + port
+ + nettyPort
+ " for "
+ requestInfo
+ ", errorMsg:"
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RetryableRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RetryableRequest.java
new file mode 100644
index 0000000..2abe4b2
--- /dev/null
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RetryableRequest.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.request;
+
+public abstract class RetryableRequest {
+ protected int retryMax;
+ protected long retryIntervalMax;
+
+ public int getRetryMax() {
+ return retryMax;
+ }
+
+ public long getRetryIntervalMax() {
+ return retryIntervalMax;
+ }
+
+ public abstract String operationType();
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
index bf3534e..64c4110 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java
@@ -17,9 +17,10 @@
package org.apache.uniffle.client.request;
+import com.google.common.annotations.VisibleForTesting;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
-public class RssGetInMemoryShuffleDataRequest {
+public class RssGetInMemoryShuffleDataRequest extends RetryableRequest {
private final String appId;
private final int shuffleId;
private final int partitionId;
@@ -33,13 +34,28 @@
int partitionId,
long lastBlockId,
int readBufferSize,
- Roaring64NavigableMap expectedTaskIds) {
+ Roaring64NavigableMap expectedTaskIds,
+ int retryMax,
+ long retryIntervalMax) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
this.lastBlockId = lastBlockId;
this.readBufferSize = readBufferSize;
this.expectedTaskIds = expectedTaskIds;
+ this.retryMax = retryMax;
+ this.retryIntervalMax = retryIntervalMax;
+ }
+
+ @VisibleForTesting
+ public RssGetInMemoryShuffleDataRequest(
+ String appId,
+ int shuffleId,
+ int partitionId,
+ long lastBlockId,
+ int readBufferSize,
+ Roaring64NavigableMap expectedTaskIds) {
+ this(appId, shuffleId, partitionId, lastBlockId, readBufferSize, expectedTaskIds, 1, 0);
}
public String getAppId() {
@@ -65,4 +81,9 @@
public Roaring64NavigableMap getExpectedTaskIds() {
return expectedTaskIds;
}
+
+ @Override
+ public String operationType() {
+ return "GetInMemoryShuffleData";
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
index 0b9997a..5801922 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
@@ -17,7 +17,9 @@
package org.apache.uniffle.client.request;
-public class RssGetShuffleDataRequest {
+import com.google.common.annotations.VisibleForTesting;
+
+public class RssGetShuffleDataRequest extends RetryableRequest {
private final String appId;
private final int shuffleId;
@@ -34,7 +36,9 @@
int partitionNumPerRange,
int partitionNum,
long offset,
- int length) {
+ int length,
+ int retryMax,
+ long retryIntervalMax) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
@@ -42,6 +46,20 @@
this.partitionNum = partitionNum;
this.offset = offset;
this.length = length;
+ this.retryMax = retryMax;
+ this.retryIntervalMax = retryIntervalMax;
+ }
+
+ @VisibleForTesting
+ public RssGetShuffleDataRequest(
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int partitionNumPerRange,
+ int partitionNum,
+ long offset,
+ int length) {
+ this(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, offset, length, 1, 0);
}
public String getAppId() {
@@ -71,4 +89,9 @@
public int getLength() {
return length;
}
+
+ @Override
+ public String operationType() {
+ return "GetShuffleData";
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java
index 8ae5da7..0e61206 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java
@@ -17,7 +17,9 @@
package org.apache.uniffle.client.request;
-public class RssGetShuffleIndexRequest {
+import com.google.common.annotations.VisibleForTesting;
+
+public class RssGetShuffleIndexRequest extends RetryableRequest {
private final String appId;
private final int shuffleId;
@@ -26,12 +28,26 @@
private final int partitionNum;
public RssGetShuffleIndexRequest(
- String appId, int shuffleId, int partitionId, int partitionNumPerRange, int partitionNum) {
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int partitionNumPerRange,
+ int partitionNum,
+ int retryMax,
+ long retryIntervalMax) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
this.partitionNumPerRange = partitionNumPerRange;
this.partitionNum = partitionNum;
+ this.retryMax = retryMax;
+ this.retryIntervalMax = retryIntervalMax;
+ }
+
+ @VisibleForTesting
+ public RssGetShuffleIndexRequest(
+ String appId, int shuffleId, int partitionId, int partitionNumPerRange, int partitionNum) {
+ this(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, 1, 0);
}
public String getAppId() {
@@ -53,4 +69,9 @@
public int getPartitionNum() {
return partitionNum;
}
+
+ @Override
+ public String operationType() {
+ return "GetShuffleIndex";
+ }
}
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
index 9ea2e84..bd71c3b 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
@@ -132,12 +132,6 @@
.withDescription(
"Expired time (ms) for application which has no heartbeat with coordinator");
- public static final ConfigOption<Integer> SERVER_MEMORY_REQUEST_RETRY_MAX =
- ConfigOptions.key("rss.server.memory.request.retry.max")
- .intType()
- .defaultValue(50)
- .withDescription("Max times to retry for memory request");
-
public static final ConfigOption<Long> SERVER_PRE_ALLOCATION_EXPIRED =
ConfigOptions.key("rss.server.preAllocation.expired")
.longType()
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index dcc2717..9f8f79e 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -674,7 +674,7 @@
storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
}
- if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) {
+ if (shuffleServer.getShuffleBufferManager().requireReadMemory(length)) {
try {
long start = System.currentTimeMillis();
sdr =
@@ -722,7 +722,7 @@
shuffleServer.getShuffleBufferManager().releaseReadMemory(length);
}
} else {
- status = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.NO_BUFFER;
msg = "Can't require memory to get shuffle data";
LOG.error(msg + " for " + requestInfo);
reply =
@@ -766,7 +766,7 @@
shuffleServer
.getShuffleServerConf()
.getLong(ShuffleServerConf.SERVER_SHUFFLE_INDEX_SIZE_HINT);
- if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(assumedFileSize)) {
+ if (shuffleServer.getShuffleBufferManager().requireReadMemory(assumedFileSize)) {
ShuffleIndexResult shuffleIndexResult = null;
try {
long start = System.currentTimeMillis();
@@ -812,7 +812,7 @@
shuffleServer.getShuffleBufferManager().releaseReadMemory(assumedFileSize);
}
} else {
- status = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.NO_BUFFER;
msg = "Can't require memory to get shuffle index";
LOG.error(msg + " for " + requestInfo);
reply =
@@ -853,7 +853,7 @@
"appId[" + appId + "], shuffleId[" + shuffleId + "], partitionId[" + partitionId + "]";
// todo: if can get the exact memory size?
- if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize)) {
+ if (shuffleServer.getShuffleBufferManager().requireReadMemory(readBufferSize)) {
ShuffleDataResult shuffleDataResult = null;
try {
Roaring64NavigableMap expectedTaskIds = null;
@@ -915,7 +915,7 @@
shuffleServer.getShuffleBufferManager().releaseReadMemory(readBufferSize);
}
} else {
- status = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.NO_BUFFER;
msg = "Can't require memory to get in memory shuffle data";
LOG.error(msg + " for " + requestInfo);
reply =
diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index 4d42b05..8f41a07 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -60,7 +60,6 @@
private final ShuffleFlushManager shuffleFlushManager;
private long capacity;
private long readCapacity;
- private int retryNum;
private long highWaterMark;
private long lowWaterMark;
private boolean bufferFlushEnabled;
@@ -111,7 +110,6 @@
readCapacity);
this.shuffleFlushManager = shuffleFlushManager;
this.bufferPool = new ConcurrentHashMap<>();
- this.retryNum = conf.getInteger(ShuffleServerConf.SERVER_MEMORY_REQUEST_RETRY_MAX);
this.highWaterMark =
(long)
(capacity
@@ -424,35 +422,37 @@
ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get());
}
- public boolean requireReadMemoryWithRetry(long size) {
+ public boolean requireReadMemory(long size) {
ShuffleServerMetrics.counterTotalRequireReadMemoryNum.inc();
- for (int i = 0; i < retryNum; i++) {
- synchronized (this) {
- if (readDataMemory.get() + size < readCapacity) {
- readDataMemory.addAndGet(size);
- ShuffleServerMetrics.gaugeReadBufferUsedSize.inc(size);
- return true;
- }
+ boolean isSuccessful = false;
+
+ do {
+ long currentReadDataMemory = readDataMemory.get();
+ long newReadDataMemory = currentReadDataMemory + size;
+ if (newReadDataMemory >= readCapacity) {
+ break;
}
- LOG.info(
+ if (readDataMemory.compareAndSet(currentReadDataMemory, newReadDataMemory)) {
+ ShuffleServerMetrics.gaugeReadBufferUsedSize.inc(size);
+ isSuccessful = true;
+ break;
+ }
+ } while (true);
+
+ if (!isSuccessful) {
+ LOG.error(
"Can't require["
+ size
+ "] for read data, current["
+ readDataMemory.get()
+ "], capacity["
+ readCapacity
- + "], re-try "
- + i
- + " times");
+ + "]");
ShuffleServerMetrics.counterTotalRequireReadMemoryRetryNum.inc();
- try {
- Thread.sleep(1000);
- } catch (Exception e) {
- LOG.warn("Error happened when require memory", e);
- }
+ ShuffleServerMetrics.counterTotalRequireReadMemoryFailedNum.inc();
}
- ShuffleServerMetrics.counterTotalRequireReadMemoryFailedNum.inc();
- return false;
+
+ return isSuccessful;
}
public void releaseReadMemory(long size) {
diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index 2e0c070..e87f9aa 100644
--- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -263,7 +263,7 @@
"appId[" + appId + "], shuffleId[" + shuffleId + "], partitionId[" + partitionId + "]";
// todo: if can get the exact memory size?
- if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize)) {
+ if (shuffleServer.getShuffleBufferManager().requireReadMemory(readBufferSize)) {
ShuffleDataResult shuffleDataResult = null;
try {
shuffleDataResult =
@@ -308,7 +308,7 @@
req.getRequestId(), status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER);
}
} else {
- status = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.NO_BUFFER;
msg = "Can't require memory to get in memory shuffle data";
LOG.error(msg + " for " + requestInfo);
response =
@@ -347,7 +347,7 @@
shuffleServer
.getShuffleServerConf()
.getLong(ShuffleServerConf.SERVER_SHUFFLE_INDEX_SIZE_HINT);
- if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(assumedFileSize)) {
+ if (shuffleServer.getShuffleBufferManager().requireReadMemory(assumedFileSize)) {
ShuffleIndexResult shuffleIndexResult = null;
try {
final long start = System.currentTimeMillis();
@@ -392,7 +392,7 @@
req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L);
}
} else {
- status = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.NO_BUFFER;
msg = "Can't require memory to get shuffle index";
LOG.error(msg + " for " + requestInfo);
response =
@@ -447,7 +447,7 @@
storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
}
- if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) {
+ if (shuffleServer.getShuffleBufferManager().requireReadMemory(length)) {
ShuffleDataResult sdr = null;
try {
final long start = System.currentTimeMillis();
@@ -486,7 +486,7 @@
req.getRequestId(), status, msg, new NettyManagedBuffer(Unpooled.EMPTY_BUFFER));
}
} else {
- status = StatusCode.INTERNAL_ERROR;
+ status = StatusCode.NO_BUFFER;
msg = "Can't require memory to get shuffle data";
LOG.error(msg + " for " + requestInfo);
response =
diff --git a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
index eafd832..87b9abd 100644
--- a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
+++ b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
@@ -22,11 +22,14 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Uninterruptibles;
+import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.proto.RssProtos;
@@ -45,7 +48,11 @@
private boolean recordGetShuffleResult = false;
private long numOfFailedReadRequest = 0;
- private AtomicInteger failedReadRequest = new AtomicInteger(0);
+ private AtomicInteger failedGetShuffleResultRequest = new AtomicInteger(0);
+ private AtomicInteger failedGetShuffleResultForMultiPartRequest = new AtomicInteger(0);
+ private AtomicInteger failedGetMemoryShuffleDataRequest = new AtomicInteger(0);
+ private AtomicInteger failedGetLocalShuffleDataRequest = new AtomicInteger(0);
+ private AtomicInteger failedGetLocalShuffleIndexRequest = new AtomicInteger(0);
public void enableMockedTimeout(long timeout) {
mockedTimeout = timeout;
@@ -69,7 +76,11 @@
public void resetFirstNReadRequestToFail() {
numOfFailedReadRequest = 0;
- failedReadRequest.set(0);
+ failedGetShuffleResultRequest.set(0);
+ failedGetShuffleResultForMultiPartRequest.set(0);
+ failedGetMemoryShuffleDataRequest.set(0);
+ failedGetLocalShuffleDataRequest.set(0);
+ failedGetLocalShuffleIndexRequest.set(0);
}
public MockedShuffleServerGrpcService(ShuffleServer shuffleServer) {
@@ -111,7 +122,7 @@
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
}
if (numOfFailedReadRequest > 0) {
- int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+ int currentFailedReadRequest = failedGetShuffleResultRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
@@ -128,11 +139,11 @@
RssProtos.GetShuffleResultForMultiPartRequest request,
StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> responseObserver) {
if (mockedTimeout > 0) {
- LOG.info("Add a mocked timeout on getShuffleResult");
+ LOG.info("Add a mocked timeout on getShuffleResultForMultiPart");
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
}
if (numOfFailedReadRequest > 0) {
- int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+ int currentFailedReadRequest = failedGetShuffleResultForMultiPartRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
@@ -163,15 +174,81 @@
RssProtos.GetMemoryShuffleDataRequest request,
StreamObserver<RssProtos.GetMemoryShuffleDataResponse> responseObserver) {
if (numOfFailedReadRequest > 0) {
- int currentFailedReadRequest = failedReadRequest.getAndIncrement();
+ int currentFailedReadRequest = failedGetMemoryShuffleDataRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
currentFailedReadRequest,
numOfFailedReadRequest);
- throw new RuntimeException("This request is failed as mocked failure");
+ StatusCode status = StatusCode.NO_BUFFER;
+ String msg =
+ "Can't require memory to get in memory shuffle data (This request is failed as mocked failure)";
+ RssProtos.GetMemoryShuffleDataResponse reply =
+ RssProtos.GetMemoryShuffleDataResponse.newBuilder()
+ .setData(UnsafeByteOperations.unsafeWrap(new byte[] {}))
+ .addAllShuffleDataBlockSegments(Lists.newArrayList())
+ .setStatus(status.toProto())
+ .setRetMsg(msg)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
}
}
super.getMemoryShuffleData(request, responseObserver);
}
+
+ @Override
+ public void getLocalShuffleData(
+ RssProtos.GetLocalShuffleDataRequest request,
+ StreamObserver<RssProtos.GetLocalShuffleDataResponse> responseObserver) {
+ if (numOfFailedReadRequest > 0) {
+ int currentFailedReadRequest = failedGetLocalShuffleDataRequest.getAndIncrement();
+ if (currentFailedReadRequest < numOfFailedReadRequest) {
+ LOG.info(
+ "This request is failed as mocked failure, current/firstN: {}/{}",
+ currentFailedReadRequest,
+ numOfFailedReadRequest);
+ StatusCode status = StatusCode.NO_BUFFER;
+ String msg =
+ "Can't require memory to get shuffle data (This request is failed as mocked failure)";
+ RssProtos.GetLocalShuffleDataResponse reply =
+ RssProtos.GetLocalShuffleDataResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(msg)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+ }
+ super.getLocalShuffleData(request, responseObserver);
+ }
+
+ @Override
+ public void getLocalShuffleIndex(
+ RssProtos.GetLocalShuffleIndexRequest request,
+ StreamObserver<RssProtos.GetLocalShuffleIndexResponse> responseObserver) {
+ if (numOfFailedReadRequest > 0) {
+ int currentFailedReadRequest = failedGetLocalShuffleIndexRequest.getAndIncrement();
+ if (currentFailedReadRequest < numOfFailedReadRequest) {
+ LOG.info(
+ "This request is failed as mocked failure, current/firstN: {}/{}",
+ currentFailedReadRequest,
+ numOfFailedReadRequest);
+ StatusCode status = StatusCode.NO_BUFFER;
+ String msg =
+ "Can't require memory to get shuffle index (This request is failed as mocked failure)";
+ RssProtos.GetLocalShuffleIndexResponse reply =
+ RssProtos.GetLocalShuffleIndexResponse.newBuilder()
+ .setStatus(status.toProto())
+ .setRetMsg(msg)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+ }
+ super.getLocalShuffleIndex(request, responseObserver);
+ }
}
diff --git a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 1db3326..819c26e 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -134,7 +134,9 @@
request.getPartitionId(),
request.getReadBufferSize(),
shuffleServerClient,
- expectTaskIds);
+ expectTaskIds,
+ request.getRetryMax(),
+ request.getRetryIntervalMax());
return memoryClientReadHandler;
}
@@ -155,7 +157,9 @@
request.getProcessBlockIds(),
shuffleServerClient,
request.getDistributionType(),
- request.getExpectTaskIds());
+ request.getExpectTaskIds(),
+ request.getRetryMax(),
+ request.getRetryIntervalMax());
}
private ClientReadHandler getHadoopClientReadHandler(
diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
index 9fc5088..2b5ea8f 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.storage.handler.impl;
+import com.google.common.annotations.VisibleForTesting;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,6 +38,8 @@
private final int partitionNumPerRange;
private final int partitionNum;
private ShuffleServerClient shuffleServerClient;
+ private int retryMax;
+ private long retryIntervalMax;
public LocalFileClientReadHandler(
String appId,
@@ -50,7 +53,9 @@
Roaring64NavigableMap processBlockIds,
ShuffleServerClient shuffleServerClient,
ShuffleDataDistributionType distributionType,
- Roaring64NavigableMap expectTaskIds) {
+ Roaring64NavigableMap expectTaskIds,
+ int retryMax,
+ long retryIntervalMax) {
super(
appId,
shuffleId,
@@ -63,9 +68,11 @@
this.shuffleServerClient = shuffleServerClient;
this.partitionNumPerRange = partitionNumPerRange;
this.partitionNum = partitionNum;
+ this.retryMax = retryMax;
+ this.retryIntervalMax = retryIntervalMax;
}
- /** Only for test */
+ @VisibleForTesting
public LocalFileClientReadHandler(
String appId,
int shuffleId,
@@ -89,7 +96,9 @@
processBlockIds,
shuffleServerClient,
ShuffleDataDistributionType.NORMAL,
- Roaring64NavigableMap.bitmapOf());
+ Roaring64NavigableMap.bitmapOf(),
+ 1,
+ 0);
}
@Override
@@ -97,7 +106,13 @@
ShuffleIndexResult shuffleIndexResult = null;
RssGetShuffleIndexRequest request =
new RssGetShuffleIndexRequest(
- appId, shuffleId, partitionId, partitionNumPerRange, partitionNum);
+ appId,
+ shuffleId,
+ partitionId,
+ partitionNumPerRange,
+ partitionNum,
+ retryMax,
+ retryIntervalMax);
try {
shuffleIndexResult = shuffleServerClient.getShuffleIndex(request).getShuffleIndexResult();
} catch (RssFetchFailedException e) {
@@ -141,17 +156,16 @@
partitionNumPerRange,
partitionNum,
shuffleDataSegment.getOffset(),
- expectedLength);
+ expectedLength,
+ retryMax,
+ retryIntervalMax);
try {
RssGetShuffleDataResponse response = shuffleServerClient.getShuffleData(request);
result =
new ShuffleDataResult(response.getShuffleData(), shuffleDataSegment.getBufferSegments());
} catch (Exception e) {
throw new RssException(
- "Failed to read shuffle data with "
- + shuffleServerClient.getClientInfo()
- + " due to "
- + e.getMessage());
+ "Failed to read shuffle data with " + shuffleServerClient.getClientInfo(), e);
}
if (result.getDataBuffer().remaining() != expectedLength) {
throw new RssException(
diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
index a3a7931..f1fbe23 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
@@ -19,6 +19,7 @@
import java.util.List;
+import com.google.common.annotations.VisibleForTesting;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,6 +38,8 @@
private long lastBlockId = Constants.INVALID_BLOCK_ID;
private ShuffleServerClient shuffleServerClient;
private Roaring64NavigableMap expectTaskIds;
+ private int retryMax;
+ private long retryIntervalMax;
public MemoryClientReadHandler(
String appId,
@@ -44,13 +47,28 @@
int partitionId,
int readBufferSize,
ShuffleServerClient shuffleServerClient,
- Roaring64NavigableMap expectTaskIds) {
+ Roaring64NavigableMap expectTaskIds,
+ int retryMax,
+ long retryIntervalMax) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
this.readBufferSize = readBufferSize;
this.shuffleServerClient = shuffleServerClient;
this.expectTaskIds = expectTaskIds;
+ this.retryMax = retryMax;
+ this.retryIntervalMax = retryIntervalMax;
+ }
+
+ @VisibleForTesting
+ public MemoryClientReadHandler(
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int readBufferSize,
+ ShuffleServerClient shuffleServerClient,
+ Roaring64NavigableMap expectTaskIds) {
+ this(appId, shuffleId, partitionId, readBufferSize, shuffleServerClient, expectTaskIds, 1, 0);
}
@Override
@@ -59,7 +77,14 @@
RssGetInMemoryShuffleDataRequest request =
new RssGetInMemoryShuffleDataRequest(
- appId, shuffleId, partitionId, lastBlockId, readBufferSize, expectTaskIds);
+ appId,
+ shuffleId,
+ partitionId,
+ lastBlockId,
+ readBufferSize,
+ expectTaskIds,
+ retryMax,
+ retryIntervalMax);
try {
RssGetInMemoryShuffleDataResponse response =
@@ -70,10 +95,7 @@
} catch (Exception e) {
// todo: fault tolerance solution should be added
throw new RssFetchFailedException(
- "Failed to read in memory shuffle data with "
- + shuffleServerClient.getClientInfo()
- + " due to "
- + e);
+ "Failed to read in memory shuffle data with " + shuffleServerClient.getClientInfo(), e);
}
// update lastBlockId for next rpc call
diff --git a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 38c7e9e..9b73dc8 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -39,6 +39,8 @@
private int partitionNumPerRange;
private int partitionNum;
private int readBufferSize;
+ private int retryMax;
+ private long retryIntervalMax;
private String storageBasePath;
private RssBaseConf rssBaseConf;
private Configuration hadoopConf;
@@ -129,6 +131,22 @@
this.readBufferSize = readBufferSize;
}
+ public int getRetryMax() {
+ return retryMax;
+ }
+
+ public void setRetryMax(int retryMax) {
+ this.retryMax = retryMax;
+ }
+
+ public long getRetryIntervalMax() {
+ return retryIntervalMax;
+ }
+
+ public void setRetryIntervalMax(long retryIntervalMax) {
+ this.retryIntervalMax = retryIntervalMax;
+ }
+
public String getStorageBasePath() {
return storageBasePath;
}
diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
index 884f2b9..2a55ae4 100644
--- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
+++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java
@@ -34,7 +34,6 @@
import org.apache.uniffle.client.request.RssGetShuffleDataRequest;
import org.apache.uniffle.client.response.RssGetShuffleDataResponse;
import org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
-import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
@@ -135,9 +134,7 @@
readBufferSize,
expectBlockIds,
processBlockIds,
- mockShuffleServerClient,
- ShuffleDataDistributionType.NORMAL,
- Roaring64NavigableMap.bitmapOf());
+ mockShuffleServerClient);
int totalSegment = ((blockSize * actualWriteDataBlock) / bytesPerSegment) + 1;
int readBlocks = 0;
for (int i = 0; i < totalSegment; i++) {