blob: 8414cd0b9b55bf036c1ad2dae01664b23d6e70a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.test;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import scala.Option;
import com.google.common.collect.Maps;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.ServerType;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RssShuffleManagerTest extends SparkIntegrationTestBase {
@BeforeAll
public static void beforeAll() throws Exception {
shutdownServers();
}
@AfterEach
public void after() throws Exception {
shutdownServers();
}
public static Map<String, String> startServers(BlockIdLayout dynamicConfLayout) throws Exception {
Map<String, String> dynamicConf = Maps.newHashMap();
dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test");
dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name());
// configure block id layout (if set)
if (dynamicConfLayout != null) {
dynamicConf.put(
RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
String.valueOf(dynamicConfLayout.sequenceNoBits));
dynamicConf.put(
RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
String.valueOf(dynamicConfLayout.partitionIdBits));
dynamicConf.put(
RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
String.valueOf(dynamicConfLayout.taskAttemptIdBits));
}
CoordinatorConf coordinatorConf = getCoordinatorConf();
addDynamicConf(coordinatorConf, dynamicConf);
createCoordinatorServer(coordinatorConf);
createShuffleServer(getShuffleServerConf(ServerType.GRPC));
startServers();
return dynamicConf;
}
@Override
Map runTest(SparkSession spark, String fileName) throws Exception {
// we don't need to run any spark application here, just return an empty map.
return new HashMap();
}
private static final BlockIdLayout DEFAULT = BlockIdLayout.from(21, 20, 22);
private static final BlockIdLayout CUSTOM1 = BlockIdLayout.from(20, 21, 22);
private static final BlockIdLayout CUSTOM2 = BlockIdLayout.from(22, 18, 23);
public static Stream<Arguments> testBlockIdLayouts() {
return Stream.of(Arguments.of(DEFAULT), Arguments.of(CUSTOM1), Arguments.of(CUSTOM2));
}
@ParameterizedTest
@ValueSource(booleans = {false, true})
public void testRssShuffleManager(boolean enableDynamicClientConf) throws Exception {
doTestRssShuffleManager(null, null, DEFAULT, enableDynamicClientConf);
}
@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void testRssShuffleManagerClientConf(BlockIdLayout layout) throws Exception {
doTestRssShuffleManager(layout, null, layout, true);
}
@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void testRssShuffleManagerDynamicClientConf(BlockIdLayout layout) throws Exception {
doTestRssShuffleManager(null, layout, layout, true);
}
@ParameterizedTest
@ValueSource(booleans = {false, true})
public void testRssShuffleManagerClientConfOverride(boolean enableDynamicClientConf)
throws Exception {
doTestRssShuffleManager(CUSTOM1, CUSTOM2, CUSTOM1, enableDynamicClientConf);
}
private void doTestRssShuffleManager(
BlockIdLayout clientConfLayout,
BlockIdLayout dynamicConfLayout,
BlockIdLayout expectedLayout,
boolean enableDynamicCLientConf)
throws Exception {
Map<String, String> dynamicConf = startServers(dynamicConfLayout);
SparkConf conf = createSparkConf();
updateSparkConfWithRss(conf);
// enable stage recompute
conf.set("spark." + RssClientConfig.RSS_RESUBMIT_STAGE, "true");
// enable dynamic client conf
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED, enableDynamicCLientConf);
// configure storage type
conf.set("spark." + RssClientConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE.name());
// restarting the coordinator may cause RssException: There isn't enough shuffle servers
// retry quickly (default is 65s interval)
conf.set("spark." + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL, "1000");
conf.set("spark." + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES, "10");
// configure client conf block id layout (if set)
if (clientConfLayout != null) {
conf.set(
"spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
String.valueOf(clientConfLayout.sequenceNoBits));
conf.set(
"spark." + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
String.valueOf(clientConfLayout.partitionIdBits));
conf.set(
"spark." + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
String.valueOf(clientConfLayout.taskAttemptIdBits));
}
JavaSparkContext sc = null;
try {
Option<SparkSession> spark = SparkSession.getActiveSession();
if (spark.nonEmpty()) {
spark.get().stop();
}
sc = new JavaSparkContext(SparkSession.builder().config(conf).getOrCreate().sparkContext());
// create a rdd that triggers shuffle registration
long count = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).groupBy(x -> x).count();
assertEquals(5, count);
// configure expected block id layout
conf.set(
"spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
String.valueOf(expectedLayout.sequenceNoBits));
conf.set(
"spark." + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
String.valueOf(expectedLayout.partitionIdBits));
conf.set(
"spark." + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
String.valueOf(expectedLayout.taskAttemptIdBits));
// get written block ids (we know there is one shuffle where two task attempts wrote two
// partitions)
RssConf rssConf = RssSparkConfig.toRssConf(conf);
if (clientConfLayout == null && dynamicConfLayout != null) {
RssSparkShuffleUtils.applyDynamicClientConf(conf, dynamicConf);
}
RssShuffleManagerBase shuffleManager =
(RssShuffleManagerBase) SparkEnv.get().shuffleManager();
shuffleManager.configureBlockIdLayout(conf, rssConf);
ShuffleWriteClient shuffleWriteClient =
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(2000)
.heartBeatThreadNum(4)
.replica(1)
.replicaWrite(1)
.replicaRead(1)
.replicaSkipEnabled(true)
.dataTransferPoolSize(1)
.dataCommitPoolSize(1)
.unregisterThreadPoolSize(10)
.unregisterRequestTimeSec(10)
.rssConf(rssConf)
.build();
ShuffleHandleInfo handle = shuffleManager.getShuffleHandleInfoByShuffleId(0);
Set<ShuffleServerInfo> servers =
handle.getPartitionToServers().values().stream()
.flatMap(Collection::stream)
.collect(Collectors.toSet());
for (int partitionId : new int[] {0, 1}) {
Roaring64NavigableMap blockIdLongs =
shuffleWriteClient.getShuffleResult(
ClientType.GRPC.name(), servers, shuffleManager.getAppId(), 0, partitionId);
List<BlockId> blockIds =
blockIdLongs.stream()
.sorted()
.mapToObj(expectedLayout::asBlockId)
.collect(Collectors.toList());
assertEquals(2, blockIds.size());
long taskAttemptId0 = shuffleManager.getTaskAttemptIdForBlockId(0, 0);
long taskAttemptId1 = shuffleManager.getTaskAttemptIdForBlockId(1, 0);
assertEquals(expectedLayout.asBlockId(0, partitionId, taskAttemptId0), blockIds.get(0));
assertEquals(expectedLayout.asBlockId(0, partitionId, taskAttemptId1), blockIds.get(1));
}
shuffleManager.unregisterAllMapOutput(0);
MapOutputTrackerMaster master = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
assertTrue(master.containsShuffle(0));
} finally {
if (sc != null) {
sc.stop();
}
}
}
}