| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.scheduler |
| |
| import java.nio.ByteBuffer |
| import java.util.concurrent.{ExecutorService, RejectedExecutionException} |
| |
| import scala.language.existentials |
| import scala.util.control.NonFatal |
| |
| import org.apache.spark._ |
| import org.apache.spark.TaskState.TaskState |
| import org.apache.spark.internal.{Logging, MDC} |
| import org.apache.spark.internal.LogKeys.CLASS_LOADER |
| import org.apache.spark.serializer.{SerializerHelper, SerializerInstance} |
| import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils} |
| |
| /** |
| * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. |
| */ |
| private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) |
| extends Logging { |
| |
| private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) |
| |
| // Exposed for testing. |
| protected val getTaskResultExecutor: ExecutorService = |
| ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter") |
| |
| // Exposed for testing. |
| protected val serializer = new ThreadLocal[SerializerInstance] { |
| override def initialValue(): SerializerInstance = { |
| sparkEnv.closureSerializer.newInstance() |
| } |
| } |
| |
| protected val taskResultSerializer = new ThreadLocal[SerializerInstance] { |
| override def initialValue(): SerializerInstance = { |
| sparkEnv.serializer.newInstance() |
| } |
| } |
| |
| def enqueueSuccessfulTask( |
| taskSetManager: TaskSetManager, |
| tid: Long, |
| serializedData: ByteBuffer): Unit = { |
| getTaskResultExecutor.execute(new Runnable { |
| override def run(): Unit = Utils.logUncaughtExceptions { |
| try { |
| val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { |
| case directResult: DirectTaskResult[_] => |
| if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) { |
| // kill the task so that it will not become zombie task |
| scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled( |
| "Tasks result size has exceeded maxResultSize")) |
| return |
| } |
| // deserialize "value" without holding any lock so that it won't block other threads. |
| // We should call it here, so that when it's called again in |
| // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. |
| directResult.value(taskResultSerializer.get()) |
| (directResult, serializedData.limit().toLong) |
| case IndirectTaskResult(blockId, size) => |
| if (!taskSetManager.canFetchMoreResults(size)) { |
| // dropped by executor if size is larger than maxResultSize |
| sparkEnv.blockManager.master.removeBlock(blockId) |
| // kill the task so that it will not become zombie task |
| scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled( |
| "Tasks result size has exceeded maxResultSize")) |
| return |
| } |
| logDebug(s"Fetching indirect task result for ${taskSetManager.taskName(tid)}") |
| scheduler.handleTaskGettingResult(taskSetManager, tid) |
| val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) |
| if (serializedTaskResult.isEmpty) { |
| /* We won't be able to get the task result if the machine that ran the task failed |
| * between when the task ended and when we tried to fetch the result, or if the |
| * block manager had to flush the result. */ |
| scheduler.handleFailedTask( |
| taskSetManager, tid, TaskState.FINISHED, TaskResultLost) |
| return |
| } |
| val deserializedResult = SerializerHelper |
| .deserializeFromChunkedBuffer[DirectTaskResult[_]]( |
| serializer.get(), |
| serializedTaskResult.get) |
| // force deserialization of referenced value |
| deserializedResult.value(taskResultSerializer.get()) |
| sparkEnv.blockManager.master.removeBlock(blockId) |
| (deserializedResult, size) |
| } |
| |
| // Set the task result size in the accumulator updates received from the executors. |
| // We need to do this here on the driver because if we did this on the executors then |
| // we would have to serialize the result again after updating the size. |
| result.accumUpdates = result.accumUpdates.map { a => |
| if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { |
| val acc = a.asInstanceOf[LongAccumulator] |
| assert(acc.sum == 0L, "task result size should not have been set on the executors") |
| acc.setValue(size) |
| acc |
| } else { |
| a |
| } |
| } |
| |
| scheduler.handleSuccessfulTask(taskSetManager, tid, result) |
| } catch { |
| case cnf: ClassNotFoundException => |
| val loader = Thread.currentThread.getContextClassLoader |
| taskSetManager.abort("ClassNotFound with classloader: " + loader) |
| // Matching NonFatal so we don't catch the ControlThrowable from the "return" above. |
| case NonFatal(ex) => |
| logError("Exception while getting task result", ex) |
| taskSetManager.abort("Exception while getting task result: %s".format(ex)) |
| } |
| } |
| }) |
| } |
| |
| def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, |
| serializedData: ByteBuffer): Unit = { |
| var reason : TaskFailedReason = UnknownReason |
| try { |
| getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions { |
| val loader = Utils.getContextOrSparkClassLoader |
| try { |
| if (serializedData != null && serializedData.limit() > 0) { |
| reason = serializer.get().deserialize[TaskFailedReason]( |
| serializedData, loader) |
| } |
| } catch { |
| case _: ClassNotFoundException => |
| // Log an error but keep going here -- the task failed, so not catastrophic |
| // if we can't deserialize the reason. |
| logError( |
| log"Could not deserialize TaskEndReason: ClassNotFound with classloader " + |
| log"${MDC(CLASS_LOADER, loader)}") |
| case _: Exception => // No-op |
| } finally { |
| // If there's an error while deserializing the TaskEndReason, this Runnable |
| // will die. Still tell the scheduler about the task failure, to avoid a hang |
| // where the scheduler thinks the task is still running. |
| scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) |
| } |
| }) |
| } catch { |
| case e: RejectedExecutionException if sparkEnv.isStopped => |
| // ignore it |
| } |
| } |
| |
| // This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want |
| // DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's |
| // synchronized and may hurt the throughput of the scheduler. |
| def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = { |
| getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions { |
| scheduler.handlePartitionCompleted(stageId, partitionId) |
| }) |
| } |
| |
| def stop(): Unit = { |
| getTaskResultExecutor.shutdownNow() |
| } |
| } |