blob: 2b939dabb1105901afb7ff7ca32360d382706834 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io.Serializable
import java.util.Properties
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}
object TaskContext {
/**
* Return the currently active TaskContext. This can be called inside of
* user functions to access contextual information about running tasks.
*/
def get(): TaskContext = taskContext.get
/**
* Returns the partition id of currently active TaskContext. It will return 0
* if there is no active TaskContext for cases like local execution.
*/
def getPartitionId(): Int = {
val tc = taskContext.get()
if (tc eq null) {
0
} else {
tc.partitionId()
}
}
private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
// Note: protected[spark] instead of private[spark] to prevent the following two from
// showing up in JavaDoc.
/**
* Set the thread local TaskContext. Internal to Spark.
*/
protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc)
/**
* Unset the thread local TaskContext. Internal to Spark.
*/
protected[spark] def unset(): Unit = taskContext.remove()
/**
* An empty task context that does not represent an actual task. This is only used in tests.
*/
private[spark] def empty(): TaskContextImpl = {
new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
}
}
/**
* Contextual information about a task which can be read or mutated during
* execution. To access the TaskContext for a running task, use:
* {{{
* org.apache.spark.TaskContext.get()
* }}}
*/
abstract class TaskContext extends Serializable {
// Note: TaskContext must NOT define a get method. Otherwise it will prevent the Scala compiler
// from generating a static get method (based on the companion object's get method).
// Note: Update JavaTaskContextCompileCheck when new methods are added to this class.
// Note: getters in this class are defined with parentheses to maintain backward compatibility.
/**
* Returns true if the task has completed.
*/
def isCompleted(): Boolean
/**
* Returns true if the task has been killed.
*/
def isInterrupted(): Boolean
/**
* Returns true if the task is running locally in the driver program.
* @return false
*/
@deprecated("Local execution was removed, so this always returns false", "2.0.0")
def isRunningLocally(): Boolean
/**
* Adds a (Java friendly) listener to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation. Adding a listener
* to an already completed task will result in that listener being called immediately.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
*/
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
/**
* Adds a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation. Adding a listener
* to an already completed task will result in that listener being called immediately.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
*/
def addTaskCompletionListener[U](f: (TaskContext) => U): TaskContext = {
// Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make
// this function polymorphic for every scala version >= 2.12, otherwise an overloaded method
// resolution error occurs at compile time.
addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f(context)
})
}
/**
* Adds a listener to be executed on task failure. Adding a listener to an already failed task
* will result in that listener being called immediately.
*/
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
/**
* Adds a listener to be executed on task failure. Adding a listener to an already failed task
* will result in that listener being called immediately.
*/
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error)
})
}
/**
* The ID of the stage that this task belong to.
*/
def stageId(): Int
/**
* How many times the stage that this task belongs to has been attempted. The first stage attempt
* will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt
* numbers.
*/
def stageAttemptNumber(): Int
/**
* The ID of the RDD partition that is computed by this task.
*/
def partitionId(): Int
/**
* How many times this task has been attempted. The first task attempt will be assigned
* attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
*/
def attemptNumber(): Int
/**
* An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
* will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
*/
def taskAttemptId(): Long
/**
* Get a local property set upstream in the driver, or null if it is missing. See also
* `org.apache.spark.SparkContext.setLocalProperty`.
*/
def getLocalProperty(key: String): String
@DeveloperApi
def taskMetrics(): TaskMetrics
/**
* ::DeveloperApi::
* Returns all metrics sources with the given name which are associated with the instance
* which runs the task. For more information see `org.apache.spark.metrics.MetricsSystem`.
*/
@DeveloperApi
def getMetricsSources(sourceName: String): Seq[Source]
/**
* If the task is interrupted, throws TaskKilledException with the reason for the interrupt.
*/
private[spark] def killTaskIfInterrupted(): Unit
/**
* If the task is interrupted, the reason this task was killed, otherwise None.
*/
private[spark] def getKillReason(): Option[String]
/**
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager
/**
* Register an accumulator that belongs to this task. Accumulators must call this method when
* deserializing in executors.
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
/**
* Record that this task has failed due to a fetch failure from a remote host. This allows
* fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
*/
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(reason: String): Unit
/** Marks the task as failed and triggers the failure listeners. */
private[spark] def markTaskFailed(error: Throwable): Unit
/** Marks the task as completed and triggers the completion listeners. */
private[spark] def markTaskCompleted(error: Option[Throwable]): Unit
/** Optionally returns the stored fetch failure in the task. */
private[spark] def fetchFailed: Option[FetchFailedException]
/** Gets local properties set upstream in the driver. */
private[spark] def getLocalProperties: Properties
}