blob: 68c38fb6179f5d27be119a53b5039bbead99fd8c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.executor
import java.io.{File, NotSerializableException}
import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
import java.util.{Locale, Properties}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import com.google.common.cache.{Cache, CacheBuilder, RemovalListener, RemovalNotification}
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.slf4j.MDC
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.{Logging, LogKeys, MDC => LogMDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config._
import org.apache.spark.internal.plugin.PluginContainer
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
import org.apache.spark.serializer.SerializerHelper
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.status.api.v1.ThreadStackTrace
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
import org.apache.spark.util.ArrayImplicits._
private[spark] class IsolatedSessionState(
val sessionUUID: String,
var urlClassLoader: MutableURLClassLoader,
var replClassLoader: ClassLoader,
val currentFiles: HashMap[String, Long],
val currentJars: HashMap[String, Long],
val currentArchives: HashMap[String, Long],
val replClassDirUri: Option[String])
/**
* Spark executor, backed by a threadpool to run tasks.
*
* This can be used with YARN, kubernetes and the standalone scheduler.
* An internal RPC interface is used for communication with the driver.
*/
private[spark] class Executor(
executorId: String,
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false,
uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler,
resources: immutable.Map[String, ResourceInformation])
extends Logging {
logInfo(s"Starting executor ID $executorId on host $executorHostname")
logInfo(s"OS info ${System.getProperty("os.name")}, ${System.getProperty("os.version")}, " +
s"${System.getProperty("os.arch")}")
logInfo(s"Java version ${System.getProperty("java.version")}")
private val executorShutdown = new AtomicBoolean(false)
val stopHookReference = ShutdownHookManager.addShutdownHook(
() => stop()
)
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
private[executor] val conf = env.conf
// SPARK-48131: Unify MDC key mdc.taskName and task_name in Spark 4.0 release.
private[executor] val taskNameMDCKey = if (conf.get(LEGACY_TASK_NAME_MDC_ENABLED)) {
"mdc.taskName"
} else {
LogKeys.TASK_NAME.name
}
// SPARK-40235: updateDependencies() uses a ReentrantLock instead of the `synchronized` keyword
// so that tasks can exit quickly if they are interrupted while waiting on another task to
// finish downloading dependencies.
private val updateDependenciesLock = new ReentrantLock()
// No ip or host:port - just hostname
Utils.checkHost(executorHostname)
// must not have port specified.
assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(executorHostname)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
}
// Start worker thread pool
// Use UninterruptibleThread to run tasks so that we can allow running codes without being
// interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
// will hang forever if some methods are interrupted.
private[executor] val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("Executor task launch worker-%d")
.setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
.build()
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
}
private val schemes = conf.get(EXECUTOR_METRICS_FILESYSTEM_SCHEMES)
.toLowerCase(Locale.ROOT).split(",").map(_.trim).filter(_.nonEmpty)
private val executorSource = new ExecutorSource(threadPool, executorId, schemes)
// Pool used for threads that supervise task killing / cancellation
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
// For tasks which are in the process of being killed, this map holds the most recently created
// TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
// a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
// the integrity of the map's internal state). The purpose of this map is to prevent the creation
// of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
// track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
// create. The map key is a task id.
private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()
val executorMetricsSource =
if (conf.get(METRICS_EXECUTORMETRICS_SOURCE_ENABLED)) {
Some(new ExecutorMetricsSource)
} else {
None
}
if (!isLocal) {
env.blockManager.initialize(conf.getAppId)
env.metricsSystem.registerSource(executorSource)
env.metricsSystem.registerSource(new JVMCPUSource())
executorMetricsSource.foreach(_.register(env.metricsSystem))
env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource)
} else {
// This enable the registration of the executor source in local mode.
// The actual registration happens in SparkContext,
// it cannot be done here as the appId is not available yet
Executor.executorSourceLocalModeOnly = executorSource
}
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst = conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
// Whether to monitor killed / interrupted tasks
private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED)
private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH)
private val systemLoader = Utils.getContextOrSparkClassLoader
private def newSessionState(jobArtifactState: JobArtifactState): IsolatedSessionState = {
val currentFiles = new HashMap[String, Long]
val currentJars = new HashMap[String, Long]
val currentArchives = new HashMap[String, Long]
val urlClassLoader =
createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid))
val replClassLoader = addReplClassLoaderIfNeeded(
urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid)
new IsolatedSessionState(
jobArtifactState.uuid, urlClassLoader, replClassLoader,
currentFiles,
currentJars,
currentArchives,
jobArtifactState.replClassDirUri
)
}
private def isStubbingEnabledForState(name: String) = {
!isDefaultState(name) &&
conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty
}
private def isDefaultState(name: String) = name == "default"
// Classloader isolation
// The default isolation group
val defaultSessionState: IsolatedSessionState = newSessionState(JobArtifactState("default", None))
val isolatedSessionCache: Cache[String, IsolatedSessionState] = CacheBuilder.newBuilder()
.maximumSize(100)
.expireAfterAccess(30, TimeUnit.MINUTES)
.removalListener(new RemovalListener[String, IsolatedSessionState]() {
override def onRemoval(
notification: RemovalNotification[String, IsolatedSessionState]): Unit = {
val state = notification.getValue
// Cache is always used for isolated sessions.
assert(!isDefaultState(state.sessionUUID))
val sessionBasedRoot = new File(SparkFiles.getRootDirectory(), state.sessionUUID)
if (sessionBasedRoot.isDirectory && sessionBasedRoot.exists()) {
Utils.deleteRecursively(sessionBasedRoot)
}
logInfo(s"Session evicted: ${state.sessionUUID}")
}
})
.build[String, IsolatedSessionState]
// Set the classloader for serializer
env.serializer.setDefaultClassLoader(defaultSessionState.replClassLoader)
// SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads
// for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too.
env.serializerManager.setDefaultClassLoader(defaultSessionState.replClassLoader)
// Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back. This is guaranteed to be smaller than array bytes limit (2GB)
private val maxDirectResultSize = Math.min(
conf.get(TASK_MAX_DIRECT_RESULT_SIZE),
RpcUtils.maxMessageSizeBytes(conf))
private val maxResultSize = conf.get(MAX_RESULT_SIZE)
// Maintains the list of running tasks.
private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
// Kill mark TTL in milliseconds - 10 seconds.
private val KILL_MARK_TTL_MS = 10000L
// Kill marks with interruptThread flag, kill reason and timestamp.
// This is to avoid dropping the kill event when killTask() is called before launchTask().
private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean, String, Long)]
private val killMarkCleanupTask = new Runnable {
override def run(): Unit = {
val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS
val iter = killMarks.entrySet().iterator()
while (iter.hasNext) {
if (iter.next().getValue._3 < oldest) {
iter.remove()
}
}
}
}
// Kill mark cleanup thread executor.
private val killMarkCleanupService =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("executor-kill-mark-cleanup")
killMarkCleanupService.scheduleAtFixedRate(
killMarkCleanupTask, KILL_MARK_TTL_MS, KILL_MARK_TTL_MS, TimeUnit.MILLISECONDS)
/**
* When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES`
* times, it should kill itself. The default value is 60. For example, if max failures is 60 and
* heartbeat interval is 10s, then it will try to send heartbeats for up to 600s (10 minutes).
*/
private val HEARTBEAT_MAX_FAILURES = conf.get(EXECUTOR_HEARTBEAT_MAX_FAILURES)
/**
* Whether to drop empty accumulators from heartbeats sent to the driver. Including the empty
* accumulators (that satisfy isZero) can make the size of the heartbeat message very large.
*/
private val HEARTBEAT_DROP_ZEROES = conf.get(EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES)
/**
* Interval to send heartbeats, in milliseconds
*/
private val HEARTBEAT_INTERVAL_MS = conf.get(EXECUTOR_HEARTBEAT_INTERVAL)
/**
* Interval to poll for executor metrics, in milliseconds
*/
private val METRICS_POLLING_INTERVAL_MS = conf.get(EXECUTOR_METRICS_POLLING_INTERVAL)
private val pollOnHeartbeat = if (METRICS_POLLING_INTERVAL_MS > 0) false else true
// Poller for the memory metrics. Visible for testing.
private[executor] val metricsPoller = new ExecutorMetricsPoller(
env.memoryManager,
METRICS_POLLING_INTERVAL_MS,
executorMetricsSource)
// Executor for the heartbeat task.
private val heartbeater = new Heartbeater(
() => Executor.this.reportHeartBeat(),
"executor-heartbeater",
HEARTBEAT_INTERVAL_MS)
// must be initialized before running startDriverHeartbeat()
private val heartbeatReceiverRef =
RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
/**
* Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each
* successful heartbeat will reset it to 0.
*/
private var heartbeatFailures = 0
/**
* Flag to prevent launching new tasks while decommissioned. There could be a race condition
* accessing this, but decommissioning is only intended to help not be a hard stop.
*/
private var decommissioned = false
heartbeater.start()
private val appStartTime = conf.getLong("spark.app.startTime", 0)
// To allow users to distribute plugins and their required files
// specified by --jars, --files and --archives on application submission, those
// jars/files/archives should be downloaded and added to the class loader via
// updateDependencies. This should be done before plugin initialization below
// because executors search plugins from the class loader and initialize them.
private val Seq(initialUserJars, initialUserFiles, initialUserArchives) =
Seq("jar", "file", "archive").map { key =>
conf.getOption(s"spark.app.initial.$key.urls").map { urls =>
import org.apache.spark.util.ArrayImplicits._
immutable.Map(urls.split(",").map(url => (url, appStartTime)).toImmutableArraySeq: _*)
}.getOrElse(immutable.Map.empty)
}
updateDependencies(initialUserFiles, initialUserJars, initialUserArchives, defaultSessionState)
// Plugins and shuffle managers need to load using a class loader that includes the executor's
// user classpath. Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175 and SPARK-45762).
private val plugins: Option[PluginContainer] =
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
PluginContainer(env, resources.asJava)
}
// Skip local mode because the ShuffleManager is already initialized
if (!isLocal) {
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
env.initializeShuffleManager()
}
}
metricsPoller.start()
private[executor] def numRunningTasks: Int = runningTasks.size()
/**
* Mark an executor for decommissioning and avoid launching new tasks.
*/
private[spark] def decommission(): Unit = {
decommissioned = true
}
private[executor] def createTaskRunner(context: ExecutorBackend,
taskDescription: TaskDescription) = new TaskRunner(context, taskDescription, plugins)
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val taskId = taskDescription.taskId
val tr = createTaskRunner(context, taskDescription)
runningTasks.put(taskId, tr)
val killMark = killMarks.get(taskId)
if (killMark != null) {
tr.kill(killMark._1, killMark._2)
killMarks.remove(taskId)
}
threadPool.execute(tr)
if (decommissioned) {
log.error(s"Launching a task while in decommissioned state.")
}
}
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
killMarks.put(taskId, (interruptThread, reason, System.currentTimeMillis()))
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
val shouldCreateReaper = taskReaperForTask.get(taskId) match {
case None => true
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
}
if (shouldCreateReaper) {
val taskReaper = new TaskReaper(
taskRunner, interruptThread = interruptThread, reason = reason)
taskReaperForTask(taskId) = taskReaper
Some(taskReaper)
} else {
None
}
}
// Execute the TaskReaper from outside of the synchronized block.
maybeNewTaskReaper.foreach(taskReaperPool.execute)
} else {
taskRunner.kill(interruptThread = interruptThread, reason = reason)
}
// Safe to remove kill mark as we got a chance with the TaskRunner.
killMarks.remove(taskId)
}
}
/**
* Function to kill the running tasks in an executor.
* This can be called by executor back-ends to kill the
* tasks instead of taking the JVM down.
* @param interruptThread whether to interrupt the task thread
*/
def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
runningTasks.keys().asScala.foreach(t =>
killTask(t, interruptThread = interruptThread, reason = reason))
}
def stop(): Unit = {
if (!executorShutdown.getAndSet(true)) {
ShutdownHookManager.removeShutdownHook(stopHookReference)
env.metricsSystem.report()
try {
if (metricsPoller != null) {
metricsPoller.stop()
}
} catch {
case NonFatal(e) =>
logWarning("Unable to stop executor metrics poller", e)
}
try {
if (heartbeater != null) {
heartbeater.stop()
}
} catch {
case NonFatal(e) =>
logWarning("Unable to stop heartbeater", e)
}
ShuffleBlockPusher.stop()
if (threadPool != null) {
threadPool.shutdown()
}
if (killMarkCleanupService != null) {
killMarkCleanupService.shutdown()
}
if (defaultSessionState != null && plugins != null) {
// Notify plugins that executor is shutting down so they can terminate cleanly
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
plugins.foreach(_.shutdown())
}
}
if (!isLocal) {
env.stop()
}
}
}
/** Returns the total amount of time this JVM process has spent in garbage collection. */
private def computeTotalGcTime(): Long = {
ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum
}
class TaskRunner(
execBackend: ExecutorBackend,
val taskDescription: TaskDescription,
private val plugins: Option[PluginContainer])
extends Runnable {
val taskId = taskDescription.taskId
val taskName = taskDescription.name
val threadName = s"Executor task launch worker for $taskName"
val mdcProperties = taskDescription.properties.asScala
.filter(_._1.startsWith("mdc.")).toSeq
/** If specified, this task has been killed and this option contains the reason. */
@volatile private var reasonIfKilled: Option[String] = None
@volatile private var threadId: Long = -1
def getThreadId: Long = threadId
/** Whether this task has been finished. */
@GuardedBy("TaskRunner.this")
private var finished = false
def isFinished: Boolean = synchronized { finished }
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
/**
* The task to run. This will be set in run() by deserializing the task binary coming
* from the driver. Once it is set, it will never be changed.
*/
@volatile var task: Task[Any] = _
def kill(interruptThread: Boolean, reason: String): Unit = {
logInfo(s"Executor is trying to kill $taskName, reason: $reason")
reasonIfKilled = Some(reason)
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread, reason)
}
}
}
}
/**
* Set the finished flag to true and clear the current thread's interrupt status
*/
private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized {
this.finished = true
// SPARK-14234 - Reset the interrupted status of the thread to avoid the
// ClosedByInterruptException during execBackend.statusUpdate which causes
// Executor to crash
Thread.interrupted()
// Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
// is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
// is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
notifyAll()
}
/**
* Utility function to:
* 1. Report executor runtime and JVM gc time if possible
* 2. Collect accumulator updates
* 3. Set the finished flag to true and clear current thread's interrupt status
*/
private def collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs: Long) = {
// Report executor runtime and JVM gc time
Option(task).foreach(t => {
t.metrics.setExecutorRunTime(TimeUnit.NANOSECONDS.toMillis(
// SPARK-32898: it's possible that a task is killed when taskStartTimeNs has the initial
// value(=0) still. In this case, the executorRunTime should be considered as 0.
if (taskStartTimeNs > 0) System.nanoTime() - taskStartTimeNs else 0))
t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
})
// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty)
val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
setTaskFinishedAndClearInterruptStatus()
(accums, accUpdates)
}
override def run(): Unit = {
// Classloader isolation
val isolatedSession = taskDescription.artifacts.state match {
case Some(jobArtifactState) =>
isolatedSessionCache.get(jobArtifactState.uuid, () => newSessionState(jobArtifactState))
case _ => defaultSessionState
}
setMDCForTask(taskName, mdcProperties)
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTimeNs = System.nanoTime()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStartTimeNs: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
var taskStarted: Boolean = false
try {
// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
Executor.taskDeserializationProps.set(taskDescription.properties)
updateDependencies(
taskDescription.artifacts.files,
taskDescription.artifacts.jars,
taskDescription.artifacts.archives,
isolatedSession)
// Always reset the thread class loader to ensure if any updates, all threads (not only
// the thread that updated the dependencies) can update to the new class loader.
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
val killReason = reasonIfKilled
if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException(killReason.get)
}
// The purpose of updating the epoch here is to invalidate executor map output status cache
// in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
// MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
// we don't need to make any special calls here.
if (!isLocal) {
logDebug(s"$taskName's epoch is ${task.epoch}")
env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
}
metricsPoller.onTaskStart(taskId, task.stageId, task.stageAttemptId)
taskStarted = true
// Run the actual task and measure its runtime.
taskStartTimeNs = System.nanoTime()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
// Convert resources amounts info to ResourceInformation
val resources = taskDescription.resources.map { case (rName, addressesAmounts) =>
rName -> new ResourceInformation(rName, addressesAmounts.keys.toSeq.sorted.toArray)
}
val value = Utils.tryWithSafeFinally {
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem,
cpus = taskDescription.cpus,
resources = resources,
plugins = plugins)
threwException = false
res
} {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0 && !threwException) {
val errMsg = log"Managed memory leak detected; size = " +
log"${LogMDC(NUM_BYTES, freedMemory)} bytes, ${LogMDC(TASK_NAME, taskName)}"
if (conf.get(UNSAFE_EXCEPTION_ON_MEMORY_LEAK)) {
throw SparkException.internalError(errMsg.message, category = "EXECUTOR")
} else {
logWarning(errMsg)
}
}
if (releasedLocks.nonEmpty && !threwException) {
val errMsg =
s"${releasedLocks.size} block locks were not released by $taskName\n" +
releasedLocks.mkString("[", ", ", "]")
if (conf.get(STORAGE_EXCEPTION_PIN_LEAK)) {
throw SparkException.internalError(errMsg, category = "EXECUTOR")
} else {
logInfo(errMsg)
}
}
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(log"${LogMDC(TASK_NAME, taskName)} completed successfully though internally " +
log"it encountered unrecoverable fetch failures! Most likely this means user code " +
log"is incorrectly swallowing Spark's internal " +
log"${LogMDC(CLASS_NAME, classOf[FetchFailedException])}", fetchFailure)
}
val taskFinishNs = System.nanoTime()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
// If the task has been killed, let's fail it.
task.context.killTaskIfInterrupted()
val resultSer = env.serializer.newInstance()
val beforeSerializationNs = System.nanoTime()
val valueByteBuffer = SerializerHelper.serializeToChunkedBuffer(resultSer, value)
val afterSerializationNs = System.nanoTime()
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
task.metrics.setExecutorDeserializeTime(TimeUnit.NANOSECONDS.toMillis(
(taskStartTimeNs - deserializeStartTimeNs) + task.executorDeserializeTimeNs))
task.metrics.setExecutorDeserializeCpuTime(
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
task.metrics.setExecutorRunTime(TimeUnit.NANOSECONDS.toMillis(
(taskFinishNs - taskStartTimeNs) - task.executorDeserializeTimeNs))
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.metrics.setResultSerializationTime(TimeUnit.NANOSECONDS.toMillis(
afterSerializationNs - beforeSerializationNs))
// Expose task metrics using the Dropwizard metrics system.
// Update task metrics counters
executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
executorSource.METRIC_INPUT_BYTES_READ
.inc(task.metrics.inputMetrics.bytesRead)
executorSource.METRIC_INPUT_RECORDS_READ
.inc(task.metrics.inputMetrics.recordsRead)
executorSource.METRIC_OUTPUT_BYTES_WRITTEN
.inc(task.metrics.outputMetrics.bytesWritten)
executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
.inc(task.metrics.outputMetrics.recordsWritten)
executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
incrementShuffleMetrics(executorSource, task.metrics)
// Note: accumulator updates must be collected after TaskMetrics is updated
val accumUpdates = task.collectAccumulatorUpdates()
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueByteBuffer, accumUpdates, metricPeaks)
// try to estimate a reasonable upper bound of DirectTaskResult serialization
val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult,
valueByteBuffer.size + accumUpdates.size * 32 + metricPeaks.length * 8)
val resultSize = serializedDirectResult.size
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(log"Finished ${LogMDC(TASK_NAME, taskName)}. " +
log"Result is larger than maxResultSize " +
log"(${LogMDC(RESULT_SIZE_BYTES, Utils.bytesToString(resultSize))} > " +
log"${LogMDC(RESULT_SIZE_BYTES_MAX, Utils.bytesToString(maxResultSize))}), " +
log"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize > maxDirectResultSize) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
serializedDirectResult,
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
// toByteBuffer is safe here, guarded by maxDirectResultSize
serializedDirectResult.toByteBuffer
}
}
executorSource.SUCCEEDED_TASKS.inc(1L)
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskSucceeded())
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName, reason: ${t.reason}")
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
// Here and below, put task metric peaks in an immutable.ArraySeq to expose them as an
// immutable.Seq without requiring a copy.
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName, reason: $killReason")
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnFatalErrorDepth) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
logWarning(log"${LogMDC(TASK_NAME, taskName)} encountered a " +
log"${LogMDC(CLASS_NAME, classOf[FetchFailedException].getName)} " +
log"and failed, but the " +
log"${LogMDC(CLASS_NAME, classOf[FetchFailedException].getName)} " +
log"was hidden by another exception. Spark is handling this like a fetch failure " +
log"and ignoring the other exception: ${LogMDC(ERROR, t)}")
}
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskCommitDeniedReason
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
case t: Throwable if env.isStopped =>
// Log the expected exception after executor.stop without stack traces
// see: SPARK-19147
logError(log"Exception in ${LogMDC(TASK_NAME, taskName)}: ${LogMDC(ERROR, t.getMessage)}")
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(log"Exception in ${LogMDC(TASK_NAME, taskName)}", t)
// SPARK-20904: Do not report failure to driver if if happened during shut down. Because
// libraries may set up shutdown hooks that race with running tasks during shutdown,
// spurious failures may occur and can result in improper accounting in the driver (e.g.
// the task failure would not be ignored if the shutdown happened because of preemption,
// instead of an app issue).
if (!ShutdownHookManager.inShutdown()) {
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId).toImmutableArraySeq
val (taskFailureReason, serializedTaskFailureReason) = {
try {
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
.withMetricPeaks(metricPeaks)
(ef, ser.serialize(ef))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
.withMetricPeaks(metricPeaks)
(ef, ser.serialize(ef))
}
}
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(taskFailureReason))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
} else {
logInfo("Not reporting error to driver during JVM shutdown.")
}
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Executor.isFatalError(t, killOnFatalErrorDepth)) {
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
cleanMDCForTask(taskName, mdcProperties)
runningTasks.remove(taskId)
if (taskStarted) {
// This means the task was successfully deserialized, its stageId and stageAttemptId
// are known, and metricsPoller.onTaskStart was called.
metricsPoller.onTaskCompletion(taskId, task.stageId, task.stageAttemptId)
}
}
}
private def incrementShuffleMetrics(
executorSource: ExecutorSource,
metrics: TaskMetrics
): Unit = {
executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
.inc(metrics.shuffleReadMetrics.fetchWaitTime)
executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(metrics.shuffleWriteMetrics.writeTime)
executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
.inc(metrics.shuffleReadMetrics.totalBytesRead)
executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
.inc(metrics.shuffleReadMetrics.remoteBytesRead)
executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
.inc(metrics.shuffleReadMetrics.remoteBytesReadToDisk)
executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
.inc(metrics.shuffleReadMetrics.localBytesRead)
executorSource.METRIC_SHUFFLE_RECORDS_READ
.inc(metrics.shuffleReadMetrics.recordsRead)
executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
.inc(metrics.shuffleReadMetrics.remoteBlocksFetched)
executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
.inc(metrics.shuffleReadMetrics.localBlocksFetched)
executorSource.METRIC_SHUFFLE_REMOTE_REQS_DURATION
.inc(metrics.shuffleReadMetrics.remoteReqsDuration)
executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
.inc(metrics.shuffleWriteMetrics.bytesWritten)
executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
.inc(metrics.shuffleWriteMetrics.recordsWritten)
executorSource.METRIC_PUSH_BASED_SHUFFLE_CORRUPT_MERGED_BLOCK_CHUNKS
.inc(metrics.shuffleReadMetrics.corruptMergedBlockChunks)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_FETCH_FALLBACK_COUNT
.inc(metrics.shuffleReadMetrics.mergedFetchFallbackCount)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_REMOTE_BLOCKS_FETCHED
.inc(metrics.shuffleReadMetrics.remoteMergedBlocksFetched)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_LOCAL_BLOCKS_FETCHED
.inc(metrics.shuffleReadMetrics.localMergedBlocksFetched)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_REMOTE_CHUNKS_FETCHED
.inc(metrics.shuffleReadMetrics.remoteMergedChunksFetched)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_LOCAL_CHUNKS_FETCHED
.inc(metrics.shuffleReadMetrics.localMergedChunksFetched)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_REMOTE_BYTES_READ
.inc(metrics.shuffleReadMetrics.remoteMergedBytesRead)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_LOCAL_BYTES_READ
.inc(metrics.shuffleReadMetrics.localMergedBytesRead)
executorSource.METRIC_PUSH_BASED_SHUFFLE_MERGED_REMOTE_REQS_DURATION
.inc(metrics.shuffleReadMetrics.remoteMergedReqsDuration)
}
private def hasFetchFailure: Boolean = {
task != null && task.context != null && task.context.fetchFailed.isDefined
}
private[executor] def theadDump(): Option[ThreadStackTrace] = {
Utils.getThreadDumpForThread(getThreadId)
}
}
private def setMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = {
try {
mdc.foreach { case (key, value) => MDC.put(key, value) }
// avoid overriding the takName by the user
MDC.put(taskNameMDCKey, taskName)
} catch {
case _: NoSuchFieldError => logInfo("MDC is not supported.")
}
}
private def cleanMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = {
try {
mdc.foreach { case (key, _) => MDC.remove(key) }
MDC.remove(taskNameMDCKey)
} catch {
case _: NoSuchFieldError => logInfo("MDC is not supported.")
}
}
/**
* Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
* sending a Thread.interrupt(), and monitoring the task until it finishes.
*
* Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
* may not be interruptible or may not respond to their "killed" flags being set. If a significant
* fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
* remain running then this can lead to a situation where new jobs and tasks are starved of
* resources that are being used by these zombie tasks.
*
* The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
* tasks. For backwards-compatibility / backportability this component is disabled by default
* and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
*
* A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
* a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
* in case kill is called twice with different values for the `interrupt` parameter.
*
* Once created, a TaskReaper will run until its supervised task has finished running. If the
* TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
* `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
* if the supervised task never exits.
*/
private class TaskReaper(
taskRunner: TaskRunner,
val interruptThread: Boolean,
val reason: String)
extends Runnable {
private[this] val taskId: Long = taskRunner.taskId
private[this] val killPollingIntervalMs: Long = conf.get(TASK_REAPER_POLLING_INTERVAL)
private[this] val killTimeoutNs: Long = {
TimeUnit.MILLISECONDS.toNanos(conf.get(TASK_REAPER_KILL_TIMEOUT))
}
private[this] val takeThreadDump: Boolean = conf.get(TASK_REAPER_THREAD_DUMP)
override def run(): Unit = {
setMDCForTask(taskRunner.taskName, taskRunner.mdcProperties)
val startTimeNs = System.nanoTime()
def elapsedTimeNs = System.nanoTime() - startTimeNs
def timeoutExceeded(): Boolean = killTimeoutNs > 0 && elapsedTimeNs > killTimeoutNs
try {
// Only attempt to kill the task once. If interruptThread = false then a second kill
// attempt would be a no-op and if interruptThread = true then it may not be safe or
// effective to interrupt multiple times:
taskRunner.kill(interruptThread = interruptThread, reason = reason)
// Monitor the killed task until it exits. The synchronization logic here is complicated
// because we don't want to synchronize on the taskRunner while possibly taking a thread
// dump, but we also need to be careful to avoid races between checking whether the task
// has finished and wait()ing for it to finish.
var finished: Boolean = false
while (!finished && !timeoutExceeded()) {
taskRunner.synchronized {
// We need to synchronize on the TaskRunner while checking whether the task has
// finished in order to avoid a race where the task is marked as finished right after
// we check and before we call wait().
if (taskRunner.isFinished) {
finished = true
} else {
taskRunner.wait(killPollingIntervalMs)
}
}
if (taskRunner.isFinished) {
finished = true
} else {
val elapsedTimeMs = TimeUnit.NANOSECONDS.toMillis(elapsedTimeNs)
logWarning(log"Killed task ${LogMDC(TASK_ID, taskId)} " +
log"is still running after ${LogMDC(TIME_UNITS, elapsedTimeMs)} ms")
if (takeThreadDump) {
try {
taskRunner.theadDump().foreach { thread =>
if (thread.threadName == taskRunner.threadName) {
logWarning(log"Thread dump from task ${LogMDC(TASK_ID, taskId)}:\n" +
log"${LogMDC(THREAD, thread.toString)}")
}
}
} catch {
case NonFatal(e) =>
logWarning("Exception thrown while obtaining thread dump: ", e)
}
}
}
}
if (!taskRunner.isFinished && timeoutExceeded()) {
val killTimeoutMs = TimeUnit.NANOSECONDS.toMillis(killTimeoutNs)
if (isLocal) {
logError(log"Killed task ${LogMDC(TASK_ID, taskId)} could not be stopped within " +
log"${LogMDC(TIMEOUT, killTimeoutMs)} ms; " +
log"not killing JVM because we are running in local mode.")
} else {
// In non-local-mode, the exception thrown here will bubble up to the uncaught exception
// handler and cause the executor JVM to exit.
throw SparkException.internalError(
s"Killing executor JVM because killed task $taskId could not be stopped within " +
s"$killTimeoutMs ms.", category = "EXECUTOR")
}
}
} finally {
cleanMDCForTask(taskRunner.taskName, taskRunner.mdcProperties)
// Clean up entries in the taskReaperForTask map.
taskReaperForTask.synchronized {
taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
if (taskReaperInMap eq this) {
taskReaperForTask.remove(taskId)
} else {
// This must have been a TaskReaper where interruptThread == false where a subsequent
// killTask() call for the same task had interruptThread == true and overwrote the
// map entry.
}
}
}
}
}
}
/**
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
private def createClassLoader(
currentJars: HashMap[String, Long],
useStub: Boolean): MutableURLClassLoader = {
// Bootstrap the list of jars with the user class path.
val now = System.currentTimeMillis()
userClassPath.foreach { url =>
currentJars(url.getPath().split("/").last) = now
}
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}
createClassLoader(urls, useStub)
}
private def createClassLoader(urls: Array[URL], useStub: Boolean): MutableURLClassLoader = {
logInfo(
s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " +
urls.mkString("'", ",", "'")
)
if (useStub) {
createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES))
} else {
createClassLoader(urls)
}
}
private def createClassLoader(urls: Array[URL]): MutableURLClassLoader = {
if (userClassPathFirst) {
new ChildFirstURLClassLoader(urls, systemLoader)
} else {
new MutableURLClassLoader(urls, systemLoader)
}
}
private def createClassLoaderWithStub(
urls: Array[URL],
binaryName: Seq[String]): MutableURLClassLoader = {
if (userClassPathFirst) {
// user -> (sys -> stub)
val stubClassLoader =
StubClassLoader(systemLoader, binaryName)
new ChildFirstURLClassLoader(urls, stubClassLoader)
} else {
// sys -> user -> stub
val stubClassLoader =
StubClassLoader(null, binaryName)
new ChildFirstURLClassLoader(urls, stubClassLoader, systemLoader)
}
}
/**
* If the REPL is in use, add another ClassLoader that will read
* new classes defined by the REPL as the user types code
*/
private def addReplClassLoaderIfNeeded(
parent: ClassLoader,
sessionClassUri: Option[String],
sessionUUID: String): ClassLoader = {
val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri", null))
val classLoader = if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst)
} else {
parent
}
logInfo(s"Created or updated repl class loader $classLoader for $sessionUUID.")
classLoader
}
/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
* Visible for testing.
*/
private[executor] def updateDependencies(
newFiles: immutable.Map[String, Long],
newJars: immutable.Map[String, Long],
newArchives: immutable.Map[String, Long],
state: IsolatedSessionState,
testStartLatch: Option[CountDownLatch] = None,
testEndLatch: Option[CountDownLatch] = None): Unit = {
var renewClassLoader = false;
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
updateDependenciesLock.lockInterruptibly()
try {
// For testing, so we can simulate a slow file download:
testStartLatch.foreach(_.countDown())
// If the session ID was specified from SparkSession, it's from a Spark Connect client.
// Specify a dedicated directory for Spark Connect client.
lazy val root = if (!isDefaultState(state.sessionUUID)) {
val newDest = new File(SparkFiles.getRootDirectory(), state.sessionUUID)
newDest.mkdir()
newDest
} else {
new File(SparkFiles.getRootDirectory())
}
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if state.currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, root, conf, hadoopConf, timestamp, useCache = !isLocal)
state.currentFiles(name) = timestamp
}
for ((name, timestamp) <- newArchives if
state.currentArchives.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
val sourceURI = new URI(name)
val uriToDownload = Utils.getUriBuilder(sourceURI).fragment(null).build()
val source = Utils.fetchFile(uriToDownload.toString, Utils.createTempDir(), conf,
hadoopConf, timestamp, useCache = !isLocal, shouldUntar = false)
val dest = new File(
root,
if (sourceURI.getFragment != null) sourceURI.getFragment else source.getName)
logInfo(
s"Unpacking an archive $name from ${source.getAbsolutePath} to ${dest.getAbsolutePath}")
Utils.deleteRecursively(dest)
Utils.unpack(source, dest)
state.currentArchives(name) = timestamp
}
for ((name, timestamp) <- newJars) {
val localName = new URI(name).getPath.split("/").last
val currentTimeStamp = state.currentJars.get(name)
.orElse(state.currentJars.get(localName))
.getOrElse(-1L)
if (currentTimeStamp < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, root, conf,
hadoopConf, timestamp, useCache = !isLocal)
state.currentJars(name) = timestamp
// Add it to our class loader
val url = new File(root, localName).toURI.toURL
if (!state.urlClassLoader.getURLs().contains(url)) {
logInfo(s"Adding $url to class loader ${state.sessionUUID}")
state.urlClassLoader.addURL(url)
if (isStubbingEnabledForState(state.sessionUUID)) {
renewClassLoader = true
}
}
}
}
if (renewClassLoader) {
// Recreate the class loader to ensure all classes are updated.
state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true)
state.replClassLoader =
addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID)
}
// For testing, so we can simulate a slow file download:
testEndLatch.foreach(_.await())
} finally {
updateDependenciesLock.unlock()
}
}
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
// list of (task id, accumUpdates) to send back to the driver
val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]()
val curGCTime = computeTotalGcTime()
if (pollOnHeartbeat) {
metricsPoller.poll()
}
val executorUpdates = metricsPoller.getExecutorUpdates()
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner.task != null) {
taskRunner.task.metrics.mergeShuffleReadMetrics()
taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
val accumulatorsToReport =
if (HEARTBEAT_DROP_ZEROES) {
taskRunner.task.metrics.accumulators().filterNot(_.isZero)
} else {
taskRunner.task.metrics.accumulators()
}
accumUpdates += ((taskRunner.taskId, accumulatorsToReport))
}
}
val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId,
executorUpdates)
try {
val response = heartbeatReceiverRef.askSync[HeartbeatResponse](
message, new RpcTimeout(HEARTBEAT_INTERVAL_MS.millis, EXECUTOR_HEARTBEAT_INTERVAL.key))
if (!executorShutdown.get && response.reregisterBlockManager) {
logInfo("Told to re-register on heartbeat")
env.blockManager.reregister()
}
heartbeatFailures = 0
} catch {
case NonFatal(e) =>
logWarning("Issue communicating with driver in heartbeater", e)
heartbeatFailures += 1
if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) {
logError(log"Exit as unable to send heartbeats to driver " +
log"more than ${LogMDC(MAX_ATTEMPTS, HEARTBEAT_MAX_FAILURES)} times")
System.exit(ExecutorExitCode.HEARTBEAT_FAILURE)
}
}
}
def getTaskThreadDump(taskId: Long): Option[ThreadStackTrace] = {
val runner = runningTasks.get(taskId)
if (runner != null) {
runner.theadDump()
} else {
logWarning(log"Failed to dump thread for task ${LogMDC(TASK_ID, taskId)}")
None
}
}
}
private[spark] object Executor {
// This is reserved for internal use by components that need to read task properties before a
// task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
// used instead.
val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
// Used to store executorSource, for local mode only
var executorSourceLocalModeOnly: ExecutorSource = null
/**
* Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether
* to kill the executor.
*
* @param depthToCheck The max depth of the exception chain we should search for a fatal error. 0
* means not checking any fatal error (in other words, return false), 1 means
* checking only the exception but not the cause, and so on. This is to avoid
* `StackOverflowError` when hitting a cycle in the exception chain.
*/
@scala.annotation.tailrec
def isFatalError(t: Throwable, depthToCheck: Int): Boolean = {
if (depthToCheck <= 0) {
false
} else {
t match {
case _: SparkOutOfMemoryError => false
case e if Utils.isFatalError(e) => true
case e if e.getCause != null => isFatalError(e.getCause, depthToCheck - 1)
case _ => false
}
}
}
}