blob: 5aaf23a718a1fb19d74541e9423ec4ffaba01234 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.shuffle.manager;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc.ShuffleManagerImplBase;
import org.apache.uniffle.shuffle.BlockIdManager;
public class ShuffleManagerGrpcService extends ShuffleManagerImplBase {
private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class);
private final Map<Integer, RssShuffleStatus> shuffleStatus = JavaUtils.newConcurrentMap();
// The shuffleId mapping records the number of ShuffleServer write failures
private final Map<Integer, ShuffleServerFailureRecord> shuffleWrtieStatus =
JavaUtils.newConcurrentMap();
private final RssShuffleManagerInterface shuffleManager;
public ShuffleManagerGrpcService(RssShuffleManagerInterface shuffleManager) {
this.shuffleManager = shuffleManager;
}
@Override
public void reportShuffleWriteFailure(
RssProtos.ReportShuffleWriteFailureRequest request,
StreamObserver<RssProtos.ReportShuffleWriteFailureResponse> responseObserver) {
String appId = request.getAppId();
int shuffleId = request.getShuffleId();
int stageAttemptNumber = request.getStageAttemptNumber();
List<RssProtos.ShuffleServerId> shuffleServerIdsList = request.getShuffleServerIdsList();
RssProtos.StatusCode code;
boolean reSubmitWholeStage;
String msg;
if (!appId.equals(shuffleManager.getAppId())) {
msg =
String.format(
"got a wrong shuffle write failure report from appId: %s, expected appId: %s",
appId, shuffleManager.getAppId());
LOG.warn(msg);
code = RssProtos.StatusCode.INVALID_REQUEST;
reSubmitWholeStage = false;
} else {
Map<String, AtomicInteger> shuffleServerInfoIntegerMap = JavaUtils.newConcurrentMap();
List<ShuffleServerInfo> shuffleServerInfos =
ShuffleServerInfo.fromProto(shuffleServerIdsList);
shuffleServerInfos.forEach(
shuffleServerInfo -> {
shuffleServerInfoIntegerMap.put(shuffleServerInfo.getId(), new AtomicInteger(0));
});
ShuffleServerFailureRecord shuffleServerFailureRecord =
shuffleWrtieStatus.computeIfAbsent(
shuffleId,
key ->
new ShuffleServerFailureRecord(shuffleServerInfoIntegerMap, stageAttemptNumber));
boolean resetflag =
shuffleServerFailureRecord.resetStageAttemptIfNecessary(stageAttemptNumber);
if (resetflag) {
msg =
String.format(
"got an old stage(%d vs %d) shuffle write failure report, which should be impossible.",
shuffleServerFailureRecord.getStageAttempt(), stageAttemptNumber);
LOG.warn(msg);
code = RssProtos.StatusCode.INVALID_REQUEST;
reSubmitWholeStage = false;
} else {
code = RssProtos.StatusCode.SUCCESS;
// update the stage shuffleServer write failed count
boolean fetchFailureflag =
shuffleServerFailureRecord.incPartitionWriteFailure(
stageAttemptNumber, shuffleServerInfos, shuffleManager);
if (fetchFailureflag) {
reSubmitWholeStage = true;
msg =
String.format(
"report shuffle write failure as maximum number(%d) of shuffle write is occurred",
shuffleManager.getMaxFetchFailures());
} else {
reSubmitWholeStage = false;
msg = "don't report shuffle write failure";
}
}
}
RssProtos.ReportShuffleWriteFailureResponse reply =
RssProtos.ReportShuffleWriteFailureResponse.newBuilder()
.setStatus(code)
.setReSubmitWholeStage(reSubmitWholeStage)
.setMsg(msg)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@Override
public void reportShuffleFetchFailure(
RssProtos.ReportShuffleFetchFailureRequest request,
StreamObserver<RssProtos.ReportShuffleFetchFailureResponse> responseObserver) {
String appId = request.getAppId();
int stageAttempt = request.getStageAttemptId();
int partitionId = request.getPartitionId();
RssProtos.StatusCode code;
boolean reSubmitWholeStage;
String msg;
if (!appId.equals(shuffleManager.getAppId())) {
msg =
String.format(
"got a wrong shuffle fetch failure report from appId: %s, expected appId: %s",
appId, shuffleManager.getAppId());
LOG.warn(msg);
code = RssProtos.StatusCode.INVALID_REQUEST;
reSubmitWholeStage = false;
} else {
RssShuffleStatus status =
shuffleStatus.computeIfAbsent(
request.getShuffleId(),
key -> {
int partitionNum = shuffleManager.getPartitionNum(key);
return new RssShuffleStatus(partitionNum, stageAttempt);
});
int c = status.resetStageAttemptIfNecessary(stageAttempt);
if (c < 0) {
msg =
String.format(
"got an old stage(%d vs %d) shuffle fetch failure report, which should be impossible.",
status.getStageAttempt(), stageAttempt);
LOG.warn(msg);
code = RssProtos.StatusCode.INVALID_REQUEST;
reSubmitWholeStage = false;
} else { // update the stage partition fetch failure count
code = RssProtos.StatusCode.SUCCESS;
status.incPartitionFetchFailure(stageAttempt, partitionId);
int fetchFailureNum = status.getPartitionFetchFailureNum(stageAttempt, partitionId);
if (fetchFailureNum >= shuffleManager.getMaxFetchFailures()) {
reSubmitWholeStage = true;
msg =
String.format(
"report shuffle fetch failure as maximum number(%d) of shuffle fetch is occurred",
shuffleManager.getMaxFetchFailures());
} else {
reSubmitWholeStage = false;
msg = "don't report shuffle fetch failure";
}
}
}
RssProtos.ReportShuffleFetchFailureResponse reply =
RssProtos.ReportShuffleFetchFailureResponse.newBuilder()
.setStatus(code)
.setReSubmitWholeStage(reSubmitWholeStage)
.setMsg(msg)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@Override
public void getPartitionToShufflerServer(
RssProtos.PartitionToShuffleServerRequest request,
StreamObserver<RssProtos.PartitionToShuffleServerResponse> responseObserver) {
RssProtos.PartitionToShuffleServerResponse reply;
RssProtos.StatusCode code;
int shuffleId = request.getShuffleId();
MutableShuffleHandleInfo shuffleHandle =
(MutableShuffleHandleInfo) shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
if (shuffleHandle != null) {
code = RssProtos.StatusCode.SUCCESS;
reply =
RssProtos.PartitionToShuffleServerResponse.newBuilder()
.setStatus(code)
.setShuffleHandleInfo(MutableShuffleHandleInfo.toProto(shuffleHandle))
.build();
} else {
code = RssProtos.StatusCode.INVALID_REQUEST;
reply = RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).build();
}
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@Override
public void reassignShuffleServers(
RssProtos.ReassignServersRequest request,
StreamObserver<RssProtos.ReassignServersReponse> responseObserver) {
int stageId = request.getStageId();
int stageAttemptNumber = request.getStageAttemptNumber();
int shuffleId = request.getShuffleId();
int numPartitions = request.getNumPartitions();
boolean needReassign =
shuffleManager.reassignAllShuffleServersForWholeStage(
stageId, stageAttemptNumber, shuffleId, numPartitions);
RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
RssProtos.ReassignServersReponse reply =
RssProtos.ReassignServersReponse.newBuilder()
.setStatus(code)
.setNeedReassign(needReassign)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@Override
public void reassignOnBlockSendFailure(
org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureRequest request,
io.grpc.stub.StreamObserver<
org.apache.uniffle.proto.RssProtos.RssReassignOnBlockSendFailureResponse>
responseObserver) {
RssProtos.StatusCode code = RssProtos.StatusCode.INTERNAL_ERROR;
RssProtos.RssReassignOnBlockSendFailureResponse reply;
try {
MutableShuffleHandleInfo handle =
shuffleManager.reassignOnBlockSendFailure(
request.getShuffleId(),
request.getFailurePartitionToServerIdsMap().entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey, x -> ReceivingFailureServer.fromProto(x.getValue()))));
code = RssProtos.StatusCode.SUCCESS;
reply =
RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
.setStatus(code)
.setHandle(MutableShuffleHandleInfo.toProto(handle))
.build();
} catch (Exception e) {
LOG.error("Errors on reassigning when block send failure.", e);
reply =
RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder()
.setStatus(code)
.setMsg(e.getMessage())
.build();
}
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
/**
* Remove the no longer used shuffle id's rss shuffle status. This is called when ShuffleManager
* unregisters the corresponding shuffle id.
*
* @param shuffleId the shuffle id to unregister.
*/
public void unregisterShuffle(int shuffleId) {
shuffleStatus.remove(shuffleId);
}
private static class ShuffleServerFailureRecord {
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private final ReentrantReadWriteLock.ReadLock readLock = lock.readLock();
private final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock();
private final Map<String, AtomicInteger> shuffleServerFailureRecordCount;
private int stageAttemptNumber;
private ShuffleServerFailureRecord(
Map<String, AtomicInteger> shuffleServerFailureRecordCount, int stageAttemptNumber) {
this.shuffleServerFailureRecordCount = shuffleServerFailureRecordCount;
this.stageAttemptNumber = stageAttemptNumber;
}
private <T> T withReadLock(Supplier<T> fn) {
readLock.lock();
try {
return fn.get();
} finally {
readLock.unlock();
}
}
private <T> T withWriteLock(Supplier<T> fn) {
writeLock.lock();
try {
return fn.get();
} finally {
writeLock.unlock();
}
}
public int getStageAttempt() {
return withReadLock(() -> this.stageAttemptNumber);
}
public boolean resetStageAttemptIfNecessary(int stageAttemptNumber) {
return withWriteLock(
() -> {
if (this.stageAttemptNumber < stageAttemptNumber) {
// a new stage attempt is issued. Record the shuffleServer status of the Map should be
// clear and reset.
shuffleServerFailureRecordCount.clear();
this.stageAttemptNumber = stageAttemptNumber;
return false;
} else if (this.stageAttemptNumber > stageAttemptNumber) {
return true;
}
return false;
});
}
public boolean incPartitionWriteFailure(
int stageAttemptNumber,
List<ShuffleServerInfo> shuffleServerInfos,
RssShuffleManagerInterface shuffleManager) {
return withWriteLock(
() -> {
if (this.stageAttemptNumber != stageAttemptNumber) {
// do nothing here
return false;
}
shuffleServerInfos.forEach(
shuffleServerInfo -> {
shuffleServerFailureRecordCount
.computeIfAbsent(shuffleServerInfo.getId(), k -> new AtomicInteger())
.incrementAndGet();
});
List<Map.Entry<String, AtomicInteger>> list =
new ArrayList(shuffleServerFailureRecordCount.entrySet());
if (!list.isEmpty()) {
Collections.sort(list, (o1, o2) -> (o1.getValue().get() - o2.getValue().get()));
Map.Entry<String, AtomicInteger> shuffleServerInfoIntegerEntry = list.get(0);
if (shuffleServerInfoIntegerEntry.getValue().get()
> shuffleManager.getMaxFetchFailures()) {
shuffleManager.addFailuresShuffleServerInfos(
shuffleServerInfoIntegerEntry.getKey());
return true;
}
}
return false;
});
}
}
private static class RssShuffleStatus {
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private final ReentrantReadWriteLock.ReadLock readLock = lock.readLock();
private final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock();
private final int[] partitions;
private int stageAttempt;
private RssShuffleStatus(int partitionNum, int stageAttempt) {
this.stageAttempt = stageAttempt;
this.partitions = new int[partitionNum];
}
private <T> T withReadLock(Supplier<T> fn) {
readLock.lock();
try {
return fn.get();
} finally {
readLock.unlock();
}
}
private <T> T withWriteLock(Supplier<T> fn) {
writeLock.lock();
try {
return fn.get();
} finally {
writeLock.unlock();
}
}
// todo: maybe it's more performant to just use synchronized method here.
public int getStageAttempt() {
return withReadLock(() -> this.stageAttempt);
}
/**
* Check whether the input stage attempt is a new stage or not. If a new stage attempt is
* requested, reset partitions.
*
* @param stageAttempt the incoming stage attempt number
* @return 0 if stageAttempt == this.stageAttempt 1 if stageAttempt > this.stageAttempt -1 if
* stateAttempt < this.stageAttempt which means nothing happens
*/
public int resetStageAttemptIfNecessary(int stageAttempt) {
return withWriteLock(
() -> {
if (this.stageAttempt < stageAttempt) {
// a new stage attempt is issued. the partitions array should be clear and reset.
Arrays.fill(this.partitions, 0);
this.stageAttempt = stageAttempt;
return 1;
} else if (this.stageAttempt > stageAttempt) {
return -1;
}
return 0;
});
}
public void incPartitionFetchFailure(int stageAttempt, int partition) {
withWriteLock(
() -> {
if (this.stageAttempt != stageAttempt) {
// do nothing here
} else {
this.partitions[partition] = this.partitions[partition] + 1;
}
return null;
});
}
public int getPartitionFetchFailureNum(int stageAttempt, int partition) {
return withReadLock(
() -> {
if (this.stageAttempt != stageAttempt) {
return 0;
} else {
return this.partitions[partition];
}
});
}
}
@Override
public void getShuffleResult(
RssProtos.GetShuffleResultRequest request,
StreamObserver<RssProtos.GetShuffleResultResponse> responseObserver) {
String appId = request.getAppId();
if (!appId.equals(shuffleManager.getAppId())) {
RssProtos.GetShuffleResultResponse reply =
RssProtos.GetShuffleResultResponse.newBuilder()
.setStatus(RssProtos.StatusCode.ACCESS_DENIED)
.setRetMsg("Illegal appId: " + appId)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
BlockIdManager blockIdManager = shuffleManager.getBlockIdManager();
Roaring64NavigableMap blockIdBitmap = blockIdManager.get(shuffleId, partitionId);
RssProtos.GetShuffleResultResponse reply;
try {
byte[] serializeBitmap = RssUtils.serializeBitMap(blockIdBitmap);
reply =
RssProtos.GetShuffleResultResponse.newBuilder()
.setStatus(RssProtos.StatusCode.SUCCESS)
.setSerializedBitmap(UnsafeByteOperations.unsafeWrap(serializeBitmap))
.build();
} catch (Exception exception) {
LOG.error("Errors on getting the blockId bitmap.", exception);
reply =
RssProtos.GetShuffleResultResponse.newBuilder()
.setStatus(RssProtos.StatusCode.INTERNAL_ERROR)
.build();
}
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@Override
public void getShuffleResultForMultiPart(
RssProtos.GetShuffleResultForMultiPartRequest request,
StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> responseObserver) {
String appId = request.getAppId();
if (!appId.equals(shuffleManager.getAppId())) {
RssProtos.GetShuffleResultForMultiPartResponse reply =
RssProtos.GetShuffleResultForMultiPartResponse.newBuilder()
.setStatus(RssProtos.StatusCode.ACCESS_DENIED)
.setRetMsg("Illegal appId: " + appId)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
BlockIdManager blockIdManager = shuffleManager.getBlockIdManager();
int shuffleId = request.getShuffleId();
List<Integer> partitionIds = request.getPartitionsList();
Roaring64NavigableMap blockIdBitmapCollection = Roaring64NavigableMap.bitmapOf();
for (int partitionId : partitionIds) {
Roaring64NavigableMap blockIds = blockIdManager.get(shuffleId, partitionId);
blockIds.forEach(x -> blockIdBitmapCollection.add(x));
}
RssProtos.GetShuffleResultForMultiPartResponse reply;
try {
byte[] serializeBitmap = RssUtils.serializeBitMap(blockIdBitmapCollection);
reply =
RssProtos.GetShuffleResultForMultiPartResponse.newBuilder()
.setStatus(RssProtos.StatusCode.SUCCESS)
.setSerializedBitmap(UnsafeByteOperations.unsafeWrap(serializeBitmap))
.build();
} catch (Exception exception) {
LOG.error("Errors on getting the blockId bitmap.", exception);
reply =
RssProtos.GetShuffleResultForMultiPartResponse.newBuilder()
.setStatus(RssProtos.StatusCode.INTERNAL_ERROR)
.build();
}
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@Override
public void reportShuffleResult(
RssProtos.ReportShuffleResultRequest request,
StreamObserver<RssProtos.ReportShuffleResultResponse> responseObserver) {
String appId = request.getAppId();
if (!appId.equals(shuffleManager.getAppId())) {
RssProtos.ReportShuffleResultResponse reply =
RssProtos.ReportShuffleResultResponse.newBuilder()
.setStatus(RssProtos.StatusCode.ACCESS_DENIED)
.setRetMsg("Illegal appId: " + appId)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
BlockIdManager blockIdManager = shuffleManager.getBlockIdManager();
int shuffleId = request.getShuffleId();
for (RssProtos.PartitionToBlockIds partitionToBlockIds : request.getPartitionToBlockIdsList()) {
int partitionId = partitionToBlockIds.getPartitionId();
List<Long> blockIds = partitionToBlockIds.getBlockIdsList();
blockIdManager.add(shuffleId, partitionId, blockIds);
}
RssProtos.ReportShuffleResultResponse reply =
RssProtos.ReportShuffleResultResponse.newBuilder()
.setStatus(RssProtos.StatusCode.SUCCESS)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
}