blob: e1b5dc037ff3c5b6d657758f33b8ca5037859ceb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.test;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.Collections;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import scala.Option;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.executor.TempShuffleReadMetrics;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.metrics.source.Source;
import org.apache.spark.resource.ResourceInformation;
import org.apache.spark.shuffle.FetchFailedException;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.AccumulatorV2;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.TaskFailureListener;
import org.junit.jupiter.api.Test;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.rpc.ServerType;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class GetReaderTest extends IntegrationTestBase {
@Test
public void test() throws Exception {
SparkConf sparkConf = new SparkConf();
sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.RssShuffleManager");
sparkConf.set(
"spark.shuffle.sort.io.plugin.class", "org.apache.spark.shuffle.RssShuffleDataIo");
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
sparkConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM);
sparkConf.setMaster("local[4]");
final String remoteStorage1 = "hdfs://h1/p1";
final String remoteStorage2 = "hdfs://h2/p2";
String cfgFile = HDFS_URI + "/test/client_conf";
Path path = new Path(cfgFile);
FSDataOutputStream out = fs.create(path);
PrintWriter printWriter = new PrintWriter(new OutputStreamWriter(out));
printWriter.println(
CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key()
+ " "
+ String.join(Constants.COMMA_SPLIT_CHAR, remoteStorage1, remoteStorage2));
printWriter.println(
CoordinatorConf.COORDINATOR_REMOTE_STORAGE_CLUSTER_CONF.key() + " h2,k1=v1,k2=v2");
printWriter.println("spark.rss.storage.type " + StorageType.MEMORY_LOCALFILE_HDFS.name());
printWriter.flush();
printWriter.close();
CoordinatorConf coordinatorConf = getCoordinatorConf();
coordinatorConf.setBoolean("rss.coordinator.dynamicClientConf.enabled", true);
coordinatorConf.setString("rss.coordinator.dynamicClientConf.path", cfgFile);
coordinatorConf.setInteger("rss.coordinator.dynamicClientConf.updateIntervalSec", 1);
coordinatorConf.setInteger("rss.coordinator.access.candidates.updateIntervalSec", 1);
coordinatorConf.setInteger("rss.coordinator.access.loadChecker.serverNum.threshold", 1);
coordinatorConf.setLong("rss.coordinator.remote.storage.schedule.time", 200);
coordinatorConf.setInteger("rss.coordinator.remote.storage.schedule.access.times", 1);
createCoordinatorServer(coordinatorConf);
ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC);
createShuffleServer(shuffleServerConf);
startServers();
Uninterruptibles.sleepUninterruptibly(2, TimeUnit.SECONDS);
SparkSession sparkSession = SparkSession.builder().config(sparkConf).getOrCreate();
JavaSparkContext jsc1 = new JavaSparkContext(sparkSession.sparkContext());
JavaPairRDD<String, Tuple2<Integer, Integer>> javaPairRDD1 =
TestUtils.combineByKeyRDD(TestUtils.getRDD(jsc1));
ShuffleDependency<?, ?, ?> shuffleDependency1 =
(ShuffleDependency<?, ?, ?>) javaPairRDD1.rdd().dependencies().head();
RssShuffleHandle<?, ?, ?> rssShuffleHandle1 =
(RssShuffleHandle<?, ?, ?>) shuffleDependency1.shuffleHandle();
RemoteStorageInfo remoteStorageInfo1 = rssShuffleHandle1.getRemoteStorage();
assertEquals(remoteStorage1, remoteStorageInfo1.getPath());
assertTrue(remoteStorageInfo1.getConfItems().isEmpty());
// emptyRDD case
JavaPairRDD<String, Tuple2<Integer, Integer>> javaEmptyPairRDD1 =
TestUtils.combineByKeyRDD(TestUtils.getEmptyRDD(jsc1));
ShuffleDependency<?, ?, ?> emptyShuffleDependency1 =
(ShuffleDependency<?, ?, ?>) javaEmptyPairRDD1.rdd().dependencies().head();
RssShuffleHandle<?, ?, ?> emptyRssShuffleHandle1 =
(RssShuffleHandle<?, ?, ?>) emptyShuffleDependency1.shuffleHandle();
assertEquals(javaEmptyPairRDD1.rdd().dependencies().head().rdd().getNumPartitions(), 0);
assertEquals(emptyRssShuffleHandle1.getPartitionToServers(), Collections.emptyMap());
assertEquals(emptyRssShuffleHandle1.getRemoteStorage(), RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
// the same app would get the same storage info
JavaPairRDD<String, Tuple2<Integer, Integer>> javaPairRDD2 =
TestUtils.combineByKeyRDD(TestUtils.getRDD(jsc1));
ShuffleDependency<?, ?, ?> shuffleDependency2 =
(ShuffleDependency<?, ?, ?>) javaPairRDD2.rdd().dependencies().head();
RssShuffleHandle<?, ?, ?> rssShuffleHandle2 =
(RssShuffleHandle<?, ?, ?>) shuffleDependency2.shuffleHandle();
RemoteStorageInfo remoteStorageInfo2 = rssShuffleHandle2.getRemoteStorage();
assertEquals(remoteStorage1, remoteStorageInfo1.getPath());
assertTrue(remoteStorageInfo2.getConfItems().isEmpty());
RssShuffleManager rssShuffleManager =
(RssShuffleManager) sparkSession.sparkContext().env().shuffleManager();
RssShuffleHandle<?, ?, ?> rssShuffleHandle =
(RssShuffleHandle<?, ?, ?>) shuffleDependency2.shuffleHandle();
RssShuffleReader<?, ?> rssShuffleReader =
(RssShuffleReader<?, ?>)
rssShuffleManager.getReader(
rssShuffleHandle, 0, 0, new MockTaskContext(), new TempShuffleReadMetrics());
Configuration hadoopConf = rssShuffleReader.getHadoopConf();
assertNull(hadoopConf.get("k1"));
assertNull(hadoopConf.get("k2"));
Configuration commonHadoopConf = jsc1.hadoopConfiguration();
assertNull(commonHadoopConf.get("k1"));
assertNull(commonHadoopConf.get("k2"));
rssShuffleManager = (RssShuffleManager) sparkSession.sparkContext().env().shuffleManager();
rssShuffleManager.setAppId("test2");
JavaSparkContext jsc2 = new JavaSparkContext(sparkSession.sparkContext());
JavaPairRDD<String, Tuple2<Integer, Integer>> javaPairRDD =
TestUtils.combineByKeyRDD(TestUtils.getRDD(jsc2));
ShuffleDependency<?, ?, ?> shuffleDependency =
(ShuffleDependency<?, ?, ?>) javaPairRDD.rdd().dependencies().head();
rssShuffleHandle = (RssShuffleHandle<?, ?, ?>) shuffleDependency.shuffleHandle();
// the reason for sleep here is to ensure that threads can be scheduled normally
Thread.sleep(500);
RemoteStorageInfo remoteStorageInfo3 = rssShuffleHandle.getRemoteStorage();
assertEquals(remoteStorage1, remoteStorageInfo1.getPath());
assertEquals(2, remoteStorageInfo3.getConfItems().size());
assertEquals("v1", remoteStorageInfo3.getConfItems().get("k1"));
assertEquals("v2", remoteStorageInfo3.getConfItems().get("k2"));
rssShuffleReader =
(RssShuffleReader<?, ?>)
rssShuffleManager.getReader(
rssShuffleHandle, 0, 0, new MockTaskContext(), new TempShuffleReadMetrics());
hadoopConf = rssShuffleReader.getHadoopConf();
assertEquals("v1", hadoopConf.get("k1"));
assertEquals("v2", hadoopConf.get("k2"));
// hadoop conf of reader and spark context should be isolated
commonHadoopConf = jsc2.hadoopConfiguration();
assertNull(commonHadoopConf.get("k1"));
assertNull(commonHadoopConf.get("k2"));
rssShuffleReader =
(RssShuffleReader<?, ?>)
rssShuffleManager.getReader(
rssShuffleHandle, 0, 0, new MockTaskContext(), new TempShuffleReadMetrics());
hadoopConf = rssShuffleReader.getHadoopConf();
assertEquals("v1", hadoopConf.get("k1"));
assertEquals("v2", hadoopConf.get("k2"));
// hadoop conf of reader and spark context should be isolated
commonHadoopConf = jsc2.hadoopConfiguration();
assertNull(commonHadoopConf.get("k1"));
assertNull(commonHadoopConf.get("k2"));
sparkSession.close();
}
private static class MockTaskContext extends TaskContext {
@Override
public int attemptNumber() {
return 0;
}
@Override
public Option<FetchFailedException> fetchFailed() {
return null;
}
@Override
public int stageId() {
return 0;
}
@Override
public void markTaskFailed(Throwable error) {}
@Override
public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
return null;
}
@Override
public void markInterrupted(String reason) {}
@Override
public TaskMemoryManager taskMemoryManager() {
return null;
}
@Override
public void killTaskIfInterrupted() {}
@Override
public TaskMetrics taskMetrics() {
return null;
}
@Override
public Seq<Source> getMetricsSources(String sourceName) {
return null;
}
@Override
public boolean isInterrupted() {
return false;
}
@Override
public Map<String, ResourceInformation> resources() {
return null;
}
@Override
public TaskContext addTaskFailureListener(TaskFailureListener listener) {
return null;
}
@Override
public long taskAttemptId() {
return 0;
}
@Override
public int stageAttemptNumber() {
return 0;
}
@Override
public void registerAccumulator(AccumulatorV2<?, ?> a) {}
@Override
public void markTaskCompleted(Option<Throwable> error) {}
@Override
public String getLocalProperty(String key) {
return null;
}
public int cpus() {
return Runtime.getRuntime().availableProcessors();
}
@Override
public java.util.Map<String, ResourceInformation> resourcesJMap() {
return null;
}
@Override
public void setFetchFailed(FetchFailedException fetchFailed) {}
@Override
public boolean isCompleted() {
return false;
}
// The following method is only available after Spark 3.5.1, and in order to be compatible
// with the version before Spark 3.5.1, the annotation @Override is not added.
public boolean isFailed() {
return false;
}
@Override
public Properties getLocalProperties() {
return null;
}
@Override
public int partitionId() {
return 0;
}
public int numPartitions() {
return 0;
}
@Override
public Option<String> getKillReason() {
return null;
}
}
}