| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.shuffle.sort; |
| |
| import java.nio.channels.Channels; |
| import java.util.Arrays; |
| import java.util.Optional; |
| import javax.annotation.Nullable; |
| import java.io.*; |
| import java.nio.channels.FileChannel; |
| import java.nio.channels.WritableByteChannel; |
| import java.util.Iterator; |
| |
| import scala.Option; |
| import scala.Product2; |
| import scala.jdk.javaapi.CollectionConverters; |
| import scala.reflect.ClassTag; |
| import scala.reflect.ClassTag$; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.io.ByteStreams; |
| import com.google.common.io.Closeables; |
| |
| import org.apache.spark.*; |
| import org.apache.spark.annotation.Private; |
| import org.apache.spark.internal.config.package$; |
| import org.apache.spark.internal.Logger; |
| import org.apache.spark.internal.LoggerFactory; |
| import org.apache.spark.internal.LogKeys; |
| import org.apache.spark.internal.MDC; |
| import org.apache.spark.io.CompressionCodec; |
| import org.apache.spark.io.CompressionCodec$; |
| import org.apache.spark.io.NioBufferedFileInputStream; |
| import org.apache.spark.memory.TaskMemoryManager; |
| import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; |
| import org.apache.spark.network.util.LimitedInputStream; |
| import org.apache.spark.scheduler.MapStatus; |
| import org.apache.spark.scheduler.MapStatus$; |
| import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; |
| import org.apache.spark.serializer.SerializationStream; |
| import org.apache.spark.serializer.SerializerInstance; |
| import org.apache.spark.shuffle.ShuffleWriter; |
| import org.apache.spark.shuffle.api.ShuffleExecutorComponents; |
| import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; |
| import org.apache.spark.shuffle.api.ShufflePartitionWriter; |
| import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; |
| import org.apache.spark.shuffle.api.WritableByteChannelWrapper; |
| import org.apache.spark.storage.BlockManager; |
| import org.apache.spark.storage.TimeTrackingOutputStream; |
| import org.apache.spark.unsafe.Platform; |
| import org.apache.spark.util.Utils; |
| |
| @Private |
| public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { |
| |
| private static final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); |
| |
| private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); |
| |
| @VisibleForTesting |
| static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; |
| |
| private final BlockManager blockManager; |
| private final TaskMemoryManager memoryManager; |
| private final SerializerInstance serializer; |
| private final Partitioner partitioner; |
| private final ShuffleWriteMetricsReporter writeMetrics; |
| private final ShuffleExecutorComponents shuffleExecutorComponents; |
| private final int shuffleId; |
| private final long mapId; |
| private final TaskContext taskContext; |
| private final SparkConf sparkConf; |
| private final boolean transferToEnabled; |
| private final int initialSortBufferSize; |
| private final int inputBufferSizeInBytes; |
| |
| @Nullable private MapStatus mapStatus; |
| @Nullable private ShuffleExternalSorter sorter; |
| @Nullable private long[] partitionLengths; |
| private long peakMemoryUsedBytes = 0; |
| |
| /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ |
| private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { |
| MyByteArrayOutputStream(int size) { super(size); } |
| public byte[] getBuf() { return buf; } |
| } |
| |
| private MyByteArrayOutputStream serBuffer; |
| private SerializationStream serOutputStream; |
| |
| /** |
| * Are we in the process of stopping? Because map tasks can call stop() with success = true |
| * and then call stop() with success = false if they get an exception, we want to make sure |
| * we don't try deleting files, etc twice. |
| */ |
| private boolean stopping = false; |
| |
| public UnsafeShuffleWriter( |
| BlockManager blockManager, |
| TaskMemoryManager memoryManager, |
| SerializedShuffleHandle<K, V> handle, |
| long mapId, |
| TaskContext taskContext, |
| SparkConf sparkConf, |
| ShuffleWriteMetricsReporter writeMetrics, |
| ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException { |
| final int numPartitions = handle.dependency().partitioner().numPartitions(); |
| if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { |
| throw new IllegalArgumentException( |
| "UnsafeShuffleWriter can only be used for shuffles with at most " + |
| SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + |
| " reduce partitions"); |
| } |
| this.blockManager = blockManager; |
| this.memoryManager = memoryManager; |
| this.mapId = mapId; |
| final ShuffleDependency<K, V, V> dep = handle.dependency(); |
| this.shuffleId = dep.shuffleId(); |
| this.serializer = dep.serializer().newInstance(); |
| this.partitioner = dep.partitioner(); |
| this.writeMetrics = writeMetrics; |
| this.shuffleExecutorComponents = shuffleExecutorComponents; |
| this.taskContext = taskContext; |
| this.sparkConf = sparkConf; |
| this.transferToEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_MERGE_PREFER_NIO()); |
| this.initialSortBufferSize = |
| (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); |
| this.inputBufferSizeInBytes = |
| (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; |
| open(); |
| } |
| |
| private void updatePeakMemoryUsed() { |
| // sorter can be null if this writer is closed |
| if (sorter != null) { |
| long mem = sorter.getPeakMemoryUsedBytes(); |
| if (mem > peakMemoryUsedBytes) { |
| peakMemoryUsedBytes = mem; |
| } |
| } |
| } |
| |
| /** |
| * Return the peak memory used so far, in bytes. |
| */ |
| public long getPeakMemoryUsedBytes() { |
| updatePeakMemoryUsed(); |
| return peakMemoryUsedBytes; |
| } |
| |
| /** |
| * This convenience method should only be called in test code. |
| */ |
| @VisibleForTesting |
| public void write(Iterator<Product2<K, V>> records) throws IOException { |
| write(CollectionConverters.asScala(records)); |
| } |
| |
| @Override |
| public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException { |
| // Keep track of success so we know if we encountered an exception |
| // We do this rather than a standard try/catch/re-throw to handle |
| // generic throwables. |
| boolean success = false; |
| try { |
| while (records.hasNext()) { |
| insertRecordIntoSorter(records.next()); |
| } |
| closeAndWriteOutput(); |
| success = true; |
| } finally { |
| if (sorter != null) { |
| try { |
| sorter.cleanupResources(); |
| } catch (Exception e) { |
| // Only throw this error if we won't be masking another |
| // error. |
| if (success) { |
| throw e; |
| } else { |
| logger.error("In addition to a failure during writing, we failed during " + |
| "cleanup.", e); |
| } |
| } |
| } |
| } |
| } |
| |
| private void open() throws SparkException { |
| assert (sorter == null); |
| sorter = new ShuffleExternalSorter( |
| memoryManager, |
| blockManager, |
| taskContext, |
| initialSortBufferSize, |
| partitioner.numPartitions(), |
| sparkConf, |
| writeMetrics); |
| serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); |
| serOutputStream = serializer.serializeStream(serBuffer); |
| } |
| |
| @VisibleForTesting |
| void closeAndWriteOutput() throws IOException { |
| assert(sorter != null); |
| updatePeakMemoryUsed(); |
| serBuffer = null; |
| serOutputStream = null; |
| final SpillInfo[] spills = sorter.closeAndGetSpills(); |
| try { |
| partitionLengths = mergeSpills(spills); |
| } finally { |
| sorter = null; |
| for (SpillInfo spill : spills) { |
| if (spill.file.exists() && !spill.file.delete()) { |
| logger.error("Error while deleting spill file {}", |
| MDC.of(LogKeys.PATH$.MODULE$, spill.file.getPath())); |
| } |
| } |
| } |
| mapStatus = MapStatus$.MODULE$.apply( |
| blockManager.shuffleServerId(), partitionLengths, mapId); |
| } |
| |
| @VisibleForTesting |
| void insertRecordIntoSorter(Product2<K, V> record) throws IOException { |
| assert(sorter != null); |
| final K key = record._1(); |
| final int partitionId = partitioner.getPartition(key); |
| serBuffer.reset(); |
| serOutputStream.writeKey(key, OBJECT_CLASS_TAG); |
| serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); |
| serOutputStream.flush(); |
| |
| final int serializedRecordSize = serBuffer.size(); |
| assert (serializedRecordSize > 0); |
| |
| sorter.insertRecord( |
| serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); |
| } |
| |
| @VisibleForTesting |
| void forceSorterToSpill() throws IOException { |
| assert (sorter != null); |
| sorter.spill(); |
| } |
| |
| /** |
| * Merge zero or more spill files together, choosing the fastest merging strategy based on the |
| * number of spills and the IO compression codec. |
| * |
| * @return the partition lengths in the merged file. |
| */ |
| private long[] mergeSpills(SpillInfo[] spills) throws IOException { |
| long[] partitionLengths; |
| if (spills.length == 0) { |
| final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents |
| .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); |
| return mapWriter.commitAllPartitions( |
| ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); |
| } else if (spills.length == 1) { |
| Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter = |
| shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId); |
| if (maybeSingleFileWriter.isPresent()) { |
| // Here, we don't need to perform any metrics updates because the bytes written to this |
| // output file would have already been counted as shuffle bytes written. |
| partitionLengths = spills[0].partitionLengths; |
| logger.debug("Merge shuffle spills for mapId {} with length {}", mapId, |
| partitionLengths.length); |
| maybeSingleFileWriter.get() |
| .transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums()); |
| } else { |
| partitionLengths = mergeSpillsUsingStandardWriter(spills); |
| } |
| } else { |
| partitionLengths = mergeSpillsUsingStandardWriter(spills); |
| } |
| return partitionLengths; |
| } |
| |
| private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException { |
| long[] partitionLengths; |
| final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); |
| final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); |
| final boolean fastMergeEnabled = |
| (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); |
| final boolean fastMergeIsSupported = !compressionEnabled || |
| CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); |
| final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); |
| final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents |
| .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); |
| try { |
| // There are multiple spills to merge, so none of these spill files' lengths were counted |
| // towards our shuffle write count or shuffle write time. If we use the slow merge path, |
| // then the final output file's size won't necessarily be equal to the sum of the spill |
| // files' sizes. To guard against this case, we look at the output file's actual size when |
| // computing shuffle bytes written. |
| // |
| // We allow the individual merge methods to report their own IO times since different merge |
| // strategies use different IO techniques. We count IO during merge towards the shuffle |
| // write time, which appears to be consistent with the "not bypassing merge-sort" branch in |
| // ExternalSorter. |
| if (fastMergeEnabled && fastMergeIsSupported) { |
| // Compression is disabled or we are using an IO compression codec that supports |
| // decompression of concatenated compressed streams, so we can perform a fast spill merge |
| // that doesn't need to interpret the spilled bytes. |
| if (transferToEnabled && !encryptionEnabled) { |
| logger.debug("Using transferTo-based fast merge"); |
| mergeSpillsWithTransferTo(spills, mapWriter); |
| } else { |
| logger.debug("Using fileStream-based fast merge"); |
| mergeSpillsWithFileStream(spills, mapWriter, null); |
| } |
| } else { |
| logger.debug("Using slow merge"); |
| mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); |
| } |
| partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); |
| } catch (Exception e) { |
| try { |
| mapWriter.abort(e); |
| } catch (Exception e2) { |
| logger.warn("Failed to abort writing the map output.", e2); |
| e.addSuppressed(e2); |
| } |
| throw e; |
| } |
| return partitionLengths; |
| } |
| |
| /** |
| * Merges spill files using Java FileStreams. This code path is typically slower than |
| * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], |
| * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec |
| * does not support concatenation of compressed data, when encryption is enabled, or when |
| * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs. |
| * This code path might also be faster in cases where individual partition size in a spill |
| * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small |
| * disk ios which is inefficient. In those case, Using large buffers for input and output |
| * files helps reducing the number of disk ios, making the file merging faster. |
| * |
| * @param spills the spills to merge. |
| * @param mapWriter the map output writer to use for output. |
| * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. |
| * @return the partition lengths in the merged file. |
| */ |
| private void mergeSpillsWithFileStream( |
| SpillInfo[] spills, |
| ShuffleMapOutputWriter mapWriter, |
| @Nullable CompressionCodec compressionCodec) throws IOException { |
| logger.debug("Merge shuffle spills with FileStream for mapId {}", mapId); |
| final int numPartitions = partitioner.numPartitions(); |
| final InputStream[] spillInputStreams = new InputStream[spills.length]; |
| |
| boolean threwException = true; |
| try { |
| for (int i = 0; i < spills.length; i++) { |
| spillInputStreams[i] = new NioBufferedFileInputStream( |
| spills[i].file, |
| inputBufferSizeInBytes); |
| // Only convert the partitionLengths when debug level is enabled. |
| if (logger.isDebugEnabled()) { |
| logger.debug("Partition lengths for mapId {} in Spill {}: {}", mapId, i, |
| Arrays.toString(spills[i].partitionLengths)); |
| } |
| } |
| for (int partition = 0; partition < numPartitions; partition++) { |
| boolean copyThrewException = true; |
| ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); |
| OutputStream partitionOutput = writer.openStream(); |
| try { |
| partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); |
| partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); |
| if (compressionCodec != null) { |
| partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); |
| } |
| for (int i = 0; i < spills.length; i++) { |
| final long partitionLengthInSpill = spills[i].partitionLengths[partition]; |
| |
| if (partitionLengthInSpill > 0) { |
| InputStream partitionInputStream = null; |
| boolean copySpillThrewException = true; |
| try { |
| partitionInputStream = new LimitedInputStream(spillInputStreams[i], |
| partitionLengthInSpill, false); |
| partitionInputStream = blockManager.serializerManager().wrapForEncryption( |
| partitionInputStream); |
| if (compressionCodec != null) { |
| partitionInputStream = compressionCodec.compressedInputStream( |
| partitionInputStream); |
| } |
| ByteStreams.copy(partitionInputStream, partitionOutput); |
| copySpillThrewException = false; |
| } finally { |
| Closeables.close(partitionInputStream, copySpillThrewException); |
| } |
| } |
| } |
| copyThrewException = false; |
| } finally { |
| Closeables.close(partitionOutput, copyThrewException); |
| } |
| long numBytesWritten = writer.getNumBytesWritten(); |
| writeMetrics.incBytesWritten(numBytesWritten); |
| } |
| threwException = false; |
| } finally { |
| // To avoid masking exceptions that caused us to prematurely enter the finally block, only |
| // throw exceptions during cleanup if threwException == false. |
| for (InputStream stream : spillInputStreams) { |
| Closeables.close(stream, threwException); |
| } |
| } |
| } |
| |
| /** |
| * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. |
| * This is only safe when the IO compression codec and serializer support concatenation of |
| * serialized streams. |
| * |
| * @param spills the spills to merge. |
| * @param mapWriter the map output writer to use for output. |
| * @return the partition lengths in the merged file. |
| */ |
| private void mergeSpillsWithTransferTo( |
| SpillInfo[] spills, |
| ShuffleMapOutputWriter mapWriter) throws IOException { |
| logger.debug("Merge shuffle spills with TransferTo for mapId {}", mapId); |
| final int numPartitions = partitioner.numPartitions(); |
| final FileChannel[] spillInputChannels = new FileChannel[spills.length]; |
| final long[] spillInputChannelPositions = new long[spills.length]; |
| |
| boolean threwException = true; |
| try { |
| for (int i = 0; i < spills.length; i++) { |
| spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); |
| // Only convert the partitionLengths when debug level is enabled. |
| if (logger.isDebugEnabled()) { |
| logger.debug("Partition lengths for mapId {} in Spill {}: {}", mapId, i, |
| Arrays.toString(spills[i].partitionLengths)); |
| } |
| } |
| for (int partition = 0; partition < numPartitions; partition++) { |
| boolean copyThrewException = true; |
| ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); |
| WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper() |
| .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); |
| try { |
| for (int i = 0; i < spills.length; i++) { |
| long partitionLengthInSpill = spills[i].partitionLengths[partition]; |
| final FileChannel spillInputChannel = spillInputChannels[i]; |
| final long writeStartTime = System.nanoTime(); |
| Utils.copyFileStreamNIO( |
| spillInputChannel, |
| resolvedChannel.channel(), |
| spillInputChannelPositions[i], |
| partitionLengthInSpill); |
| copyThrewException = false; |
| spillInputChannelPositions[i] += partitionLengthInSpill; |
| writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); |
| } |
| } finally { |
| Closeables.close(resolvedChannel, copyThrewException); |
| } |
| long numBytes = writer.getNumBytesWritten(); |
| writeMetrics.incBytesWritten(numBytes); |
| } |
| threwException = false; |
| } finally { |
| // To avoid masking exceptions that caused us to prematurely enter the finally block, only |
| // throw exceptions during cleanup if threwException == false. |
| for (int i = 0; i < spills.length; i++) { |
| assert(spillInputChannelPositions[i] == spills[i].file.length()); |
| Closeables.close(spillInputChannels[i], threwException); |
| } |
| } |
| } |
| |
| @Override |
| public Option<MapStatus> stop(boolean success) { |
| try { |
| taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); |
| |
| if (stopping) { |
| return Option.apply(null); |
| } else { |
| stopping = true; |
| if (success) { |
| if (mapStatus == null) { |
| throw new IllegalStateException("Cannot call stop(true) without having called write()"); |
| } |
| return Option.apply(mapStatus); |
| } else { |
| return Option.apply(null); |
| } |
| } |
| } finally { |
| if (sorter != null) { |
| // If sorter is non-null, then this implies that we called stop() in response to an error, |
| // so we need to clean up memory and spill files created by the sorter |
| sorter.cleanupResources(); |
| } |
| } |
| } |
| |
| private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { |
| try { |
| return writer.openStream(); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { |
| private final WritableByteChannel channel; |
| |
| StreamFallbackChannelWrapper(OutputStream fallbackStream) { |
| this.channel = Channels.newChannel(fallbackStream); |
| } |
| |
| @Override |
| public WritableByteChannel channel() { |
| return channel; |
| } |
| |
| @Override |
| public void close() throws IOException { |
| channel.close(); |
| } |
| } |
| |
| @Override |
| public long[] getPartitionLengths() { |
| return partitionLengths; |
| } |
| } |