blob: 8183ef7e2308d8c65a1697c4fb3cb9fea1b9622c [file] [log] [blame]
/*
* Tencent is pleased to support the open source community by making
* Firestorm-Spark remote shuffle server available.
*
* Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package com.tencent.rss.test;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.rss.coordinator.CoordinatorConf;
import com.tencent.rss.server.ShuffleServerConf;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.RssClientConfig;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.Dataset;
import org.junit.BeforeClass;
import org.junit.Test;
import scala.collection.JavaConverters;
import static org.junit.Assert.assertTrue;
public class AQERepartitionTest extends SparkIntegrationTestBase {
@BeforeClass
public static void setupServers() throws Exception {
CoordinatorConf coordinatorConf = getCoordinatorConf();
createCoordinatorServer(coordinatorConf);
ShuffleServerConf shuffleServerConf = getShuffleServerConf();
createShuffleServer(shuffleServerConf);
startServers();
}
@Override
public void updateCommonSparkConf(SparkConf sparkConf) {
sparkConf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "true");
sparkConf.set(SQLConf.COALESCE_PARTITIONS_ENABLED(), "true");
sparkConf.set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM().key(), "10");
sparkConf.set(SQLConf.SHUFFLE_PARTITIONS().key(), "10");
}
@Override
public void updateSparkConfCustomer(SparkConf sparkConf) {
sparkConf.set(RssClientConfig.RSS_STORAGE_TYPE, "HDFS");
sparkConf.set(RssClientConfig.RSS_BASE_PATH, HDFS_URI + "rss/test");
}
@Test
public void resultCompareTest() throws Exception {
run();
}
@Override
Map runTest(SparkSession spark, String fileName) throws Exception {
Thread.sleep(4000);
List<Column> repartitionCols = Lists.newArrayList();
repartitionCols.add(new Column("id"));
Dataset<Long> df = spark.range(10).repartition(
JavaConverters.asScalaBuffer(repartitionCols).toList());
Long[][] result = (Long[][])df.rdd().collectPartitions();
Map<Integer, List<Long>> map = Maps.newHashMap();
for (int i = 0; i < result.length; i++) {
map.putIfAbsent(i, Lists.newArrayList());
for (int j = 0; j < result[i].length; j++) {
map.get(i).add(result[i][j]);
}
}
for (int i = 0; i < result.length; i++) {
map.get(i).sort(new Comparator<Long>() {
@Override
public int compare(Long o1, Long o2) {
return Long.compare(o1, o2);
}
});
}
assertTrue(result.length < 10);
return map;
}
}