blob: c38f159d362c6af3bfc84bc5539986593d040dcb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.writer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import scala.Function1;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.FetchFailedException;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.storage.util.StorageType;
public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class);
private static final String DUMMY_HOST = "dummy_host";
private static final int DUMMY_PORT = 99999;
// they will be used in commit phase
private final Set<ShuffleServerInfo> shuffleServersForData;
// server -> partitionId -> blockIds
private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds;
private final ShuffleWriteClient shuffleWriteClient;
private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
private String appId;
private int numMaps;
private int shuffleId;
private int bitmapSplitNum;
private String taskId;
private long taskAttemptId;
private ShuffleDependency<K, V, C> shuffleDependency;
private ShuffleWriteMetrics shuffleWriteMetrics;
private Partitioner partitioner;
private boolean shouldPartition;
private WriteBufferManager bufferManager;
private RssShuffleManager shuffleManager;
private long sendCheckTimeout;
private long sendCheckInterval;
private boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
private SparkConf sparkConf;
public RssShuffleWriter(
String appId,
int shuffleId,
String taskId,
long taskAttemptId,
WriteBufferManager bufferManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
SimpleShuffleHandleInfo shuffleHandleInfo,
TaskContext context) {
this(
appId,
shuffleId,
taskId,
taskAttemptId,
shuffleWriteMetrics,
shuffleManager,
sparkConf,
shuffleWriteClient,
rssHandle,
(tid) -> true,
shuffleHandleInfo,
context);
this.bufferManager = bufferManager;
}
private RssShuffleWriter(
String appId,
int shuffleId,
String taskId,
long taskAttemptId,
ShuffleWriteMetrics shuffleWriteMetrics,
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
ShuffleHandleInfo shuffleHandleInfo,
TaskContext context) {
this.appId = appId;
this.shuffleId = shuffleId;
this.taskId = taskId;
this.taskAttemptId = taskAttemptId;
this.numMaps = rssHandle.getNumMaps();
this.shuffleDependency = rssHandle.getDependency();
this.shuffleWriteMetrics = shuffleWriteMetrics;
this.partitioner = shuffleDependency.partitioner();
this.shuffleManager = shuffleManager;
this.shouldPartition = partitioner.numPartitions() > 1;
this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
this.serverToPartitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
this.shuffleServersForData = shuffleHandleInfo.getServers();
this.partitionToServers = shuffleHandleInfo.getAvailablePartitionServersForWriter();
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
this.taskContext = context;
this.sparkConf = sparkConf;
}
public RssShuffleWriter(
String appId,
int shuffleId,
String taskId,
long taskAttemptId,
ShuffleWriteMetrics shuffleWriteMetrics,
RssShuffleManager shuffleManager,
SparkConf sparkConf,
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context,
ShuffleHandleInfo shuffleHandleInfo) {
this(
appId,
shuffleId,
taskId,
taskAttemptId,
shuffleWriteMetrics,
shuffleManager,
sparkConf,
shuffleWriteClient,
rssHandle,
taskFailureCallback,
shuffleHandleInfo,
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
new WriteBufferManager(
shuffleId,
taskId,
taskAttemptId,
bufferOptions,
rssHandle.getDependency().serializer(),
shuffleHandleInfo.getAvailablePartitionServersForWriter(),
context.taskMemoryManager(),
shuffleWriteMetrics,
RssSparkConfig.toRssConf(sparkConf),
this::processShuffleBlockInfos);
this.bufferManager = bufferManager;
}
private boolean isMemoryShuffleEnabled(String storageType) {
return StorageType.withMemory(StorageType.valueOf(storageType));
}
/** Create dummy BlockManagerId and embed partition->blockIds */
private BlockManagerId createDummyBlockManagerId(String executorId, long taskAttemptId) {
// dummy values are used there for host and port check in BlockManagerId
// hack: use topologyInfo field in BlockManagerId to store [partition, blockIds]
return BlockManagerId.apply(
executorId, DUMMY_HOST, DUMMY_PORT, Option.apply(Long.toString(taskAttemptId)));
}
@Override
public void write(Iterator<Product2<K, V>> records) {
try {
writeImpl(records);
} catch (Exception e) {
taskFailureCallback.apply(taskId);
if (shuffleManager.isRssResubmitStage()) {
throwFetchFailedIfNecessary(e);
} else {
throw e;
}
}
}
private void writeImpl(Iterator<Product2<K, V>> records) {
List<ShuffleBlockInfo> shuffleBlockInfos;
long recordCount = 0;
while (records.hasNext()) {
recordCount++;
Product2<K, V> record = records.next();
int partition = getPartition(record._1());
if (shuffleDependency.mapSideCombine()) {
Function1<V, C> createCombiner = shuffleDependency.aggregator().get().createCombiner();
Object c = createCombiner.apply(record._2());
shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), c);
} else {
shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), record._2());
}
processShuffleBlockInfos(shuffleBlockInfos);
}
final long start = System.currentTimeMillis();
shuffleBlockInfos = bufferManager.clear();
processShuffleBlockInfos(shuffleBlockInfos);
long s = System.currentTimeMillis();
checkSentRecordCount(recordCount);
checkBlockSendResult(blockIds);
final long checkDuration = System.currentTimeMillis() - s;
long commitDuration = 0;
if (!isMemoryShuffleEnabled) {
s = System.currentTimeMillis();
sendCommit();
commitDuration = System.currentTimeMillis() - s;
}
long writeDurationMs = bufferManager.getWriteTime() + (System.currentTimeMillis() - start);
shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(writeDurationMs));
LOG.info(
"Finish write shuffle for appId["
+ appId
+ "], shuffleId["
+ shuffleId
+ "], taskId["
+ taskId
+ "] with write "
+ writeDurationMs
+ " ms, include checkSendResult["
+ checkDuration
+ "], commit["
+ commitDuration
+ "], "
+ bufferManager.getManagerCostInfo());
}
private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
String errorMsg =
"Potential record loss may have occurred while preparing to send blocks for task["
+ taskId
+ "]";
throw new RssSendFailedException(errorMsg);
}
}
/**
* ShuffleBlock will be added to queue and send to shuffle server maintenance the following
* information: 1. add blockId to set, check if it is send later 2. update shuffle server info,
* they will be used in commit phase 3. update [partition, blockIds], it will be set to MapStatus,
* and shuffle reader will do the integration check with them
*
* @param shuffleBlockInfoList
*/
private List<CompletableFuture<Long>> processShuffleBlockInfos(
List<ShuffleBlockInfo> shuffleBlockInfoList) {
if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
shuffleBlockInfoList.stream()
.forEach(
sbi -> {
long blockId = sbi.getBlockId();
// add blockId to set, check if it is send later
blockIds.add(blockId);
// update [partition, blockIds], it will be sent to shuffle server
int partitionId = sbi.getPartitionId();
sbi.getShuffleServerInfos()
.forEach(
shuffleServerInfo -> {
Map<Integer, Set<Long>> pToBlockIds =
serverToPartitionToBlockIds.computeIfAbsent(
shuffleServerInfo, k -> Maps.newHashMap());
pToBlockIds
.computeIfAbsent(partitionId, v -> Sets.newHashSet())
.add(blockId);
});
});
return postBlockEvent(shuffleBlockInfoList);
}
return Collections.emptyList();
}
// don't send huge block to shuffle server, or there will be OOM if shuffle sever receives data
// more than expected
protected List<CompletableFuture<Long>> postBlockEvent(
List<ShuffleBlockInfo> shuffleBlockInfoList) {
List<CompletableFuture<Long>> futures = new ArrayList<>();
for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
futures.add(shuffleManager.sendData(event));
}
return futures;
}
@VisibleForTesting
protected void sendCommit() {
ExecutorService executor = Executors.newSingleThreadExecutor();
Future<Boolean> future =
executor.submit(
() -> shuffleWriteClient.sendCommit(shuffleServersForData, appId, shuffleId, numMaps));
long start = System.currentTimeMillis();
int currentWait = 200;
int maxWait = 5000;
while (!future.isDone()) {
LOG.info(
"Wait commit to shuffle server for task["
+ taskAttemptId
+ "] cost "
+ (System.currentTimeMillis() - start)
+ " ms");
Uninterruptibles.sleepUninterruptibly(currentWait, TimeUnit.MILLISECONDS);
currentWait = Math.min(currentWait * 2, maxWait);
}
try {
// check if commit/finish rpc is successful
if (!future.get()) {
throw new RssException("Failed to commit task to shuffle server");
}
} catch (InterruptedException ie) {
LOG.warn("Ignore the InterruptedException which should be caused by internal killed");
} catch (Exception e) {
throw new RssException("Exception happened when get commit status", e);
} finally {
executor.shutdown();
}
}
@VisibleForTesting
protected void checkBlockSendResult(Set<Long> blockIds) {
long start = System.currentTimeMillis();
while (true) {
Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
// if failed when send data to shuffle server, mark task as failed
if (failedBlockIds.size() > 0) {
String errorMsg =
"Send failed: Task["
+ taskId
+ "] failed because "
+ failedBlockIds.size()
+ " blocks can't be sent to shuffle server: "
+ shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers();
LOG.error(errorMsg);
throw new RssSendFailedException(errorMsg);
}
// remove blockIds which was sent successfully, if there has none left, all data are sent
blockIds.removeAll(successBlockIds);
if (blockIds.isEmpty()) {
break;
}
LOG.info("Wait " + blockIds.size() + " blocks sent to shuffle server");
Uninterruptibles.sleepUninterruptibly(sendCheckInterval, TimeUnit.MILLISECONDS);
if (System.currentTimeMillis() - start > sendCheckTimeout) {
String errorMsg =
"Timeout: Task["
+ taskId
+ "] failed because "
+ blockIds.size()
+ " blocks can't be sent to shuffle server in "
+ sendCheckTimeout
+ " ms.";
LOG.error(errorMsg);
throw new RssWaitFailedException(errorMsg);
}
}
}
@Override
public Option<MapStatus> stop(boolean success) {
try {
if (success) {
// fill partitionLengths with non zero dummy value so map output tracker could work
// correctly
long[] partitionLengths = new long[partitioner.numPartitions()];
Arrays.fill(partitionLengths, 1);
final BlockManagerId blockManagerId =
createDummyBlockManagerId(appId + "_" + taskId, taskAttemptId);
long start = System.currentTimeMillis();
shuffleWriteClient.reportShuffleResult(
serverToPartitionToBlockIds, appId, shuffleId, taskAttemptId, bitmapSplitNum);
LOG.info(
"Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms",
taskAttemptId,
bitmapSplitNum,
(System.currentTimeMillis() - start));
MapStatus mapStatus = MapStatus$.MODULE$.apply(blockManagerId, partitionLengths);
return Option.apply(mapStatus);
} else {
return Option.empty();
}
} finally {
// free all memory & metadata, or memory leak happen in executor
if (bufferManager != null) {
bufferManager.freeAllMemory();
}
if (shuffleManager != null) {
shuffleManager.clearTaskMeta(taskId);
}
}
}
@VisibleForTesting
protected <T> int getPartition(T key) {
int result = 0;
if (shouldPartition) {
result = partitioner.getPartition(key);
}
return result;
}
@VisibleForTesting
protected Map<Integer, Set<Long>> getPartitionToBlockIds() {
return serverToPartitionToBlockIds.values().stream()
.flatMap(s -> s.entrySet().stream())
.collect(
Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(existingSet, newSet) -> {
Set<Long> mergedSet = new HashSet<>(existingSet);
mergedSet.addAll(newSet);
return mergedSet;
}));
}
@VisibleForTesting
protected ShuffleWriteMetrics getShuffleWriteMetrics() {
return shuffleWriteMetrics;
}
private static ShuffleManagerClient createShuffleManagerClient(String host, int port)
throws IOException {
ClientType grpc = ClientType.GRPC;
// Host can be inferred from `spark.driver.bindAddress`, which would be set when SparkContext is
// constructed.
return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port);
}
private void throwFetchFailedIfNecessary(Exception e) {
// The shuffleServer is registered only when a Block fails to be sent
if (e instanceof RssSendFailedException) {
FailedBlockSendTracker blockIdsFailedSendTracker =
shuffleManager.getBlockIdsFailedSendTracker(taskId);
List<ShuffleServerInfo> shuffleServerInfos =
Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers());
RssReportShuffleWriteFailureRequest req =
new RssReportShuffleWriteFailureRequest(
appId,
shuffleId,
taskContext.stageAttemptNumber(),
shuffleServerInfos,
e.getMessage());
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) {
RssReportShuffleWriteFailureResponse response =
shuffleManagerClient.reportShuffleWriteFailure(req);
if (response.getReSubmitWholeStage()) {
RssReassignServersRequest rssReassignServersRequest =
new RssReassignServersRequest(
taskContext.stageId(),
taskContext.stageAttemptNumber(),
shuffleId,
partitioner.numPartitions());
RssReassignServersReponse rssReassignServersReponse =
shuffleManagerClient.reassignShuffleServers(rssReassignServersRequest);
LOG.info(
"Whether the reassignment is successful: {}",
rssReassignServersReponse.isNeedReassign());
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
shuffleId, -1, taskContext.stageAttemptNumber(), e);
throw new RssException(ffe);
}
} catch (IOException ioe) {
LOG.info("Error closing shuffle manager client with error:", ioe);
}
}
throw new RssException(e);
}
}