blob: 27c2226ea09ed46c33371a3256808e734cfe3af4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import scala.Option;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.BufferManagerOptions;
import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
public class RssShuffleManager extends RssShuffleManagerBase {
private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManager.class);
private final long heartbeatInterval;
private final long heartbeatTimeout;
private ScheduledExecutorService heartBeatScheduledExecutorService;
private SparkConf sparkConf;
private String appId = "";
private String clientType;
private ShuffleWriteClient shuffleWriteClient;
private Map<String, Set<Long>> taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
private Map<String, Set<Long>> taskToFailedBlockIds = JavaUtils.newConcurrentMap();
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
private final boolean dataReplicaSkipEnabled;
private final int dataTransferPoolSize;
private final int dataCommitPoolSize;
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
private final String user;
private final String uuid;
private DataPusher dataPusher;
private final int maxConcurrencyPerPartitionToWrite;
private final Map<Integer, Integer> shuffleIdToPartitionNum = Maps.newConcurrentMap();
private final Map<Integer, Integer> shuffleIdToNumMapTasks = Maps.newConcurrentMap();
private GrpcServer shuffleManagerServer;
private ShuffleManagerGrpcService service;
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
throw new IllegalArgumentException("Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false.");
}
this.sparkConf = sparkConf;
this.user = sparkConf.get("spark.rss.quota.user", "user");
this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
// set & check replica config
this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
this.maxConcurrencyPerPartitionToWrite =
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
LOG.info("Check quorum config ["
+ dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + ":" + dataReplicaSkipEnabled + "]");
RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead);
this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
this.heartbeatTimeout = sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2);
this.dynamicConfEnabled = sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
int unregisterThreadPoolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
int unregisterRequestTimeoutSec = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
this.shuffleWriteClient = ShuffleClientFactory
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum,
dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize,
dataCommitPoolSize, unregisterThreadPoolSize, unregisterRequestTimeoutSec, rssConf);
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && dynamicConfEnabled) {
Map<String, String> clusterClientConf = shuffleWriteClient.fetchClientConf(
sparkConf.get(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS));
RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, clusterClientConf);
}
RssSparkShuffleUtils.validateRssClientConf(sparkConf);
// External shuffle service is not supported when using remote shuffle service
sparkConf.set("spark.shuffle.service.enabled", "false");
LOG.info("Disable external shuffle service in RssShuffleManager.");
// If we store shuffle data in distributed filesystem or in a disaggregated
// shuffle cluster, we don't need shuffle data locality
sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
LOG.info("Disable shuffle data locality in RssShuffleManager.");
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false)
&& RssSparkShuffleUtils.isStageResubmitSupported()) {
LOG.info("stage resubmit is supported and enabled");
// start shuffle manager server
rssConf.set(RPC_SERVER_PORT, 0);
ShuffleManagerServerFactory factory = new ShuffleManagerServerFactory(this, rssConf);
service = factory.getService();
shuffleManagerServer = factory.getServer(service);
try {
shuffleManagerServer.start();
// pass this as a spark.rss.shuffle.manager.grpc.port config, so it can be propagated to executor properly.
sparkConf.set(RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, shuffleManagerServer.getPort());
} catch (Exception e) {
LOG.error("Failed to start shuffle manager server", e);
throw new RssException(e);
}
}
}
// for non-driver executor, start a thread for sending shuffle data to shuffle server
LOG.info("RSS data pusher is starting...");
int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
this.dataPusher = new DataPusher(
shuffleWriteClient,
taskToSuccessBlockIds,
taskToFailedBlockIds,
failedTaskIds,
poolSize,
keepAliveTime
);
}
}
// This method is called in Spark driver side,
// and Spark driver will make some decision according to coordinator,
// e.g. determining what RSS servers to use.
// Then Spark driver will return a ShuffleHandle and
// pass that ShuffleHandle to executors (getWriter/getReader).
@Override
public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, ShuffleDependency<K, V, C> dependency) {
//Spark have three kinds of serializer:
//org.apache.spark.serializer.JavaSerializer
//org.apache.spark.sql.execution.UnsafeRowSerializer
//org.apache.spark.serializer.KryoSerializer,
//Only org.apache.spark.serializer.JavaSerializer don't support RelocationOfSerializedObjects.
//So when we find the parameters to use org.apache.spark.serializer.JavaSerializer, We should throw an exception
if (!SparkEnv.get().serializer().supportsRelocationOfSerializedObjects()) {
throw new IllegalArgumentException("Can't use serialized shuffle for shuffleId: " + shuffleId + ", because the"
+ " serializer: " + SparkEnv.get().serializer().getClass().getName() + " does not support object "
+ "relocation.");
}
// If yarn enable retry ApplicationMaster, appId will be not unique and shuffle data will be incorrect,
// appId + uuid can avoid such problem,
// can't get appId in construct because SparkEnv is not created yet,
// appId will be initialized only once in this method which
// will be called many times depend on how many shuffle stage
if ("".equals(appId)) {
appId = SparkEnv.get().conf().getAppId() + "_" + uuid;
dataPusher.setRssAppId(appId);
LOG.info("Generate application id used in rss: " + appId);
}
if (dependency.partitioner().numPartitions() == 0) {
shuffleIdToPartitionNum.putIfAbsent(shuffleId, 0);
shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length);
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum is 0, "
+ "return the empty RssShuffleHandle directly");
Broadcast<ShuffleHandleInfo> hdlInfoBd = RssSparkShuffleUtils.broadcastShuffleHdlInfo(
RssSparkShuffleUtils.getActiveSparkContext(), shuffleId, Collections.emptyMap(),
RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
return new RssShuffleHandle<>(shuffleId,
appId,
dependency.rdd().getNumPartitions(),
dependency,
hdlInfoBd);
}
String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
RemoteStorageInfo defaultRemoteStorage = new RemoteStorageInfo(
sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""));
RemoteStorageInfo remoteStorage = ClientUtils.fetchRemoteStorage(
appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);
int partitionNumPerRange = sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE);
// get all register info according to coordinator's response
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);
int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
// retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the same result
long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
Map<Integer, List<ShuffleServerInfo>> partitionToServers;
try {
partitionToServers = RetryUtils.retry(() -> {
ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
appId, shuffleId, dependency.partitioner().numPartitions(),
partitionNumPerRange, assignmentTags, requiredShuffleServerNumber, -1);
registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges(), remoteStorage);
return response.getPartitionToServers();
}, retryInterval, retryTimes);
} catch (Throwable throwable) {
throw new RssException("registerShuffle failed!", throwable);
}
startHeartbeat();
shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions());
shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length);
Broadcast<ShuffleHandleInfo> hdlInfoBd = RssSparkShuffleUtils.broadcastShuffleHdlInfo(
RssSparkShuffleUtils.getActiveSparkContext(), shuffleId, partitionToServers, remoteStorage);
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size() + "]");
return new RssShuffleHandle(shuffleId, appId, numMaps, dependency, hdlInfoBd);
}
private void startHeartbeat() {
shuffleWriteClient.registerApplicationInfo(appId, heartbeatTimeout, user);
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && !heartbeatStarted) {
heartBeatScheduledExecutorService.scheduleAtFixedRate(
() -> {
try {
shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
LOG.info("Finish send heartbeat to coordinator and servers");
} catch (Exception e) {
LOG.warn("Fail to send heartbeat to coordinator and servers", e);
}
},
heartbeatInterval / 2,
heartbeatInterval,
TimeUnit.MILLISECONDS);
heartbeatStarted = true;
}
}
@VisibleForTesting
protected void registerShuffleServers(
String appId,
int shuffleId,
Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
RemoteStorageInfo remoteStorage) {
if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
return;
}
LOG.info("Start to register shuffleId[" + shuffleId + "]");
long start = System.currentTimeMillis();
serverToPartitionRanges.entrySet()
.stream()
.forEach(entry -> {
shuffleWriteClient.registerShuffle(
entry.getKey(),
appId,
shuffleId,
entry.getValue(),
remoteStorage,
ShuffleDataDistributionType.NORMAL,
maxConcurrencyPerPartitionToWrite
);
});
LOG.info("Finish register shuffleId[" + shuffleId + "] with " + (System.currentTimeMillis() - start) + " ms");
}
@VisibleForTesting
protected void registerCoordinator() {
String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
LOG.info("Registering coordinators {}", coordinators);
shuffleWriteClient.registerCoordinators(coordinators);
}
public CompletableFuture<Long> sendData(AddBlockEvent event) {
if (dataPusher != null && event != null) {
return dataPusher.send(event);
}
return new CompletableFuture<>();
}
// This method is called in Spark executor,
// getting information from Spark driver via the ShuffleHandle.
@Override
public <K, V> ShuffleWriter<K, V> getWriter(ShuffleHandle handle, int mapId,
TaskContext context) {
if (handle instanceof RssShuffleHandle) {
RssShuffleHandle<K, V, ?> rssHandle = (RssShuffleHandle<K, V, ?>) handle;
appId = rssHandle.getAppId();
dataPusher.setRssAppId(appId);
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
WriteBufferManager bufferManager = new WriteBufferManager(
shuffleId,
taskId,
context.taskAttemptId(),
bufferOptions,
rssHandle.getDependency().serializer(),
rssHandle.getPartitionToServers(),
context.taskMemoryManager(),
writeMetrics,
RssSparkConfig.toRssConf(sparkConf),
this::sendData
);
return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId, context.taskAttemptId(), bufferManager,
writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
(Function<String, Boolean>) this::markFailedTask);
} else {
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
}
}
// This method is called in Spark executor,
// getting information from Spark driver via the ShuffleHandle.
@Override
public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle handle,
int startPartition, int endPartition, TaskContext context) {
if (handle instanceof RssShuffleHandle) {
RssShuffleHandle<K, C, ?> rssShuffleHandle = (RssShuffleHandle<K, C, ?>) handle;
final int partitionNumPerRange = sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE);
final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
long start = System.currentTimeMillis();
Roaring64NavigableMap taskIdBitmap = getExpectedTasks(shuffleId, startPartition, endPartition);
LOG.info("Get taskId cost " + (System.currentTimeMillis() - start) + " ms, and request expected blockIds from "
+ taskIdBitmap.getLongCardinality() + " tasks for shuffleId[" + shuffleId + "], partitionId["
+ startPartition + "]");
start = System.currentTimeMillis();
Map<Integer, List<ShuffleServerInfo>> partitionToServers = rssShuffleHandle.getPartitionToServers();
Roaring64NavigableMap blockIdBitmap = shuffleWriteClient.getShuffleResult(
clientType, Sets.newHashSet(partitionToServers.get(startPartition)),
rssShuffleHandle.getAppId(), shuffleId, startPartition);
LOG.info("Get shuffle blockId cost " + (System.currentTimeMillis() - start) + " ms, and get "
+ blockIdBitmap.getLongCardinality() + " blockIds for shuffleId[" + shuffleId + "], partitionId["
+ startPartition + "]");
final RemoteStorageInfo shuffleRemoteStorageInfo = rssShuffleHandle.getRemoteStorage();
LOG.info("Shuffle reader using remote storage {}", shuffleRemoteStorageInfo);
final String shuffleRemoteStoragePath = shuffleRemoteStorageInfo.getPath();
Configuration readerHadoopConf = RssSparkShuffleUtils.getRemoteStorageHadoopConf(
sparkConf, shuffleRemoteStorageInfo);
return new RssShuffleReader<K, C>(
startPartition, endPartition, context,
rssShuffleHandle, shuffleRemoteStoragePath,
readerHadoopConf, partitionNumPerRange, partitionNum,
blockIdBitmap, taskIdBitmap, RssSparkConfig.toRssConf(sparkConf));
} else {
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
}
}
public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle handle, int startPartition,
int endPartition, TaskContext context, int startMapId, int endMapId) {
return null;
}
@Override
public boolean unregisterShuffle(int shuffleId) {
try {
if (SparkEnv.get().executorId().equals("driver")) {
shuffleWriteClient.unregisterShuffle(appId, shuffleId);
shuffleIdToNumMapTasks.remove(shuffleId);
shuffleIdToPartitionNum.remove(shuffleId);
if (service != null) {
service.unregisterShuffle(shuffleId);
}
}
} catch (Exception e) {
LOG.warn("Errors on unregister to remote shuffle-servers", e);
}
return true;
}
@Override
public void stop() {
if (heartBeatScheduledExecutorService != null) {
heartBeatScheduledExecutorService.shutdownNow();
}
if (dataPusher != null) {
try {
dataPusher.close();
} catch (IOException e) {
LOG.warn("Errors on closing data pusher", e);
}
}
shuffleWriteClient.close();
}
@Override
public ShuffleBlockResolver shuffleBlockResolver() {
throw new RssException("RssShuffleManager.shuffleBlockResolver is not implemented");
}
// when speculation enable, duplicate data will be sent and reported to shuffle server,
// get the actual tasks and filter the duplicate data caused by speculation task
private Roaring64NavigableMap getExpectedTasks(int shuffleId, int startPartition, int endPartition) {
Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf();
// In 2.3, getMapSizesByExecutorId returns Seq, while it returns Iterator in 2.4,
// so we use toIterator() to support Spark 2.3 & 2.4
Iterator<Tuple2<BlockManagerId, Seq<Tuple2<BlockId, Object>>>> mapStatusIter =
SparkEnv.get().mapOutputTracker().getMapSizesByExecutorId(shuffleId, startPartition, endPartition)
.toIterator();
while (mapStatusIter.hasNext()) {
Tuple2<BlockManagerId, Seq<Tuple2<BlockId, Object>>> tuple2 = mapStatusIter.next();
Option<String> topologyInfo = tuple2._1().topologyInfo();
if (topologyInfo.isDefined()) {
taskIdBitmap.addLong(Long.parseLong(tuple2._1().topologyInfo().get()));
} else {
throw new RssException("Can't get expected taskAttemptId");
}
}
LOG.info("Got result from MapStatus for expected tasks " + taskIdBitmap.getLongCardinality());
return taskIdBitmap;
}
public Set<Long> getFailedBlockIds(String taskId) {
Set<Long> result = taskToFailedBlockIds.get(taskId);
if (result == null) {
result = Sets.newHashSet();
}
return result;
}
public Set<Long> getSuccessBlockIds(String taskId) {
Set<Long> result = taskToSuccessBlockIds.get(taskId);
if (result == null) {
result = Sets.newHashSet();
}
return result;
}
@VisibleForTesting
public void addFailedBlockIds(String taskId, Set<Long> blockIds) {
if (taskToFailedBlockIds.get(taskId) == null) {
taskToFailedBlockIds.put(taskId, Sets.newHashSet());
}
taskToFailedBlockIds.get(taskId).addAll(blockIds);
}
@VisibleForTesting
public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
if (taskToSuccessBlockIds.get(taskId) == null) {
taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
}
taskToSuccessBlockIds.get(taskId).addAll(blockIds);
}
public void clearTaskMeta(String taskId) {
taskToSuccessBlockIds.remove(taskId);
taskToFailedBlockIds.remove(taskId);
}
@VisibleForTesting
public SparkConf getSparkConf() {
return sparkConf;
}
@VisibleForTesting
public void setAppId(String appId) {
this.appId = appId;
}
public boolean markFailedTask(String taskId) {
LOG.info("Mark the task: {} failed.", taskId);
failedTaskIds.add(taskId);
return true;
}
public boolean isValidTask(String taskId) {
return !failedTaskIds.contains(taskId);
}
public DataPusher getDataPusher() {
return dataPusher;
}
public void setDataPusher(DataPusher dataPusher) {
this.dataPusher = dataPusher;
}
/**
* @return the unique spark id for rss shuffle
*/
@Override
public String getAppId() {
return appId;
}
/**
* @return the maximum number of fetch failures per shuffle partition before that shuffle stage should be recomputed
*/
@Override
public int getMaxFetchFailures() {
final String TASK_MAX_FAILURE = "spark.task.maxFailures";
return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1);
}
/**
* @param shuffleId the shuffleId to query
* @return the num of partitions(a.k.a reduce tasks) for shuffle with shuffle id.
*/
@Override
public int getPartitionNum(int shuffleId) {
return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
}
/**
* @param shuffleId the shuffle id to query
* @return the num of map tasks for current shuffle with shuffle id.
*/
@Override
public int getNumMaps(int shuffleId) {
return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
}
}