blob: 87b9abd51db7bb3bf9ce093cad7e3156ee0a63a9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.server;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.proto.RssProtos;
public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService {
private static final Logger LOG = LoggerFactory.getLogger(MockedShuffleServerGrpcService.class);
// appId -> shuffleId -> partitionRequestNum
private Map<String, Map<Integer, AtomicInteger>> appToPartitionRequest =
JavaUtils.newConcurrentMap();
private long mockedTimeout = -1L;
private boolean mockSendDataFailed = false;
private boolean recordGetShuffleResult = false;
private long numOfFailedReadRequest = 0;
private AtomicInteger failedGetShuffleResultRequest = new AtomicInteger(0);
private AtomicInteger failedGetShuffleResultForMultiPartRequest = new AtomicInteger(0);
private AtomicInteger failedGetMemoryShuffleDataRequest = new AtomicInteger(0);
private AtomicInteger failedGetLocalShuffleDataRequest = new AtomicInteger(0);
private AtomicInteger failedGetLocalShuffleIndexRequest = new AtomicInteger(0);
public void enableMockedTimeout(long timeout) {
mockedTimeout = timeout;
}
public void enableMockSendDataFailed(boolean mockSendDataFailed) {
this.mockSendDataFailed = mockSendDataFailed;
}
public void enableRecordGetShuffleResult() {
recordGetShuffleResult = true;
}
public void disableMockedTimeout() {
mockedTimeout = -1;
}
public void enableFirstNReadRequestToFail(int n) {
numOfFailedReadRequest = n;
}
public void resetFirstNReadRequestToFail() {
numOfFailedReadRequest = 0;
failedGetShuffleResultRequest.set(0);
failedGetShuffleResultForMultiPartRequest.set(0);
failedGetMemoryShuffleDataRequest.set(0);
failedGetLocalShuffleDataRequest.set(0);
failedGetLocalShuffleIndexRequest.set(0);
}
public MockedShuffleServerGrpcService(ShuffleServer shuffleServer) {
super(shuffleServer);
}
@Override
public void sendShuffleData(
RssProtos.SendShuffleDataRequest request,
StreamObserver<RssProtos.SendShuffleDataResponse> responseObserver) {
if (mockSendDataFailed) {
LOG.info("Add a mocked sendData failed on sendShuffleData");
throw new RuntimeException("This write request is failed as mocked failureļ¼");
}
if (mockedTimeout > 0) {
LOG.info("Add a mocked timeout on sendShuffleData");
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
}
super.sendShuffleData(request, responseObserver);
}
@Override
public void reportShuffleResult(
RssProtos.ReportShuffleResultRequest request,
StreamObserver<RssProtos.ReportShuffleResultResponse> responseObserver) {
if (mockedTimeout > 0) {
LOG.info("Add a mocked timeout on reportShuffleResult");
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
}
super.reportShuffleResult(request, responseObserver);
}
@Override
public void getShuffleResult(
RssProtos.GetShuffleResultRequest request,
StreamObserver<RssProtos.GetShuffleResultResponse> responseObserver) {
if (mockedTimeout > 0) {
LOG.info("Add a mocked timeout on getShuffleResult");
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
}
if (numOfFailedReadRequest > 0) {
int currentFailedReadRequest = failedGetShuffleResultRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
currentFailedReadRequest,
numOfFailedReadRequest);
throw new RuntimeException("This request is failed as mocked failure");
}
}
super.getShuffleResult(request, responseObserver);
}
@Override
public void getShuffleResultForMultiPart(
RssProtos.GetShuffleResultForMultiPartRequest request,
StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> responseObserver) {
if (mockedTimeout > 0) {
LOG.info("Add a mocked timeout on getShuffleResultForMultiPart");
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
}
if (numOfFailedReadRequest > 0) {
int currentFailedReadRequest = failedGetShuffleResultForMultiPartRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
currentFailedReadRequest,
numOfFailedReadRequest);
throw new RuntimeException("This request is failed as mocked failure");
}
}
if (recordGetShuffleResult) {
List<Integer> requestPartitions = request.getPartitionsList();
Map<Integer, AtomicInteger> shuffleIdToPartitionRequestNum =
appToPartitionRequest.computeIfAbsent(
request.getAppId(), x -> JavaUtils.newConcurrentMap());
AtomicInteger partitionRequestNum =
shuffleIdToPartitionRequestNum.computeIfAbsent(
request.getShuffleId(), x -> new AtomicInteger(0));
partitionRequestNum.addAndGet(requestPartitions.size());
}
super.getShuffleResultForMultiPart(request, responseObserver);
}
public Map<String, Map<Integer, AtomicInteger>> getShuffleIdToPartitionRequest() {
return appToPartitionRequest;
}
@Override
public void getMemoryShuffleData(
RssProtos.GetMemoryShuffleDataRequest request,
StreamObserver<RssProtos.GetMemoryShuffleDataResponse> responseObserver) {
if (numOfFailedReadRequest > 0) {
int currentFailedReadRequest = failedGetMemoryShuffleDataRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
currentFailedReadRequest,
numOfFailedReadRequest);
StatusCode status = StatusCode.NO_BUFFER;
String msg =
"Can't require memory to get in memory shuffle data (This request is failed as mocked failure)";
RssProtos.GetMemoryShuffleDataResponse reply =
RssProtos.GetMemoryShuffleDataResponse.newBuilder()
.setData(UnsafeByteOperations.unsafeWrap(new byte[] {}))
.addAllShuffleDataBlockSegments(Lists.newArrayList())
.setStatus(status.toProto())
.setRetMsg(msg)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
}
super.getMemoryShuffleData(request, responseObserver);
}
@Override
public void getLocalShuffleData(
RssProtos.GetLocalShuffleDataRequest request,
StreamObserver<RssProtos.GetLocalShuffleDataResponse> responseObserver) {
if (numOfFailedReadRequest > 0) {
int currentFailedReadRequest = failedGetLocalShuffleDataRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
currentFailedReadRequest,
numOfFailedReadRequest);
StatusCode status = StatusCode.NO_BUFFER;
String msg =
"Can't require memory to get shuffle data (This request is failed as mocked failure)";
RssProtos.GetLocalShuffleDataResponse reply =
RssProtos.GetLocalShuffleDataResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(msg)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
}
super.getLocalShuffleData(request, responseObserver);
}
@Override
public void getLocalShuffleIndex(
RssProtos.GetLocalShuffleIndexRequest request,
StreamObserver<RssProtos.GetLocalShuffleIndexResponse> responseObserver) {
if (numOfFailedReadRequest > 0) {
int currentFailedReadRequest = failedGetLocalShuffleIndexRequest.getAndIncrement();
if (currentFailedReadRequest < numOfFailedReadRequest) {
LOG.info(
"This request is failed as mocked failure, current/firstN: {}/{}",
currentFailedReadRequest,
numOfFailedReadRequest);
StatusCode status = StatusCode.NO_BUFFER;
String msg =
"Can't require memory to get shuffle index (This request is failed as mocked failure)";
RssProtos.GetLocalShuffleIndexResponse reply =
RssProtos.GetLocalShuffleIndexResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(msg)
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
}
super.getLocalShuffleIndex(request, responseObserver);
}
}