blob: 89730424e5acf9eb12e1cc961667de29456bfd8f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.util.Properties
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
/**
* A [[TaskContext]] implementation.
*
* A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes
* sure that updates are always visible across threads. The complete & failed flags and their
* callbacks are protected by locking on the context instance. For instance, this ensures
* that you cannot add a completion listener in one thread while we are completing (and calling
* the completion listeners) in another thread. Other state is immutable, however the exposed
* `TaskMetrics` & `MetricsSystem` objects are not thread safe.
*/
private[spark] class TaskContextImpl(
override val stageId: Int,
override val stageAttemptNumber: Int,
override val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
// The default value is only used in tests.
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
/** List of callback functions to execute when the task completes. */
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None
// Whether the task has completed.
private var completed: Boolean = false
// Whether the task has failed.
private var failed: Boolean = false
// Throwable that caused the task to fail
private var failure: Throwable = _
// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None
@GuardedBy("this")
override def addTaskCompletionListener(listener: TaskCompletionListener)
: this.type = synchronized {
if (completed) {
listener.onTaskCompletion(this)
} else {
onCompleteCallbacks += listener
}
this
}
@GuardedBy("this")
override def addTaskFailureListener(listener: TaskFailureListener)
: this.type = synchronized {
if (failed) {
listener.onTaskFailure(this, failure)
} else {
onFailureCallbacks += listener
}
this
}
@GuardedBy("this")
private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized {
if (failed) return
failed = true
failure = error
invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
_.onTaskFailure(this, error)
}
}
@GuardedBy("this")
private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
if (completed) return
completed = true
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}
private def invokeListeners[T](
listeners: Seq[T],
name: String,
error: Option[Throwable])(
callback: T => Unit): Unit = {
val errorMsgs = new ArrayBuffer[String](2)
// Process callbacks in the reverse order of registration
listeners.reverse.foreach { listener =>
try {
callback(listener)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
logError(s"Error in $name", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs, error)
}
}
private[spark] override def markInterrupted(reason: String): Unit = {
reasonIfKilled = Some(reason)
}
private[spark] override def killTaskIfInterrupted(): Unit = {
val reason = reasonIfKilled
if (reason.isDefined) {
throw new TaskKilledException(reason.get)
}
}
private[spark] override def getKillReason(): Option[String] = {
reasonIfKilled
}
@GuardedBy("this")
override def isCompleted(): Boolean = synchronized(completed)
override def isRunningLocally(): Boolean = false
override def isInterrupted(): Boolean = reasonIfKilled.isDefined
override def getLocalProperty(key: String): String = localProperties.getProperty(key)
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
private[spark] override def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
taskMetrics.registerAccumulator(a)
}
private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
this._fetchFailedException = Option(fetchFailed)
}
private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException
private[spark] override def getLocalProperties(): Properties = localProperties
}