blob: 2f930a85342c4c7f17430ea4b0fe7a238c87e819 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.writer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import scala.Product2;
import scala.Tuple2;
import scala.collection.mutable.MutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.shuffle.TestUtils;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
public class RssShuffleWriterTest {
private MutableList<Product2<String, String>> createMockRecords() {
MutableList<Product2<String, String>> data = new MutableList<>();
data.appendElem(new Tuple2<>("testKey2", "testValue2"));
data.appendElem(new Tuple2<>("testKey3", "testValue3"));
data.appendElem(new Tuple2<>("testKey4", "testValue4"));
data.appendElem(new Tuple2<>("testKey6", "testValue6"));
data.appendElem(new Tuple2<>("testKey1", "testValue1"));
data.appendElem(new Tuple2<>("testKey5", "testValue5"));
return data;
}
@Test
public void blockFailureResendTest() throws Exception {
SparkConf conf = new SparkConf();
conf.setAppName("testApp")
.setMaster("local[2]")
.set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
.set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
Map<String, Set<Long>> successBlockIds = JavaUtils.newConcurrentMap();
Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
taskToFailedBlockSendTracker.put("taskId", new FailedBlockSendTracker());
AtomicInteger sentFailureCnt = new AtomicInteger();
FakedDataPusher dataPusher =
new FakedDataPusher(
event -> {
assertEquals("taskId", event.getTaskId());
FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(event.getTaskId());
for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
boolean isSuccessful = true;
ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0);
if (shuffleServer.getId().equals("id1") && block.getRetryCnt() == 0) {
tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
sentFailureCnt.addAndGet(1);
isSuccessful = false;
} else {
successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet());
successBlockIds.get(event.getTaskId()).add(block.getBlockId());
shuffleBlockInfos.add(block);
}
block.executeCompletionCallback(isSuccessful);
}
return new CompletableFuture<>();
});
final RssShuffleManager manager =
TestUtils.createShuffleManager(
conf, false, dataPusher, successBlockIds, taskToFailedBlockSendTracker);
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
when(mockDependency.serializer()).thenReturn(kryoSerializer);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
when(mockPartitioner.numPartitions()).thenReturn(3);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
List<ShuffleServerInfo> ssi12 =
Arrays.asList(
new ShuffleServerInfo("id1", "0.0.0.1", 100),
new ShuffleServerInfo("id2", "0.0.0.2", 100));
partitionToServers.put(0, ssi12);
List<ShuffleServerInfo> ssi34 =
Arrays.asList(
new ShuffleServerInfo("id3", "0.0.0.3", 100),
new ShuffleServerInfo("id4", "0.0.0.4", 100));
partitionToServers.put(1, ssi34);
List<ShuffleServerInfo> ssi56 =
Arrays.asList(
new ShuffleServerInfo("id5", "0.0.0.5", 100),
new ShuffleServerInfo("id6", "0.0.0.6", 100));
partitionToServers.put(2, ssi56);
when(mockPartitioner.getPartition("testKey1")).thenReturn(0);
when(mockPartitioner.getPartition("testKey2")).thenReturn(1);
when(mockPartitioner.getPartition("testKey4")).thenReturn(0);
when(mockPartitioner.getPartition("testKey5")).thenReturn(1);
when(mockPartitioner.getPartition("testKey3")).thenReturn(2);
when(mockPartitioner.getPartition("testKey7")).thenReturn(0);
when(mockPartitioner.getPartition("testKey8")).thenReturn(1);
when(mockPartitioner.getPartition("testKey9")).thenReturn(2);
when(mockPartitioner.getPartition("testKey6")).thenReturn(2);
TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
WriteBufferManager bufferManager =
new WriteBufferManager(
0,
0,
bufferOptions,
kryoSerializer,
partitionToServers,
mockTaskMemoryManager,
shuffleWriteMetrics,
RssSparkConfig.toRssConf(conf));
bufferManager.setTaskId("taskId");
WriteBufferManager bufferManagerSpy = spy(bufferManager);
TaskContext contextMock = mock(TaskContext.class);
ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
"appId",
0,
"taskId",
1L,
bufferManagerSpy,
shuffleWriteMetrics,
manager,
conf,
mockShuffleWriteClient,
mockHandle,
mockShuffleHandleInfo,
contextMock);
rssShuffleWriter.enableBlockFailSentRetry();
doReturn(100000L).when(bufferManagerSpy).acquireMemory(anyLong());
ShuffleServerInfo replacement = new ShuffleServerInfo("id10", "0.0.0.10", 100);
rssShuffleWriter.addReassignmentShuffleServer("id1", replacement);
RssShuffleWriter<String, String, String> rssShuffleWriterSpy = spy(rssShuffleWriter);
doNothing().when(rssShuffleWriterSpy).sendCommit();
// case 1. failed blocks will be resent
MutableList<Product2<String, String>> data = createMockRecords();
rssShuffleWriterSpy.write(data.iterator());
Awaitility.await()
.timeout(Duration.ofSeconds(5))
.until(() -> successBlockIds.get("taskId").size() == data.size());
assertEquals(2, sentFailureCnt.get());
assertEquals(0, taskToFailedBlockSendTracker.get("taskId").getFailedBlockIds().size());
assertEquals(6, shuffleWriteMetrics.recordsWritten());
assertEquals(
shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
shuffleWriteMetrics.bytesWritten());
assertEquals(6, shuffleBlockInfos.size());
assertEquals(0, bufferManagerSpy.getUsedBytes());
assertEquals(0, bufferManagerSpy.getInSendListBytes());
// check the blockId -> servers mapping.
// server -> partitionId -> blockIds
Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds =
rssShuffleWriterSpy.getServerToPartitionToBlockIds();
assertEquals(2, serverToPartitionToBlockIds.get(replacement).get(0).size());
// case2. If exceeding the max retry times, it will fast fail.
rssShuffleWriter.setBlockFailSentRetryMaxTimes(1);
rssShuffleWriter.setTaskId("taskId2");
rssShuffleWriter.getBufferManager().setTaskId("taskId2");
taskToFailedBlockSendTracker.put("taskId2", new FailedBlockSendTracker());
FakedDataPusher alwaysFailedDataPusher =
new FakedDataPusher(
event -> {
assertEquals("taskId2", event.getTaskId());
FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(event.getTaskId());
for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
boolean isSuccessful = true;
ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0);
if (shuffleServer.getId().equals("id1")) {
tracker.add(block, shuffleServer, StatusCode.NO_BUFFER);
isSuccessful = false;
} else {
successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet());
successBlockIds.get(event.getTaskId()).add(block.getBlockId());
}
block.executeCompletionCallback(isSuccessful);
}
return new CompletableFuture<>();
});
manager.setDataPusher(alwaysFailedDataPusher);
MutableList<Product2<String, String>> mockedData = createMockRecords();
try {
rssShuffleWriter.write(mockedData.iterator());
fail();
} catch (Exception e) {
// ignore
}
assertEquals(0, bufferManagerSpy.getUsedBytes());
assertEquals(0, bufferManagerSpy.getInSendListBytes());
}
@Test
public void checkBlockSendResultTest() {
SparkConf conf = new SparkConf();
conf.setAppName("testApp")
.setMaster("local[2]")
.set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
.set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000")
.set(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key(), "10")
.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name())
.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346");
Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
JavaUtils.newConcurrentMap();
Serializer kryoSerializer = new KryoSerializer(conf);
RssShuffleManager manager =
TestUtils.createShuffleManager(
conf, false, null, successBlocks, taskToFailedBlockSendTracker);
ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
Partitioner mockPartitioner = mock(Partitioner.class);
RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
when(mockPartitioner.numPartitions()).thenReturn(2);
TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
when(mockHandle.getPartitionToServers()).thenReturn(Maps.newHashMap());
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager =
new WriteBufferManager(
0,
0,
bufferOptions,
kryoSerializer,
Maps.newHashMap(),
mockTaskMemoryManager,
new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf));
WriteBufferManager bufferManagerSpy = spy(bufferManager);
TaskContext contextMock = mock(TaskContext.class);
ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
"appId",
0,
"taskId",
1L,
bufferManagerSpy,
(new TaskMetrics()).shuffleWriteMetrics(),
manager,
conf,
mockShuffleWriteClient,
mockHandle,
mockShuffleHandleInfo,
contextMock);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
// case 1: all blocks are sent successfully
successBlocks.put("taskId", Sets.newHashSet(1L, 2L, 3L));
rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L));
successBlocks.clear();
// case 2: partial blocks aren't sent before spark.rss.writer.send.check.timeout,
// Runtime exception will be thrown
successBlocks.put("taskId", Sets.newHashSet(1L, 2L));
Throwable e2 =
assertThrows(
RuntimeException.class,
() -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L)));
assertTrue(e2.getMessage().startsWith("Timeout:"));
successBlocks.clear();
// case 3: partial blocks are sent failed, Runtime exception will be thrown
successBlocks.put("taskId", Sets.newHashSet(1L, 2L));
FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
taskToFailedBlockSendTracker.put("taskId", failedBlockSendTracker);
ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo("127.0.0.1", 20001);
failedBlockSendTracker.add(
TestUtils.createMockBlockOnlyBlockId(3L), shuffleServerInfo, StatusCode.INTERNAL_ERROR);
Throwable e3 =
assertThrows(
RuntimeException.class,
() -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L)));
assertTrue(e3.getMessage().startsWith("Fail to send the block"));
successBlocks.clear();
taskToFailedBlockSendTracker.clear();
}
static class FakedDataPusher extends DataPusher {
private final Function<AddBlockEvent, CompletableFuture<Long>> sendFunc;
FakedDataPusher(Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) {
this(null, null, null, null, null, 1, 1, sendFunc);
}
private FakedDataPusher(
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Map<String, FailedBlockSendTracker> failedBlockSendTracker,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime,
Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) {
super(
shuffleWriteClient,
taskToSuccessBlockIds,
failedBlockSendTracker,
failedTaskIds,
threadPoolSize,
threadKeepAliveTime);
this.sendFunc = sendFunc;
}
@Override
public CompletableFuture<Long> send(AddBlockEvent event) {
return sendFunc.apply(event);
}
}
@Test
public void dataConsistencyWhenSpillTriggeredTest() throws Exception {
SparkConf conf = new SparkConf();
conf.set("spark.rss.client.memory.spill.enabled", "true");
conf.setAppName("dataConsistencyWhenSpillTriggeredTest_app")
.setMaster("local[2]")
.set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
.set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "32")
.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "100000")
.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY.name())
.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346");
Map<String, Set<Long>> successBlockIds = Maps.newConcurrentMap();
List<Long> freeMemoryList = new ArrayList<>();
FakedDataPusher dataPusher =
new FakedDataPusher(
event -> {
event.getProcessedCallbackChain().stream().forEach(x -> x.run());
long sum =
event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum();
freeMemoryList.add(sum);
successBlockIds.putIfAbsent(event.getTaskId(), new HashSet<>());
successBlockIds
.get(event.getTaskId())
.add(event.getShuffleDataInfoList().get(0).getBlockId());
return CompletableFuture.completedFuture(sum);
});
final RssShuffleManager manager =
TestUtils.createShuffleManager(
conf, false, dataPusher, successBlockIds, JavaUtils.newConcurrentMap());
WriteBufferManagerTest.FakedTaskMemoryManager fakedTaskMemoryManager =
new WriteBufferManagerTest.FakedTaskMemoryManager();
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newConcurrentMap();
partitionToServers.put(0, Lists.newArrayList(new ShuffleServerInfo("127.0.0.1", 1111)));
WriteBufferManager bufferManager =
new WriteBufferManager(
0,
"taskId",
0,
bufferOptions,
new KryoSerializer(conf),
partitionToServers,
fakedTaskMemoryManager,
new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf),
null);
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
when(mockDependency.serializer()).thenReturn(kryoSerializer);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
when(mockPartitioner.numPartitions()).thenReturn(1);
TaskContext contextMock = mock(TaskContext.class);
ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
"appId",
0,
"taskId",
1L,
bufferManager,
new ShuffleWriteMetrics(),
manager,
conf,
mockShuffleWriteClient,
mockHandle,
mockShuffleHandleInfo,
contextMock);
rssShuffleWriter.getBufferManager().setSpillFunc(rssShuffleWriter::processShuffleBlockInfos);
MutableList<Product2<String, String>> data = new MutableList<>();
// One record is 26 bytes
data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
data.appendElem(new Tuple2<>("Key", "Value11111111111111"));
// case1: all blocks are sent and pass the blocks check when spill is triggered
rssShuffleWriter.write(data.iterator());
assertEquals(4, successBlockIds.get("taskId").size());
for (int i = 0; i < 4; i++) {
assertEquals(32, freeMemoryList.get(i));
}
}
@Test
public void writeTest() throws Exception {
SparkConf conf = new SparkConf();
conf.setAppName("testApp")
.setMaster("local[2]")
.set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
.set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
.set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
.set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name())
.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346");
List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
Map<String, Set<Long>> successBlockIds = Maps.newConcurrentMap();
FakedDataPusher dataPusher =
new FakedDataPusher(
event -> {
assertEquals("taskId", event.getTaskId());
shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
Set<Long> blockIds =
event
.getShuffleDataInfoList()
.parallelStream()
.map(sdi -> sdi.getBlockId())
.collect(Collectors.toSet());
successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet());
successBlockIds.get(event.getTaskId()).addAll(blockIds);
return new CompletableFuture<>();
});
final RssShuffleManager manager =
TestUtils.createShuffleManager(
conf, false, dataPusher, successBlockIds, JavaUtils.newConcurrentMap());
Serializer kryoSerializer = new KryoSerializer(conf);
Partitioner mockPartitioner = mock(Partitioner.class);
final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
when(mockDependency.serializer()).thenReturn(kryoSerializer);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
when(mockPartitioner.numPartitions()).thenReturn(3);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
List<ShuffleServerInfo> ssi34 =
Arrays.asList(
new ShuffleServerInfo("id3", "0.0.0.3", 100),
new ShuffleServerInfo("id4", "0.0.0.4", 100));
partitionToServers.put(1, ssi34);
List<ShuffleServerInfo> ssi56 =
Arrays.asList(
new ShuffleServerInfo("id5", "0.0.0.5", 100),
new ShuffleServerInfo("id6", "0.0.0.6", 100));
partitionToServers.put(2, ssi56);
List<ShuffleServerInfo> ssi12 =
Arrays.asList(
new ShuffleServerInfo("id1", "0.0.0.1", 100),
new ShuffleServerInfo("id2", "0.0.0.2", 100));
partitionToServers.put(0, ssi12);
when(mockPartitioner.getPartition("testKey1")).thenReturn(0);
when(mockPartitioner.getPartition("testKey2")).thenReturn(1);
when(mockPartitioner.getPartition("testKey4")).thenReturn(0);
when(mockPartitioner.getPartition("testKey5")).thenReturn(1);
when(mockPartitioner.getPartition("testKey3")).thenReturn(2);
when(mockPartitioner.getPartition("testKey7")).thenReturn(0);
when(mockPartitioner.getPartition("testKey8")).thenReturn(1);
when(mockPartitioner.getPartition("testKey9")).thenReturn(2);
when(mockPartitioner.getPartition("testKey6")).thenReturn(2);
TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
WriteBufferManager bufferManager =
new WriteBufferManager(
0,
0,
bufferOptions,
kryoSerializer,
partitionToServers,
mockTaskMemoryManager,
shuffleWriteMetrics,
RssSparkConfig.toRssConf(conf));
bufferManager.setTaskId("taskId");
WriteBufferManager bufferManagerSpy = spy(bufferManager);
TaskContext contextMock = mock(TaskContext.class);
ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
RssShuffleWriter<String, String, String> rssShuffleWriter =
new RssShuffleWriter<>(
"appId",
0,
"taskId",
1L,
bufferManagerSpy,
shuffleWriteMetrics,
manager,
conf,
mockShuffleWriteClient,
mockHandle,
mockShuffleHandleInfo,
contextMock);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
RssShuffleWriter<String, String, String> rssShuffleWriterSpy = spy(rssShuffleWriter);
doNothing().when(rssShuffleWriterSpy).sendCommit();
// case 1
MutableList<Product2<String, String>> data = new MutableList<>();
data.appendElem(new Tuple2<>("testKey2", "testValue2"));
data.appendElem(new Tuple2<>("testKey3", "testValue3"));
data.appendElem(new Tuple2<>("testKey4", "testValue4"));
data.appendElem(new Tuple2<>("testKey6", "testValue6"));
data.appendElem(new Tuple2<>("testKey1", "testValue1"));
data.appendElem(new Tuple2<>("testKey5", "testValue5"));
rssShuffleWriterSpy.write(data.iterator());
assertTrue(shuffleWriteMetrics.writeTime() > 0);
assertEquals(6, shuffleWriteMetrics.recordsWritten());
assertEquals(
shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
shuffleWriteMetrics.bytesWritten());
assertEquals(6, shuffleBlockInfos.size());
for (ShuffleBlockInfo shuffleBlockInfo : shuffleBlockInfos) {
assertEquals(22, shuffleBlockInfo.getUncompressLength());
assertEquals(0, shuffleBlockInfo.getShuffleId());
if (shuffleBlockInfo.getPartitionId() == 0) {
assertEquals(shuffleBlockInfo.getShuffleServerInfos(), ssi12);
}
if (shuffleBlockInfo.getPartitionId() == 1) {
assertEquals(shuffleBlockInfo.getShuffleServerInfos(), ssi34);
}
if (shuffleBlockInfo.getPartitionId() == 2) {
assertEquals(shuffleBlockInfo.getShuffleServerInfos(), ssi56);
}
if (shuffleBlockInfo.getPartitionId() < 0 || shuffleBlockInfo.getPartitionId() > 2) {
throw new Exception("Shouldn't be here");
}
}
Map<Integer, Set<Long>> partitionToBlockIds = rssShuffleWriterSpy.getPartitionToBlockIds();
System.out.println(11111);
assertEquals(2, partitionToBlockIds.get(1).size());
assertEquals(2, partitionToBlockIds.get(0).size());
assertEquals(2, partitionToBlockIds.get(2).size());
partitionToBlockIds.clear();
}
@Test
public void postBlockEventTest() throws Exception {
SparkConf conf = new SparkConf();
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
+ RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(),
"64")
.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name());
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager =
new WriteBufferManager(
0,
0,
bufferOptions,
new KryoSerializer(conf),
Maps.newHashMap(),
mock(TaskMemoryManager.class),
new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf));
WriteBufferManager bufferManagerSpy = spy(bufferManager);
ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
Partitioner mockPartitioner = mock(Partitioner.class);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
SparkConf sparkConf = new SparkConf();
when(mockPartitioner.numPartitions()).thenReturn(2);
List<AddBlockEvent> events = Lists.newArrayList();
FakedDataPusher dataPusher =
new FakedDataPusher(
event -> {
events.add(event);
return new CompletableFuture<>();
});
RssShuffleManager mockShuffleManager =
spy(
TestUtils.createShuffleManager(
sparkConf,
false,
dataPusher,
Maps.newConcurrentMap(),
JavaUtils.newConcurrentMap()));
RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
TaskContext contextMock = mock(TaskContext.class);
ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class);
ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 31);
RssShuffleWriter<String, String, String> writer =
new RssShuffleWriter<>(
"appId",
0,
"taskId",
1L,
bufferManagerSpy,
mockMetrics,
mockShuffleManager,
conf,
mockWriteClient,
mockHandle,
mockShuffleHandleInfo,
contextMock);
writer.postBlockEvent(shuffleBlockInfoList);
Awaitility.await().timeout(Duration.ofSeconds(1)).until(() -> events.size() == 1);
assertEquals(1, events.get(0).getShuffleDataInfoList().size());
events.clear();
testSingleEvent(events, writer, 2, 15);
testSingleEvent(events, writer, 1, 33);
testSingleEvent(events, writer, 2, 16);
testSingleEvent(events, writer, 2, 15);
testSingleEvent(events, writer, 2, 17);
testSingleEvent(events, writer, 2, 32);
testTwoEvents(events, writer, 2, 33, 1, 1);
testTwoEvents(events, writer, 3, 17, 2, 1);
}
private void testTwoEvents(
List<AddBlockEvent> events,
RssShuffleWriter<String, String, String> writer,
int blockNum,
int blockLength,
int firstEventSize,
int secondEventSize)
throws InterruptedException {
List<ShuffleBlockInfo> shuffleBlockInfoList;
shuffleBlockInfoList = createShuffleBlockList(blockNum, blockLength);
writer.postBlockEvent(shuffleBlockInfoList);
Thread.sleep(500);
assertEquals(2, events.size());
assertEquals(firstEventSize, events.get(0).getShuffleDataInfoList().size());
assertEquals(secondEventSize, events.get(1).getShuffleDataInfoList().size());
events.clear();
}
private void testSingleEvent(
List<AddBlockEvent> events,
RssShuffleWriter<String, String, String> writer,
int blockNum,
int blockLength)
throws InterruptedException {
List<ShuffleBlockInfo> shuffleBlockInfoList;
shuffleBlockInfoList = createShuffleBlockList(blockNum, blockLength);
writer.postBlockEvent(shuffleBlockInfoList);
Thread.sleep(500);
assertEquals(1, events.size());
assertEquals(blockNum, events.get(0).getShuffleDataInfoList().size());
events.clear();
}
private List<ShuffleBlockInfo> createShuffleBlockList(int blockNum, int blockLength) {
List<ShuffleServerInfo> shuffleServerInfoList =
Lists.newArrayList(new ShuffleServerInfo("id", "host", 0));
List<ShuffleBlockInfo> shuffleBlockInfoList = Lists.newArrayList();
for (int i = 0; i < blockNum; i++) {
shuffleBlockInfoList.add(
new ShuffleBlockInfo(
0,
0,
10,
blockLength,
10,
new byte[] {1},
shuffleServerInfoList,
blockLength,
10,
0));
}
return shuffleBlockInfoList;
}
}