blob: aaff4cb8e038053371a910df451f6f4420237860 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.reader;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import scala.Option;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleReadMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
public class RssShuffleReaderTest extends AbstractRssReaderTest {
private static final Serializer KRYO_SERIALIZER = new KryoSerializer(new SparkConf(false));
@Test
public void readTest() throws Exception {
ShuffleServerInfo ssi = new ShuffleServerInfo("127.0.0.1", 0);
String basePath = HDFS_URI + "readTest1";
HadoopShuffleWriteHandler writeHandler =
new HadoopShuffleWriteHandler("appId", 0, 0, 0, basePath, ssi.getId(), conf);
final HadoopShuffleWriteHandler writeHandler1 =
new HadoopShuffleWriteHandler("appId", 0, 1, 1, basePath, ssi.getId(), conf);
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
final Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
Map<String, String> expectedData = Maps.newHashMap();
final Roaring64NavigableMap blockIdBitmap1 = Roaring64NavigableMap.bitmapOf();
writeTestData(writeHandler, 2, 5, expectedData, blockIdBitmap, "key", KRYO_SERIALIZER, 0);
RssShuffleHandle<String, String, String> handleMock = mock(RssShuffleHandle.class);
ShuffleDependency<String, String, String> dependencyMock = mock(ShuffleDependency.class);
when(handleMock.getAppId()).thenReturn("appId");
when(handleMock.getDependency()).thenReturn(dependencyMock);
when(handleMock.getShuffleId()).thenReturn(1);
when(handleMock.getNumMaps()).thenReturn(1);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
partitionToServers.put(0, Lists.newArrayList(ssi));
partitionToServers.put(1, Lists.newArrayList(ssi));
when(handleMock.getPartitionToServers()).thenReturn(partitionToServers);
when(dependencyMock.serializer()).thenReturn(KRYO_SERIALIZER);
TaskContext contextMock = mock(TaskContext.class);
when(contextMock.attemptNumber()).thenReturn(1);
when(contextMock.taskAttemptId()).thenReturn(1L);
when(contextMock.taskMetrics()).thenReturn(new TaskMetrics());
doNothing().when(contextMock).killTaskIfInterrupted();
when(dependencyMock.aggregator()).thenReturn(Option.empty());
when(dependencyMock.keyOrdering()).thenReturn(Option.empty());
when(dependencyMock.mapSideCombine()).thenReturn(false);
Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks = Maps.newHashMap();
partitionToExpectBlocks.put(0, blockIdBitmap);
RssConf rssConf = new RssConf();
rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
RssShuffleReader<String, String> rssShuffleReaderSpy =
spy(
new RssShuffleReader<>(
0,
1,
0,
Integer.MAX_VALUE,
contextMock,
handleMock,
basePath,
conf,
1,
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
writeTestData(
writeHandler1, 2, 4, expectedData, blockIdBitmap1, "another_key", KRYO_SERIALIZER, 1);
partitionToExpectBlocks.put(1, blockIdBitmap1);
RssShuffleReader<String, String> rssShuffleReaderSpy1 =
spy(
new RssShuffleReader<>(
0,
2,
0,
Integer.MAX_VALUE,
contextMock,
handleMock,
basePath,
conf,
2,
partitionToExpectBlocks,
taskIdBitmap,
new ShuffleReadMetrics(),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
validateResult(rssShuffleReaderSpy1.read(), expectedData, 18);
RssShuffleReader<String, String> rssShuffleReaderSpy2 =
spy(
new RssShuffleReader<>(
0,
2,
0,
Integer.MAX_VALUE,
contextMock,
handleMock,
basePath,
conf,
2,
partitionToExpectBlocks,
Roaring64NavigableMap.bitmapOf(),
new ShuffleReadMetrics(),
rssConf,
ShuffleDataDistributionType.NORMAL,
partitionToServers));
validateResult(rssShuffleReaderSpy2.read(), Maps.newHashMap(), 0);
}
}