[#1373][part-1] feat(spark): partition write to multi servers leveraging from reassignment mechanism (#1445)
### What changes were proposed in this pull request?
partition write to multi servers leveraging from reassignment mechanism
### Why are the changes needed?
For: https://github.com/apache/incubator-uniffle/issues/1373
### Does this PR introduce _any_ user-facing change?
1、add config `rss.server.dynamic.assign.enabled` for whether to reassign a faulty shuffle server.
2、support reassign a new shuffle server for send failed blocks
3、ShuffleReader read partition in muitl server implement will in next pr
### How was this patch tested?
UTs
---------
Co-authored-by: jam.xu <jam.xu@vipshop.com>
diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index 0da7cd9..c98870e 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -85,7 +85,9 @@
private final long sendCheckInterval;
private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
private final int bitmapSplitNum;
- private final Map<Integer, List<Long>> partitionToBlocks = JavaUtils.newConcurrentMap();
+ // server -> partitionId -> blockIds
+ private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds =
+ Maps.newHashMap();
private final long maxSegmentSize;
private final boolean isMemoryShuffleEnabled;
private final int numMaps;
@@ -250,8 +252,17 @@
buffer.clear();
shuffleBlocks.add(block);
allBlockIds.add(block.getBlockId());
- partitionToBlocks.computeIfAbsent(block.getPartitionId(), key -> Lists.newArrayList());
- partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
+ block
+ .getShuffleServerInfos()
+ .forEach(
+ shuffleServerInfo -> {
+ Map<Integer, Set<Long>> pToBlockIds =
+ serverToPartitionToBlockIds.computeIfAbsent(
+ shuffleServerInfo, k -> Maps.newHashMap());
+ pToBlockIds
+ .computeIfAbsent(block.getPartitionId(), v -> Sets.newHashSet())
+ .add(block.getBlockId());
+ });
}
public SortWriteBuffer<K, V> combineBuffer(SortWriteBuffer<K, V> buffer)
@@ -336,7 +347,7 @@
start = System.currentTimeMillis();
shuffleWriteClient.reportShuffleResult(
- partitionToServers, appId, 0, taskAttemptId, partitionToBlocks, bitmapSplitNum);
+ serverToPartitionToBlockIds, appId, 0, taskAttemptId, bitmapSplitNum);
LOG.info(
"Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms",
taskAttemptId,
diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index fbb2803..f016bfc 100644
--- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -40,6 +40,7 @@
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
@@ -49,6 +50,7 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -69,7 +71,6 @@
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
- RssException rssException;
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
@@ -101,6 +102,7 @@
// case 1
Random random = new Random();
+ partitionToServers.put(1, Lists.newArrayList(mock(ShuffleServerInfo.class)));
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
@@ -108,7 +110,7 @@
random.nextBytes(value);
manager.addRecord(1, new BytesWritable(key), new BytesWritable(value));
}
- rssException = assertThrows(RssException.class, manager::waitSendFinished);
+ RssException rssException = assertThrows(RssException.class, manager::waitSendFinished);
assertTrue(rssException.getMessage().contains("Timeout"));
// case 2
@@ -220,6 +222,7 @@
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
+ partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
assertTrue(manager.getWaitSendBuffers().isEmpty());
}
@@ -271,6 +274,7 @@
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
+ partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
manager.waitSendFinished();
@@ -337,6 +341,7 @@
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
+ partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
manager.waitSendFinished();
@@ -474,7 +479,12 @@
if (mode == 0) {
throw new RssException("send data failed");
} else if (mode == 1) {
- return new SendShuffleDataResult(Sets.newHashSet(2L), Sets.newHashSet(1L));
+ FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
+ ShuffleBlockInfo failedBlock =
+ new ShuffleBlockInfo(1, 1, 3, 1, 1, new byte[1], null, 1, 100, 1);
+ failedBlockSendTracker.add(
+ failedBlock, new ShuffleServerInfo("host", 39998), StatusCode.NO_BUFFER);
+ return new SendShuffleDataResult(Sets.newHashSet(2L), failedBlockSendTracker);
} else {
if (mode == 3) {
try {
@@ -489,7 +499,7 @@
for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
successBlockIds.add(blockInfo.getBlockId());
}
- return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
+ return new SendShuffleDataResult(successBlockIds, new FailedBlockSendTracker());
}
}
@@ -537,17 +547,21 @@
@Override
public void reportShuffleResult(
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
String appId,
int shuffleId,
long taskAttemptId,
- Map<Integer, List<Long>> partitionToBlockIds,
int bitmapNum) {
if (mode == 3) {
- mockedShuffleServer.addFinishedBlockInfos(
- partitionToBlockIds.values().stream()
- .flatMap(it -> it.stream())
- .collect(Collectors.toList()));
+ serverToPartitionToBlockIds
+ .values()
+ .forEach(
+ partitionToBlockIds -> {
+ mockedShuffleServer.addFinishedBlockInfos(
+ partitionToBlockIds.values().stream()
+ .flatMap(it -> it.stream())
+ .collect(Collectors.toList()));
+ });
}
}
diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 6a7d36d..eeae036 100644
--- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -61,6 +61,7 @@
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.PartitionRange;
@@ -73,10 +74,12 @@
import org.apache.uniffle.common.compression.Lz4Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.hadoop.shim.HadoopShimImpl;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
public class FetcherTest {
static JobID jobId = new JobID("a", 0);
@@ -338,6 +341,7 @@
MockShuffleWriteClient client = new MockShuffleWriteClient();
client.setMode(2);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
+ partitionToServers.put(0, Lists.newArrayList(mock(ShuffleServerInfo.class)));
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
@@ -463,7 +467,12 @@
if (mode == 0) {
throw new RssException("send data failed");
} else if (mode == 1) {
- return new SendShuffleDataResult(Sets.newHashSet(2L), Sets.newHashSet(1L));
+ FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
+ ShuffleBlockInfo failedBlock =
+ new ShuffleBlockInfo(1, 1, 3, 1, 1, new byte[1], null, 1, 100, 1);
+ failedBlockSendTracker.add(
+ failedBlock, new ShuffleServerInfo("host", 39998), StatusCode.NO_BUFFER);
+ return new SendShuffleDataResult(Sets.newHashSet(2L), failedBlockSendTracker);
} else {
Set<Long> successBlockIds = Sets.newHashSet();
for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
@@ -476,7 +485,7 @@
block.getData().nioBuffer(), block.getUncompressLength(), uncompressedBuffer, 0);
data.add(uncompressedBuffer.array());
});
- return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
+ return new SendShuffleDataResult(successBlockIds, new FailedBlockSendTracker());
}
}
@@ -517,11 +526,10 @@
@Override
public void reportShuffleResult(
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
String appId,
int shuffleId,
long taskAttemptId,
- Map<Integer, List<Long>> partitionToBlockIds,
int bitmapNum) {}
@Override
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
index c599aee..de999ed 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
@@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Set;
+import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.uniffle.common.RemoteStorageInfo;
@@ -38,6 +39,9 @@
private int shuffleId;
private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
+
+ // partitionId -> replica -> failover servers
+ private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> failoverPartitionServers;
// shuffle servers which is for store shuffle data
private Set<ShuffleServerInfo> shuffleServersForData;
// remoteStorage used for this job
@@ -53,6 +57,7 @@
this.shuffleId = shuffleId;
this.partitionToServers = partitionToServers;
this.shuffleServersForData = Sets.newHashSet();
+ this.failoverPartitionServers = Maps.newConcurrentMap();
for (List<ShuffleServerInfo> ssis : partitionToServers.values()) {
this.shuffleServersForData.addAll(ssis);
}
@@ -63,6 +68,10 @@
return partitionToServers;
}
+ public Map<Integer, Map<Integer, List<ShuffleServerInfo>>> getFailoverPartitionServers() {
+ return failoverPartitionServers;
+ }
+
public Set<ShuffleServerInfo> getShuffleServersForData() {
return shuffleServersForData;
}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index 68ec8fb..30f649f 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -24,7 +24,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
@@ -36,11 +35,10 @@
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
-import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
-import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.ThreadUtils;
/**
@@ -56,9 +54,7 @@
// Must be thread safe
private final Map<String, Set<Long>> taskToSuccessBlockIds;
// Must be thread safe
- private final Map<String, Set<Long>> taskToFailedBlockIds;
- private final Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>>
- taskToFailedBlockIdsAndServer;
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker;
private String rssAppId;
// Must be thread safe
private final Set<String> failedTaskIds;
@@ -66,15 +62,13 @@
public DataPusher(
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
- Map<String, Set<Long>> taskToFailedBlockIds,
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime) {
this.shuffleWriteClient = shuffleWriteClient;
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
- this.taskToFailedBlockIds = taskToFailedBlockIds;
- this.taskToFailedBlockIdsAndServer = taskToFailedBlockIdsAndServer;
+ this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
this.failedTaskIds = failedTaskIds;
this.executorService =
new ThreadPoolExecutor(
@@ -99,9 +93,8 @@
shuffleWriteClient.sendShuffleData(
rssAppId, shuffleBlockInfoList, () -> !isValidTask(taskId));
putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
- putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
- putSendFailedBlockIdAndShuffleServer(
- taskToFailedBlockIdsAndServer, taskId, result.getSendFailedBlockIds());
+ putFailedBlockSendTracker(
+ taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker());
} finally {
List<Runnable> callbackChain =
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
@@ -127,16 +120,16 @@
.addAll(blockIds);
}
- private synchronized void putSendFailedBlockIdAndShuffleServer(
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
+ private synchronized void putFailedBlockSendTracker(
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker,
String taskAttemptId,
- Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdsAndServer) {
- if (blockIdsAndServer == null || blockIdsAndServer.isEmpty()) {
+ FailedBlockSendTracker failedBlockSendTracker) {
+ if (failedBlockSendTracker == null) {
return;
}
- taskToFailedBlockIdsAndServer
- .computeIfAbsent(taskAttemptId, x -> JavaUtils.newConcurrentMap())
- .putAll(blockIdsAndServer);
+ taskToFailedBlockSendTracker
+ .computeIfAbsent(taskAttemptId, x -> new FailedBlockSendTracker())
+ .merge(failedBlockSendTracker);
}
public boolean isValidTask(String taskId) {
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index 73fdbdb..3160cc6 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -17,9 +17,13 @@
package org.apache.uniffle.shuffle.manager;
+import java.util.Set;
+
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
/**
* This is a proxy interface that mainly delegates the un-registration of shuffles to the
* MapOutputTrackerMaster on the driver. It provides a unified interface that hides implementation
@@ -71,5 +75,9 @@
*/
void addFailuresShuffleServerInfos(String shuffleServerId);
- boolean reassignShuffleServers(int stageId, int stageAttemptNumber, int shuffleId, int numMaps);
+ boolean reassignAllShuffleServersForWholeStage(
+ int stageId, int stageAttemptNumber, int shuffleId, int numMaps);
+
+ ShuffleServerInfo reassignFaultyShuffleServerForTasks(
+ int shuffleId, Set<String> partitionIds, String faultyShuffleServerId);
}
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index e95aad6..bcf1303 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -26,6 +26,7 @@
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
+import com.google.common.collect.Sets;
import io.grpc.stub.StreamObserver;
import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.slf4j.Logger;
@@ -239,7 +240,7 @@
int shuffleId = request.getShuffleId();
int numPartitions = request.getNumPartitions();
boolean needReassign =
- shuffleManager.reassignShuffleServers(
+ shuffleManager.reassignAllShuffleServersForWholeStage(
stageId, stageAttemptNumber, shuffleId, numPartitions);
RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
RssProtos.ReassignServersReponse reply =
@@ -251,6 +252,25 @@
responseObserver.onCompleted();
}
+ @Override
+ public void reassignFaultyShuffleServer(
+ RssProtos.RssReassignFaultyShuffleServerRequest request,
+ StreamObserver<RssProtos.RssReassignFaultyShuffleServerResponse> responseObserver) {
+ ShuffleServerInfo shuffleServerInfo =
+ shuffleManager.reassignFaultyShuffleServerForTasks(
+ request.getShuffleId(),
+ Sets.newHashSet(request.getPartitionIdsList()),
+ request.getFaultyShuffleServerId());
+ RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
+ RssProtos.RssReassignFaultyShuffleServerResponse reply =
+ RssProtos.RssReassignFaultyShuffleServerResponse.newBuilder()
+ .setStatus(code)
+ .setServer(ShuffleServerInfo.convertToShuffleServerId(shuffleServerInfo))
+ .build();
+ responseObserver.onNext(reply);
+ responseObserver.onCompleted();
+ }
+
/**
* Remove the no longer used shuffle id's rss shuffle status. This is called when ShuffleManager
* unregisters the corresponding shuffle id.
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index a3cdbb6..2a608bd 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -22,7 +22,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
@@ -32,10 +31,12 @@
import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -81,35 +82,39 @@
FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
- Map<String, Set<Long>> taskToFailedBlockIds = Maps.newConcurrentMap();
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
- JavaUtils.newConcurrentMap();
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
Set<String> failedTaskIds = new HashSet<>();
DataPusher dataPusher =
new DataPusher(
shuffleWriteClient,
taskToSuccessBlockIds,
- taskToFailedBlockIds,
- taskToFailedBlockIdsAndServer,
+ taskToFailedBlockSendTracker,
failedTaskIds,
1,
2);
dataPusher.setRssAppId("testSendData_appId");
-
- // sync send
- AddBlockEvent event =
- new AddBlockEvent(
- "taskId",
- Arrays.asList(new ShuffleBlockInfo(1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1)));
+ FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
+ ShuffleBlockInfo failedBlock1 =
+ new ShuffleBlockInfo(1, 1, 3, 1, 1, new byte[1], null, 1, 100, 1);
+ ShuffleBlockInfo failedBlock2 =
+ new ShuffleBlockInfo(1, 1, 4, 1, 1, new byte[1], null, 1, 100, 1);
+ failedBlockSendTracker.add(
+ failedBlock1, new ShuffleServerInfo("host", 39998), StatusCode.NO_BUFFER);
+ failedBlockSendTracker.add(
+ failedBlock2, new ShuffleServerInfo("host", 39998), StatusCode.NO_BUFFER);
shuffleWriteClient.setFakedShuffleDataResult(
- new SendShuffleDataResult(Sets.newHashSet(1L, 2L), Sets.newHashSet(3L, 4L)));
+ new SendShuffleDataResult(Sets.newHashSet(1L, 2L), failedBlockSendTracker));
+ ShuffleBlockInfo shuffleBlockInfo =
+ new ShuffleBlockInfo(1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1);
+ AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList(shuffleBlockInfo));
+ // sync send
CompletableFuture<Long> future = dataPusher.send(event);
long memoryFree = future.get();
assertEquals(100, memoryFree);
assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L));
assertTrue(taskToSuccessBlockIds.get("taskId").contains(2L));
- assertTrue(taskToFailedBlockIds.get("taskId").contains(3L));
- assertTrue(taskToFailedBlockIds.get("taskId").contains(4L));
+ assertTrue(taskToFailedBlockSendTracker.get("taskId").getFailedBlockIds().contains(3L));
+ assertTrue(taskToFailedBlockSendTracker.get("taskId").getFailedBlockIds().contains(4L));
}
}
diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 0b9e4f8..841e13f 100644
--- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -22,6 +22,10 @@
import org.apache.spark.shuffle.ShuffleHandleInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.mockito.Mockito.mock;
+
public class DummyRssShuffleManager implements RssShuffleManagerInterface {
public Set<Integer> unregisteredShuffleIds = new LinkedHashSet<>();
@@ -59,8 +63,14 @@
public void addFailuresShuffleServerInfos(String shuffleServerId) {}
@Override
- public boolean reassignShuffleServers(
+ public boolean reassignAllShuffleServersForWholeStage(
int stageId, int stageAttemptNumber, int shuffleId, int numMaps) {
return false;
}
+
+ @Override
+ public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
+ int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
+ return mock(ShuffleServerInfo.class);
+ }
}
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 74c8ba0..ed3f340 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -22,9 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@@ -34,6 +32,7 @@
import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
@@ -58,6 +57,7 @@
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.util.ClientUtils;
@@ -95,10 +95,8 @@
private String clientType;
private ShuffleWriteClient shuffleWriteClient;
private Map<String, Set<Long>> taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
- private Map<String, Set<Long>> taskToFailedBlockIds = JavaUtils.newConcurrentMap();
- // Record both the block that failed to be sent and the ShuffleServer
- private final Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>>
- taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap();
+ private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
+ JavaUtils.newConcurrentMap();
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
@@ -134,6 +132,8 @@
*/
private Map<String, Boolean> serverAssignedInfos = JavaUtils.newConcurrentMap();
+ private Map<String, ShuffleServerInfo> reassignedFaultyServers = JavaUtils.newConcurrentMap();
+
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
throw new IllegalArgumentException(
@@ -242,8 +242,7 @@
new DataPusher(
shuffleWriteClient,
taskToSuccessBlockIds,
- taskToFailedBlockIds,
- taskToFailedBlockIdsAndServer,
+ taskToFailedBlockSendTracker,
failedTaskIds,
poolSize,
keepAliveTime);
@@ -637,11 +636,11 @@
}
public Set<Long> getFailedBlockIds(String taskId) {
- Set<Long> result = taskToFailedBlockIds.get(taskId);
- if (result == null) {
- result = Collections.emptySet();
+ FailedBlockSendTracker blockIdsFailedSendTracker = getBlockIdsFailedSendTracker(taskId);
+ if (blockIdsFailedSendTracker == null) {
+ return Collections.emptySet();
}
- return result;
+ return blockIdsFailedSendTracker.getFailedBlockIds();
}
public Set<Long> getSuccessBlockIds(String taskId) {
@@ -653,22 +652,6 @@
}
@VisibleForTesting
- public void addFailedBlockIds(String taskId, Set<Long> blockIds) {
- if (taskToFailedBlockIds.get(taskId) == null) {
- taskToFailedBlockIds.put(taskId, Sets.newHashSet());
- }
- taskToFailedBlockIds.get(taskId).addAll(blockIds);
- }
-
- @VisibleForTesting
- public void addTaskToFailedBlockIdsAndServer(
- String taskId, Long blockId, ShuffleServerInfo shuffleServerInfo) {
- taskToFailedBlockIdsAndServer.putIfAbsent(taskId, Maps.newHashMap());
- taskToFailedBlockIdsAndServer.get(taskId).putIfAbsent(blockId, new LinkedBlockingDeque<>());
- taskToFailedBlockIdsAndServer.get(taskId).get(blockId).add(shuffleServerInfo);
- }
-
- @VisibleForTesting
public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
if (taskToSuccessBlockIds.get(taskId) == null) {
taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
@@ -676,9 +659,15 @@
taskToSuccessBlockIds.get(taskId).addAll(blockIds);
}
+ @VisibleForTesting
+ public void addFailedBlockSendTracker(
+ String taskId, FailedBlockSendTracker failedBlockSendTracker) {
+ taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker);
+ }
+
public void clearTaskMeta(String taskId) {
taskToSuccessBlockIds.remove(taskId);
- taskToFailedBlockIds.remove(taskId);
+ taskToFailedBlockSendTracker.remove(taskId);
}
@VisibleForTesting
@@ -759,19 +748,8 @@
}
}
- /**
- * The ShuffleServer list of block sending failures is returned using the shuffle task ID
- *
- * @param taskId Shuffle taskId
- * @return List of failed ShuffleServer blocks
- */
- public Map<Long, BlockingQueue<ShuffleServerInfo>> getFailedBlockIdsWithShuffleServer(
- String taskId) {
- Map<Long, BlockingQueue<ShuffleServerInfo>> result = taskToFailedBlockIdsAndServer.get(taskId);
- if (result == null) {
- result = Collections.emptyMap();
- }
- return result;
+ public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
+ return taskToFailedBlockSendTracker.get(taskId);
}
@Override
@@ -829,51 +807,24 @@
* @param numPartitions
*/
@Override
- public synchronized boolean reassignShuffleServers(
+ public synchronized boolean reassignAllShuffleServersForWholeStage(
int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
- Boolean needReassgin = serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
- if (!needReassgin) {
- String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
- RemoteStorageInfo defaultRemoteStorage =
- new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""));
- RemoteStorageInfo remoteStorage =
- ClientUtils.fetchRemoteStorage(
- appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);
- Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
- ClientUtils.validateClientType(clientType);
- assignmentTags.add(clientType);
+ Boolean needReassign = serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
+ if (!needReassign) {
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
- long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
- int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
/** Before reassigning ShuffleServer, clear the ShuffleServer list in ShuffleWriteClient. */
shuffleWriteClient.unregisterShuffle(appId, shuffleId);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers;
- try {
- partitionToServers =
- RetryUtils.retry(
- () -> {
- ShuffleAssignmentsInfo response =
- shuffleWriteClient.getShuffleAssignments(
- appId,
- shuffleId,
- numPartitions,
- 1,
- assignmentTags,
- requiredShuffleServerNumber,
- estimateTaskConcurrency,
- failuresShuffleServerIds);
- registerShuffleServers(
- appId, shuffleId, response.getServerToPartitionRanges(), remoteStorage);
- return response.getPartitionToServers();
- },
- retryInterval,
- retryTimes);
- } catch (Throwable throwable) {
- throw new RssException("registerShuffle failed!", throwable);
- }
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(
+ shuffleId,
+ numPartitions,
+ 1,
+ requiredShuffleServerNumber,
+ estimateTaskConcurrency,
+ failuresShuffleServerIds);
/**
* we need to clear the metadata of the completed task, otherwise some of the stage's data
* will be lost
@@ -885,7 +836,7 @@
throw new RssException("Clear MapoutTracker Meta failed!", e);
}
ShuffleHandleInfo handleInfo =
- new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+ new ShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
serverAssignedInfos.put(stageIdAndAttempt, true);
return true;
@@ -898,6 +849,97 @@
}
}
+ @Override
+ public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
+ int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
+ ShuffleServerInfo newShuffleServerInfo =
+ reassignedFaultyServers.computeIfAbsent(
+ faultyShuffleServerId,
+ id -> {
+ ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, id);
+ ShuffleHandleInfo shuffleHandleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
+ for (String partitionId : partitionIds) {
+ List<ShuffleServerInfo> shuffleServerInfoList =
+ shuffleHandleInfo.getPartitionToServers().get(partitionId);
+ for (int i = 0; i < shuffleServerInfoList.size(); i++) {
+ if (shuffleServerInfoList.get(i).getId().equals(faultyShuffleServerId)) {
+ shuffleHandleInfo
+ .getFailoverPartitionServers()
+ .computeIfAbsent(Integer.valueOf(partitionId), k -> Maps.newHashMap());
+ shuffleHandleInfo
+ .getFailoverPartitionServers()
+ .get(partitionId)
+ .computeIfAbsent(i, j -> Lists.newArrayList())
+ .add(newAssignedServer);
+ }
+ }
+ }
+ return newAssignedServer;
+ });
+ return newShuffleServerInfo;
+ }
+
+ public Map<String, ShuffleServerInfo> getReassignedFaultyServers() {
+ return reassignedFaultyServers;
+ }
+
+ private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
+ Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
+ faultyServerIds.addAll(failuresShuffleServerIds);
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
+ if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) {
+ return partitionToServers.get(0).get(0);
+ }
+ return null;
+ }
+
+ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
+ int shuffleId,
+ int partitionNum,
+ int partitionNumPerRange,
+ int assignmentShuffleServerNumber,
+ int estimateTaskConcurrency,
+ Set<String> faultyServerIds) {
+ Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ ClientUtils.validateClientType(clientType);
+ assignmentTags.add(clientType);
+
+ long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+ int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+ faultyServerIds.addAll(failuresShuffleServerIds);
+ try {
+ return RetryUtils.retry(
+ () -> {
+ ShuffleAssignmentsInfo response =
+ shuffleWriteClient.getShuffleAssignments(
+ appId,
+ shuffleId,
+ partitionNum,
+ partitionNumPerRange,
+ assignmentTags,
+ assignmentShuffleServerNumber,
+ estimateTaskConcurrency,
+ faultyServerIds);
+ registerShuffleServers(
+ appId, shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
+ return response.getPartitionToServers();
+ },
+ retryInterval,
+ retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
+ }
+
+ private RemoteStorageInfo getRemoteStorageInfo() {
+ String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
+ RemoteStorageInfo defaultRemoteStorage =
+ new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""));
+ return ClientUtils.fetchRemoteStorage(
+ appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);
+ }
+
public boolean isRssResubmitStage() {
return rssResubmitStage;
}
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index b634839..b5428bd 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -20,12 +20,11 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collection;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -65,6 +64,7 @@
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssReassignServersReponse;
@@ -87,7 +87,8 @@
private static final int DUMMY_PORT = 99999;
// they will be used in commit phase
private final Set<ShuffleServerInfo> shuffleServersForData;
- private final Map<Integer, Set<Long>> partitionToBlockIds;
+ // server -> partitionId -> blockIds
+ private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds;
private final ShuffleWriteClient shuffleWriteClient;
private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
private String appId;
@@ -165,7 +166,7 @@
this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
- this.partitionToBlockIds = Maps.newHashMap();
+ this.serverToPartitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
this.partitionToServers = shuffleHandleInfo.getPartitionToServers();
@@ -309,9 +310,16 @@
blockIds.add(blockId);
// update [partition, blockIds], it will be sent to shuffle server
int partitionId = sbi.getPartitionId();
- partitionToBlockIds
- .computeIfAbsent(partitionId, k -> Sets.newHashSet())
- .add(blockId);
+ sbi.getShuffleServerInfos()
+ .forEach(
+ shuffleServerInfo -> {
+ Map<Integer, Set<Long>> pToBlockIds =
+ serverToPartitionToBlockIds.computeIfAbsent(
+ shuffleServerInfo, k -> Maps.newHashMap());
+ pToBlockIds
+ .computeIfAbsent(partitionId, v -> Sets.newHashSet())
+ .add(blockId);
+ });
});
return postBlockEvent(shuffleBlockInfoList);
}
@@ -366,9 +374,7 @@
protected void checkBlockSendResult(Set<Long> blockIds) {
long start = System.currentTimeMillis();
while (true) {
- Map<Long, BlockingQueue<ShuffleServerInfo>> failedBlockIdsWithShuffleServer =
- shuffleManager.getFailedBlockIdsWithShuffleServer(taskId);
- Set<Long> failedBlockIds = failedBlockIdsWithShuffleServer.keySet();
+ Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
// if failed when send data to shuffle server, mark task as failed
if (failedBlockIds.size() > 0) {
@@ -378,9 +384,7 @@
+ "] failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server: "
- + failedBlockIdsWithShuffleServer.values().stream()
- .flatMap(Collection::stream)
- .collect(Collectors.toSet());
+ + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers();
LOG.error(errorMsg);
throw new RssSendFailedException(errorMsg);
}
@@ -417,14 +421,9 @@
Arrays.fill(partitionLengths, 1);
final BlockManagerId blockManagerId =
createDummyBlockManagerId(appId + "_" + taskId, taskAttemptId);
-
- Map<Integer, List<Long>> ptb = Maps.newHashMap();
- for (Map.Entry<Integer, Set<Long>> entry : partitionToBlockIds.entrySet()) {
- ptb.put(entry.getKey(), Lists.newArrayList(entry.getValue()));
- }
long start = System.currentTimeMillis();
shuffleWriteClient.reportShuffleResult(
- partitionToServers, appId, shuffleId, taskAttemptId, ptb, bitmapSplitNum);
+ serverToPartitionToBlockIds, appId, shuffleId, taskAttemptId, bitmapSplitNum);
LOG.info(
"Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms",
taskAttemptId,
@@ -457,7 +456,17 @@
@VisibleForTesting
protected Map<Integer, Set<Long>> getPartitionToBlockIds() {
- return partitionToBlockIds;
+ return serverToPartitionToBlockIds.values().stream()
+ .flatMap(s -> s.entrySet().stream())
+ .collect(
+ Collectors.toMap(
+ Map.Entry::getKey,
+ Map.Entry::getValue,
+ (existingSet, newSet) -> {
+ Set<Long> mergedSet = new HashSet<>(existingSet);
+ mergedSet.addAll(newSet);
+ return mergedSet;
+ }));
}
@VisibleForTesting
@@ -476,13 +485,10 @@
private void throwFetchFailedIfNecessary(Exception e) {
// The shuffleServer is registered only when a Block fails to be sent
if (e instanceof RssSendFailedException) {
- Map<Long, BlockingQueue<ShuffleServerInfo>> failedBlockIds =
- shuffleManager.getFailedBlockIdsWithShuffleServer(taskId);
- List<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList();
- for (Map.Entry<Long, BlockingQueue<ShuffleServerInfo>> longListEntry :
- failedBlockIds.entrySet()) {
- shuffleServerInfos.addAll(longListEntry.getValue());
- }
+ FailedBlockSendTracker blockIdsFailedSendTracker =
+ shuffleManager.getBlockIdsFailedSendTracker(taskId);
+ List<ShuffleServerInfo> shuffleServerInfos =
+ Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers());
RssReportShuffleWriteFailureRequest req =
new RssReportShuffleWriteFailureRequest(
appId,
diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 8b150f9..8711c48 100644
--- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -21,7 +21,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -50,9 +49,11 @@
import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -145,9 +146,12 @@
// case 3: partial blocks are sent failed, Runtime exception will be thrown
manager.addSuccessBlockIds(taskId, Sets.newHashSet(1L, 2L));
- manager.addFailedBlockIds(taskId, Sets.newHashSet(3L));
- ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo("127.0.0.1", 20001);
- manager.addTaskToFailedBlockIdsAndServer(taskId, 3L, shuffleServerInfo);
+ FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
+ ShuffleBlockInfo failedBlock1 =
+ new ShuffleBlockInfo(1, 1, 3, 1, 1, new byte[1], null, 1, 100, 1);
+ failedBlockSendTracker.add(
+ failedBlock1, new ShuffleServerInfo("127.0.0.1", 20001), StatusCode.INTERNAL_ERROR);
+ manager.addFailedBlockSendTracker(taskId, failedBlockSendTracker);
Throwable e3 =
assertThrows(
RuntimeException.class,
@@ -171,7 +175,7 @@
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime,
@@ -179,8 +183,7 @@
super(
shuffleWriteClient,
taskToSuccessBlockIds,
- taskToFailedBlockIds,
- taskToFailedBlockIdsAndServer,
+ taskToFailedBlockSendTracker,
failedTaskIds,
threadPoolSize,
threadKeepAliveTime);
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index b3aa469..b7710e7 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -23,7 +23,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@@ -36,6 +35,8 @@
import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.apache.hadoop.conf.Configuration;
@@ -63,6 +64,7 @@
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
import org.apache.uniffle.client.util.ClientUtils;
@@ -103,10 +105,7 @@
private final int dataTransferPoolSize;
private final int dataCommitPoolSize;
private final Map<String, Set<Long>> taskToSuccessBlockIds;
- private final Map<String, Set<Long>> taskToFailedBlockIds;
- // Record both the block that failed to be sent and the ShuffleServer
- private final Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>>
- taskToFailedBlockIdsAndServer;
+ private final Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker;
private ScheduledExecutorService heartBeatScheduledExecutorService;
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
@@ -145,6 +144,8 @@
*/
private Map<String, Boolean> serverAssignedInfos;
+ private Map<String, ShuffleServerInfo> reassignedFaultyServers;
+
public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;
boolean supportsRelocation =
@@ -234,8 +235,7 @@
sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
LOG.info("Disable shuffle data locality in RssShuffleManager.");
taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
- taskToFailedBlockIds = JavaUtils.newConcurrentMap();
- this.taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap();
+ taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
this.rssResubmitStage =
rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
&& RssSparkShuffleUtils.isStageResubmitSupported();
@@ -268,14 +268,14 @@
new DataPusher(
shuffleWriteClient,
taskToSuccessBlockIds,
- taskToFailedBlockIds,
- taskToFailedBlockIdsAndServer,
+ taskToFailedBlockSendTracker,
failedTaskIds,
poolSize,
keepAliveTime);
this.shuffleIdToShuffleHandleInfo = JavaUtils.newConcurrentMap();
this.failuresShuffleServerIds = Sets.newHashSet();
this.serverAssignedInfos = JavaUtils.newConcurrentMap();
+ this.reassignedFaultyServers = JavaUtils.newConcurrentMap();
}
public CompletableFuture<Long> sendData(AddBlockEvent event) {
@@ -303,8 +303,7 @@
boolean isDriver,
DataPusher dataPusher,
Map<String, Set<Long>> taskToSuccessBlockIds,
- Map<String, Set<Long>> taskToFailedBlockIds,
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer) {
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
this.sparkConf = conf;
this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
this.dataDistributionType =
@@ -359,9 +358,8 @@
.unregisterRequestTimeSec(unregisterRequestTimeoutSec)
.rssConf(RssSparkConfig.toRssConf(sparkConf)));
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
- this.taskToFailedBlockIds = taskToFailedBlockIds;
- this.taskToFailedBlockIdsAndServer = taskToFailedBlockIdsAndServer;
this.heartBeatScheduledExecutorService = null;
+ this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
this.dataPusher = dataPusher;
}
@@ -930,7 +928,7 @@
public void clearTaskMeta(String taskId) {
taskToSuccessBlockIds.remove(taskId);
- taskToFailedBlockIds.remove(taskId);
+ taskToFailedBlockSendTracker.remove(taskId);
}
@VisibleForTesting
@@ -999,11 +997,11 @@
}
public Set<Long> getFailedBlockIds(String taskId) {
- Set<Long> result = taskToFailedBlockIds.get(taskId);
- if (result == null) {
- result = Collections.emptySet();
+ FailedBlockSendTracker blockIdsFailedSendTracker = getBlockIdsFailedSendTracker(taskId);
+ if (blockIdsFailedSendTracker == null) {
+ return Collections.emptySet();
}
- return result;
+ return blockIdsFailedSendTracker.getFailedBlockIds();
}
public Set<Long> getSuccessBlockIds(String taskId) {
@@ -1121,19 +1119,8 @@
}
}
- /**
- * The ShuffleServer list of block sending failures is returned using the shuffle task ID
- *
- * @param taskId Shuffle taskId
- * @return failed ShuffleServer blocks
- */
- public Map<Long, BlockingQueue<ShuffleServerInfo>> getFailedBlockIdsWithShuffleServer(
- String taskId) {
- Map<Long, BlockingQueue<ShuffleServerInfo>> result = taskToFailedBlockIdsAndServer.get(taskId);
- if (result == null) {
- result = Collections.emptyMap();
- }
- return result;
+ public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
+ return taskToFailedBlockSendTracker.get(taskId);
}
@Override
@@ -1191,50 +1178,24 @@
* @param numPartitions
*/
@Override
- public synchronized boolean reassignShuffleServers(
+ public synchronized boolean reassignAllShuffleServersForWholeStage(
int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
- Boolean needReassgin = serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
- if (!needReassgin) {
- String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
- RemoteStorageInfo defaultRemoteStorage =
- new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""));
- RemoteStorageInfo remoteStorage =
- ClientUtils.fetchRemoteStorage(
- id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);
- Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ Boolean needReassign = serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false);
+ if (!needReassign) {
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
- long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
- int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
/** Before reassigning ShuffleServer, clear the ShuffleServer list in ShuffleWriteClient. */
shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers;
- try {
- partitionToServers =
- RetryUtils.retry(
- () -> {
- ShuffleAssignmentsInfo response =
- shuffleWriteClient.getShuffleAssignments(
- id.get(),
- shuffleId,
- numPartitions,
- 1,
- assignmentTags,
- requiredShuffleServerNumber,
- estimateTaskConcurrency,
- failuresShuffleServerIds);
- registerShuffleServers(
- id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage);
- return response.getPartitionToServers();
- },
- retryInterval,
- retryTimes);
-
- } catch (Throwable throwable) {
- throw new RssException("registerShuffle failed!", throwable);
- }
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(
+ shuffleId,
+ numPartitions,
+ 1,
+ requiredShuffleServerNumber,
+ estimateTaskConcurrency,
+ failuresShuffleServerIds);
/**
* we need to clear the metadata of the completed task, otherwise some of the stage's data
* will be lost
@@ -1246,7 +1207,7 @@
throw new RssException("Clear MapoutTracker Meta failed!", e);
}
ShuffleHandleInfo handleInfo =
- new ShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage);
+ new ShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo());
shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo);
serverAssignedInfos.put(stageIdAndAttempt, true);
return true;
@@ -1259,6 +1220,94 @@
}
}
+ @Override
+ public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
+ int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
+ ShuffleServerInfo newShuffleServerInfo =
+ reassignedFaultyServers.computeIfAbsent(
+ faultyShuffleServerId,
+ id -> {
+ ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, id);
+ ShuffleHandleInfo shuffleHandleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
+ for (String partitionId : partitionIds) {
+ List<ShuffleServerInfo> shuffleServerInfoList =
+ shuffleHandleInfo.getPartitionToServers().get(partitionId);
+ for (int i = 0; i < shuffleServerInfoList.size(); i++) {
+ if (shuffleServerInfoList.get(i).getId().equals(faultyShuffleServerId)) {
+ shuffleHandleInfo
+ .getFailoverPartitionServers()
+ .computeIfAbsent(Integer.valueOf(partitionId), k -> Maps.newHashMap());
+ shuffleHandleInfo
+ .getFailoverPartitionServers()
+ .get(partitionId)
+ .computeIfAbsent(i, j -> Lists.newArrayList())
+ .add(newAssignedServer);
+ }
+ }
+ }
+ return newAssignedServer;
+ });
+ return newShuffleServerInfo;
+ }
+
+ public Map<String, ShuffleServerInfo> getReassignedFaultyServers() {
+ return reassignedFaultyServers;
+ }
+
+ private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
+ Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
+ faultyServerIds.addAll(failuresShuffleServerIds);
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
+ requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
+ if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) {
+ return partitionToServers.get(0).get(0);
+ }
+ return null;
+ }
+
+ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
+ int shuffleId,
+ int partitionNum,
+ int partitionNumPerRange,
+ int assignmentShuffleServerNumber,
+ int estimateTaskConcurrency,
+ Set<String> faultyServerIds) {
+ Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+ long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
+ int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+ faultyServerIds.addAll(failuresShuffleServerIds);
+ try {
+ return RetryUtils.retry(
+ () -> {
+ ShuffleAssignmentsInfo response =
+ shuffleWriteClient.getShuffleAssignments(
+ id.get(),
+ shuffleId,
+ partitionNum,
+ partitionNumPerRange,
+ assignmentTags,
+ assignmentShuffleServerNumber,
+ estimateTaskConcurrency,
+ faultyServerIds);
+ registerShuffleServers(
+ id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
+ return response.getPartitionToServers();
+ },
+ retryInterval,
+ retryTimes);
+ } catch (Throwable throwable) {
+ throw new RssException("registerShuffle failed!", throwable);
+ }
+ }
+
+ private RemoteStorageInfo getRemoteStorageInfo() {
+ String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
+ RemoteStorageInfo defaultRemoteStorage =
+ new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""));
+ return ClientUtils.fetchRemoteStorage(
+ id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);
+ }
+
public boolean isRssResubmitStage() {
return rssResubmitStage;
}
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 612f4a2..395f082 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -20,8 +20,8 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collection;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -65,8 +65,12 @@
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.impl.TrackingBlockStatus;
+import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
+import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.common.ClientType;
@@ -77,6 +81,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.util.StorageType;
public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
@@ -97,9 +102,9 @@
private final long sendCheckTimeout;
private final long sendCheckInterval;
private final int bitmapSplitNum;
- private final Map<Integer, Set<Long>> partitionToBlockIds;
+ // server -> partitionId -> blockIds
+ private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds;
private final ShuffleWriteClient shuffleWriteClient;
- private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
private final Set<ShuffleServerInfo> shuffleServersForData;
private final long[] partitionLengths;
private final boolean isMemoryShuffleEnabled;
@@ -107,6 +112,7 @@
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
private SparkConf sparkConf;
+ private boolean taskFailRetry;
/** used by columnar rss shuffle writer implementation */
protected final long taskAttemptId;
@@ -173,17 +179,20 @@
this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
- this.partitionToBlockIds = Maps.newHashMap();
+ this.serverToPartitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
this.partitionLengths = new long[partitioner.numPartitions()];
Arrays.fill(partitionLengths, 0);
- partitionToServers = shuffleHandleInfo.getPartitionToServers();
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
this.taskContext = context;
this.sparkConf = sparkConf;
+ this.taskFailRetry =
+ sparkConf.getBoolean(
+ RssClientConf.RSS_TASK_FAILED_RETRY_ENABLED.key(),
+ RssClientConf.RSS_TASK_FAILED_RETRY_ENABLED.defaultValue());
}
public RssShuffleWriter(
@@ -318,7 +327,14 @@
blockIds.add(blockId);
// update [partition, blockIds], it will be sent to shuffle server
int partitionId = sbi.getPartitionId();
- partitionToBlockIds.computeIfAbsent(partitionId, k -> Sets.newHashSet()).add(blockId);
+ sbi.getShuffleServerInfos()
+ .forEach(
+ shuffleServerInfo -> {
+ Map<Integer, Set<Long>> pToBlockIds =
+ serverToPartitionToBlockIds.computeIfAbsent(
+ shuffleServerInfo, k -> Maps.newHashMap());
+ pToBlockIds.computeIfAbsent(partitionId, v -> Sets.newHashSet()).add(blockId);
+ });
partitionLengths[partitionId] += sbi.getLength();
});
return postBlockEvent(shuffleBlockInfoList);
@@ -390,9 +406,16 @@
}
private void checkIfBlocksFailed() {
- Map<Long, BlockingQueue<ShuffleServerInfo>> failedBlockIdsWithShuffleServer =
- shuffleManager.getFailedBlockIdsWithShuffleServer(taskId);
- Set<Long> failedBlockIds = failedBlockIdsWithShuffleServer.keySet();
+ Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
+ if (taskFailRetry && !failedBlockIds.isEmpty()) {
+ Set<TrackingBlockStatus> shouldResendBlockSet = shouldResendBlockStatusSet(failedBlockIds);
+ try {
+ reSendFailedBlockIds(shouldResendBlockSet);
+ } catch (Exception e) {
+ LOG.error("resend failed blocks failed.", e);
+ }
+ failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
+ }
if (!failedBlockIds.isEmpty()) {
String errorMsg =
"Send failed: Task["
@@ -401,14 +424,107 @@
+ " failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server: "
- + failedBlockIdsWithShuffleServer.values().stream()
- .flatMap(Collection::stream)
- .collect(Collectors.toSet());
+ + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers();
LOG.error(errorMsg);
throw new RssSendFailedException(errorMsg);
}
}
+ private Set<TrackingBlockStatus> shouldResendBlockStatusSet(Set<Long> failedBlockIds) {
+ FailedBlockSendTracker failedBlockTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId);
+ Set<TrackingBlockStatus> resendBlockStatusSet = Sets.newHashSet();
+ for (Long failedBlockId : failedBlockIds) {
+ failedBlockTracker.getFailedBlockStatus(failedBlockId).stream()
+ // todo: more status need reassign
+ .filter(
+ trackingBlockStatus -> trackingBlockStatus.getStatusCode() == StatusCode.NO_BUFFER)
+ .forEach(trackingBlockStatus -> resendBlockStatusSet.add(trackingBlockStatus));
+ }
+ return resendBlockStatusSet;
+ }
+
+ private void reSendFailedBlockIds(Set<TrackingBlockStatus> failedBlockStatusSet) {
+ List<ShuffleBlockInfo> reAssignSeverBlockInfoList = Lists.newArrayList();
+ List<ShuffleBlockInfo> failedBlockInfoList = Lists.newArrayList();
+ Map<ShuffleServerInfo, List<TrackingBlockStatus>> faultyServerToPartitions =
+ failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()));
+ Map<String, ShuffleServerInfo> faultyServers = shuffleManager.getReassignedFaultyServers();
+ faultyServerToPartitions.entrySet().stream()
+ .forEach(
+ t -> {
+ Set<String> partitionIds =
+ t.getValue().stream()
+ .map(x -> String.valueOf(x.getShuffleBlockInfo().getPartitionId()))
+ .collect(Collectors.toSet());
+ ShuffleServerInfo dynamicShuffleServer = faultyServers.get(t.getKey().getId());
+ if (dynamicShuffleServer == null) {
+ dynamicShuffleServer =
+ reAssignFaultyShuffleServer(partitionIds, t.getKey().getId());
+ faultyServers.put(t.getKey().getId(), dynamicShuffleServer);
+ }
+
+ ShuffleServerInfo finalDynamicShuffleServer = dynamicShuffleServer;
+ failedBlockStatusSet.forEach(
+ trackingBlockStatus -> {
+ ShuffleBlockInfo failedBlockInfo = trackingBlockStatus.getShuffleBlockInfo();
+ failedBlockInfoList.add(failedBlockInfo);
+ reAssignSeverBlockInfoList.add(
+ new ShuffleBlockInfo(
+ failedBlockInfo.getShuffleId(),
+ failedBlockInfo.getPartitionId(),
+ failedBlockInfo.getBlockId(),
+ failedBlockInfo.getLength(),
+ failedBlockInfo.getCrc(),
+ failedBlockInfo.getData(),
+ Lists.newArrayList(finalDynamicShuffleServer),
+ failedBlockInfo.getUncompressLength(),
+ failedBlockInfo.getFreeMemory(),
+ taskAttemptId));
+ });
+ });
+ clearFailedBlockIdsStates(failedBlockInfoList, faultyServers);
+ processShuffleBlockInfos(reAssignSeverBlockInfoList);
+ checkIfBlocksFailed();
+ }
+
+ private void clearFailedBlockIdsStates(
+ List<ShuffleBlockInfo> failedBlockInfoList, Map<String, ShuffleServerInfo> faultyServers) {
+ failedBlockInfoList.forEach(
+ shuffleBlockInfo -> {
+ shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(shuffleBlockInfo.getBlockId());
+ shuffleBlockInfo.getShuffleServerInfos().stream()
+ .filter(s -> faultyServers.containsKey(s.getId()))
+ .forEach(
+ s ->
+ serverToPartitionToBlockIds
+ .get(s)
+ .get(shuffleBlockInfo.getPartitionId())
+ .remove(shuffleBlockInfo.getBlockId()));
+ partitionLengths[shuffleBlockInfo.getPartitionId()] -= shuffleBlockInfo.getLength();
+ });
+ }
+
+ private ShuffleServerInfo reAssignFaultyShuffleServer(
+ Set<String> partitionIds, String faultyServerId) {
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ String driver = rssConf.getString("driver.host", "");
+ int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
+ try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) {
+ RssReassignFaultyShuffleServerRequest request =
+ new RssReassignFaultyShuffleServerRequest(shuffleId, partitionIds, faultyServerId);
+ RssReassignFaultyShuffleServerResponse response =
+ shuffleManagerClient.reassignFaultyShuffleServer(request);
+ if (response.getStatusCode() != StatusCode.SUCCESS) {
+ throw new RssException(
+ "reassign server response with statusCode[" + response.getStatusCode() + "]");
+ }
+ return response.getShuffleServer();
+ } catch (Exception e) {
+ throw new RssException(
+ "Failed to reassign a new server for faultyServerId server[" + faultyServerId + "]", e);
+ }
+ }
+
@VisibleForTesting
protected void sendCommit() {
ExecutorService executor = Executors.newSingleThreadExecutor();
@@ -454,13 +570,9 @@
public Option<MapStatus> stop(boolean success) {
try {
if (success) {
- Map<Integer, List<Long>> ptb = Maps.newHashMap();
- for (Map.Entry<Integer, Set<Long>> entry : partitionToBlockIds.entrySet()) {
- ptb.put(entry.getKey(), Lists.newArrayList(entry.getValue()));
- }
long start = System.currentTimeMillis();
shuffleWriteClient.reportShuffleResult(
- partitionToServers, appId, shuffleId, taskAttemptId, ptb, bitmapSplitNum);
+ serverToPartitionToBlockIds, appId, shuffleId, taskAttemptId, bitmapSplitNum);
LOG.info(
"Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms",
taskAttemptId,
@@ -492,7 +604,17 @@
@VisibleForTesting
Map<Integer, Set<Long>> getPartitionToBlockIds() {
- return partitionToBlockIds;
+ return serverToPartitionToBlockIds.values().stream()
+ .flatMap(s -> s.entrySet().stream())
+ .collect(
+ Collectors.toMap(
+ Map.Entry::getKey,
+ Map.Entry::getValue,
+ (existingSet, newSet) -> {
+ Set<Long> mergedSet = new HashSet<>(existingSet);
+ mergedSet.addAll(newSet);
+ return mergedSet;
+ }));
}
@VisibleForTesting
@@ -511,13 +633,10 @@
private void throwFetchFailedIfNecessary(Exception e) {
// The shuffleServer is registered only when a Block fails to be sent
if (e instanceof RssSendFailedException) {
- Map<Long, BlockingQueue<ShuffleServerInfo>> failedBlockIds =
- shuffleManager.getFailedBlockIdsWithShuffleServer(taskId);
- List<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList();
- for (Map.Entry<Long, BlockingQueue<ShuffleServerInfo>> longListEntry :
- failedBlockIds.entrySet()) {
- shuffleServerInfos.addAll(longListEntry.getValue());
- }
+ FailedBlockSendTracker blockIdsFailedSendTracker =
+ shuffleManager.getBlockIdsFailedSendTracker(taskId);
+ List<ShuffleServerInfo> shuffleServerInfos =
+ Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers());
RssReportShuffleWriteFailureRequest req =
new RssReportShuffleWriteFailureRequest(
appId,
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
index 2312424..cb2a7f9 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
@@ -19,13 +19,13 @@
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import org.apache.commons.lang3.SystemUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.writer.DataPusher;
-import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.common.ShuffleBlockInfo;
public class TestUtils {
@@ -36,13 +36,16 @@
Boolean isDriver,
DataPusher dataPusher,
Map<String, Set<Long>> successBlockIds,
- Map<String, Set<Long>> failBlockIds,
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer) {
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
return new RssShuffleManager(
- conf, isDriver, dataPusher, successBlockIds, failBlockIds, taskToFailedBlockIdsAndServer);
+ conf, isDriver, dataPusher, successBlockIds, taskToFailedBlockSendTracker);
}
public static boolean isMacOnAppleSilicon() {
return SystemUtils.IS_OS_MAC_OSX && SystemUtils.OS_ARCH.equals("aarch64");
}
+
+ public static ShuffleBlockInfo createMockBlockOnlyBlockId(long blockId) {
+ return new ShuffleBlockInfo(1, 1, blockId, 1, 1, new byte[1], null, 1, 100, 1);
+ }
}
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 65c4b4f..b68d4b7 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -26,7 +26,6 @@
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -55,8 +54,10 @@
import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
@@ -84,14 +85,14 @@
.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name())
.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346");
- Map<String, Set<Long>> failBlocks = JavaUtils.newConcurrentMap();
Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
JavaUtils.newConcurrentMap();
Serializer kryoSerializer = new KryoSerializer(conf);
RssShuffleManager manager =
TestUtils.createShuffleManager(
- conf, false, null, successBlocks, failBlocks, taskToFailedBlockIdsAndServer);
+ conf, false, null, successBlocks, taskToFailedBlockSendTracker);
ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
Partitioner mockPartitioner = mock(Partitioner.class);
@@ -151,21 +152,19 @@
// case 3: partial blocks are sent failed, Runtime exception will be thrown
successBlocks.put("taskId", Sets.newHashSet(1L, 2L));
- failBlocks.put("taskId", Sets.newHashSet(3L));
- Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdToShuffleServerInfoMap =
- JavaUtils.newConcurrentMap();
- BlockingQueue blockingQueue = new LinkedBlockingQueue<>();
+ FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
+ taskToFailedBlockSendTracker.put("taskId", failedBlockSendTracker);
ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo("127.0.0.1", 20001);
- blockingQueue.add(shuffleServerInfo);
- blockIdToShuffleServerInfoMap.put(3L, blockingQueue);
- taskToFailedBlockIdsAndServer.put("taskId", blockIdToShuffleServerInfoMap);
+ failedBlockSendTracker.add(
+ TestUtils.createMockBlockOnlyBlockId(3L), shuffleServerInfo, StatusCode.INTERNAL_ERROR);
Throwable e3 =
assertThrows(
RuntimeException.class,
() -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L)));
+ System.out.println(e2.getMessage());
assertTrue(e3.getMessage().startsWith("Send failed:"));
successBlocks.clear();
- failBlocks.clear();
+ taskToFailedBlockSendTracker.clear();
}
static class FakedDataPusher extends DataPusher {
@@ -179,7 +178,7 @@
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
- Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
+ Map<String, FailedBlockSendTracker> failedBlockSendTracker,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime,
@@ -187,8 +186,7 @@
super(
shuffleWriteClient,
taskToSuccessBlockIds,
- taskToFailedBlockIds,
- taskToFailedBlockIdsAndServer,
+ failedBlockSendTracker,
failedTaskIds,
threadPoolSize,
threadKeepAliveTime);
@@ -237,16 +235,13 @@
final RssShuffleManager manager =
TestUtils.createShuffleManager(
- conf,
- false,
- dataPusher,
- successBlockIds,
- JavaUtils.newConcurrentMap(),
- JavaUtils.newConcurrentMap());
+ conf, false, dataPusher, successBlockIds, JavaUtils.newConcurrentMap());
WriteBufferManagerTest.FakedTaskMemoryManager fakedTaskMemoryManager =
new WriteBufferManagerTest.FakedTaskMemoryManager();
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newConcurrentMap();
+ partitionToServers.put(0, Lists.newArrayList(new ShuffleServerInfo("127.0.0.1", 1111)));
WriteBufferManager bufferManager =
new WriteBufferManager(
0,
@@ -254,7 +249,7 @@
0,
bufferOptions,
new KryoSerializer(conf),
- Maps.newHashMap(),
+ partitionToServers,
fakedTaskMemoryManager,
new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf),
@@ -338,12 +333,7 @@
final RssShuffleManager manager =
TestUtils.createShuffleManager(
- conf,
- false,
- dataPusher,
- successBlockIds,
- JavaUtils.newConcurrentMap(),
- JavaUtils.newConcurrentMap());
+ conf, false, dataPusher, successBlockIds, JavaUtils.newConcurrentMap());
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
@@ -453,6 +443,7 @@
}
}
Map<Integer, Set<Long>> partitionToBlockIds = rssShuffleWriterSpy.getPartitionToBlockIds();
+ System.out.println(11111);
assertEquals(2, partitionToBlockIds.get(1).size());
assertEquals(2, partitionToBlockIds.get(0).size());
assertEquals(2, partitionToBlockIds.get(2).size());
@@ -503,7 +494,6 @@
false,
dataPusher,
Maps.newConcurrentMap(),
- Maps.newConcurrentMap(),
JavaUtils.newConcurrentMap()));
RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
index 1c83124..ec83a09 100644
--- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
+++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java
@@ -86,7 +86,9 @@
private final Codec codec;
private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
private final Set<Long> allBlockIds = Sets.newConcurrentHashSet();
- private final Map<Integer, List<Long>> partitionToBlocks = Maps.newConcurrentMap();
+ // server -> partitionId -> blockIds
+ private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds =
+ Maps.newConcurrentMap();
private final int numMaps;
private final boolean isMemoryShuffleEnabled;
private final long sendCheckInterval;
@@ -246,10 +248,17 @@
buffer.clear();
shuffleBlocks.add(block);
allBlockIds.add(block.getBlockId());
- if (!partitionToBlocks.containsKey(block.getPartitionId())) {
- partitionToBlocks.putIfAbsent(block.getPartitionId(), Lists.newArrayList());
- }
- partitionToBlocks.get(block.getPartitionId()).add(block.getBlockId());
+ block
+ .getShuffleServerInfos()
+ .forEach(
+ shuffleServerInfo -> {
+ Map<Integer, Set<Long>> pToBlockIds =
+ serverToPartitionToBlockIds.computeIfAbsent(
+ shuffleServerInfo, k -> Maps.newHashMap());
+ pToBlockIds
+ .computeIfAbsent(block.getPartitionId(), v -> Sets.newHashSet())
+ .add(block.getBlockId());
+ });
}
private void sendShuffleBlocks(List<ShuffleBlockInfo> shuffleBlocks) {
@@ -322,7 +331,7 @@
LOG.info(
"tezVertexID is {}, tezDAGID is {}, shuffleId is {}", tezVertexID, tezDAGID, shuffleId);
shuffleWriteClient.reportShuffleResult(
- partitionToServers, appId, shuffleId, taskAttemptId, partitionToBlocks, bitmapSplitNum);
+ serverToPartitionToBlockIds, appId, shuffleId, taskAttemptId, bitmapSplitNum);
LOG.info(
"Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms",
taskAttemptId,
diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index bef29fa..d8d435b 100644
--- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -20,7 +20,6 @@
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -29,6 +28,7 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
@@ -55,6 +55,7 @@
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
@@ -64,11 +65,13 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
public class WriteBufferManagerTest {
@Test
@@ -152,7 +155,7 @@
true,
mapOutputByteCounter,
mapOutputRecordCounter);
-
+ partitionToServers.put(1, Lists.newArrayList(mock(ShuffleServerInfo.class)));
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
@@ -264,6 +267,7 @@
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
+ partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
bufferManager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
@@ -378,6 +382,7 @@
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
+ partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
bufferManager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
bufferManager.waitSendFinished();
@@ -485,6 +490,8 @@
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
+ partitionToServers.put(
+ partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
bufferManager.addRecord(
partitionId, new BytesWritable(key), new BytesWritable(value));
}
@@ -542,7 +549,12 @@
if (mode == 0) {
throw new RssException("send data failed.");
} else if (mode == 1) {
- return new SendShuffleDataResult(Sets.newHashSet(2L), Sets.newHashSet(1L));
+ FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
+ ShuffleBlockInfo failedBlock =
+ new ShuffleBlockInfo(1, 1, 3, 1, 1, new byte[1], null, 1, 100, 1);
+ failedBlockSendTracker.add(
+ failedBlock, new ShuffleServerInfo("host", 39998), StatusCode.NO_BUFFER);
+ return new SendShuffleDataResult(Sets.newHashSet(2L), failedBlockSendTracker);
} else {
if (mode == 3) {
try {
@@ -557,7 +569,7 @@
for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
successBlockIds.add(blockInfo.getBlockId());
}
- return new SendShuffleDataResult(successBlockIds, Collections.EMPTY_SET);
+ return new SendShuffleDataResult(successBlockIds, new FailedBlockSendTracker());
}
}
@@ -605,17 +617,21 @@
@Override
public void reportShuffleResult(
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
String appId,
int shuffleId,
long taskAttemptId,
- Map<Integer, List<Long>> partitionToBlockIds,
int bitmapNum) {
if (mode == 3) {
- mockedShuffleServer.addFinishedBlockInfos(
- partitionToBlockIds.values().stream()
- .flatMap(it -> it.stream())
- .collect(Collectors.toList()));
+ serverToPartitionToBlockIds
+ .values()
+ .forEach(
+ partitionToBlockIds -> {
+ mockedShuffleServer.addFinishedBlockInfos(
+ partitionToBlockIds.values().stream()
+ .flatMap(it -> it.stream())
+ .collect(Collectors.toList()));
+ });
}
}
diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index b28320c..88d97c3 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -62,11 +62,10 @@
RemoteStorageInfo fetchRemoteStorage(String appId);
void reportShuffleResult(
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
String appId,
int shuffleId,
long taskAttemptId,
- Map<Integer, List<Long>> partitionToBlockIds,
int bitmapNum);
default ShuffleAssignmentsInfo getShuffleAssignments(
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
new file mode 100644
index 0000000..0c239c7
--- /dev/null
+++ b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+public class FailedBlockSendTracker {
+
+ private Map<Long, List<TrackingBlockStatus>> trackingBlockStatusMap;
+
+ public FailedBlockSendTracker() {
+ this.trackingBlockStatusMap = Maps.newConcurrentMap();
+ }
+
+ public void add(
+ ShuffleBlockInfo shuffleBlockInfo,
+ ShuffleServerInfo shuffleServerInfo,
+ StatusCode statusCode) {
+ trackingBlockStatusMap
+ .computeIfAbsent(shuffleBlockInfo.getBlockId(), s -> Lists.newLinkedList())
+ .add(new TrackingBlockStatus(shuffleBlockInfo, shuffleServerInfo, statusCode));
+ }
+
+ public void merge(FailedBlockSendTracker failedBlockSendTracker) {
+ this.trackingBlockStatusMap.putAll(failedBlockSendTracker.trackingBlockStatusMap);
+ }
+
+ public void remove(long blockId) {
+ trackingBlockStatusMap.remove(blockId);
+ }
+
+ public void clear() {
+ trackingBlockStatusMap.clear();
+ }
+
+ public Set<Long> getFailedBlockIds() {
+ return trackingBlockStatusMap.keySet();
+ }
+
+ public List<TrackingBlockStatus> getFailedBlockStatus(Long blockId) {
+ return trackingBlockStatusMap.get(blockId);
+ }
+
+ public Set<ShuffleServerInfo> getFaultyShuffleServers() {
+ return trackingBlockStatusMap.values().stream()
+ .flatMap(Collection::stream)
+ .map(s -> s.getShuffleServerInfo())
+ .collect(Collectors.toSet());
+ }
+}
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 337869d..7f40085 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -20,16 +20,16 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
-import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
+import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.google.common.annotations.VisibleForTesting;
@@ -154,7 +154,7 @@
Map<ShuffleServerInfo, Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>> serverToBlocks,
Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
Map<Long, AtomicInteger> blockIdsSendSuccessTracker,
- Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdsSendFailTracker,
+ FailedBlockSendTracker failedBlockSendTracker,
boolean allowFastFail,
Supplier<Boolean> needCancelRequest) {
@@ -205,13 +205,8 @@
LOG.debug("{} successfully.", logMsg);
}
} else {
- serverToBlockIds
- .get(ssi)
- .forEach(
- blockId ->
- blockIdsSendFailTracker
- .computeIfAbsent(blockId, id -> new LinkedBlockingQueue<>())
- .add(ssi));
+ recordFailedBlocks(
+ failedBlockSendTracker, serverToBlocks, ssi, response.getStatusCode());
if (defectiveServers != null) {
defectiveServers.add(ssi);
}
@@ -219,13 +214,8 @@
return false;
}
} catch (Exception e) {
- serverToBlockIds
- .get(ssi)
- .forEach(
- blockId ->
- blockIdsSendFailTracker
- .computeIfAbsent(blockId, id -> new LinkedBlockingQueue<>())
- .add(ssi));
+ recordFailedBlocks(
+ failedBlockSendTracker, serverToBlocks, ssi, StatusCode.INTERNAL_ERROR);
if (defectiveServers != null) {
defectiveServers.add(ssi);
}
@@ -254,6 +244,17 @@
return result;
}
+ void recordFailedBlocks(
+ FailedBlockSendTracker blockIdsSendFailTracker,
+ Map<ShuffleServerInfo, Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>> serverToBlocks,
+ ShuffleServerInfo shuffleServerInfo,
+ StatusCode statusCode) {
+ serverToBlocks.getOrDefault(shuffleServerInfo, Collections.emptyMap()).values().stream()
+ .flatMap(innerMap -> innerMap.values().stream())
+ .flatMap(List::stream)
+ .forEach(block -> blockIdsSendFailTracker.add(block, shuffleServerInfo, statusCode));
+ }
+
void genServerToBlocks(
ShuffleBlockInfo sbi,
List<ShuffleServerInfo> serverList,
@@ -374,8 +375,7 @@
block ->
blockIdsSendSuccessTracker.computeIfAbsent(
block, id -> new AtomicInteger(0))));
- Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdsSendFailTracker =
- JavaUtils.newConcurrentMap();
+ FailedBlockSendTracker blockIdsSendFailTracker = new FailedBlockSendTracker();
// sent the primary round of blocks.
boolean isAllSuccess =
@@ -419,8 +419,7 @@
blockIdsSendFailTracker.remove(successBlockId.getKey());
}
});
- return new SendShuffleDataResult(
- blockIdsSendSuccessSet, blockIdsSendFailTracker.keySet(), blockIdsSendFailTracker);
+ return new SendShuffleDataResult(blockIdsSendSuccessSet, blockIdsSendFailTracker);
}
/**
@@ -675,41 +674,27 @@
@Override
public void reportShuffleResult(
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
String appId,
int shuffleId,
long taskAttemptId,
- Map<Integer, List<Long>> partitionToBlockIds,
int bitmapNum) {
- Map<ShuffleServerInfo, List<Integer>> groupedPartitions = Maps.newHashMap();
- Map<Integer, Integer> partitionReportTracker = Maps.newHashMap();
- for (Map.Entry<Integer, List<ShuffleServerInfo>> entry : partitionToServers.entrySet()) {
- int partitionIdx = entry.getKey();
- for (ShuffleServerInfo ssi : entry.getValue()) {
- if (!groupedPartitions.containsKey(ssi)) {
- groupedPartitions.put(ssi, Lists.newArrayList());
- }
- groupedPartitions.get(ssi).add(partitionIdx);
- }
- if (CollectionUtils.isNotEmpty(partitionToBlockIds.get(partitionIdx))) {
- partitionReportTracker.putIfAbsent(partitionIdx, 0);
- }
- }
-
- for (Map.Entry<ShuffleServerInfo, List<Integer>> entry : groupedPartitions.entrySet()) {
- Map<Integer, List<Long>> requestBlockIds = Maps.newHashMap();
- for (Integer partitionId : entry.getValue()) {
- List<Long> blockIds = partitionToBlockIds.get(partitionId);
- if (CollectionUtils.isNotEmpty(blockIds)) {
- requestBlockIds.put(partitionId, blockIds);
- }
- }
+ // record blockId count for quora check,but this is not a good realization.
+ Map<Long, Integer> blockReportTracker = createBlockReportTracker(serverToPartitionToBlockIds);
+ for (Map.Entry<ShuffleServerInfo, Map<Integer, Set<Long>>> entry :
+ serverToPartitionToBlockIds.entrySet()) {
+ Map<Integer, Set<Long>> requestBlockIds = entry.getValue();
if (requestBlockIds.isEmpty()) {
continue;
}
RssReportShuffleResultRequest request =
new RssReportShuffleResultRequest(
- appId, shuffleId, taskAttemptId, requestBlockIds, bitmapNum);
+ appId,
+ shuffleId,
+ taskAttemptId,
+ requestBlockIds.entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getKey, e -> new ArrayList<>(e.getValue()))),
+ bitmapNum);
ShuffleServerInfo ssi = entry.getKey();
try {
RssReportShuffleResultResponse response =
@@ -723,9 +708,6 @@
+ "], shuffleId["
+ shuffleId
+ "] successfully");
- for (Integer partitionId : requestBlockIds.keySet()) {
- partitionReportTracker.put(partitionId, partitionReportTracker.get(partitionId) + 1);
- }
} else {
LOG.warn(
"Report shuffle result to "
@@ -736,6 +718,7 @@
+ shuffleId
+ "] failed with "
+ response.getStatusCode());
+ recordFailedBlockIds(blockReportTracker, requestBlockIds);
}
} catch (Exception e) {
LOG.warn(
@@ -746,19 +729,37 @@
+ "], shuffleId["
+ shuffleId
+ "]");
+ recordFailedBlockIds(blockReportTracker, requestBlockIds);
}
}
- // quorum check
- for (Map.Entry<Integer, Integer> entry : partitionReportTracker.entrySet()) {
- if (entry.getValue() < replicaWrite) {
- throw new RssException(
- "Quorum check of report shuffle result is failed for appId["
- + appId
- + "], shuffleId["
- + shuffleId
- + "]");
+ if (blockReportTracker.values().stream().anyMatch(cnt -> cnt < replicaWrite)) {
+ throw new RssException(
+ "Quorum check of report shuffle result is failed for appId["
+ + appId
+ + "], shuffleId["
+ + shuffleId
+ + "]");
+ }
+ }
+
+ private void recordFailedBlockIds(
+ Map<Long, Integer> blockReportTracker, Map<Integer, Set<Long>> requestBlockIds) {
+ requestBlockIds.values().stream()
+ .flatMap(Set::stream)
+ .forEach(blockId -> blockReportTracker.merge(blockId, -1, Integer::sum));
+ }
+
+ private Map<Long, Integer> createBlockReportTracker(
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds) {
+ Map<Long, Integer> blockIdCount = new HashMap<>();
+ for (Map<Integer, Set<Long>> partitionToBlockIds : serverToPartitionToBlockIds.values()) {
+ for (Set<Long> blockIds : partitionToBlockIds.values()) {
+ for (Long blockId : blockIds) {
+ blockIdCount.put(blockId, blockIdCount.getOrDefault(blockId, 0) + 1);
+ }
}
}
+ return blockIdCount;
}
@Override
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/TrackingBlockStatus.java b/client/src/main/java/org/apache/uniffle/client/impl/TrackingBlockStatus.java
new file mode 100644
index 0000000..18dbeb9
--- /dev/null
+++ b/client/src/main/java/org/apache/uniffle/client/impl/TrackingBlockStatus.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+public class TrackingBlockStatus {
+ private ShuffleBlockInfo shuffleBlockInfo;
+ private ShuffleServerInfo shuffleServerInfo;
+ private StatusCode statusCode;
+
+ public TrackingBlockStatus(
+ ShuffleBlockInfo shuffleBlockInfo,
+ ShuffleServerInfo shuffleServerInfo,
+ StatusCode statusCode) {
+ this.shuffleBlockInfo = shuffleBlockInfo;
+ this.shuffleServerInfo = shuffleServerInfo;
+ this.statusCode = statusCode;
+ }
+
+ public ShuffleBlockInfo getShuffleBlockInfo() {
+ return shuffleBlockInfo;
+ }
+
+ public ShuffleServerInfo getShuffleServerInfo() {
+ return shuffleServerInfo;
+ }
+
+ public StatusCode getStatusCode() {
+ return statusCode;
+ }
+}
diff --git a/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java b/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
index f2d820e..595de29 100644
--- a/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
+++ b/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
@@ -17,32 +17,19 @@
package org.apache.uniffle.client.response;
-import java.util.Map;
import java.util.Set;
-import java.util.concurrent.BlockingQueue;
-import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
public class SendShuffleDataResult {
private Set<Long> successBlockIds;
- private Set<Long> failedBlockIds;
- private Map<Long, BlockingQueue<ShuffleServerInfo>> sendFailedBlockIds;
-
- public SendShuffleDataResult(Set<Long> successBlockIds, Set<Long> failedBlockIds) {
- this.successBlockIds = successBlockIds;
- this.failedBlockIds = failedBlockIds;
- this.sendFailedBlockIds = JavaUtils.newConcurrentMap();
- }
+ private FailedBlockSendTracker failedBlockSendTracker;
public SendShuffleDataResult(
- Set<Long> successBlockIds,
- Set<Long> failedBlockIds,
- Map<Long, BlockingQueue<ShuffleServerInfo>> sendFailedBlockIds) {
+ Set<Long> successBlockIds, FailedBlockSendTracker failedBlockSendTracker) {
this.successBlockIds = successBlockIds;
- this.failedBlockIds = failedBlockIds;
- this.sendFailedBlockIds = sendFailedBlockIds;
+ this.failedBlockSendTracker = failedBlockSendTracker;
}
public Set<Long> getSuccessBlockIds() {
@@ -50,10 +37,10 @@
}
public Set<Long> getFailedBlockIds() {
- return failedBlockIds;
+ return failedBlockSendTracker.getFailedBlockIds();
}
- public Map<Long, BlockingQueue<ShuffleServerInfo>> getSendFailedBlockIds() {
- return sendFailedBlockIds;
+ public FailedBlockSendTracker getFailedBlockSendTracker() {
+ return failedBlockSendTracker;
}
}
diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 5b9a6fb..d14b764 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -112,7 +112,7 @@
return shuffleServerInfo;
}
- private static RssProtos.ShuffleServerId convertToShuffleServerId(
+ public static RssProtos.ShuffleServerId convertToShuffleServerId(
ShuffleServerInfo shuffleServerInfo) {
RssProtos.ShuffleServerId shuffleServerId =
RssProtos.ShuffleServerId.newBuilder()
diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 3e61c61..006c9ab 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -171,4 +171,11 @@
.withDescription(
"This option is only valid when the remote storage path is specified. If ture, "
+ "the remote storage conf will use the client side hadoop configuration loaded from the classpath.");
+
+ public static final ConfigOption<Boolean> RSS_TASK_FAILED_RETRY_ENABLED =
+ ConfigOptions.key("rss.task.failed.retry.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Whether to support task write failed retry internal, default value is false.");
}
diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
index 03acbc7..5b0a719 100644
--- a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
@@ -20,6 +20,7 @@
import java.io.File;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@@ -363,13 +364,13 @@
enableTimeout((MockedShuffleServer) shuffleServers.get(2), 500);
// report result should success
- Map<Integer, List<Long>> partitionToBlockIds = Maps.newHashMap();
- partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator()));
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToServers.put(
- 0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 0, 0L, partitionToBlockIds, 1);
+ Map<Integer, Set<Long>> partitionToBlockIds = Maps.newHashMap();
+ partitionToBlockIds.put(0, Sets.newHashSet(blockIdBitmap.stream().iterator()));
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo0, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo2, partitionToBlockIds);
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 0, 0L, 1);
Roaring64NavigableMap report =
shuffleWriteClientImpl.getShuffleResult(
"GRPC",
@@ -441,14 +442,14 @@
assertEquals(0, result.getSuccessBlockIds().size());
// report result should fail
- Map<Integer, List<Long>> partitionToBlockIds = Maps.newHashMap();
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator()));
- partitionToServers.put(
- 0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
+ Map<Integer, Set<Long>> partitionToBlockIds = Maps.newHashMap();
+ partitionToBlockIds.put(0, Sets.newHashSet(blockIdBitmap.stream().iterator()));
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo0, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo2, partitionToBlockIds);
try {
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 0, 0L, partitionToBlockIds, 1);
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 0, 0L, 1);
fail(EXPECTED_EXCEPTION_MESSAGE);
} catch (Exception e) {
assertTrue(e.getMessage().startsWith("Quorum check of report shuffle result is failed"));
@@ -502,13 +503,14 @@
assertEquals(blockIdBitmap, succBlockIdBitmap);
assertEquals(0, failedBlockIdBitmap.getLongCardinality());
- Map<Integer, List<Long>> partitionToBlockIds = Maps.newHashMap();
- partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator()));
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToServers.put(
- 0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 0, 0L, partitionToBlockIds, 1);
+ Map<Integer, Set<Long>> partitionToBlockIds = Maps.newHashMap();
+ partitionToBlockIds.put(0, Sets.newHashSet(blockIdBitmap.stream().iterator()));
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo0, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo2, partitionToBlockIds);
+
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 0, 0L, 1);
Roaring64NavigableMap report =
shuffleWriteClientImpl.getShuffleResult(
@@ -603,13 +605,13 @@
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
// report result should success
- Map<Integer, List<Long>> partitionToBlockIds = Maps.newHashMap();
- partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator()));
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToServers.put(
- 0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 0, 0L, partitionToBlockIds, 1);
+ Map<Integer, Set<Long>> partitionToBlockIds = Maps.newHashMap();
+ partitionToBlockIds.put(0, Sets.newHashSet(blockIdBitmap.stream().iterator()));
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo0, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo2, partitionToBlockIds);
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 0, 0L, 1);
Roaring64NavigableMap report =
shuffleWriteClientImpl.getShuffleResult(
"GRPC",
@@ -703,23 +705,21 @@
enableTimeout((MockedShuffleServer) shuffleServers.get(3), 500);
enableTimeout((MockedShuffleServer) shuffleServers.get(4), 500);
- Map<Integer, List<Long>> partitionToBlockIds = Maps.newHashMap();
- partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap0.stream().iterator()));
- partitionToBlockIds.put(1, Lists.newArrayList(blockIdBitmap1.stream().iterator()));
- partitionToBlockIds.put(2, Lists.newArrayList(blockIdBitmap2.stream().iterator()));
+ Map<Integer, Set<Long>> partitionToBlockIds = Maps.newHashMap();
+ partitionToBlockIds.put(0, Sets.newHashSet(blockIdBitmap0.stream().iterator()));
+ partitionToBlockIds.put(1, Sets.newHashSet(blockIdBitmap1.stream().iterator()));
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo0, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo2, partitionToBlockIds);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToServers.put(
- 0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
- partitionToServers.put(
- 1, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
- partitionToServers.put(
- 2, Lists.newArrayList(shuffleServerInfo2, shuffleServerInfo3, shuffleServerInfo4));
-
+ Map<Integer, Set<Long>> partitionToBlockIds2 = Maps.newHashMap();
+ partitionToBlockIds2.put(2, Sets.newHashSet(blockIdBitmap2.stream().iterator()));
+ serverToPartitionToBlockIds.put(shuffleServerInfo3, partitionToBlockIds2);
+ serverToPartitionToBlockIds.put(shuffleServerInfo4, partitionToBlockIds2);
// report result should fail because partition2 is failed to report server 3,4
try {
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 0, 0L, partitionToBlockIds, 1);
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 0, 0L, 1);
fail(EXPECTED_EXCEPTION_MESSAGE);
} catch (Exception e) {
assertTrue(e.getMessage().startsWith("Quorum check of report shuffle result is failed"));
diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index eee83cb..1442b21 100644
--- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -20,6 +20,7 @@
import java.io.File;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Lists;
@@ -163,11 +164,12 @@
assertFalse(commitResult);
// Report will success when replica=2
- Map<Integer, List<Long>> ptb = Maps.newHashMap();
- ptb.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator()));
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToServers.put(0, Lists.newArrayList(shuffleServerInfo1, fakeShuffleServerInfo));
- shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0, ptb, 2);
+ Map<Integer, Set<Long>> ptb = Maps.newHashMap();
+ ptb.put(0, Sets.newHashSet(blockIdBitmap.stream().iterator()));
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, ptb);
+ serverToPartitionToBlockIds.put(fakeShuffleServerInfo, ptb);
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 0, 0, 2);
Roaring64NavigableMap report =
shuffleWriteClientImpl.getShuffleResult(
"GRPC", Sets.newHashSet(shuffleServerInfo1, fakeShuffleServerInfo), testAppId, 0, 0);
@@ -196,21 +198,17 @@
ShuffleDataDistributionType.NORMAL,
-1);
- Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
- partitionToServers.put(1, Lists.newArrayList(shuffleServerInfo1));
- partitionToServers.put(2, Lists.newArrayList(shuffleServerInfo2));
- Map<Integer, List<Long>> partitionToBlocks = Maps.newHashMap();
- List<Long> blockIds = Lists.newArrayList();
-
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ Map<Integer, Set<Long>> partitionToBlocks = Maps.newHashMap();
+ Set<Long> blockIds = Sets.newHashSet();
int partitionIdx = 1;
for (int i = 0; i < 5; i++) {
blockIds.add(ClientUtils.getBlockId(partitionIdx, 0, i));
}
partitionToBlocks.put(partitionIdx, blockIds);
-
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlocks);
// case1
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 1, 0, partitionToBlocks, 1);
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 1, 0, 1);
Roaring64NavigableMap bitmap =
shuffleWriteClientImpl.getShuffleResult(
"GRPC", Sets.newHashSet(shuffleServerInfo1), testAppId, 1, 0);
@@ -220,8 +218,8 @@
shuffleWriteClientImpl.getShuffleResult(
"GRPC", Sets.newHashSet(shuffleServerInfo1), testAppId, 1, partitionIdx);
assertEquals(5, bitmap.getLongCardinality());
- for (int i = 0; i < 5; i++) {
- assertTrue(bitmap.contains(partitionToBlocks.get(1).get(i)));
+ for (Long b : partitionToBlocks.get(1)) {
+ assertTrue(bitmap.contains(b));
}
}
@@ -250,19 +248,24 @@
Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
partitionToServers.putIfAbsent(1, Lists.newArrayList(shuffleServerInfo1));
partitionToServers.putIfAbsent(2, Lists.newArrayList(shuffleServerInfo2));
- Map<Integer, List<Long>> partitionToBlocks = Maps.newHashMap();
- List<Long> blockIds = Lists.newArrayList();
+ Map<Integer, Set<Long>> partitionToBlocks1 = Maps.newHashMap();
+ Set<Long> blockIds = Sets.newHashSet();
for (int i = 0; i < 5; i++) {
blockIds.add(ClientUtils.getBlockId(1, 0, i));
}
- partitionToBlocks.put(1, blockIds);
- blockIds = Lists.newArrayList();
+ partitionToBlocks1.put(1, blockIds);
+ Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds = Maps.newHashMap();
+ serverToPartitionToBlockIds.put(shuffleServerInfo1, partitionToBlocks1);
+
+ Map<Integer, Set<Long>> partitionToBlocks2 = Maps.newHashMap();
+ blockIds = Sets.newHashSet();
for (int i = 0; i < 7; i++) {
blockIds.add(ClientUtils.getBlockId(2, 0, i));
}
- partitionToBlocks.put(2, blockIds);
- shuffleWriteClientImpl.reportShuffleResult(
- partitionToServers, testAppId, 1, 0, partitionToBlocks, 1);
+ partitionToBlocks2.put(2, blockIds);
+ serverToPartitionToBlockIds.put(shuffleServerInfo2, partitionToBlocks2);
+
+ shuffleWriteClientImpl.reportShuffleResult(serverToPartitionToBlockIds, testAppId, 1, 0, 1);
Roaring64NavigableMap bitmap =
shuffleWriteClientImpl.getShuffleResult(
@@ -273,8 +276,8 @@
shuffleWriteClientImpl.getShuffleResult(
"GRPC", Sets.newHashSet(shuffleServerInfo1), testAppId, 1, 1);
assertEquals(5, bitmap.getLongCardinality());
- for (int i = 0; i < 5; i++) {
- assertTrue(bitmap.contains(partitionToBlocks.get(1).get(i)));
+ for (Long b : partitionToBlocks1.get(1)) {
+ assertTrue(bitmap.contains(b));
}
bitmap =
@@ -296,8 +299,8 @@
shuffleWriteClientImpl.getShuffleResult(
"GRPC", Sets.newHashSet(shuffleServerInfo2), testAppId, 1, 2);
assertEquals(7, bitmap.getLongCardinality());
- for (int i = 0; i < 7; i++) {
- assertTrue(bitmap.contains(partitionToBlocks.get(2).get(i)));
+ for (Long b : partitionToBlocks2.get(2)) {
+ assertTrue(bitmap.contains(b));
}
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index ddba67b..77506a7 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -20,10 +20,12 @@
import java.io.Closeable;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
+import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
+import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
@@ -45,4 +47,7 @@
RssReportShuffleWriteFailureRequest req);
RssReassignServersReponse reassignShuffleServers(RssReassignServersRequest req);
+
+ RssReassignFaultyShuffleServerResponse reassignFaultyShuffleServer(
+ RssReassignFaultyShuffleServerRequest request);
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index b38113c..128f26d 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -24,10 +24,12 @@
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
+import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse;
+import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
@@ -115,4 +117,14 @@
getBlockingStub().reassignShuffleServers(reassignServersRequest);
return RssReassignServersReponse.fromProto(reassignServersReponse);
}
+
+ @Override
+ public RssReassignFaultyShuffleServerResponse reassignFaultyShuffleServer(
+ RssReassignFaultyShuffleServerRequest request) {
+ RssProtos.RssReassignFaultyShuffleServerRequest rssReassignFaultyShuffleServerRequest =
+ request.toProto();
+ RssProtos.RssReassignFaultyShuffleServerResponse response =
+ getBlockingStub().reassignFaultyShuffleServer(rssReassignFaultyShuffleServerRequest);
+ return RssReassignFaultyShuffleServerResponse.fromProto(response);
+ }
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 7297aec..1dee0ee 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -24,6 +24,7 @@
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
@@ -437,6 +438,7 @@
request.getShuffleIdToBlocks();
boolean isSuccessful = true;
+ AtomicReference<StatusCode> failedStatusCode = new AtomicReference<>(StatusCode.INTERNAL_ERROR);
// prepare rpc request based on shuffleId -> partitionId -> blocks
for (Map.Entry<Integer, Map<Integer, List<ShuffleBlockInfo>>> stb :
@@ -523,6 +525,7 @@
+ response.getStatus()
+ ", errorMsg:"
+ response.getRetMsg();
+ failedStatusCode.set(StatusCode.fromCode(response.getStatus().getNumber()));
if (response.getStatus() == RssProtos.StatusCode.NO_REGISTER) {
throw new NotRetryException(msg);
} else {
@@ -546,7 +549,7 @@
if (isSuccessful) {
response = new RssSendShuffleDataResponse(StatusCode.SUCCESS);
} else {
- response = new RssSendShuffleDataResponse(StatusCode.INTERNAL_ERROR);
+ response = new RssSendShuffleDataResponse(failedStatusCode.get());
}
return response;
}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignFaultyShuffleServerRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignFaultyShuffleServerRequest.java
new file mode 100644
index 0000000..ac96a9e
--- /dev/null
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignFaultyShuffleServerRequest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.request;
+
+import java.util.Set;
+
+import org.apache.uniffle.proto.RssProtos;
+
+public class RssReassignFaultyShuffleServerRequest {
+
+ private int shuffleId;
+ private Set<String> partitionIds;
+ private String faultyShuffleServerId;
+
+ public RssReassignFaultyShuffleServerRequest(
+ int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
+ this.shuffleId = shuffleId;
+ this.partitionIds = partitionIds;
+ this.faultyShuffleServerId = faultyShuffleServerId;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public Set<String> getPartitionIds() {
+ return partitionIds;
+ }
+
+ public String getFaultyShuffleServerId() {
+ return faultyShuffleServerId;
+ }
+
+ public RssProtos.RssReassignFaultyShuffleServerRequest toProto() {
+ RssProtos.RssReassignFaultyShuffleServerRequest.Builder builder =
+ RssProtos.RssReassignFaultyShuffleServerRequest.newBuilder()
+ .setShuffleId(this.shuffleId)
+ .setFaultyShuffleServerId(this.faultyShuffleServerId)
+ .addAllPartitionIds(this.partitionIds);
+ return builder.build();
+ }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignFaultyShuffleServerResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignFaultyShuffleServerResponse.java
new file mode 100644
index 0000000..4c3b7c4
--- /dev/null
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssReassignFaultyShuffleServerResponse.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.response;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.proto.RssProtos;
+
+public class RssReassignFaultyShuffleServerResponse extends ClientResponse {
+
+ private ShuffleServerInfo shuffleServer;
+
+ public RssReassignFaultyShuffleServerResponse(
+ StatusCode statusCode, String message, ShuffleServerInfo shuffleServer) {
+ super(statusCode, message);
+ this.shuffleServer = shuffleServer;
+ }
+
+ public ShuffleServerInfo getShuffleServer() {
+ return shuffleServer;
+ }
+
+ public static RssReassignFaultyShuffleServerResponse fromProto(
+ RssProtos.RssReassignFaultyShuffleServerResponse response) {
+ return new RssReassignFaultyShuffleServerResponse(
+ StatusCode.valueOf(response.getStatus().name()),
+ response.getMsg(),
+ new ShuffleServerInfo(
+ response.getServer().getId(),
+ response.getServer().getIp(),
+ response.getServer().getPort(),
+ response.getServer().getNettyPort()));
+ }
+}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index aab38ef..932c15e 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -515,6 +515,8 @@
rpc reportShuffleWriteFailure (ReportShuffleWriteFailureRequest) returns (ReportShuffleWriteFailureResponse);
// Reassign the RPC interface of the ShuffleServer list
rpc reassignShuffleServers(ReassignServersRequest) returns (ReassignServersReponse);
+ // Reassign a new server instead a faulty server the RPC interface
+ rpc reassignFaultyShuffleServer(RssReassignFaultyShuffleServerRequest) returns (RssReassignFaultyShuffleServerResponse);
}
message ReportShuffleFetchFailureRequest {
@@ -577,3 +579,17 @@
bool needReassign = 2;
string msg = 3;
}
+
+message RssReassignFaultyShuffleServerRequest{
+ int32 shuffleId = 1;
+ repeated string partitionIds = 2;
+ string faultyShuffleServerId = 3;
+}
+
+message RssReassignFaultyShuffleServerResponse{
+ StatusCode status = 1;
+ ShuffleServerId server = 2;
+ string msg = 3;
+}
+
+