blob: 5c9f401b8afa0dbdffa4691bde053ded2446fef0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.mapred;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.io.serializer.Serializer;
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
public class SortWriteBufferManagerTest {
@Test
public void testWriteException() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
MockShuffleWriteClient client = new MockShuffleWriteClient();
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
successBlocks,
failedBlocks,
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
1000,
true,
5,
0.2f,
1024000L,
new RssConf(),
null);
// case 1
Random random = new Random();
partitionToServers.put(1, Lists.newArrayList(mock(ShuffleServerInfo.class)));
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
manager.addRecord(1, new BytesWritable(key), new BytesWritable(value));
}
RssException rssException = assertThrows(RssException.class, manager::waitSendFinished);
assertTrue(rssException.getMessage().contains("Timeout"));
// case 2
client.setMode(1);
SortWriteBufferManager<BytesWritable, BytesWritable> finalManager = manager;
rssException =
assertThrows(
RssException.class,
() -> {
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
finalManager.addRecord(1, new BytesWritable(key), new BytesWritable(value));
}
});
assertFalse(failedBlocks.isEmpty());
assertTrue(rssException.getMessage().contains("Send failed"));
rssException = assertThrows(RssException.class, finalManager::waitSendFinished);
assertTrue(rssException.getMessage().contains("Send failed"));
// case 3
client.setMode(0);
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
100,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
Sets.newConcurrentHashSet(),
Sets.newConcurrentHashSet(),
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
1000,
true,
5,
0.2f,
1024000L,
new RssConf(),
null);
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
SortWriteBufferManager<BytesWritable, BytesWritable> finalManager1 = manager;
rssException =
assertThrows(
RssException.class,
() -> finalManager1.addRecord(1, new BytesWritable(key), new BytesWritable(value)));
assertTrue(rssException.getMessage().contains("too big"));
}
@Test
public void testOnePartition() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
MockShuffleWriteClient client = new MockShuffleWriteClient();
client.setMode(2);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
successBlocks,
failedBlocks,
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
2000,
true,
5,
0.2f,
100L,
new RssConf(),
null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
assertTrue(manager.getWaitSendBuffers().isEmpty());
}
}
@Test
public void testWriteNormal() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
MockShuffleWriteClient client = new MockShuffleWriteClient();
client.setMode(2);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
successBlocks,
failedBlocks,
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
2000,
true,
5,
0.2f,
1024000L,
new RssConf(),
null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
manager.waitSendFinished();
assertTrue(manager.getWaitSendBuffers().isEmpty());
for (int i = 0; i < 14; i++) {
byte[] key = new byte[20];
byte[] value = new byte[i * 100];
random.nextBytes(key);
random.nextBytes(value);
manager.addRecord(i, new BytesWritable(key), new BytesWritable(value));
}
assertEquals(4, manager.getWaitSendBuffers().size());
for (int i = 0; i < 4; i++) {
int dataLength = manager.getWaitSendBuffers().get(i).getDataLength();
assertEquals((3 - i) * 100 + 28, dataLength);
}
manager.waitSendFinished();
assertTrue(manager.getWaitSendBuffers().isEmpty());
}
@Test
public void testCommitBlocksWhenMemoryShuffleDisabled() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
MockShuffleWriteClient client = new MockShuffleWriteClient();
client.setMode(3);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager =
new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
successBlocks,
failedBlocks,
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
1,
false,
5,
0.2f,
1024000L,
new RssConf(),
null);
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
partitionToServers.put(partitionId, Lists.newArrayList(mock(ShuffleServerInfo.class)));
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
manager.waitSendFinished();
assertTrue(manager.getWaitSendBuffers().isEmpty());
// When MEMOEY storage type is disable, all blocks should flush.
assertEquals(
client.mockedShuffleServer.getFinishBlockSize(),
client.mockedShuffleServer.getFlushBlockSize());
}
@Test
public void testCombineBuffer() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
jobConf.setOutputKeyClass(Text.class);
jobConf.setOutputValueClass(IntWritable.class);
jobConf.setCombinerClass(Reduce.class);
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
Serializer<Text> keySerializer = serializationFactory.getSerializer(Text.class);
Serializer<IntWritable> valueSerializer = serializationFactory.getSerializer(IntWritable.class);
WritableComparator comparator = WritableComparator.get(Text.class);
Task.TaskReporter reporter = mock(Task.TaskReporter.class);
final Counters.Counter combineInputCounter = new Counters.Counter();
Task.CombinerRunner<Text, IntWritable> combinerRunner =
Task.CombinerRunner.create(
jobConf, new TaskAttemptID(), combineInputCounter, reporter, null);
SortWriteBuffer<Text, IntWritable> buffer =
new SortWriteBuffer<Text, IntWritable>(1, comparator, 3072, keySerializer, valueSerializer);
List<String> wordTable =
Lists.newArrayList(
"apple", "banana", "fruit", "cherry", "Chinese", "America", "Japan", "tomato");
Random random = new Random();
for (int i = 0; i < 8; i++) {
buffer.addRecord(new Text(wordTable.get(i)), new IntWritable(1));
}
for (int i = 0; i < 10000; i++) {
int index = random.nextInt(wordTable.size());
buffer.addRecord(new Text(wordTable.get(index)), new IntWritable(1));
}
SortWriteBufferManager<Text, IntWritable> manager =
new SortWriteBufferManager<Text, IntWritable>(
10240,
1L,
10,
keySerializer,
valueSerializer,
comparator,
0.9,
"test",
null,
500,
5 * 1000,
null,
null,
null,
null,
null,
1,
100,
1,
true,
5,
0.2f,
1024000L,
new RssConf(),
combinerRunner);
buffer.sort();
SortWriteBuffer<Text, IntWritable> newBuffer = manager.combineBuffer(buffer);
RawKeyValueIterator kvIterator1 = new SortWriteBuffer.SortBufferIterator<>(buffer);
RawKeyValueIterator kvIterator2 = new SortWriteBuffer.SortBufferIterator<>(newBuffer);
int count1 = 0;
while (kvIterator1.next()) {
count1++;
}
int count2 = 0;
while (kvIterator2.next()) {
count2++;
}
assertEquals(10008, count1);
assertEquals(8, count2);
}
class MockShuffleServer {
// All methods of MockShuffle are thread safe, because send-thread may do something in
// concurrent way.
private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
private List<Long> finishedBlockInfos = new ArrayList<>();
public synchronized void finishShuffle() {
flushBlockInfos.addAll(cachedBlockInfos);
}
public synchronized void addCachedBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
cachedBlockInfos.addAll(shuffleBlockInfoList);
}
public synchronized void addFinishedBlockInfos(List<Long> shuffleBlockInfoList) {
finishedBlockInfos.addAll(shuffleBlockInfoList);
}
public synchronized int getFlushBlockSize() {
return flushBlockInfos.size();
}
public synchronized int getFinishBlockSize() {
return finishedBlockInfos.size();
}
}
class MockShuffleWriteClient implements ShuffleWriteClient {
int mode = 0;
MockShuffleServer mockedShuffleServer = new MockShuffleServer();
int committedMaps = 0;
public void setMode(int mode) {
this.mode = mode;
}
@Override
public SendShuffleDataResult sendShuffleData(
String appId,
List<ShuffleBlockInfo> shuffleBlockInfoList,
Supplier<Boolean> needCancelRequest) {
if (mode == 0) {
throw new RssException("send data failed");
} else if (mode == 1) {
FailedBlockSendTracker failedBlockSendTracker = new FailedBlockSendTracker();
ShuffleBlockInfo failedBlock =
new ShuffleBlockInfo(1, 1, 3, 1, 1, new byte[1], null, 1, 100, 1);
failedBlockSendTracker.add(
failedBlock, new ShuffleServerInfo("host", 39998), StatusCode.NO_BUFFER);
return new SendShuffleDataResult(Sets.newHashSet(2L), failedBlockSendTracker);
} else {
if (mode == 3) {
try {
Thread.sleep(10);
mockedShuffleServer.addCachedBlockInfos(shuffleBlockInfoList);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RssException(e);
}
}
Set<Long> successBlockIds = Sets.newHashSet();
for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
successBlockIds.add(blockInfo.getBlockId());
}
return new SendShuffleDataResult(successBlockIds, new FailedBlockSendTracker());
}
}
@Override
public void sendAppHeartbeat(String appId, long timeoutMs) {}
@Override
public void registerApplicationInfo(String appId, long timeoutMs, String user) {}
@Override
public void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
String appId,
int shuffleId,
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType distributionType,
int maxConcurrencyPerPartitionToWrite) {}
@Override
public boolean sendCommit(
Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int shuffleId, int numMaps) {
if (mode == 3) {
committedMaps++;
if (committedMaps >= numMaps) {
mockedShuffleServer.finishShuffle();
}
return true;
}
return false;
}
@Override
public void registerCoordinators(String coordinators) {}
@Override
public Map<String, String> fetchClientConf(int timeoutMs) {
return null;
}
@Override
public RemoteStorageInfo fetchRemoteStorage(String appId) {
return null;
}
@Override
public void reportShuffleResult(
Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds,
String appId,
int shuffleId,
long taskAttemptId,
int bitmapNum) {
if (mode == 3) {
serverToPartitionToBlockIds
.values()
.forEach(
partitionToBlockIds -> {
mockedShuffleServer.addFinishedBlockInfos(
partitionToBlockIds.values().stream()
.flatMap(it -> it.stream())
.collect(Collectors.toList()));
});
}
}
@Override
public ShuffleAssignmentsInfo getShuffleAssignments(
String appId,
int shuffleId,
int partitionNum,
int partitionNumPerRange,
Set<String> requiredTags,
int assignmentShuffleServerNumber,
int estimateTaskConcurrency) {
return null;
}
@Override
public Roaring64NavigableMap getShuffleResult(
String clientType,
Set<ShuffleServerInfo> shuffleServerInfoSet,
String appId,
int shuffleId,
int partitionId) {
return null;
}
@Override
public Roaring64NavigableMap getShuffleResultForMultiPart(
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
String appId,
int shuffleId,
Set<Integer> failedPartitions,
PartitionDataReplicaRequirementTracking tracking) {
return null;
}
@Override
public void close() {}
@Override
public void unregisterShuffle(String appId, int shuffleId) {}
@Override
public void unregisterShuffle(String appId) {}
}
static class Reduce extends MapReduceBase
implements Reducer<Text, IntWritable, Text, IntWritable> {
Reduce() {}
public void reduce(
Text key,
Iterator<IntWritable> values,
OutputCollector<Text, IntWritable> output,
Reporter reporter)
throws IOException {
int sum = 0;
while (values.hasNext()) {
sum += values.next().get();
}
output.collect(key, new IntWritable(sum));
}
}
}