[#1538] feat(spark): report blockIds to spark driver optionally (#1677)
### What changes were proposed in this pull request?
Support report blockIds from shuffle-servers to spark driver optionally
### Why are the changes needed?
Fix: #1538
### Does this PR introduce _any_ user-facing change?
Yes. `rss.client.blockId.selfManagedEnabled` is introduced, default value is false.
### How was this patch tested?
Integration tests.
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index f118c85..54a08a5 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -37,6 +37,13 @@
public class RssSparkConfig {
+ public static final ConfigOption<Boolean> RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED =
+ ConfigOptions.key("rss.blockId.selfManagementEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Whether to enable the blockId self management in spark driver side. Default value is false.");
+
public static final ConfigOption<Long> RSS_CLIENT_SEND_SIZE_LIMITATION =
ConfigOptions.key("rss.client.send.size.limit")
.longType()
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdManager.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdManager.java
new file mode 100644
index 0000000..56c38b5
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdManager.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.RssUtils;
+
+/** The class is to manage the shuffle data blockIds in spark driver side. */
+public class BlockIdManager {
+ private static final Logger LOGGER = LoggerFactory.getLogger(BlockIdManager.class);
+
+ // shuffleId -> partitionId -> blockIds
+ private Map<Integer, Map<Integer, Roaring64NavigableMap>> blockIds;
+
+ public BlockIdManager() {
+ this.blockIds = JavaUtils.newConcurrentMap();
+ }
+
+ public void add(int shuffleId, int partitionId, List<Long> ids) {
+ if (CollectionUtils.isEmpty(ids)) {
+ return;
+ }
+ Map<Integer, Roaring64NavigableMap> partitionedBlockIds =
+ blockIds.computeIfAbsent(shuffleId, (k) -> JavaUtils.newConcurrentMap());
+ partitionedBlockIds.compute(
+ partitionId,
+ (id, bitmap) -> {
+ Roaring64NavigableMap store = bitmap == null ? Roaring64NavigableMap.bitmapOf() : bitmap;
+ ids.stream().forEach(x -> store.add(x));
+ return store;
+ });
+ }
+
+ public Roaring64NavigableMap get(int shuffleId, int partitionId) {
+ Map<Integer, Roaring64NavigableMap> partitionedBlockIds = blockIds.get(shuffleId);
+ if (partitionedBlockIds == null || partitionedBlockIds.isEmpty()) {
+ return Roaring64NavigableMap.bitmapOf();
+ }
+
+ Roaring64NavigableMap idMap = partitionedBlockIds.get(partitionId);
+ if (idMap == null || idMap.isEmpty()) {
+ return Roaring64NavigableMap.bitmapOf();
+ }
+
+ return RssUtils.cloneBitMap(idMap);
+ }
+
+ public boolean remove(int shuffleId) {
+ blockIds.remove(shuffleId);
+ return true;
+ }
+}
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
new file mode 100644
index 0000000..1429bac
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
+import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
+import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.BlockIdLayout;
+
+/**
+ * This class delegates the blockIds reporting/getting operations from shuffleServer side to Spark
+ * driver side.
+ */
+public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl {
+ private ShuffleManagerClient shuffleManagerClient;
+
+ public BlockIdSelfManagedShuffleWriteClient(
+ RssShuffleClientFactory.ExtendWriteClientBuilder builder) {
+ super(builder);
+
+ if (builder.getShuffleManagerClient() == null) {
+ throw new RssException("Illegal empty shuffleManagerClient. This should not happen");
+ }
+ this.shuffleManagerClient = builder.getShuffleManagerClient();
+ }
+
+ @Override
+ public void reportShuffleResult(
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
+ String appId,
+ int shuffleId,
+ long taskAttemptId,
+ int bitmapNum) {
+ Map<Integer, List<Long>> partitionToBlockIds = new HashMap<>();
+ for (Map<Integer, Set<Long>> k : serverToPartitionToBlockIds.values()) {
+ for (Map.Entry<Integer, Set<Long>> entry : k.entrySet()) {
+ int partitionId = entry.getKey();
+ partitionToBlockIds
+ .computeIfAbsent(partitionId, x -> new ArrayList<>())
+ .addAll(entry.getValue());
+ }
+ }
+
+ RssReportShuffleResultRequest request =
+ new RssReportShuffleResultRequest(
+ appId, shuffleId, taskAttemptId, partitionToBlockIds, bitmapNum);
+ shuffleManagerClient.reportShuffleResult(request);
+ }
+
+ @Override
+ public Roaring64NavigableMap getShuffleResult(
+ String clientType,
+ Set<ShuffleServerInfo> shuffleServerInfoSet,
+ String appId,
+ int shuffleId,
+ int partitionId) {
+ RssGetShuffleResultRequest request =
+ new RssGetShuffleResultRequest(appId, shuffleId, partitionId, BlockIdLayout.DEFAULT);
+ return shuffleManagerClient.getShuffleResult(request).getBlockIdBitmap();
+ }
+
+ @Override
+ public Roaring64NavigableMap getShuffleResultForMultiPart(
+ String clientType,
+ Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
+ String appId,
+ int shuffleId,
+ Set<Integer> failedPartitions,
+ PartitionDataReplicaRequirementTracking replicaRequirementTracking) {
+ Set<Integer> partitionIds =
+ serverToPartitions.values().stream().flatMap(x -> x.stream()).collect(Collectors.toSet());
+ RssGetShuffleResultForMultiPartRequest request =
+ new RssGetShuffleResultForMultiPartRequest(
+ appId, shuffleId, partitionIds, BlockIdLayout.DEFAULT);
+ return shuffleManagerClient.getShuffleResultForMultiPart(request).getBlockIdBitmap();
+ }
+}
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
new file mode 100644
index 0000000..c19d913
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+
+public class RssShuffleClientFactory extends ShuffleClientFactory {
+
+ private static final RssShuffleClientFactory INSTANCE = new RssShuffleClientFactory();
+
+ public static RssShuffleClientFactory getInstance() {
+ return INSTANCE;
+ }
+
+ public ShuffleWriteClient createShuffleWriteClient(ExtendWriteClientBuilder builder) {
+ return builder.build();
+ }
+
+ public static ExtendWriteClientBuilder<?> newWriteBuilder() {
+ return new ExtendWriteClientBuilder();
+ }
+
+ public static class ExtendWriteClientBuilder<T extends ExtendWriteClientBuilder<T>>
+ extends WriteClientBuilder<T> {
+ private boolean blockIdSelfManagedEnabled;
+ private ShuffleManagerClient shuffleManagerClient;
+
+ public boolean isBlockIdSelfManagedEnabled() {
+ return blockIdSelfManagedEnabled;
+ }
+
+ public ShuffleManagerClient getShuffleManagerClient() {
+ return shuffleManagerClient;
+ }
+
+ public T shuffleManagerClient(ShuffleManagerClient client) {
+ this.shuffleManagerClient = client;
+ return self();
+ }
+
+ public T blockIdSelfManagedEnabled(boolean blockIdSelfManagedEnabled) {
+ this.blockIdSelfManagedEnabled = blockIdSelfManagedEnabled;
+ return self();
+ }
+
+ @Override
+ public ShuffleWriteClientImpl build() {
+ if (blockIdSelfManagedEnabled) {
+ return new BlockIdSelfManagedShuffleWriteClient(this);
+ }
+ return super.build();
+ }
+ }
+}
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 2392264..6a9baec 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -52,6 +52,7 @@
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.shuffle.BlockIdManager;
import static org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED;
@@ -61,6 +62,28 @@
private AtomicBoolean isInitialized = new AtomicBoolean(false);
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;
+ private volatile BlockIdManager blockIdManager;
+ private Object blockIdManagerLock = new Object();
+
+ public BlockIdManager getBlockIdManager() {
+ if (blockIdManager == null) {
+ synchronized (blockIdManagerLock) {
+ if (blockIdManager == null) {
+ blockIdManager = new BlockIdManager();
+ LOG.info("BlockId manager has been initialized.");
+ }
+ }
+ }
+ return blockIdManager;
+ }
+
+ @Override
+ public boolean unregisterShuffle(int shuffleId) {
+ if (blockIdManager != null) {
+ blockIdManager.remove(shuffleId);
+ }
+ return true;
+ }
/** See static overload of this method. */
public abstract void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf);
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index 4f16917..b213600 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -25,6 +25,7 @@
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.uniffle.common.ReceivingFailureServer;
+import org.apache.uniffle.shuffle.BlockIdManager;
/**
* This is a proxy interface that mainly delegates the un-registration of shuffles to the
@@ -82,4 +83,6 @@
MutableShuffleHandleInfo reassignOnBlockSendFailure(
int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers);
+
+ BlockIdManager getBlockIdManager();
}
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index 11f613f..5aaf23a 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -27,16 +27,20 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc.ShuffleManagerImplBase;
+import org.apache.uniffle.shuffle.BlockIdManager;
public class ShuffleManagerGrpcService extends ShuffleManagerImplBase {
private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class);
@@ -437,4 +441,122 @@
});
}
}
+
+ @Override
+ public void getShuffleResult(
+ RssProtos.GetShuffleResultRequest request,
+ StreamObserver<RssProtos.GetShuffleResultResponse> responseObserver) {
+ String appId = request.getAppId();
+ if (!appId.equals(shuffleManager.getAppId())) {
+ RssProtos.GetShuffleResultResponse reply =
+ RssProtos.GetShuffleResultResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.ACCESS_DENIED)
+ .setRetMsg("Illegal appId: " + appId)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+
+ int shuffleId = request.getShuffleId();
+ int partitionId = request.getPartitionId();
+
+ BlockIdManager blockIdManager = shuffleManager.getBlockIdManager();
+ Roaring64NavigableMap blockIdBitmap = blockIdManager.get(shuffleId, partitionId);
+ RssProtos.GetShuffleResultResponse reply;
+ try {
+ byte[] serializeBitmap = RssUtils.serializeBitMap(blockIdBitmap);
+ reply =
+ RssProtos.GetShuffleResultResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.SUCCESS)
+ .setSerializedBitmap(UnsafeByteOperations.unsafeWrap(serializeBitmap))
+ .build();
+ } catch (Exception exception) {
+ LOG.error("Errors on getting the blockId bitmap.", exception);
+ reply =
+ RssProtos.GetShuffleResultResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.INTERNAL_ERROR)
+ .build();
+ }
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ }
+
+ @Override
+ public void getShuffleResultForMultiPart(
+ RssProtos.GetShuffleResultForMultiPartRequest request,
+ StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> responseObserver) {
+ String appId = request.getAppId();
+ if (!appId.equals(shuffleManager.getAppId())) {
+ RssProtos.GetShuffleResultForMultiPartResponse reply =
+ RssProtos.GetShuffleResultForMultiPartResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.ACCESS_DENIED)
+ .setRetMsg("Illegal appId: " + appId)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+
+ BlockIdManager blockIdManager = shuffleManager.getBlockIdManager();
+ int shuffleId = request.getShuffleId();
+ List<Integer> partitionIds = request.getPartitionsList();
+
+ Roaring64NavigableMap blockIdBitmapCollection = Roaring64NavigableMap.bitmapOf();
+ for (int partitionId : partitionIds) {
+ Roaring64NavigableMap blockIds = blockIdManager.get(shuffleId, partitionId);
+ blockIds.forEach(x -> blockIdBitmapCollection.add(x));
+ }
+
+ RssProtos.GetShuffleResultForMultiPartResponse reply;
+ try {
+ byte[] serializeBitmap = RssUtils.serializeBitMap(blockIdBitmapCollection);
+ reply =
+ RssProtos.GetShuffleResultForMultiPartResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.SUCCESS)
+ .setSerializedBitmap(UnsafeByteOperations.unsafeWrap(serializeBitmap))
+ .build();
+ } catch (Exception exception) {
+ LOG.error("Errors on getting the blockId bitmap.", exception);
+ reply =
+ RssProtos.GetShuffleResultForMultiPartResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.INTERNAL_ERROR)
+ .build();
+ }
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ }
+
+ @Override
+ public void reportShuffleResult(
+ RssProtos.ReportShuffleResultRequest request,
+ StreamObserver<RssProtos.ReportShuffleResultResponse> responseObserver) {
+ String appId = request.getAppId();
+ if (!appId.equals(shuffleManager.getAppId())) {
+ RssProtos.ReportShuffleResultResponse reply =
+ RssProtos.ReportShuffleResultResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.ACCESS_DENIED)
+ .setRetMsg("Illegal appId: " + appId)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ return;
+ }
+
+ BlockIdManager blockIdManager = shuffleManager.getBlockIdManager();
+ int shuffleId = request.getShuffleId();
+
+ for (RssProtos.PartitionToBlockIds partitionToBlockIds : request.getPartitionToBlockIdsList()) {
+ int partitionId = partitionToBlockIds.getPartitionId();
+ List<Long> blockIds = partitionToBlockIds.getBlockIdsList();
+ blockIdManager.add(shuffleId, partitionId, blockIds);
+ }
+
+ RssProtos.ReportShuffleResultResponse reply =
+ RssProtos.ReportShuffleResultResponse.newBuilder()
+ .setStatus(RssProtos.StatusCode.SUCCESS)
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ }
}
diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index e7acaaf..66bb26d 100644
--- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -26,6 +26,7 @@
import org.apache.spark.shuffle.handle.ShuffleHandleInfoBase;
import org.apache.uniffle.common.ReceivingFailureServer;
+import org.apache.uniffle.shuffle.BlockIdManager;
import static org.mockito.Mockito.mock;
@@ -76,4 +77,9 @@
int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) {
return mock(MutableShuffleHandleInfo.class);
}
+
+ @Override
+ public BlockIdManager getBlockIdManager() {
+ return null;
+ }
}
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e7bc631..ba2d275 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -56,7 +56,6 @@
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
@@ -80,10 +79,12 @@
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
+import org.apache.uniffle.shuffle.RssShuffleClientFactory;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
@@ -142,6 +143,8 @@
*/
private Map<String, Boolean> serverAssignedInfos = JavaUtils.newConcurrentMap();
+ private boolean blockIdSelfManagedEnabled;
+
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
throw new IllegalArgumentException(
@@ -198,24 +201,6 @@
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
- this.shuffleWriteClient =
- ShuffleClientFactory.getInstance()
- .createShuffleWriteClient(
- ShuffleClientFactory.newWriteBuilder()
- .clientType(clientType)
- .retryMax(retryMax)
- .retryIntervalMax(retryIntervalMax)
- .heartBeatThreadNum(heartBeatThreadNum)
- .replica(dataReplica)
- .replicaWrite(dataReplicaWrite)
- .replicaRead(dataReplicaRead)
- .replicaSkipEnabled(dataReplicaSkipEnabled)
- .dataTransferPoolSize(dataTransferPoolSize)
- .dataCommitPoolSize(dataCommitPoolSize)
- .unregisterThreadPoolSize(unregisterThreadPoolSize)
- .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
- .rssConf(rssConf));
- registerCoordinator();
// External shuffle service is not supported when using remote shuffle service
sparkConf.set("spark.shuffle.service.enabled", "false");
LOG.info("Disable external shuffle service in RssShuffleManager.");
@@ -228,7 +213,9 @@
&& RssSparkShuffleUtils.isStageResubmitSupported();
this.taskBlockSendFailureRetry =
rssConf.getBoolean(RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED);
- this.shuffleManagerRpcServiceEnabled = taskBlockSendFailureRetry || rssResubmitStage;
+ this.blockIdSelfManagedEnabled = rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
+ this.shuffleManagerRpcServiceEnabled =
+ taskBlockSendFailureRetry || rssResubmitStage || blockIdSelfManagedEnabled;
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
if (isDriver) {
heartBeatScheduledExecutorService =
@@ -252,6 +239,31 @@
}
}
}
+
+ if (shuffleManagerRpcServiceEnabled) {
+ this.shuffleManagerClient = getOrCreateShuffleManagerClient();
+ }
+ this.shuffleWriteClient =
+ RssShuffleClientFactory.getInstance()
+ .createShuffleWriteClient(
+ RssShuffleClientFactory.newWriteBuilder()
+ .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
+ .shuffleManagerClient(shuffleManagerClient)
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(rssConf));
+ registerCoordinator();
+
// for non-driver executor, start a thread for sending shuffle data to shuffle server
LOG.info("RSS data pusher is starting...");
int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
@@ -618,6 +630,7 @@
@Override
public boolean unregisterShuffle(int shuffleId) {
try {
+ super.unregisterShuffle(shuffleId);
if (SparkEnv.get().executorId().equals("driver")) {
shuffleWriteClient.unregisterShuffle(appId, shuffleId);
shuffleIdToNumMapTasks.remove(shuffleId);
@@ -810,6 +823,18 @@
.createShuffleManagerClient(ClientType.GRPC, host, port);
}
+ private ShuffleManagerClient getOrCreateShuffleManagerClient() {
+ if (shuffleManagerClient == null) {
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ String driver = rssConf.getString("driver.host", "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ this.shuffleManagerClient =
+ ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, driver, port);
+ }
+ return shuffleManagerClient;
+ }
+
/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
@@ -817,18 +842,12 @@
* @return ShuffleHandleInfo
*/
private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
- MutableShuffleHandleInfo shuffleHandleInfo;
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- if (shuffleManagerClient == null) {
- shuffleManagerClient = createShuffleManagerClient(driver, port);
- }
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssPartitionToShuffleServerResponse handleInfoResponse =
- shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
- shuffleHandleInfo =
+ getOrCreateShuffleManagerClient()
+ .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
+ MutableShuffleHandleInfo shuffleHandleInfo =
MutableShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto());
return shuffleHandleInfo;
}
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e629b23..983a2a0 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -69,7 +69,6 @@
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
@@ -94,10 +93,12 @@
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
+import org.apache.uniffle.shuffle.RssShuffleClientFactory;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
@@ -157,6 +158,7 @@
private final int partitionReassignMaxServerNum;
private final ShuffleHandleInfoManager shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+ private boolean blockIdSelfManagedEnabled;
public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;
@@ -209,32 +211,8 @@
// configureBlockIdLayout requires maxFailures and speculation to be initialized
configureBlockIdLayout(sparkConf, rssConf);
this.blockIdLayout = BlockIdLayout.from(rssConf);
- long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
- int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
- int unregisterThreadPoolSize =
- sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
- int unregisterRequestTimeoutSec =
- sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
- shuffleWriteClient =
- ShuffleClientFactory.getInstance()
- .createShuffleWriteClient(
- ShuffleClientFactory.newWriteBuilder()
- .clientType(clientType)
- .retryMax(retryMax)
- .retryIntervalMax(retryIntervalMax)
- .heartBeatThreadNum(heartBeatThreadNum)
- .replica(dataReplica)
- .replicaWrite(dataReplicaWrite)
- .replicaRead(dataReplicaRead)
- .replicaSkipEnabled(dataReplicaSkipEnabled)
- .dataTransferPoolSize(dataTransferPoolSize)
- .dataCommitPoolSize(dataCommitPoolSize)
- .unregisterThreadPoolSize(unregisterThreadPoolSize)
- .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
- .rssConf(rssConf));
- registerCoordinator();
// External shuffle service is not supported when using remote shuffle service
sparkConf.set("spark.shuffle.service.enabled", "false");
sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
@@ -261,7 +239,9 @@
}
}
- this.shuffleManagerRpcServiceEnabled = taskBlockSendFailureRetryEnabled || rssResubmitStage;
+ this.blockIdSelfManagedEnabled = rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
+ this.shuffleManagerRpcServiceEnabled =
+ taskBlockSendFailureRetryEnabled || rssResubmitStage || blockIdSelfManagedEnabled;
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
@@ -284,6 +264,36 @@
}
}
}
+ if (shuffleManagerRpcServiceEnabled) {
+ this.shuffleManagerClient = getOrCreateShuffleManagerClient();
+ }
+ int unregisterThreadPoolSize =
+ sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+ int unregisterRequestTimeoutSec =
+ sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+ long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
+ int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
+ shuffleWriteClient =
+ RssShuffleClientFactory.getInstance()
+ .createShuffleWriteClient(
+ RssShuffleClientFactory.newWriteBuilder()
+ .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
+ .shuffleManagerClient(shuffleManagerClient)
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(rssConf));
+ registerCoordinator();
+
LOG.info("Rss data pusher is starting...");
int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
@@ -366,9 +376,10 @@
int unregisterRequestTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
shuffleWriteClient =
- ShuffleClientFactory.getInstance()
+ RssShuffleClientFactory.getInstance()
.createShuffleWriteClient(
- ShuffleClientFactory.newWriteBuilder()
+ RssShuffleClientFactory.getInstance()
+ .newWriteBuilder()
.clientType(clientType)
.retryMax(retryMax)
.retryIntervalMax(retryIntervalMax)
@@ -872,6 +883,7 @@
@Override
public boolean unregisterShuffle(int shuffleId) {
try {
+ super.unregisterShuffle(shuffleId);
if (SparkEnv.get().executorId().equals("driver")) {
shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);
shuffleIdToPartitionNum.remove(shuffleId);
@@ -1127,11 +1139,18 @@
return shuffleHandleInfoManager.get(shuffleId);
}
- private ShuffleManagerClient createShuffleManagerClient(String host, int port) {
- // Host can be inferred from `spark.driver.bindAddress`, which would be set when SparkContext is
- // constructed.
- return ShuffleManagerClientFactory.getInstance()
- .createShuffleManagerClient(ClientType.GRPC, host, port);
+ // todo: automatic close client when the client is idle to avoid too much connections for spark
+ // driver.
+ private ShuffleManagerClient getOrCreateShuffleManagerClient() {
+ if (shuffleManagerClient == null) {
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ String driver = rssConf.getString("driver.host", "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ this.shuffleManagerClient =
+ ShuffleManagerClientFactory.getInstance()
+ .createShuffleManagerClient(ClientType.GRPC, driver, port);
+ }
+ return shuffleManagerClient;
}
/**
@@ -1141,18 +1160,12 @@
* @return ShuffleHandleInfo
*/
private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) {
- MutableShuffleHandleInfo shuffleHandleInfo;
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- String driver = rssConf.getString("driver.host", "");
- int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
- if (shuffleManagerClient == null) {
- shuffleManagerClient = createShuffleManagerClient(driver, port);
- }
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer =
- shuffleManagerClient.getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
- shuffleHandleInfo =
+ getOrCreateShuffleManagerClient()
+ .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest);
+ MutableShuffleHandleInfo shuffleHandleInfo =
MutableShuffleHandleInfo.fromProto(
rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
return shuffleHandleInfo;
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
index 2157cda..66b2c9a 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/RssShuffleManagerTest.java
@@ -100,6 +100,7 @@
SparkConf conf = new SparkConf();
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002");
+ conf.set("spark.driver.host", "localhost");
conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name());
conf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE, true);
// enable stage recompute
diff --git a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index 8efdd44..0eed01e 100644
--- a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++ b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -36,8 +36,6 @@
private static final ShuffleClientFactory INSTANCE = new ShuffleClientFactory();
- private ShuffleClientFactory() {}
-
public static ShuffleClientFactory getInstance() {
return INSTANCE;
}
@@ -53,9 +51,7 @@
return builder.build();
}
- public static class WriteClientBuilder {
- private WriteClientBuilder() {}
-
+ public static class WriteClientBuilder<T extends WriteClientBuilder> {
private String clientType;
private int retryMax;
private long retryIntervalMax;
@@ -122,69 +118,73 @@
return rssConf;
}
- public WriteClientBuilder clientType(String clientType) {
+ protected T self() {
+ return (T) this;
+ }
+
+ public T clientType(String clientType) {
this.clientType = clientType;
- return this;
+ return self();
}
- public WriteClientBuilder retryMax(int retryMax) {
+ public T retryMax(int retryMax) {
this.retryMax = retryMax;
- return this;
+ return self();
}
- public WriteClientBuilder retryIntervalMax(long retryIntervalMax) {
+ public T retryIntervalMax(long retryIntervalMax) {
this.retryIntervalMax = retryIntervalMax;
- return this;
+ return self();
}
- public WriteClientBuilder heartBeatThreadNum(int heartBeatThreadNum) {
+ public T heartBeatThreadNum(int heartBeatThreadNum) {
this.heartBeatThreadNum = heartBeatThreadNum;
- return this;
+ return self();
}
- public WriteClientBuilder replica(int replica) {
+ public T replica(int replica) {
this.replica = replica;
- return this;
+ return self();
}
- public WriteClientBuilder replicaWrite(int replicaWrite) {
+ public T replicaWrite(int replicaWrite) {
this.replicaWrite = replicaWrite;
- return this;
+ return self();
}
- public WriteClientBuilder replicaRead(int replicaRead) {
+ public T replicaRead(int replicaRead) {
this.replicaRead = replicaRead;
- return this;
+ return self();
}
- public WriteClientBuilder replicaSkipEnabled(boolean replicaSkipEnabled) {
+ public T replicaSkipEnabled(boolean replicaSkipEnabled) {
this.replicaSkipEnabled = replicaSkipEnabled;
- return this;
+ return self();
}
- public WriteClientBuilder dataTransferPoolSize(int dataTransferPoolSize) {
+ public T dataTransferPoolSize(int dataTransferPoolSize) {
this.dataTransferPoolSize = dataTransferPoolSize;
- return this;
+ return self();
}
- public WriteClientBuilder dataCommitPoolSize(int dataCommitPoolSize) {
+ public T dataCommitPoolSize(int dataCommitPoolSize) {
this.dataCommitPoolSize = dataCommitPoolSize;
- return this;
+ return self();
}
- public WriteClientBuilder unregisterThreadPoolSize(int unregisterThreadPoolSize) {
+ public T unregisterThreadPoolSize(int unregisterThreadPoolSize) {
this.unregisterThreadPoolSize = unregisterThreadPoolSize;
- return this;
+ return self();
}
- public WriteClientBuilder unregisterRequestTimeSec(int unregisterRequestTimeSec) {
+ public T unregisterRequestTimeSec(int unregisterRequestTimeSec) {
this.unregisterRequestTimeSec = unregisterRequestTimeSec;
- return this;
+ return self();
}
- public WriteClientBuilder rssConf(RssConf rssConf) {
+ public T rssConf(RssConf rssConf) {
this.rssConf = rssConf;
- return this;
+ return self();
}
public ShuffleWriteClientImpl build() {
diff --git a/docs/client_guide/spark_client_guide.md b/docs/client_guide/spark_client_guide.md
index bbcfe59..2f2dcbb 100644
--- a/docs/client_guide/spark_client_guide.md
+++ b/docs/client_guide/spark_client_guide.md
@@ -126,6 +126,11 @@
For example: `22` bits is sufficient for `taskAttemptIdBits` with `partitionIdBits=20`, and Spark conf `spark.task.maxFailures=4` and `spark.speculation=false`.
3. Reserve the remaining bits to `sequenceNoBits`: `sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits`.
+### Block id self management (experimental)
+
+Now, the block id could be managed by the spark driver self when specifying the `spark.rss.blockId.selfManagementEnabled=true`.
+And this will reduce shuffle server pressure but significantly increase memory consumption on the Spark driver side.
+
### Adaptive Remote Shuffle Enabling
Currently, this feature only supports Spark.
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
index e1095e2..3a04680 100644
--- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
@@ -61,22 +61,29 @@
start = System.currentTimeMillis();
Map resultWithRssGrpc = runSparkApp(sparkConf, fileName);
final long durationWithRssGrpc = System.currentTimeMillis() - start;
+ verifyTestResult(resultWithoutRss, resultWithRssGrpc);
updateSparkConfWithRssNetty(sparkConf);
start = System.currentTimeMillis();
Map resultWithRssNetty = runSparkApp(sparkConf, fileName);
final long durationWithRssNetty = System.currentTimeMillis() - start;
- verifyTestResult(resultWithoutRss, resultWithRssGrpc);
verifyTestResult(resultWithoutRss, resultWithRssNetty);
+ updateSparkConfWithBlockIdSelfManaged(sparkConf);
+ start = System.currentTimeMillis();
+ Map resultWithBlockIdSelfManaged = runSparkApp(sparkConf, fileName);
+ final long durationWithBlockIdSelfManaged = System.currentTimeMillis() - start;
+ verifyTestResult(resultWithoutRss, resultWithBlockIdSelfManaged);
+
LOG.info(
"Test: durationWithoutRss["
+ durationWithoutRss
+ "], durationWithRssGrpc["
+ durationWithRssGrpc
- + "]"
+ "], durationWithRssNetty["
+ durationWithRssNetty
+ + "], durationWithBlockIdSelfManaged["
+ + durationWithBlockIdSelfManaged
+ "]");
}
@@ -127,6 +134,14 @@
sparkConf.set(RssSparkConfig.RSS_CLIENT_TYPE, ClientType.GRPC_NETTY.name());
}
+ public void updateSparkConfWithBlockIdSelfManaged(SparkConf sparkConf) {
+ sparkConf.set(RssSparkConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
+ sparkConf.set(
+ RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
+ + RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED.key(),
+ "true");
+ }
+
protected void verifyTestResult(Map expected, Map actual) {
assertEquals(expected.size(), actual.size());
for (Object expectedKey : expected.keySet()) {
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java
index 1b46e6e..24b9d70 100644
--- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java
@@ -48,6 +48,7 @@
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -238,8 +239,12 @@
// Validate getShuffleResultForMultiPart is correct before return result
ClientType clientType =
ClientType.valueOf(spark.sparkContext().getConf().get(RssSparkConfig.RSS_CLIENT_TYPE));
- if (ClientType.GRPC == clientType) {
+ boolean blockIdSelfManagedEnabled =
+ RssSparkConfig.toRssConf(spark.sparkContext().getConf())
+ .get(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
+ if (ClientType.GRPC == clientType && !blockIdSelfManagedEnabled) {
// TODO skip validating for GRPC_NETTY, needs to mock ShuffleServerNettyHandler
+ // skip validating when blockId is managed in spark driver side.
validateRequestCount(
spark.sparkContext().applicationId(), expectRequestNum * replicateRead);
}
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
index 462ee8d..cd5510e 100644
--- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
@@ -56,6 +56,7 @@
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -227,8 +228,12 @@
// Validate getShuffleResultForMultiPart is correct before return result
ClientType clientType =
ClientType.valueOf(spark.sparkContext().getConf().get(RssSparkConfig.RSS_CLIENT_TYPE));
- if (ClientType.GRPC == clientType) {
+ boolean blockIdSelfManagedEnabled =
+ RssSparkConfig.toRssConf(spark.sparkContext().getConf())
+ .get(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
+ if (ClientType.GRPC == clientType && !blockIdSelfManagedEnabled) {
// TODO skip validating for GRPC_NETTY, needs to mock ShuffleServerNettyHandler
+ // skip validating when blockId is managed in spark driver side.
validateRequestCount(
spark.sparkContext().applicationId(), expectRequestNum * replicateRead);
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index c74843c..45d570e 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -19,15 +19,20 @@
import java.io.Closeable;
+import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
+import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
+import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
+import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
public interface ShuffleManagerClient extends Closeable {
@@ -50,4 +55,11 @@
RssReassignOnBlockSendFailureResponse reassignOnBlockSendFailure(
RssReassignOnBlockSendFailureRequest request);
+
+ RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest request);
+
+ RssGetShuffleResultResponse getShuffleResultForMultiPart(
+ RssGetShuffleResultForMultiPartRequest request);
+
+ RssReportShuffleResultResponse reportShuffleResult(RssReportShuffleResultRequest request);
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 61e24b5..bebee89 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -23,15 +23,20 @@
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleManagerClient;
+import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
+import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
+import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
+import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
+import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.exception.RssException;
@@ -127,4 +132,26 @@
getBlockingStub().reassignOnBlockSendFailure(protoReq);
return RssReassignOnBlockSendFailureResponse.fromProto(response);
}
+
+ @Override
+ public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest request) {
+ RssProtos.GetShuffleResultResponse response =
+ getBlockingStub().getShuffleResult(request.toProto());
+ return RssGetShuffleResultResponse.fromProto(response);
+ }
+
+ @Override
+ public RssGetShuffleResultResponse getShuffleResultForMultiPart(
+ RssGetShuffleResultForMultiPartRequest request) {
+ RssProtos.GetShuffleResultForMultiPartResponse response =
+ getBlockingStub().getShuffleResultForMultiPart(request.toProto());
+ return RssGetShuffleResultResponse.fromProto(response);
+ }
+
+ @Override
+ public RssReportShuffleResultResponse reportShuffleResult(RssReportShuffleResultRequest request) {
+ RssProtos.ReportShuffleResultResponse response =
+ getBlockingStub().reportShuffleResult(request.toProto());
+ return RssReportShuffleResultResponse.fromProto(response);
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java
index ec8f460..23c0a6a 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java
@@ -20,6 +20,7 @@
import java.util.Set;
import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.proto.RssProtos;
public class RssGetShuffleResultForMultiPartRequest {
private String appId;
@@ -50,4 +51,21 @@
public BlockIdLayout getBlockIdLayout() {
return blockIdLayout;
}
+
+ public RssProtos.GetShuffleResultForMultiPartRequest toProto() {
+ RssGetShuffleResultForMultiPartRequest request = this;
+ RssProtos.GetShuffleResultForMultiPartRequest rpcRequest =
+ RssProtos.GetShuffleResultForMultiPartRequest.newBuilder()
+ .setAppId(request.getAppId())
+ .setShuffleId(request.getShuffleId())
+ .addAllPartitions(request.getPartitions())
+ .setBlockIdLayout(
+ RssProtos.BlockIdLayout.newBuilder()
+ .setSequenceNoBits(request.getBlockIdLayout().sequenceNoBits)
+ .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits)
+ .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits)
+ .build())
+ .build();
+ return rpcRequest;
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultRequest.java
index 0d0796a..c2e4fea 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultRequest.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.client.request;
import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.proto.RssProtos;
public class RssGetShuffleResultRequest {
@@ -49,4 +50,21 @@
public BlockIdLayout getBlockIdLayout() {
return layout;
}
+
+ public RssProtos.GetShuffleResultRequest toProto() {
+ RssGetShuffleResultRequest request = this;
+ RssProtos.GetShuffleResultRequest rpcRequest =
+ RssProtos.GetShuffleResultRequest.newBuilder()
+ .setAppId(request.getAppId())
+ .setShuffleId(request.getShuffleId())
+ .setPartitionId(request.getPartitionId())
+ .setBlockIdLayout(
+ RssProtos.BlockIdLayout.newBuilder()
+ .setSequenceNoBits(request.getBlockIdLayout().sequenceNoBits)
+ .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits)
+ .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits)
+ .build())
+ .build();
+ return rpcRequest;
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
index 76af691..3a4f9fb 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleResultRequest.java
@@ -20,6 +20,10 @@
import java.util.List;
import java.util.Map;
+import com.google.common.collect.Lists;
+
+import org.apache.uniffle.proto.RssProtos;
+
public class RssReportShuffleResultRequest {
private String appId;
@@ -60,4 +64,29 @@
public Map<Integer, List<Long>> getPartitionToBlockIds() {
return partitionToBlockIds;
}
+
+ public RssProtos.ReportShuffleResultRequest toProto() {
+ RssReportShuffleResultRequest request = this;
+ List<RssProtos.PartitionToBlockIds> partitionToBlockIds = Lists.newArrayList();
+ for (Map.Entry<Integer, List<Long>> entry : request.getPartitionToBlockIds().entrySet()) {
+ List<Long> blockIds = entry.getValue();
+ if (blockIds != null && !blockIds.isEmpty()) {
+ partitionToBlockIds.add(
+ RssProtos.PartitionToBlockIds.newBuilder()
+ .setPartitionId(entry.getKey())
+ .addAllBlockIds(entry.getValue())
+ .build());
+ }
+ }
+
+ RssProtos.ReportShuffleResultRequest rpcRequest =
+ RssProtos.ReportShuffleResultRequest.newBuilder()
+ .setAppId(request.getAppId())
+ .setShuffleId(request.getShuffleId())
+ .setTaskAttemptId(request.getTaskAttemptId())
+ .setBitmapNum(request.getBitmapNum())
+ .addAllPartitionToBlockIds(partitionToBlockIds)
+ .build();
+ return rpcRequest;
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
index 4ba8717..aca33aa 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleResultResponse.java
@@ -21,8 +21,10 @@
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.RssUtils;
+import org.apache.uniffle.proto.RssProtos;
public class RssGetShuffleResultResponse extends ClientResponse {
@@ -37,4 +39,26 @@
public Roaring64NavigableMap getBlockIdBitmap() {
return blockIdBitmap;
}
+
+ public static RssGetShuffleResultResponse fromProto(
+ RssProtos.GetShuffleResultResponse rpcResponse) {
+ try {
+ return new RssGetShuffleResultResponse(
+ StatusCode.fromProto(rpcResponse.getStatus()),
+ rpcResponse.getSerializedBitmap().toByteArray());
+ } catch (Exception e) {
+ throw new RssException(e);
+ }
+ }
+
+ public static RssGetShuffleResultResponse fromProto(
+ RssProtos.GetShuffleResultForMultiPartResponse rpcResponse) {
+ try {
+ return new RssGetShuffleResultResponse(
+ StatusCode.fromProto(rpcResponse.getStatus()),
+ rpcResponse.getSerializedBitmap().toByteArray());
+ } catch (Exception e) {
+ throw new RssException(e);
+ }
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReportShuffleResultResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReportShuffleResultResponse.java
index ab87ee0..f70f7d4 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReportShuffleResultResponse.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReportShuffleResultResponse.java
@@ -18,10 +18,16 @@
package org.apache.uniffle.client.response;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.proto.RssProtos;
public class RssReportShuffleResultResponse extends ClientResponse {
public RssReportShuffleResultResponse(StatusCode statusCode) {
super(statusCode);
}
+
+ public static RssReportShuffleResultResponse fromProto(
+ RssProtos.ReportShuffleResultResponse rpcResponse) {
+ return new RssReportShuffleResultResponse(StatusCode.fromProto(rpcResponse.getStatus()));
+ }
}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 97470f4..d8d384f 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -525,6 +525,9 @@
rpc reassignShuffleServers(ReassignServersRequest) returns (ReassignServersReponse);
// Reassign on block send failure that occurs in writer
rpc reassignOnBlockSendFailure(RssReassignOnBlockSendFailureRequest) returns (RssReassignOnBlockSendFailureResponse);
+ rpc reportShuffleResult (ReportShuffleResultRequest) returns (ReportShuffleResultResponse);
+ rpc getShuffleResult (GetShuffleResultRequest) returns (GetShuffleResultResponse);
+ rpc getShuffleResultForMultiPart (GetShuffleResultForMultiPartRequest) returns (GetShuffleResultForMultiPartResponse);
}
message ReportShuffleFetchFailureRequest {