blob: d26e969827ec62c53ad598f87337d9917150291c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.client.impl;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.response.RssGetShuffleResultResponse;
import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.netty.IOMode;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.RssUtils;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ShuffleWriteClientImplTest {
@Test
public void testAbandonEventWhenTaskFailed() {
ShuffleWriteClientImpl shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(1)
.replicaWrite(1)
.replicaRead(1)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.build();
ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
when(mockShuffleServerClient.sendShuffleData(any()))
.thenAnswer(
(Answer<String>)
invocation -> {
Thread.sleep(50000);
return "ABCD1234";
});
List<ShuffleServerInfo> shuffleServerInfoList =
Lists.newArrayList(new ShuffleServerInfo("id", "host", 0));
List<ShuffleBlockInfo> shuffleBlockInfoList =
Lists.newArrayList(
new ShuffleBlockInfo(
0, 0, 10, 10, 10, new byte[] {10}, shuffleServerInfoList, 10, 100, 0));
// It should directly exit and wont do rpc request.
Awaitility.await()
.timeout(1, TimeUnit.SECONDS)
.until(
() -> {
spyClient.sendShuffleData("appId", shuffleBlockInfoList, () -> true);
return true;
});
}
@Test
public void testSendData() {
ShuffleWriteClientImpl shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(1)
.replicaWrite(1)
.replicaRead(1)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.build();
ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
when(mockShuffleServerClient.sendShuffleData(any()))
.thenReturn(new RssSendShuffleDataResponse(StatusCode.NO_BUFFER));
List<ShuffleServerInfo> shuffleServerInfoList =
Lists.newArrayList(new ShuffleServerInfo("id", "host", 0));
List<ShuffleBlockInfo> shuffleBlockInfoList =
Lists.newArrayList(
new ShuffleBlockInfo(
0, 0, 10, 10, 10, new byte[] {10}, shuffleServerInfoList, 10, 100, 0));
SendShuffleDataResult result =
spyClient.sendShuffleData("appId", shuffleBlockInfoList, () -> false);
assertTrue(result.getFailedBlockIds().contains(10L));
}
@Test
public void testRegisterAndUnRegisterShuffleServer() {
ShuffleWriteClientImpl shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(1)
.replicaWrite(1)
.replicaRead(1)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.build();
String appId1 = "testRegisterAndUnRegisterShuffleServer-1";
String appId2 = "testRegisterAndUnRegisterShuffleServer-2";
ShuffleServerInfo server1 = new ShuffleServerInfo("host1-0", "host1", 0);
ShuffleServerInfo server2 = new ShuffleServerInfo("host2-0", "host2", 0);
ShuffleServerInfo server3 = new ShuffleServerInfo("host3-0", "host3", 0);
shuffleWriteClient.addShuffleServer(appId1, 0, server1);
shuffleWriteClient.addShuffleServer(appId1, 1, server2);
shuffleWriteClient.addShuffleServer(appId2, 1, server3);
assertEquals(2, shuffleWriteClient.getAllShuffleServers(appId1).size());
assertEquals(1, shuffleWriteClient.getAllShuffleServers(appId2).size());
shuffleWriteClient.addShuffleServer(appId1, 1, server1);
shuffleWriteClient.unregisterShuffle(appId1, 1);
assertEquals(1, shuffleWriteClient.getAllShuffleServers(appId1).size());
shuffleWriteClient.unregisterShuffle(appId1);
assertEquals(0, shuffleWriteClient.getAllShuffleServers(appId1).size());
shuffleWriteClient.addShuffleServer(appId2, 2, server1);
assertEquals(2, shuffleWriteClient.getAllShuffleServers(appId2).size());
shuffleWriteClient.unregisterShuffle(appId2);
assertEquals(0, shuffleWriteClient.getAllShuffleServers(appId2).size());
}
@Test
public void testSendDataWithDefectiveServers() {
ShuffleWriteClientImpl shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(3)
.replicaWrite(2)
.replicaRead(2)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.build();
ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
when(mockShuffleServerClient.sendShuffleData(any()))
.thenReturn(
new RssSendShuffleDataResponse(StatusCode.NO_BUFFER),
new RssSendShuffleDataResponse(StatusCode.SUCCESS),
new RssSendShuffleDataResponse(StatusCode.SUCCESS));
String appId = "testSendDataWithDefectiveServers_appId";
ShuffleServerInfo ssi1 = new ShuffleServerInfo("127.0.0.1", 0);
ShuffleServerInfo ssi2 = new ShuffleServerInfo("127.0.0.1", 1);
ShuffleServerInfo ssi3 = new ShuffleServerInfo("127.0.0.1", 2);
List<ShuffleServerInfo> shuffleServerInfoList = Lists.newArrayList(ssi1, ssi2, ssi3);
List<ShuffleBlockInfo> shuffleBlockInfoList =
Lists.newArrayList(
new ShuffleBlockInfo(
0, 0, 10, 10, 10, new byte[] {10}, shuffleServerInfoList, 10, 100, 0));
SendShuffleDataResult result =
spyClient.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
assertEquals(0, result.getFailedBlockIds().size());
// Send data for the second time, the first shuffle server will be moved to the last.
when(mockShuffleServerClient.sendShuffleData(any()))
.thenReturn(
new RssSendShuffleDataResponse(StatusCode.SUCCESS),
new RssSendShuffleDataResponse(StatusCode.SUCCESS));
List<ShuffleServerInfo> excludeServers = new ArrayList<>();
spyClient.genServerToBlocks(
shuffleBlockInfoList.get(0),
shuffleServerInfoList,
2,
excludeServers,
Maps.newHashMap(),
Maps.newHashMap(),
true);
assertEquals(2, excludeServers.size());
assertEquals(ssi2, excludeServers.get(0));
assertEquals(ssi3, excludeServers.get(1));
spyClient.genServerToBlocks(
shuffleBlockInfoList.get(0),
shuffleServerInfoList,
1,
excludeServers,
Maps.newHashMap(),
Maps.newHashMap(),
false);
assertEquals(3, excludeServers.size());
assertEquals(ssi1, excludeServers.get(2));
result = spyClient.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
assertEquals(0, result.getFailedBlockIds().size());
// Send data for the third time, the first server will be removed from the defectiveServers
// and the second server will be added to the defectiveServers.
when(mockShuffleServerClient.sendShuffleData(any()))
.thenReturn(
new RssSendShuffleDataResponse(StatusCode.NO_BUFFER),
new RssSendShuffleDataResponse(StatusCode.SUCCESS),
new RssSendShuffleDataResponse(StatusCode.SUCCESS));
List<ShuffleServerInfo> shuffleServerInfoList2 = Lists.newArrayList(ssi2, ssi1, ssi3);
List<ShuffleBlockInfo> shuffleBlockInfoList2 =
Lists.newArrayList(
new ShuffleBlockInfo(
0, 0, 10, 10, 10, new byte[] {10}, shuffleServerInfoList2, 10, 100, 0));
result = spyClient.sendShuffleData(appId, shuffleBlockInfoList2, () -> false);
assertEquals(0, result.getFailedBlockIds().size());
assertEquals(1, spyClient.getDefectiveServers().size());
assertEquals(ssi2, spyClient.getDefectiveServers().toArray()[0]);
excludeServers = new ArrayList<>();
spyClient.genServerToBlocks(
shuffleBlockInfoList.get(0),
shuffleServerInfoList,
2,
excludeServers,
Maps.newHashMap(),
Maps.newHashMap(),
true);
assertEquals(2, excludeServers.size());
assertEquals(ssi1, excludeServers.get(0));
assertEquals(ssi3, excludeServers.get(1));
spyClient.genServerToBlocks(
shuffleBlockInfoList.get(0),
shuffleServerInfoList,
1,
excludeServers,
Maps.newHashMap(),
Maps.newHashMap(),
false);
assertEquals(3, excludeServers.size());
assertEquals(ssi2, excludeServers.get(2));
// Check whether it is normal when two shuffle servers in defectiveServers
spyClient.getDefectiveServers().add(ssi1);
assertEquals(2, spyClient.getDefectiveServers().size());
excludeServers = new ArrayList<>();
spyClient.genServerToBlocks(
shuffleBlockInfoList.get(0),
shuffleServerInfoList,
2,
excludeServers,
Maps.newHashMap(),
Maps.newHashMap(),
true);
assertEquals(2, excludeServers.size());
assertEquals(ssi3, excludeServers.get(0));
assertEquals(ssi1, excludeServers.get(1));
}
@Test
public void testSettingRssClientConfigs() {
RssConf rssConf = new RssConf();
rssConf.set(RssClientConf.NETTY_IO_MODE, IOMode.EPOLL);
ShuffleClientFactory.WriteClientBuilder writeClientBuilder =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC_NETTY.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(1)
.replicaWrite(1)
.replicaRead(1)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.rssConf(rssConf);
ShuffleWriteClientImpl client = writeClientBuilder.build();
IOMode ioMode = writeClientBuilder.getRssConf().get(RssClientConf.NETTY_IO_MODE);
client.close();
assertEquals(IOMode.EPOLL, ioMode);
}
public static Stream<Arguments> testBlockIdLayouts() {
return Stream.of(
Arguments.of(BlockIdLayout.DEFAULT), Arguments.of(BlockIdLayout.from(20, 21, 22)));
}
@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void testGetShuffleResult(BlockIdLayout layout) {
RssConf rssConf = new RssConf();
rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, layout.sequenceNoBits);
rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, layout.partitionIdBits);
rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, layout.taskAttemptIdBits);
ShuffleWriteClientImpl shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(1)
.replicaWrite(1)
.replicaRead(1)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.rssConf(rssConf)
.build();
ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
RssGetShuffleResultResponse response;
try {
Roaring64NavigableMap res = Roaring64NavigableMap.bitmapOf(1L, 2L, 5L);
response = new RssGetShuffleResultResponse(StatusCode.SUCCESS, RssUtils.serializeBitMap(res));
} catch (Exception e) {
throw new RssException(e);
}
when(mockShuffleServerClient.getShuffleResult(any())).thenReturn(response);
Set<ShuffleServerInfo> shuffleServerInfoSet =
Sets.newHashSet(new ShuffleServerInfo("id", "host", 0));
Roaring64NavigableMap result =
spyClient.getShuffleResult("GRPC", shuffleServerInfoSet, "appId", 1, 2);
verify(mockShuffleServerClient)
.getShuffleResult(argThat(request -> request.getBlockIdLayout().equals(layout)));
assertArrayEquals(result.stream().sorted().toArray(), new long[] {1L, 2L, 5L});
}
}