blob: ec7a7c93a2c92b2a33ce0a160fb91e240df7ec8e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.test;
import java.io.File;
import java.nio.file.Files;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec;
import org.apache.spark.sql.execution.joins.SortMergeJoinExec;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.server.MockedGrpcServer;
import org.apache.uniffle.server.MockedShuffleServerGrpcService;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class GetShuffleReportForMultiPartTest extends SparkIntegrationTestBase {
private static final int replicateWrite = 3;
private static final int replicateRead = 2;
@BeforeAll
public static void setupServers() throws Exception {
CoordinatorConf coordinatorConf = getCoordinatorConf();
Map<String, String> dynamicConf = Maps.newHashMap();
dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test");
dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name());
addDynamicConf(coordinatorConf, dynamicConf);
createCoordinatorServer(coordinatorConf);
// Create multi shuffle servers
createShuffleServers();
startServers();
}
private static void createShuffleServers() throws Exception {
for (int i = 0; i < 4; i++) {
// Copy from IntegrationTestBase#getShuffleServerConf
File dataFolder = Files.createTempDirectory("rssdata" + i).toFile();
ShuffleServerConf serverConf = new ShuffleServerConf();
dataFolder.deleteOnExit();
serverConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + i);
serverConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE_HDFS.name());
serverConf.setString("rss.storage.basePath", dataFolder.getAbsolutePath());
serverConf.setString("rss.server.buffer.capacity", "671088640");
serverConf.setString("rss.server.memory.shuffle.highWaterMark", "50.0");
serverConf.setString("rss.server.memory.shuffle.lowWaterMark", "0.0");
serverConf.setString("rss.server.read.buffer.capacity", "335544320");
serverConf.setString("rss.coordinator.quorum", COORDINATOR_QUORUM);
serverConf.setString("rss.server.heartbeat.delay", "1000");
serverConf.setString("rss.server.heartbeat.interval", "1000");
serverConf.setInteger("rss.jetty.http.port", 18080 + i);
serverConf.setInteger("rss.jetty.corePool.size", 64);
serverConf.setInteger("rss.rpc.executor.size", 10);
serverConf.setString("rss.server.hadoop.dfs.replication", "2");
serverConf.setLong("rss.server.disk.capacity", 10L * 1024L * 1024L * 1024L);
serverConf.setBoolean("rss.server.health.check.enable", false);
createMockedShuffleServer(serverConf);
}
enableRecordGetShuffleResult();
}
private static void enableRecordGetShuffleResult() {
for (ShuffleServer shuffleServer : shuffleServers) {
((MockedGrpcServer) shuffleServer.getServer()).getService()
.enableRecordGetShuffleResult();
}
}
@Override
public void updateCommonSparkConf(SparkConf sparkConf) {
sparkConf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "true");
sparkConf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), "-1");
sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM().key(), "1");
sparkConf.set(SQLConf.SHUFFLE_PARTITIONS().key(), "100");
sparkConf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD().key(), "800");
sparkConf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "800");
}
@Override
public void updateSparkConfCustomer(SparkConf sparkConf) {
sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), "HDFS");
sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test");
}
@Override
public void updateSparkConfWithRss(SparkConf sparkConf) {
super.updateSparkConfWithRss(sparkConf);
// Add multi replica conf
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA.key(), String.valueOf(replicateWrite));
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_WRITE.key(), String.valueOf(replicateWrite));
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_READ.key(), String.valueOf(replicateRead));
sparkConf.set("spark.shuffle.manager",
"org.apache.uniffle.test.GetShuffleReportForMultiPartTest$RssShuffleManagerWrapper");
}
@Test
public void resultCompareTest() throws Exception {
run();
}
@Override
Map runTest(SparkSession spark, String fileName) throws Exception {
Thread.sleep(4000);
Map<Integer, String> map = Maps.newHashMap();
Dataset<Row> df2 = spark.range(0, 1000, 1, 10)
.select(functions.when(functions.col("id").$less(250), 249)
.otherwise(functions.col("id")).as("key2"), functions.col("id").as("value2"));
Dataset<Row> df1 = spark.range(0, 1000, 1, 10)
.select(functions.when(functions.col("id").$less(250), 249)
.when(functions.col("id").$greater(750), 1000)
.otherwise(functions.col("id")).as("key1"), functions.col("id").as("value2"));
Dataset<Row> df3 = df1.join(df2, df1.col("key1").equalTo(df2.col("key2")));
List<String> result = Lists.newArrayList();
assertTrue(df3.queryExecution().executedPlan().toString().startsWith("AdaptiveSparkPlan isFinalPlan=false"));
df3.collectAsList().forEach(row -> {
result.add(row.json());
});
assertTrue(df3.queryExecution().executedPlan().toString().startsWith("AdaptiveSparkPlan isFinalPlan=true"));
AdaptiveSparkPlanExec plan = (AdaptiveSparkPlanExec) df3.queryExecution().executedPlan();
SortMergeJoinExec joinExec = (SortMergeJoinExec) plan.executedPlan().children().iterator().next();
assertTrue(joinExec.isSkewJoin());
result.sort(new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
return o1.compareTo(o2);
}
});
int i = 0;
for (String str : result) {
map.put(i, str);
i++;
}
SparkConf conf = spark.sparkContext().conf();
if (conf.get("spark.shuffle.manager", "")
.equals("org.apache.uniffle.test.GetShuffleReportForMultiPartTest$RssShuffleManagerWrapper")) {
RssShuffleManagerWrapper mockRssShuffleManager =
(RssShuffleManagerWrapper) spark.sparkContext().env().shuffleManager();
int expectRequestNum = mockRssShuffleManager.getShuffleIdToPartitionNum().values().stream()
.mapToInt(x -> x.get()).sum();
// Validate getShuffleResultForMultiPart is correct before return result
validateRequestCount(spark.sparkContext().applicationId(), expectRequestNum * replicateRead);
}
return map;
}
public void validateRequestCount(String appId, int expectRequestNum) {
for (ShuffleServer shuffleServer : shuffleServers) {
MockedShuffleServerGrpcService service = ((MockedGrpcServer) shuffleServer.getServer()).getService();
Map<String, Map<Integer, AtomicInteger>> serviceRequestCount = service.getShuffleIdToPartitionRequest();
int requestNum = serviceRequestCount.entrySet().stream().filter(x -> x.getKey().startsWith(appId))
.flatMap(x -> x.getValue().values().stream()).mapToInt(AtomicInteger::get).sum();
expectRequestNum -= requestNum;
}
assertEquals(0, expectRequestNum);
}
public static class RssShuffleManagerWrapper extends RssShuffleManager {
// shuffleId -> partShouldRequestNum
Map<Integer, AtomicInteger> shuffleToPartShouldRequestNum = Maps.newConcurrentMap();
public RssShuffleManagerWrapper(SparkConf conf, boolean isDriver) {
super(conf, isDriver);
}
@Override
public <K, C> ShuffleReader<K, C> getReaderImpl(
ShuffleHandle handle,
int startMapIndex,
int endMapIndex,
int startPartition,
int endPartition,
TaskContext context,
ShuffleReadMetricsReporter metrics,
Roaring64NavigableMap taskIdBitmap) {
int shuffleId = handle.shuffleId();
RssShuffleHandle rssShuffleHandle = (RssShuffleHandle) handle;
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers = rssShuffleHandle.getPartitionToServers();
int partitionNum = (int) allPartitionToServers.entrySet().stream()
.filter(x -> x.getKey() >= startPartition && x.getKey() < endPartition).count();
AtomicInteger partShouldRequestNum = shuffleToPartShouldRequestNum.computeIfAbsent(shuffleId,
x -> new AtomicInteger(0));
partShouldRequestNum.addAndGet(partitionNum);
return super.getReaderImpl(handle, startMapIndex, endMapIndex, startPartition, endPartition,
context, metrics, taskIdBitmap);
}
public Map<Integer, AtomicInteger> getShuffleIdToPartitionNum() {
return shuffleToPartShouldRequestNum;
}
}
}