[#2636] feat(spark): Cache shuffle handle info for reader to reduce RPC cost when partition reassign is enabled (#2637)
### What changes were proposed in this pull request?
This PR is to introduce the cache mechanism to cache the read shuffle handle info to reduce the RPC cost and driver the GC pressure when the partition reassign is enabled
### Why are the changes needed?
for #2636 .
From the cluster spark jobs, I found some tasks failed on the failure of RPC of getting shuffle handle from the driver side when the partition reassign is enabled. This is the first step to optimize shuffle info getting for the reader side.
### Does this PR introduce _any_ user-facing change?
Yes.
`rss.client.read.shuffleHandleCacheEnabled=false`
### How was this patch tested?
Existing tests
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 4af6cd4..8dc17b7 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -39,6 +39,12 @@
public class RssSparkConfig {
+ public static final ConfigOption<Boolean> RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED =
+ ConfigOptions.key("rss.client.read.shuffleHandleCacheEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription("Whether or not to read shuffle handle cache enabled");
+
public static final ConfigOption<Boolean> RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED =
ConfigOptions.key("rss.client.read.overlappingDecompressionEnable")
.booleanType()
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index ad1c4dd..be5e56a 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -109,6 +109,7 @@
import static org.apache.spark.launcher.SparkLauncher.EXECUTOR_CORES;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
import static org.apache.spark.shuffle.RssSparkShuffleUtils.isSparkUIEnabled;
@@ -184,6 +185,11 @@
private AtomicBoolean reassignTriggeredOnBlockSendFailure = new AtomicBoolean(false);
private AtomicBoolean reassignTriggeredOnStageRetry = new AtomicBoolean(false);
+ // cache to shuffle handle info to reduce the RPC cost when getting the reader.
+ // this is only valid when the partition reassign is enabled.
+ protected final boolean readShuffleHandleCacheEnabled;
+ private Map<Integer, ShuffleHandleInfo> readShuffleHandleCache = Maps.newConcurrentMap();
+
private boolean isDriver = false;
public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
@@ -373,6 +379,8 @@
this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
this.rssStageResubmitManager = new RssStageResubmitManager();
this.shuffleIdMappingManager = new ShuffleIdMappingManager();
+
+ this.readShuffleHandleCacheEnabled = rssConf.get(RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED);
}
@VisibleForTesting
@@ -424,6 +432,7 @@
this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
this.rssStageResubmitManager = new RssStageResubmitManager();
this.shuffleIdMappingManager = new ShuffleIdMappingManager();
+ this.readShuffleHandleCacheEnabled = rssConf.get(RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED);
}
public BlockIdManager getBlockIdManager() {
@@ -444,6 +453,9 @@
if (blockIdManager != null) {
blockIdManager.remove(shuffleId);
}
+ if (readShuffleHandleCache != null) {
+ readShuffleHandleCache.remove(shuffleId);
+ }
if (SparkEnv.get().executorId().equals("driver")) {
shuffleWriteClient.unregisterShuffle(getAppId(), shuffleId);
shuffleIdToPartitionNum.remove(shuffleId);
@@ -1574,4 +1586,34 @@
}
return new CompletableFuture<>();
}
+
+ // only for tests
+ public void clearShuffleHandleCache() {
+ readShuffleHandleCache.clear();
+ }
+
+ public ShuffleHandleInfo getOrFetchShuffleHandle(
+ int shuffleId, Supplier<ShuffleHandleInfo> func) {
+ ShuffleHandleInfo handle =
+ readShuffleHandleCache.computeIfAbsent(
+ shuffleId,
+ integer -> {
+ long start = System.currentTimeMillis();
+ try {
+ return func.get();
+ } catch (Exception e) {
+ LOG.error("Fail to get the shuffle handle for {}", shuffleId, e);
+ } finally {
+ LOG.info(
+ "Gotten the shuffle handle for shuffle: {} that costs {} ms",
+ shuffleId,
+ System.currentTimeMillis() - start);
+ }
+ return null;
+ });
+ if (handle == null) {
+ throw new RssException("Shuffle handle id " + shuffleId + " not found");
+ }
+ return handle;
+ }
}
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 58349b5..e4b180a 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
import scala.Tuple2;
@@ -365,16 +366,22 @@
final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled && rssStageRetryForWriteFailureEnabled) {
- // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId.
- shuffleHandleInfo =
- getRemoteShuffleHandleInfoWithStageRetry(
- context.stageId(), context.stageAttemptNumber(), shuffleId, false);
- } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
- // In Block Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId.
- shuffleHandleInfo =
- getRemoteShuffleHandleInfoWithBlockRetry(
- context.stageId(), context.stageAttemptNumber(), shuffleId, false);
+
+ if (shuffleManagerRpcServiceEnabled
+ && (rssStageRetryForWriteFailureEnabled || partitionReassignEnabled)) {
+ Supplier<ShuffleHandleInfo> func =
+ rssStageRetryForWriteFailureEnabled
+ ? () ->
+ getRemoteShuffleHandleInfoWithStageRetry(
+ context.stageId(), context.stageAttemptNumber(), shuffleId, false)
+ : () ->
+ getRemoteShuffleHandleInfoWithBlockRetry(
+ context.stageId(), context.stageAttemptNumber(), shuffleId, false);
+ if (readShuffleHandleCacheEnabled) {
+ shuffleHandleInfo = super.getOrFetchShuffleHandle(shuffleId, func);
+ } else {
+ shuffleHandleInfo = func.get();
+ }
} else {
shuffleHandleInfo =
new SimpleShuffleHandleInfo(
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 877bf6e..df6255c 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -17,14 +17,22 @@
package org.apache.spark.shuffle;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Supplier;
+
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
@@ -33,6 +41,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
@@ -42,6 +51,7 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -252,4 +262,42 @@
conf.set("spark.driver.host", "localhost");
return conf;
}
+
+ @Test
+ public void testReadCacheShuffleInfo() {
+ SparkConf conf = new SparkConf();
+ conf.setAppName("testApp")
+ .setMaster("local[2]")
+ .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+ .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+ .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000")
+ .set(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key(), "10")
+ .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+ .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name())
+ .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346");
+ Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
+ RssShuffleManager manager =
+ TestUtils.createShuffleManager(
+ conf, false, null, successBlocks, taskToFailedBlockSendTracker);
+
+ // case1: legal fetch and cache
+ Supplier<ShuffleHandleInfo> func1 =
+ () ->
+ new SimpleShuffleHandleInfo(
+ 1, Collections.emptyMap(), RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+ ShuffleHandleInfo handle1 = manager.getOrFetchShuffleHandle(1, func1);
+ ShuffleHandleInfo handle2 = manager.getOrFetchShuffleHandle(1, func1);
+ assertEquals(handle1, handle2);
+
+ // case2: illegal fetch
+ manager.clearShuffleHandleCache();
+ Supplier<ShuffleHandleInfo> func2 = () -> null;
+ try {
+ ShuffleHandleInfo handle3 = manager.getOrFetchShuffleHandle(1, func2);
+ fail();
+ } catch (Exception e) {
+ // ignore
+ }
+ }
}
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
index c0fbb3a..9adc4b3 100644
--- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/PartitionBlockDataReassignBasicTest.java
@@ -35,6 +35,7 @@
import org.apache.uniffle.server.buffer.ShuffleBufferManager;
import org.apache.uniffle.storage.util.StorageType;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED;
import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_RETRY_MAX;
import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REASSIGN_ENABLED;
@@ -99,6 +100,7 @@
"spark." + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
String.valueOf(grpcShuffleServers.size()));
sparkConf.set("spark." + RSS_CLIENT_REASSIGN_ENABLED.key(), "true");
+ sparkConf.set("spark." + RSS_READ_SHUFFLE_HANDLE_CACHE_ENABLED.key(), "true");
}
@Override