| package com.ibm.spark.utils |
| |
| import java.util.concurrent.atomic.AtomicInteger |
| |
| import org.slf4j.LoggerFactory |
| |
| import scala.concurrent.{promise, Future} |
| import java.util.concurrent._ |
| |
| import com.ibm.spark.security.KernelSecurityManager._ |
| import TaskManager._ |
| |
| import scala.util.Try |
| |
| /** |
| * Represents a processor of tasks that has X worker threads dedicated to |
| * executing the tasks. |
| * |
| * @param threadGroup The thread group to use with all worker threads |
| * @param minimumWorkers The number of workers to spawn initially and keep |
| * alive even when idle |
| * @param maximumWorkers The max number of worker threads to spawn, defaulting |
| * to the number of processors on the machine |
| * @param keepAliveTime The maximum time in milliseconds for workers to remain |
| * idle before shutting down |
| */ |
| class TaskManager( |
| private val threadGroup: ThreadGroup = DefaultThreadGroup, |
| private val maxTasks: Int = DefaultMaxTasks, |
| private val minimumWorkers: Int = DefaultMinimumWorkers, |
| private val maximumWorkers: Int = DefaultMaximumWorkers, |
| private val keepAliveTime: Long = DefaultKeepAliveTime |
| ) { |
| protected val logger = LoggerFactory.getLogger(this.getClass.getName) |
| |
| private class TaskManagerThreadFactory extends ThreadFactory { |
| override def newThread(r: Runnable): Thread = { |
| val thread = new Thread(threadGroup, r) |
| |
| logger.trace(s"Creating new thread named ${thread.getName}!") |
| |
| thread |
| } |
| } |
| |
| private[utils] class ScalingThreadPoolExecutor extends ThreadPoolExecutor( |
| minimumWorkers, |
| maximumWorkers, |
| keepAliveTime, |
| TimeUnit.MILLISECONDS, |
| taskQueue, |
| taskManagerThreadFactory |
| ) { |
| protected val logger = LoggerFactory.getLogger(this.getClass.getName) |
| |
| /** Used to keep track of tasks separately from the task queue. */ |
| private val taskCount = new AtomicInteger(0) |
| |
| /** |
| * Syncs the core pool size of the executor with the current number of |
| * tasks, using the minimum worker size and maximum worker size as the |
| * bounds. |
| */ |
| private def syncPoolLimits(): Unit = { |
| val totalTasks = taskCount.get() |
| val newCorePoolSize = |
| math.max(minimumWorkers, math.min(totalTasks, maximumWorkers)) |
| |
| logger.trace(Seq( |
| s"Task execution count is $totalTasks!", |
| s"Updating core pool size to $newCorePoolSize!" |
| ).mkString(" ")) |
| executor.foreach(_.setCorePoolSize(newCorePoolSize)) |
| } |
| |
| override def execute(r: Runnable): Unit = { |
| synchronized { |
| if (taskCount.incrementAndGet() > maximumWorkers) |
| logger.warn(s"Exceeded $maximumWorkers workers during processing!") |
| |
| syncPoolLimits() |
| } |
| |
| super.execute(r) |
| } |
| |
| override def afterExecute(r: Runnable, t: Throwable): Unit = { |
| super.afterExecute(r, t) |
| |
| synchronized { |
| taskCount.decrementAndGet() |
| syncPoolLimits() |
| } |
| } |
| } |
| |
| private val taskManagerThreadFactory = new TaskManagerThreadFactory |
| private val taskQueue = new ArrayBlockingQueue[Runnable](maxTasks) |
| |
| @volatile |
| private[utils] var executor: Option[ScalingThreadPoolExecutor] = None |
| |
| /** |
| * Adds a new task to the list to execute. |
| * |
| * @param taskFunction The new task as a block of code |
| * |
| * @return Future representing the return value (or error) from the task |
| */ |
| def add[T <: Any](taskFunction: => T): Future[T] = { |
| assert(executor.nonEmpty, "Task manager not started!") |
| |
| val taskPromise = promise[T]() |
| |
| // Construct runnable that completes the promise |
| logger.trace(s"Queueing new task to be processed!") |
| executor.foreach(_.execute(new Runnable { |
| override def run(): Unit = { |
| var threadName: String = "???" |
| try { |
| threadName = Try(Thread.currentThread().getName).getOrElse(threadName) |
| logger.trace(s"(Thread $threadName) Executing task!") |
| val result = taskFunction |
| |
| logger.trace(s"(Thread $threadName) Task finished successfully!") |
| taskPromise.success(result) |
| } catch { |
| case ex: Throwable => |
| val exName = ex.getClass.getName |
| val exMessage = Option(ex.getLocalizedMessage).getOrElse("???") |
| logger.trace( |
| s"(Thread $threadName) Task failed: ($exName) = $exMessage") |
| taskPromise.tryFailure(ex) |
| } |
| } |
| })) |
| |
| taskPromise.future |
| } |
| |
| /** |
| * Returns the count of tasks including the currently-running ones. |
| * |
| * @return The count of tasks |
| */ |
| def size: Int = taskQueue.size() + executor.map(_.getActiveCount).getOrElse(0) |
| |
| /** |
| * Returns whether or not there is a task in the queue to be processed. |
| * |
| * @return True if the internal queue is not empty, otherwise false |
| */ |
| def hasTaskInQueue: Boolean = !taskQueue.isEmpty |
| |
| /** |
| * Whether or not there is a task being executed currently. |
| * |
| * @return True if there is a task being executed, otherwise false |
| */ |
| def isExecutingTask: Boolean = executor.exists(_.getActiveCount > 0) |
| |
| /** |
| * Block execution (by sleeping) until all tasks currently queued up for |
| * execution are processed. |
| */ |
| def await(): Unit = |
| while (!taskQueue.isEmpty || isExecutingTask) Thread.sleep(1) |
| |
| /** |
| * Starts the task manager (begins processing tasks). Creates X new threads |
| * in the process. |
| */ |
| def start(): Unit = { |
| logger.trace( |
| s""" |
| |Initializing with the following settings: |
| |- $minimumWorkers core worker pool |
| |- $maximumWorkers maximum workers |
| |- $keepAliveTime milliseconds keep alive time |
| """.stripMargin.trim) |
| executor = Some(new ScalingThreadPoolExecutor) |
| } |
| |
| /** |
| * Restarts internal processing of tasks (removing current task). |
| */ |
| def restart(): Unit = { |
| stop() |
| start() |
| } |
| |
| /** |
| * Stops internal processing of tasks. |
| */ |
| def stop(): Unit = { |
| executor.foreach(_.shutdownNow()) |
| executor = None |
| } |
| } |
| |
| /** |
| * Represents constants associated with the task manager. |
| */ |
| object TaskManager { |
| /** The default thread group to use with all worker threads. */ |
| val DefaultThreadGroup = new ThreadGroup(RestrictedGroupName) |
| |
| /** The default number of maximum tasks accepted by the task manager. */ |
| val DefaultMaxTasks = 200 |
| |
| /** |
| * The default number of workers to spawn initially and keep alive |
| * even when idle. |
| */ |
| val DefaultMinimumWorkers = 1 |
| |
| /** The default maximum number of workers to spawn. */ |
| val DefaultMaximumWorkers = Runtime.getRuntime.availableProcessors() |
| |
| /** The default timeout in milliseconds for workers waiting for tasks. */ |
| val DefaultKeepAliveTime = 1000 |
| |
| /** |
| * The default timeout in milliseconds to wait before stopping a thread |
| * if it cannot be interrupted. |
| */ |
| val InterruptTimeout = 5000 |
| |
| /** The maximum time to wait to add a task to the queue in milliseconds. */ |
| val MaximumTaskQueueTimeout = 10000 |
| |
| /** |
| * The maximum time in milliseconds to wait to queue up a thread in the |
| * thread factory. |
| */ |
| val MaximumThreadQueueTimeout = 10000 |
| } |