| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.shuffle.writer; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Optional; |
| import java.util.Set; |
| import java.util.concurrent.BlockingQueue; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.Future; |
| import java.util.concurrent.LinkedBlockingQueue; |
| import java.util.concurrent.TimeUnit; |
| import java.util.function.Function; |
| import java.util.stream.Collectors; |
| |
| import scala.Function1; |
| import scala.Option; |
| import scala.Product2; |
| import scala.collection.Iterator; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.collect.Lists; |
| import com.google.common.collect.Maps; |
| import com.google.common.collect.Sets; |
| import com.google.common.util.concurrent.Uninterruptibles; |
| import org.apache.commons.collections.CollectionUtils; |
| import org.apache.spark.Partitioner; |
| import org.apache.spark.ShuffleDependency; |
| import org.apache.spark.SparkConf; |
| import org.apache.spark.TaskContext; |
| import org.apache.spark.executor.ShuffleWriteMetrics; |
| import org.apache.spark.scheduler.MapStatus; |
| import org.apache.spark.shuffle.FetchFailedException; |
| import org.apache.spark.shuffle.RssShuffleHandle; |
| import org.apache.spark.shuffle.RssShuffleManager; |
| import org.apache.spark.shuffle.RssSparkConfig; |
| import org.apache.spark.shuffle.RssSparkShuffleUtils; |
| import org.apache.spark.shuffle.ShuffleHandleInfo; |
| import org.apache.spark.shuffle.ShuffleWriter; |
| import org.apache.spark.storage.BlockManagerId; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import org.apache.uniffle.client.api.ShuffleManagerClient; |
| import org.apache.uniffle.client.api.ShuffleWriteClient; |
| import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; |
| import org.apache.uniffle.client.impl.FailedBlockSendTracker; |
| import org.apache.uniffle.client.impl.TrackingBlockStatus; |
| import org.apache.uniffle.client.request.RssReassignFaultyShuffleServerRequest; |
| import org.apache.uniffle.client.request.RssReassignServersRequest; |
| import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; |
| import org.apache.uniffle.client.response.RssReassignFaultyShuffleServerResponse; |
| import org.apache.uniffle.client.response.RssReassignServersReponse; |
| import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; |
| import org.apache.uniffle.common.ClientType; |
| import org.apache.uniffle.common.ShuffleBlockInfo; |
| import org.apache.uniffle.common.ShuffleServerInfo; |
| import org.apache.uniffle.common.config.RssClientConf; |
| import org.apache.uniffle.common.config.RssConf; |
| import org.apache.uniffle.common.exception.RssException; |
| import org.apache.uniffle.common.exception.RssSendFailedException; |
| import org.apache.uniffle.common.exception.RssWaitFailedException; |
| import org.apache.uniffle.common.rpc.StatusCode; |
| import org.apache.uniffle.common.util.JavaUtils; |
| import org.apache.uniffle.storage.util.StorageType; |
| |
| public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class); |
| private static final String DUMMY_HOST = "dummy_host"; |
| private static final int DUMMY_PORT = 99999; |
| |
| private final String appId; |
| private final int shuffleId; |
| private WriteBufferManager bufferManager; |
| private String taskId; |
| private final int numMaps; |
| private final ShuffleDependency<K, V, C> shuffleDependency; |
| private final Partitioner partitioner; |
| private final RssShuffleManager shuffleManager; |
| private final boolean shouldPartition; |
| private final long sendCheckTimeout; |
| private final long sendCheckInterval; |
| private final int bitmapSplitNum; |
| // server -> partitionId -> blockIds |
| private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds; |
| private final ShuffleWriteClient shuffleWriteClient; |
| private final Set<ShuffleServerInfo> shuffleServersForData; |
| private final long[] partitionLengths; |
| private final boolean isMemoryShuffleEnabled; |
| private final Function<String, Boolean> taskFailureCallback; |
| private final Set<Long> blockIds = Sets.newConcurrentHashSet(); |
| private TaskContext taskContext; |
| private SparkConf sparkConf; |
| private boolean blockFailSentRetryEnabled; |
| private int blockFailSentRetryMaxTimes = 1; |
| |
| /** used by columnar rss shuffle writer implementation */ |
| protected final long taskAttemptId; |
| |
| protected final ShuffleWriteMetrics shuffleWriteMetrics; |
| |
| private final BlockingQueue<Object> finishEventQueue = new LinkedBlockingQueue<>(); |
| |
| // shuffleServerId -> failoverShuffleServer |
| private final Map<String, ShuffleServerInfo> replacementShuffleServers = |
| JavaUtils.newConcurrentMap(); |
| |
| // Only for tests |
| @VisibleForTesting |
| public RssShuffleWriter( |
| String appId, |
| int shuffleId, |
| String taskId, |
| long taskAttemptId, |
| WriteBufferManager bufferManager, |
| ShuffleWriteMetrics shuffleWriteMetrics, |
| RssShuffleManager shuffleManager, |
| SparkConf sparkConf, |
| ShuffleWriteClient shuffleWriteClient, |
| RssShuffleHandle<K, V, C> rssHandle, |
| ShuffleHandleInfo shuffleHandleInfo, |
| TaskContext context) { |
| this( |
| appId, |
| shuffleId, |
| taskId, |
| taskAttemptId, |
| shuffleWriteMetrics, |
| shuffleManager, |
| sparkConf, |
| shuffleWriteClient, |
| rssHandle, |
| (tid) -> true, |
| shuffleHandleInfo, |
| context); |
| this.bufferManager = bufferManager; |
| } |
| |
| private RssShuffleWriter( |
| String appId, |
| int shuffleId, |
| String taskId, |
| long taskAttemptId, |
| ShuffleWriteMetrics shuffleWriteMetrics, |
| RssShuffleManager shuffleManager, |
| SparkConf sparkConf, |
| ShuffleWriteClient shuffleWriteClient, |
| RssShuffleHandle<K, V, C> rssHandle, |
| Function<String, Boolean> taskFailureCallback, |
| ShuffleHandleInfo shuffleHandleInfo, |
| TaskContext context) { |
| LOG.info("RssShuffle start write taskAttemptId data" + taskAttemptId); |
| this.shuffleManager = shuffleManager; |
| this.appId = appId; |
| this.shuffleId = shuffleId; |
| this.taskId = taskId; |
| this.taskAttemptId = taskAttemptId; |
| this.numMaps = rssHandle.getNumMaps(); |
| this.shuffleWriteMetrics = shuffleWriteMetrics; |
| this.shuffleDependency = rssHandle.getDependency(); |
| this.partitioner = shuffleDependency.partitioner(); |
| this.shouldPartition = partitioner.numPartitions() > 1; |
| this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); |
| this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); |
| this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM); |
| this.serverToPartitionToBlockIds = Maps.newHashMap(); |
| this.shuffleWriteClient = shuffleWriteClient; |
| this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData(); |
| this.partitionLengths = new long[partitioner.numPartitions()]; |
| Arrays.fill(partitionLengths, 0); |
| this.isMemoryShuffleEnabled = |
| isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key())); |
| this.taskFailureCallback = taskFailureCallback; |
| this.taskContext = context; |
| this.sparkConf = sparkConf; |
| this.blockFailSentRetryEnabled = |
| sparkConf.getBoolean( |
| RssSparkConfig.SPARK_RSS_CONFIG_PREFIX |
| + RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(), |
| RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue()); |
| } |
| |
| public RssShuffleWriter( |
| String appId, |
| int shuffleId, |
| String taskId, |
| long taskAttemptId, |
| ShuffleWriteMetrics shuffleWriteMetrics, |
| RssShuffleManager shuffleManager, |
| SparkConf sparkConf, |
| ShuffleWriteClient shuffleWriteClient, |
| RssShuffleHandle<K, V, C> rssHandle, |
| Function<String, Boolean> taskFailureCallback, |
| TaskContext context, |
| ShuffleHandleInfo shuffleHandleInfo) { |
| this( |
| appId, |
| shuffleId, |
| taskId, |
| taskAttemptId, |
| shuffleWriteMetrics, |
| shuffleManager, |
| sparkConf, |
| shuffleWriteClient, |
| rssHandle, |
| taskFailureCallback, |
| shuffleHandleInfo, |
| context); |
| BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf); |
| final WriteBufferManager bufferManager = |
| new WriteBufferManager( |
| shuffleId, |
| taskId, |
| taskAttemptId, |
| bufferOptions, |
| rssHandle.getDependency().serializer(), |
| shuffleHandleInfo.getPartitionToServers(), |
| context.taskMemoryManager(), |
| shuffleWriteMetrics, |
| RssSparkConfig.toRssConf(sparkConf), |
| this::processShuffleBlockInfos); |
| this.bufferManager = bufferManager; |
| } |
| |
| private boolean isMemoryShuffleEnabled(String storageType) { |
| return StorageType.withMemory(StorageType.valueOf(storageType)); |
| } |
| |
| @Override |
| public void write(Iterator<Product2<K, V>> records) { |
| try { |
| writeImpl(records); |
| } catch (Exception e) { |
| taskFailureCallback.apply(taskId); |
| if (shuffleManager.isRssResubmitStage()) { |
| throwFetchFailedIfNecessary(e); |
| } else { |
| throw e; |
| } |
| } |
| } |
| |
| private void writeImpl(Iterator<Product2<K, V>> records) { |
| List<ShuffleBlockInfo> shuffleBlockInfos; |
| boolean isCombine = shuffleDependency.mapSideCombine(); |
| Function1<V, C> createCombiner = null; |
| if (isCombine) { |
| createCombiner = shuffleDependency.aggregator().get().createCombiner(); |
| } |
| long recordCount = 0; |
| while (records.hasNext()) { |
| recordCount++; |
| |
| checkDataIfAnyFailure(); |
| |
| Product2<K, V> record = records.next(); |
| K key = record._1(); |
| int partition = getPartition(key); |
| if (isCombine) { |
| Object c = createCombiner.apply(record._2()); |
| shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), c); |
| } else { |
| shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), record._2()); |
| } |
| if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) { |
| processShuffleBlockInfos(shuffleBlockInfos); |
| } |
| } |
| final long start = System.currentTimeMillis(); |
| shuffleBlockInfos = bufferManager.clear(); |
| if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) { |
| processShuffleBlockInfos(shuffleBlockInfos); |
| } |
| long checkStartTs = System.currentTimeMillis(); |
| checkSentRecordCount(recordCount); |
| checkBlockSendResult(blockIds); |
| long commitStartTs = System.currentTimeMillis(); |
| long checkDuration = commitStartTs - checkStartTs; |
| if (!isMemoryShuffleEnabled) { |
| sendCommit(); |
| } |
| long writeDurationMs = bufferManager.getWriteTime() + (System.currentTimeMillis() - start); |
| shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(writeDurationMs)); |
| LOG.info( |
| "Finish write shuffle for appId[" |
| + appId |
| + "], shuffleId[" |
| + shuffleId |
| + "], taskId[" |
| + taskId |
| + "] with write " |
| + writeDurationMs |
| + " ms, include checkSendResult[" |
| + checkDuration |
| + "], commit[" |
| + (System.currentTimeMillis() - commitStartTs) |
| + "], " |
| + bufferManager.getManagerCostInfo()); |
| } |
| |
| private void checkSentRecordCount(long recordCount) { |
| if (recordCount != bufferManager.getRecordCount()) { |
| String errorMsg = |
| "Potential record loss may have occurred while preparing to send blocks for task[" |
| + taskId |
| + "]"; |
| throw new RssSendFailedException(errorMsg); |
| } |
| } |
| |
| // only push-based shuffle use this interface, but rss won't be used when push-based shuffle is |
| // enabled. |
| public long[] getPartitionLengths() { |
| return new long[0]; |
| } |
| |
| @VisibleForTesting |
| protected List<CompletableFuture<Long>> processShuffleBlockInfos( |
| List<ShuffleBlockInfo> shuffleBlockInfoList) { |
| if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) { |
| shuffleBlockInfoList.forEach( |
| sbi -> { |
| long blockId = sbi.getBlockId(); |
| // add blockId to set, check if it is sent later |
| blockIds.add(blockId); |
| // update [partition, blockIds], it will be sent to shuffle server |
| int partitionId = sbi.getPartitionId(); |
| sbi.getShuffleServerInfos() |
| .forEach( |
| shuffleServerInfo -> { |
| Map<Integer, Set<Long>> pToBlockIds = |
| serverToPartitionToBlockIds.computeIfAbsent( |
| shuffleServerInfo, k -> Maps.newHashMap()); |
| pToBlockIds.computeIfAbsent(partitionId, v -> Sets.newHashSet()).add(blockId); |
| }); |
| partitionLengths[partitionId] += sbi.getLength(); |
| }); |
| return postBlockEvent(shuffleBlockInfoList); |
| } |
| return Collections.emptyList(); |
| } |
| |
| protected List<CompletableFuture<Long>> postBlockEvent( |
| List<ShuffleBlockInfo> shuffleBlockInfoList) { |
| List<CompletableFuture<Long>> futures = new ArrayList<>(); |
| for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { |
| if (blockFailSentRetryEnabled) { |
| // do nothing if failed. |
| for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) { |
| block.withCompletionCallback( |
| (completionBlock, isSuccessful) -> { |
| if (isSuccessful) { |
| bufferManager.releaseBlockResource(completionBlock); |
| } |
| }); |
| } |
| } |
| event.addCallback( |
| () -> { |
| boolean ret = finishEventQueue.add(new Object()); |
| if (!ret) { |
| LOG.error("Add event " + event + " to finishEventQueue fail"); |
| } |
| }); |
| futures.add(shuffleManager.sendData(event)); |
| } |
| return futures; |
| } |
| |
| @VisibleForTesting |
| protected void checkBlockSendResult(Set<Long> blockIds) { |
| boolean interrupted = false; |
| |
| try { |
| long remainingMs = sendCheckTimeout; |
| long end = System.currentTimeMillis() + remainingMs; |
| |
| while (true) { |
| try { |
| finishEventQueue.clear(); |
| checkDataIfAnyFailure(); |
| Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId); |
| blockIds.removeAll(successBlockIds); |
| if (blockIds.isEmpty()) { |
| break; |
| } |
| if (finishEventQueue.isEmpty()) { |
| remainingMs = Math.max(end - System.currentTimeMillis(), 0); |
| Object event = finishEventQueue.poll(remainingMs, TimeUnit.MILLISECONDS); |
| if (event == null) { |
| break; |
| } |
| } |
| } catch (InterruptedException e) { |
| interrupted = true; |
| } |
| } |
| if (!blockIds.isEmpty()) { |
| String errorMsg = |
| "Timeout: Task[" |
| + taskId |
| + "] failed because " |
| + blockIds.size() |
| + " blocks can't be sent to shuffle server in " |
| + sendCheckTimeout |
| + " ms."; |
| LOG.error(errorMsg); |
| throw new RssWaitFailedException(errorMsg); |
| } |
| } finally { |
| if (interrupted) { |
| Thread.currentThread().interrupt(); |
| } |
| } |
| } |
| |
| private void checkDataIfAnyFailure() { |
| if (blockFailSentRetryEnabled) { |
| collectFailedBlocksToResend(); |
| } else { |
| if (hasAnyBlockFailure()) { |
| throw new RssSendFailedException("Fail to send the block"); |
| } |
| } |
| } |
| |
| private boolean hasAnyBlockFailure() { |
| Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId); |
| if (!failedBlockIds.isEmpty()) { |
| LOG.error( |
| "Errors on sending blocks for task[{}]. {} blocks can't be sent to remote servers: {}", |
| taskId, |
| failedBlockIds.size(), |
| shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers()); |
| return true; |
| } |
| return false; |
| } |
| |
| private void collectFailedBlocksToResend() { |
| if (!blockFailSentRetryEnabled) { |
| return; |
| } |
| |
| FailedBlockSendTracker failedTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId); |
| Set<Long> failedBlockIds = failedTracker.getFailedBlockIds(); |
| if (CollectionUtils.isEmpty(failedBlockIds)) { |
| return; |
| } |
| |
| boolean isFastFail = false; |
| Set<TrackingBlockStatus> resendCandidates = new HashSet<>(); |
| // to check whether the blocks resent exceed the max resend count. |
| for (Long blockId : failedBlockIds) { |
| List<TrackingBlockStatus> failedBlockStatus = failedTracker.getFailedBlockStatus(blockId); |
| int retryIndex = failedBlockStatus.get(0).getShuffleBlockInfo().getRetryCnt(); |
| // todo: support retry times by config |
| if (retryIndex >= blockFailSentRetryMaxTimes) { |
| LOG.error( |
| "Partial blocks for taskId: [{}] retry exceeding the max retry times: [{}]. Fast fail! faulty server list: {}", |
| taskId, |
| blockFailSentRetryMaxTimes, |
| failedBlockStatus.stream() |
| .map(x -> x.getShuffleServerInfo()) |
| .collect(Collectors.toSet())); |
| isFastFail = true; |
| break; |
| } |
| |
| // todo: if setting multi replica and another replica is succeed to send, no need to resend |
| resendCandidates.addAll(failedBlockStatus); |
| } |
| |
| if (isFastFail) { |
| // release data and allocated memory |
| for (Long blockId : failedBlockIds) { |
| List<TrackingBlockStatus> failedBlockStatus = failedTracker.getFailedBlockStatus(blockId); |
| Optional<TrackingBlockStatus> blockStatus = failedBlockStatus.stream().findFirst(); |
| if (blockStatus.isPresent()) { |
| blockStatus.get().getShuffleBlockInfo().executeCompletionCallback(true); |
| } |
| } |
| |
| throw new RssSendFailedException( |
| "Errors on resending the blocks data to the remote shuffle-server."); |
| } |
| |
| resendFailedBlocks(resendCandidates); |
| } |
| |
| private void resendFailedBlocks(Set<TrackingBlockStatus> failedBlockStatusSet) { |
| List<ShuffleBlockInfo> reassignBlocks = Lists.newArrayList(); |
| Map<ShuffleServerInfo, List<TrackingBlockStatus>> faultyServerToPartitions = |
| failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo())); |
| |
| for (Map.Entry<ShuffleServerInfo, List<TrackingBlockStatus>> entry : |
| faultyServerToPartitions.entrySet()) { |
| Set<Integer> partitionIds = |
| entry.getValue().stream() |
| .map(x -> x.getShuffleBlockInfo().getPartitionId()) |
| .collect(Collectors.toSet()); |
| ShuffleServerInfo replacement = replacementShuffleServers.get(entry.getKey().getId()); |
| if (replacement == null) { |
| // todo: merge multiple requests into one. |
| replacement = reassignFaultyShuffleServer(partitionIds, entry.getKey().getId()); |
| replacementShuffleServers.put(entry.getKey().getId(), replacement); |
| } |
| |
| for (TrackingBlockStatus blockStatus : failedBlockStatusSet) { |
| // clear the previous retry state of block |
| ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo(); |
| clearFailedBlockState(block); |
| |
| final ShuffleBlockInfo newBlock = block; |
| newBlock.incrRetryCnt(); |
| newBlock.reassignShuffleServers(Arrays.asList(replacement)); |
| |
| reassignBlocks.add(newBlock); |
| } |
| } |
| |
| processShuffleBlockInfos(reassignBlocks); |
| } |
| |
| private void clearFailedBlockState(ShuffleBlockInfo block) { |
| shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId()); |
| block.getShuffleServerInfos().stream() |
| .filter(s -> replacementShuffleServers.containsKey(s.getId())) |
| .forEach( |
| s -> |
| serverToPartitionToBlockIds |
| .get(s) |
| .get(block.getPartitionId()) |
| .remove(block.getBlockId())); |
| partitionLengths[block.getPartitionId()] -= block.getLength(); |
| } |
| |
| private ShuffleServerInfo reassignFaultyShuffleServer( |
| Set<Integer> partitionIds, String faultyServerId) { |
| RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); |
| String driver = rssConf.getString("driver.host", ""); |
| int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); |
| try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) { |
| RssReassignFaultyShuffleServerRequest request = |
| new RssReassignFaultyShuffleServerRequest(shuffleId, partitionIds, faultyServerId); |
| RssReassignFaultyShuffleServerResponse response = |
| shuffleManagerClient.reassignFaultyShuffleServer(request); |
| if (response.getStatusCode() != StatusCode.SUCCESS) { |
| throw new RssException( |
| "reassign server response with statusCode[" + response.getStatusCode() + "]"); |
| } |
| if (response.getShuffleServer() == null) { |
| throw new RssException("empty newer reassignment server!"); |
| } |
| return response.getShuffleServer(); |
| } catch (Exception e) { |
| throw new RssException( |
| "Failed to reassign a new server for faultyServerId server[" + faultyServerId + "]", e); |
| } |
| } |
| |
| @VisibleForTesting |
| protected void sendCommit() { |
| ExecutorService executor = Executors.newSingleThreadExecutor(); |
| Future<Boolean> future = |
| executor.submit( |
| () -> shuffleWriteClient.sendCommit(shuffleServersForData, appId, shuffleId, numMaps)); |
| int maxWait = 5000; |
| int currentWait = 200; |
| long start = System.currentTimeMillis(); |
| while (!future.isDone()) { |
| LOG.info( |
| "Wait commit to shuffle server for task[" |
| + taskAttemptId |
| + "] cost " |
| + (System.currentTimeMillis() - start) |
| + " ms"); |
| Uninterruptibles.sleepUninterruptibly(currentWait, TimeUnit.MILLISECONDS); |
| currentWait = Math.min(currentWait * 2, maxWait); |
| } |
| try { |
| if (!future.get()) { |
| throw new RssException("Failed to commit task to shuffle server"); |
| } |
| } catch (InterruptedException ie) { |
| LOG.warn("Ignore the InterruptedException which should be caused by internal killed"); |
| } catch (Exception e) { |
| throw new RssException("Exception happened when get commit status", e); |
| } finally { |
| executor.shutdown(); |
| } |
| } |
| |
| @VisibleForTesting |
| protected <T> int getPartition(T key) { |
| int result = 0; |
| if (shouldPartition) { |
| result = partitioner.getPartition(key); |
| } |
| return result; |
| } |
| |
| @Override |
| public Option<MapStatus> stop(boolean success) { |
| try { |
| if (success) { |
| long start = System.currentTimeMillis(); |
| shuffleWriteClient.reportShuffleResult( |
| serverToPartitionToBlockIds, appId, shuffleId, taskAttemptId, bitmapSplitNum); |
| LOG.info( |
| "Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms", |
| taskAttemptId, |
| bitmapSplitNum, |
| (System.currentTimeMillis() - start)); |
| // todo: we can replace the dummy host and port with the real shuffle server which we prefer |
| // to read |
| final BlockManagerId blockManagerId = |
| BlockManagerId.apply( |
| appId + "_" + taskId, |
| DUMMY_HOST, |
| DUMMY_PORT, |
| Option.apply(Long.toString(taskAttemptId))); |
| MapStatus mapStatus = MapStatus.apply(blockManagerId, partitionLengths, taskAttemptId); |
| return Option.apply(mapStatus); |
| } else { |
| return Option.empty(); |
| } |
| } finally { |
| if (blockFailSentRetryEnabled) { |
| if (success) { |
| if (CollectionUtils.isNotEmpty(shuffleManager.getFailedBlockIds(taskId))) { |
| LOG.error( |
| "Errors on stopping writer due to the remaining failed blockIds. This should not happen."); |
| return Option.empty(); |
| } |
| } else { |
| shuffleManager.getBlockIdsFailedSendTracker(taskId).clearAndReleaseBlockResources(); |
| } |
| } |
| // free all memory & metadata, or memory leak happen in executor |
| if (bufferManager != null) { |
| bufferManager.freeAllMemory(); |
| } |
| if (shuffleManager != null) { |
| shuffleManager.clearTaskMeta(taskId); |
| } |
| } |
| } |
| |
| @VisibleForTesting |
| Map<Integer, Set<Long>> getPartitionToBlockIds() { |
| return serverToPartitionToBlockIds.values().stream() |
| .flatMap(s -> s.entrySet().stream()) |
| .collect( |
| Collectors.toMap( |
| Map.Entry::getKey, |
| Map.Entry::getValue, |
| (existingSet, newSet) -> { |
| Set<Long> mergedSet = new HashSet<>(existingSet); |
| mergedSet.addAll(newSet); |
| return mergedSet; |
| })); |
| } |
| |
| @VisibleForTesting |
| public WriteBufferManager getBufferManager() { |
| return bufferManager; |
| } |
| |
| private static ShuffleManagerClient createShuffleManagerClient(String host, int port) |
| throws IOException { |
| ClientType grpc = ClientType.GRPC; |
| // Host can be inferred from `spark.driver.bindAddress`, which would be set when SparkContext is |
| // constructed. |
| return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port); |
| } |
| |
| private void throwFetchFailedIfNecessary(Exception e) { |
| // The shuffleServer is registered only when a Block fails to be sent |
| if (e instanceof RssSendFailedException) { |
| FailedBlockSendTracker blockIdsFailedSendTracker = |
| shuffleManager.getBlockIdsFailedSendTracker(taskId); |
| List<ShuffleServerInfo> shuffleServerInfos = |
| Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers()); |
| RssReportShuffleWriteFailureRequest req = |
| new RssReportShuffleWriteFailureRequest( |
| appId, |
| shuffleId, |
| taskContext.stageAttemptNumber(), |
| shuffleServerInfos, |
| e.getMessage()); |
| RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); |
| String driver = rssConf.getString("driver.host", ""); |
| int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); |
| try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) { |
| RssReportShuffleWriteFailureResponse response = |
| shuffleManagerClient.reportShuffleWriteFailure(req); |
| if (response.getReSubmitWholeStage()) { |
| RssReassignServersRequest rssReassignServersRequest = |
| new RssReassignServersRequest( |
| taskContext.stageId(), |
| taskContext.stageAttemptNumber(), |
| shuffleId, |
| partitioner.numPartitions()); |
| RssReassignServersReponse rssReassignServersReponse = |
| shuffleManagerClient.reassignShuffleServers(rssReassignServersRequest); |
| LOG.info( |
| "Whether the reassignment is successful: {}", |
| rssReassignServersReponse.isNeedReassign()); |
| // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is |
| // provided. |
| FetchFailedException ffe = |
| RssSparkShuffleUtils.createFetchFailedException( |
| shuffleId, -1, taskContext.stageAttemptNumber(), e); |
| throw new RssException(ffe); |
| } |
| } catch (IOException ioe) { |
| LOG.info("Error closing shuffle manager client with error:", ioe); |
| } |
| } |
| throw new RssException(e); |
| } |
| |
| @VisibleForTesting |
| protected void enableBlockFailSentRetry() { |
| this.blockFailSentRetryEnabled = true; |
| } |
| |
| @VisibleForTesting |
| protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes) { |
| this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes; |
| } |
| |
| @VisibleForTesting |
| protected void addReassignmentShuffleServer(String shuffleId, ShuffleServerInfo replacement) { |
| replacementShuffleServers.put(shuffleId, replacement); |
| } |
| |
| @VisibleForTesting |
| protected void setTaskId(String taskId) { |
| this.taskId = taskId; |
| } |
| |
| @VisibleForTesting |
| protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>> getServerToPartitionToBlockIds() { |
| return serverToPartitionToBlockIds; |
| } |
| } |