blob: 97c539bb05a58d9f29993472e18314d1a2c75f68 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.{ExecutorService, RejectedExecutionException}
import scala.language.existentials
import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.CLASS_LOADER
import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
extends Logging {
private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
// Exposed for testing.
protected val getTaskResultExecutor: ExecutorService =
ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter")
// Exposed for testing.
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
sparkEnv.closureSerializer.newInstance()
}
}
protected val taskResultSerializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
sparkEnv.serializer.newInstance()
}
}
def enqueueSuccessfulTask(
taskSetManager: TaskSetManager,
tid: Long,
serializedData: ByteBuffer): Unit = {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
case directResult: DirectTaskResult[_] =>
if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) {
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
"Tasks result size has exceeded maxResultSize"))
return
}
// deserialize "value" without holding any lock so that it won't block other threads.
// We should call it here, so that when it's called again in
// "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
directResult.value(taskResultSerializer.get())
(directResult, serializedData.limit().toLong)
case IndirectTaskResult(blockId, size) =>
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
sparkEnv.blockManager.master.removeBlock(blockId)
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
"Tasks result size has exceeded maxResultSize"))
return
}
logDebug(s"Fetching indirect task result for ${taskSetManager.taskName(tid)}")
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (serializedTaskResult.isEmpty) {
/* We won't be able to get the task result if the machine that ran the task failed
* between when the task ended and when we tried to fetch the result, or if the
* block manager had to flush the result. */
scheduler.handleFailedTask(
taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
return
}
val deserializedResult = SerializerHelper
.deserializeFromChunkedBuffer[DirectTaskResult[_]](
serializer.get(),
serializedTaskResult.get)
// force deserialization of referenced value
deserializedResult.value(taskResultSerializer.get())
sparkEnv.blockManager.master.removeBlock(blockId)
(deserializedResult, size)
}
// Set the task result size in the accumulator updates received from the executors.
// We need to do this here on the driver because if we did this on the executors then
// we would have to serialize the result again after updating the size.
result.accumUpdates = result.accumUpdates.map { a =>
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
val acc = a.asInstanceOf[LongAccumulator]
assert(acc.sum == 0L, "task result size should not have been set on the executors")
acc.setValue(size)
acc
} else {
a
}
}
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
// Matching NonFatal so we don't catch the ControlThrowable from the "return" above.
case NonFatal(ex) =>
logError("Exception while getting task result", ex)
taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
}
})
}
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer): Unit = {
var reason : TaskFailedReason = UnknownReason
try {
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskFailedReason](
serializedData, loader)
}
} catch {
case _: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
logError(
log"Could not deserialize TaskEndReason: ClassNotFound with classloader " +
log"${MDC(CLASS_LOADER, loader)}")
case _: Exception => // No-op
} finally {
// If there's an error while deserializing the TaskEndReason, this Runnable
// will die. Still tell the scheduler about the task failure, to avoid a hang
// where the scheduler thinks the task is still running.
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
})
} catch {
case e: RejectedExecutionException if sparkEnv.isStopped =>
// ignore it
}
}
// This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
// DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
// synchronized and may hurt the throughput of the scheduler.
def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
scheduler.handlePartitionCompleted(stageId, partitionId)
})
}
def stop(): Unit = {
getTaskResultExecutor.shutdownNow()
}
}