blob: 66b2c9a44b9924bf3f75c207be2b6dc751c01fea [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.storage.util.StorageType;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RssShuffleManagerTest extends RssShuffleManagerTestBase {
private static final String SPARK_ADAPTIVE_EXECUTION_ENABLED_KEY = "spark.sql.adaptive.enabled";
@Test
public void testGetDataDistributionType() {
// case1
SparkConf sparkConf = new SparkConf();
sparkConf.set(SPARK_ADAPTIVE_EXECUTION_ENABLED_KEY, "true");
assertEquals(
ShuffleDataDistributionType.LOCAL_ORDER,
RssShuffleManager.getDataDistributionType(sparkConf));
// case2
sparkConf = new SparkConf();
sparkConf.set(SPARK_ADAPTIVE_EXECUTION_ENABLED_KEY, "false");
assertEquals(
RssClientConf.DATA_DISTRIBUTION_TYPE.defaultValue(),
RssShuffleManager.getDataDistributionType(sparkConf));
// case3
sparkConf = new SparkConf();
sparkConf.set(SPARK_ADAPTIVE_EXECUTION_ENABLED_KEY, "true");
sparkConf.set(
"spark." + RssClientConf.DATA_DISTRIBUTION_TYPE.key(),
ShuffleDataDistributionType.NORMAL.name());
assertEquals(
ShuffleDataDistributionType.NORMAL, RssShuffleManager.getDataDistributionType(sparkConf));
// case4
sparkConf = new SparkConf();
sparkConf.set(SPARK_ADAPTIVE_EXECUTION_ENABLED_KEY, "true");
sparkConf.set(
"spark." + RssClientConf.DATA_DISTRIBUTION_TYPE.key(),
ShuffleDataDistributionType.LOCAL_ORDER.name());
assertEquals(
ShuffleDataDistributionType.LOCAL_ORDER,
RssShuffleManager.getDataDistributionType(sparkConf));
// case5
sparkConf = new SparkConf();
boolean aqeEnable = (boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED());
if (aqeEnable) {
assertEquals(
ShuffleDataDistributionType.LOCAL_ORDER,
RssShuffleManager.getDataDistributionType(sparkConf));
} else {
assertEquals(
RssClientConf.DATA_DISTRIBUTION_TYPE.defaultValue(),
RssShuffleManager.getDataDistributionType(sparkConf));
}
}
@Test
public void testCreateShuffleManagerServer() {
setupMockedRssShuffleUtils(StatusCode.SUCCESS);
SparkConf conf = new SparkConf();
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002");
conf.set("spark.driver.host", "localhost");
conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name());
conf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE, true);
// enable stage recompute
conf.set("spark." + RssClientConfig.RSS_RESUBMIT_STAGE, "true");
RssShuffleManager shuffleManager = new RssShuffleManager(conf, true);
assertTrue(conf.get(RSS_SHUFFLE_MANAGER_GRPC_PORT) > 0);
}
@Test
public void testRssShuffleManagerInterface() throws Exception {
setupMockedRssShuffleUtils(StatusCode.SUCCESS);
SparkConf conf = new SparkConf();
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002");
conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name());
conf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE, true);
conf.set("spark.task.maxFailures", "3");
RssShuffleManager shuffleManager = new RssShuffleManager(conf, true);
assertEquals(shuffleManager.getMaxFetchFailures(), 2);
// by default, the appId is null
assertNull(shuffleManager.getAppId());
}
@ParameterizedTest
@ValueSource(ints = {16, 20, 24})
public void testRssShuffleManagerRegisterShuffle(int partitionIdBits) {
BlockIdLayout layout =
BlockIdLayout.from(
63 - partitionIdBits - partitionIdBits - 2, partitionIdBits, partitionIdBits + 2);
SparkConf conf = new SparkConf();
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false");
conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002");
conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name());
conf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE, true);
conf.set("spark.task.maxFailures", "4");
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(),
String.valueOf(layout.sequenceNoBits));
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(),
String.valueOf(layout.partitionIdBits));
conf.set(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(),
String.valueOf(layout.taskAttemptIdBits));
// register a shuffle with too many partitions should fail
Partitioner mockPartitioner = mock(Partitioner.class);
when(mockPartitioner.numPartitions()).thenReturn(layout.maxNumPartitions + 1);
ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
RssShuffleManager shuffleManager = new RssShuffleManager(conf, true);
RssException e =
assertThrowsExactly(
RssException.class, () -> shuffleManager.registerShuffle(0, mockDependency));
assertEquals(
"Cannot register shuffle with "
+ (layout.maxNumPartitions + 1)
+ " partitions because the configured block id layout supports at most "
+ layout.maxNumPartitions
+ " partitions.",
e.getMessage());
}
}