blob: 1498b224b0c92b013978225ca962b42710855ed2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.io._
import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel}
import java.nio.channels.FileChannel.MapMode
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.mutable.ListBuffer
import com.google.common.io.Closeables
import io.netty.channel.DefaultFileRegion
import org.apache.commons.io.FileUtils
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.internal.{config, Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils}
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBuffer
/**
* Stores BlockManager blocks on disk.
*/
private[spark] class DiskStore(
conf: SparkConf,
diskManager: DiskBlockManager,
securityManager: SecurityManager) extends Logging {
private val minMemoryMapBytes = conf.get(config.STORAGE_MEMORY_MAP_THRESHOLD)
private val maxMemoryMapBytes = conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS)
private val blockSizes = new ConcurrentHashMap[BlockId, Long]()
private val shuffleServiceFetchRddEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) &&
conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)
def getSize(blockId: BlockId): Long = blockSizes.get(blockId)
/**
* Invokes the provided callback function to write the specific block.
*
* @throws IllegalStateException if the block already exists in the disk store.
*/
def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = {
if (contains(blockId)) {
logWarning(log"Block ${MDC(BLOCK_ID, blockId)} is already present in the disk store")
try {
diskManager.getFile(blockId).delete()
} catch {
case e: Exception =>
throw SparkException.internalError(
s"Block $blockId is already present in the disk store and could not delete it $e",
category = "STORAGE")
}
}
logDebug(s"Attempting to put block $blockId")
val startTimeNs = System.nanoTime()
val file = diskManager.getFile(blockId)
// SPARK-37618: If fetching cached RDDs from the shuffle service is enabled, we must make
// the file world readable, as it will not be owned by the group running the shuffle service
// in a secure environment. This is due to changing directory permissions to allow deletion,
if (shuffleServiceFetchRddEnabled) {
diskManager.createWorldReadableFile(file)
}
val out = new CountingWritableChannel(openForWrite(file))
var threwException: Boolean = true
try {
writeFunc(out)
blockSizes.put(blockId, out.getCount)
threwException = false
} finally {
try {
out.close()
} catch {
case ioe: IOException =>
if (!threwException) {
threwException = true
throw ioe
}
} finally {
if (threwException) {
remove(blockId)
}
}
}
logDebug(s"Block ${file.getName} stored as ${Utils.bytesToString(file.length())} file" +
s" on disk in ${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
}
def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
put(blockId) { channel =>
bytes.writeFully(channel)
}
}
def getBytes(blockId: BlockId): BlockData = {
getBytes(diskManager.getFile(blockId.name), getSize(blockId))
}
def getBytes(f: File, blockSize: Long): BlockData = securityManager.getIOEncryptionKey() match {
case Some(key) =>
// Encrypted blocks cannot be memory mapped; return a special object that does decryption
// and provides InputStream / FileRegion implementations for reading the data.
new EncryptedBlockData(f, blockSize, conf, key)
case _ =>
new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, f, blockSize)
}
def remove(blockId: BlockId): Boolean = {
blockSizes.remove(blockId)
val file = diskManager.getFile(blockId.name)
if (file.exists()) {
val ret = file.delete()
if (!ret) {
logWarning(log"Error deleting ${MDC(PATH, file.getPath())}")
}
ret
} else {
false
}
}
/**
* @param blockSize if encryption is configured, the file is assumed to already be encrypted and
* blockSize should be the decrypted size
*/
def moveFileToBlock(sourceFile: File, blockSize: Long, targetBlockId: BlockId): Unit = {
blockSizes.put(targetBlockId, blockSize)
val targetFile = diskManager.getFile(targetBlockId.name)
logDebug(s"${sourceFile.getPath()} -> ${targetFile.getPath()}")
FileUtils.moveFile(sourceFile, targetFile)
}
def contains(blockId: BlockId): Boolean = diskManager.containsBlock(blockId)
private def openForWrite(file: File): WritableByteChannel = {
val out = new FileOutputStream(file).getChannel()
try {
securityManager.getIOEncryptionKey().map { key =>
CryptoStreamUtils.createWritableChannel(out, conf, key)
}.getOrElse(out)
} catch {
case e: Exception =>
Closeables.close(out, true)
file.delete()
throw e
}
}
}
private class DiskBlockData(
minMemoryMapBytes: Long,
maxMemoryMapBytes: Long,
file: File,
blockSize: Long) extends BlockData {
override def toInputStream(): InputStream = new FileInputStream(file)
/**
* Returns a Netty-friendly wrapper for the block's data.
*
* Please see `ManagedBuffer.convertToNetty()` for more details.
*/
override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size)
/**
* Returns a Netty-friendly wrapper for the block's data.
*
* Please see `ManagedBuffer.convertToNettyForSsl()` for more details.
*/
override def toNettyForSsl(): AnyRef =
toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl
override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = {
Utils.tryWithResource(open()) { channel =>
var remaining = blockSize
val chunks = new ListBuffer[ByteBuffer]()
while (remaining > 0) {
val chunkSize = math.min(remaining, maxMemoryMapBytes)
val chunk = allocator(chunkSize.toInt)
remaining -= chunkSize
JavaUtils.readFully(channel, chunk)
chunk.flip()
chunks += chunk
}
new ChunkedByteBuffer(chunks.toArray)
}
}
override def toByteBuffer(): ByteBuffer = {
require(blockSize < maxMemoryMapBytes,
s"can't create a byte buffer of size $blockSize" +
s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.")
Utils.tryWithResource(open()) { channel =>
if (blockSize < minMemoryMapBytes) {
// For small files, directly read rather than memory map.
val buf = ByteBuffer.allocate(blockSize.toInt)
JavaUtils.readFully(channel, buf)
buf.flip()
buf
} else {
channel.map(MapMode.READ_ONLY, 0, file.length)
}
}
}
override def size: Long = blockSize
override def dispose(): Unit = {}
private def open() = new FileInputStream(file).getChannel
}
private[spark] class EncryptedBlockData(
file: File,
blockSize: Long,
conf: SparkConf,
key: Array[Byte]) extends BlockData {
override def toInputStream(): InputStream = Channels.newInputStream(open())
override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize)
override def toNettyForSsl(): AnyRef =
toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl
override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
val source = open()
try {
var remaining = blockSize
val chunks = new ListBuffer[ByteBuffer]()
while (remaining > 0) {
val chunkSize = math.min(remaining, ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)
val chunk = allocator(chunkSize.toInt)
remaining -= chunkSize
JavaUtils.readFully(source, chunk)
chunk.flip()
chunks += chunk
}
new ChunkedByteBuffer(chunks.toArray)
} finally {
source.close()
}
}
override def toByteBuffer(): ByteBuffer = {
// This is used by the block transfer service to replicate blocks. The upload code reads
// all bytes into memory to send the block to the remote executor, so it's ok to do this
// as long as the block fits in a Java array.
assert(blockSize <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH,
"Block is too large to be wrapped in a byte buffer.")
val dst = ByteBuffer.allocate(blockSize.toInt)
val in = open()
try {
JavaUtils.readFully(in, dst)
dst.flip()
dst
} finally {
Closeables.close(in, true)
}
}
override def size: Long = blockSize
override def dispose(): Unit = { }
private def open(): ReadableByteChannel = {
val channel = new FileInputStream(file).getChannel()
try {
CryptoStreamUtils.createReadableChannel(channel, conf, key)
} catch {
case e: Exception =>
Closeables.close(channel, true)
throw e
}
}
}
private[spark] class EncryptedManagedBuffer(
val blockData: EncryptedBlockData) extends ManagedBuffer {
// This is the size of the decrypted data
override def size(): Long = blockData.size
override def nioByteBuffer(): ByteBuffer = blockData.toByteBuffer()
override def convertToNetty(): AnyRef = blockData.toNetty()
override def convertToNettyForSsl(): AnyRef = blockData.toNettyForSsl()
override def createInputStream(): InputStream = blockData.toInputStream()
override def retain(): ManagedBuffer = this
override def release(): ManagedBuffer = this
}
private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long)
extends AbstractFileRegion {
private var _transferred = 0L
private val buffer = ByteBuffer.allocateDirect(64 * 1024)
buffer.flip()
override def count(): Long = blockSize
override def position(): Long = 0
override def transferred(): Long = _transferred
override def transferTo(target: WritableByteChannel, pos: Long): Long = {
assert(pos == transferred(), "Invalid position.")
var written = 0L
var lastWrite = -1L
while (lastWrite != 0) {
if (!buffer.hasRemaining()) {
buffer.clear()
source.read(buffer)
buffer.flip()
}
if (buffer.hasRemaining()) {
lastWrite = target.write(buffer)
written += lastWrite
} else {
lastWrite = 0
}
}
_transferred += written
written
}
override def deallocate(): Unit = source.close()
}
private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel {
private var count = 0L
def getCount: Long = count
override def write(src: ByteBuffer): Int = {
val written = sink.write(src)
if (written > 0) {
count += written
}
written
}
override def isOpen(): Boolean = sink.isOpen()
override def close(): Unit = sink.close()
}