Merge pull request #107 from ibm-et/ParallelizeTaskManager
Parallelize task manager
diff --git a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
index e8db370..e91a2f2 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
@@ -53,6 +53,7 @@
* Sets the Interpreter for this map.
* @param interpreter The new Interpreter
*/
+ //@deprecated("Use setInterpreter with IncludeInterpreter!", "2015.05.06")
def setKernelInterpreter(interpreter: Interpreter) = {
internalMap(typeOf[IncludeKernelInterpreter]) =
PartialFunction[Magic, Unit](
@@ -123,7 +124,7 @@
/**
* Sets the MagicLoader for this map.
- * @param kernel The new Kernel
+ * @param magicLoader The new MagicLoader
*/
def setMagicLoader(magicLoader: MagicLoader) = {
internalMap(typeOf[IncludeMagicLoader]) =
diff --git a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
index 60faf50..de19c07 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
@@ -19,6 +19,7 @@
import com.ibm.spark.interpreter.Interpreter
import com.ibm.spark.magic.Magic
+//@deprecated("Use IncludeInterpreter instead!", "2015.05.06")
trait IncludeKernelInterpreter {
this: Magic =>
diff --git a/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala b/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
index 55b1b97..c3a43fc 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
@@ -1,93 +1,104 @@
-/*
- * Copyright 2014 IBM Corp.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
package com.ibm.spark.utils
-import java.util.UUID
-import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.atomic.AtomicInteger
-import scala.concurrent.{Future, Promise, promise}
+import org.slf4j.LoggerFactory
+
+import scala.concurrent.{promise, Future}
+import java.util.concurrent._
import com.ibm.spark.security.KernelSecurityManager._
+import TaskManager._
+
+import scala.util.Try
/**
- * Represents a generic manager of Runnable tasks that will be executed in a
- * separate thread (created inside the manager). Facilitates running tasks and
- * provides a method to kill
+ * Represents a processor of tasks that has X worker threads dedicated to
+ * executing the tasks.
+ *
+ * @param threadGroup The thread group to use with all worker threads
+ * @param minimumWorkers The number of workers to spawn initially and keep
+ * alive even when idle
+ * @param maximumWorkers The max number of worker threads to spawn, defaulting
+ * to the number of processors on the machine
+ * @param keepAliveTime The maximum time in milliseconds for workers to remain
+ * idle before shutting down
*/
class TaskManager(
- taskThreadGroup: ThreadGroup = new ThreadGroup(RestrictedGroupName),
- taskThreadName: String = "task-" + UUID.randomUUID().toString
+ private val threadGroup: ThreadGroup = DefaultThreadGroup,
+ private val maxTasks: Int = DefaultMaxTasks,
+ private val minimumWorkers: Int = DefaultMinimumWorkers,
+ private val maximumWorkers: Int = DefaultMaximumWorkers,
+ private val keepAliveTime: Long = DefaultKeepAliveTime
) {
- // Maximum time to wait (in milliseconds) before forcefully stopping this
- // thread when an interrupt fails
- private val InterruptTimeout = 5 * 1000
- private val _queueCapacity = 200
- private val _taskQueue =
- new ArrayBlockingQueue[(Runnable, Promise[_])](_queueCapacity)
- private var _taskThread: TaskThread = _
+ protected val logger = LoggerFactory.getLogger(this.getClass.getName)
- private val _currentPromise: AtomicReference[Promise[_]] =
- new AtomicReference[Promise[_]]()
+ private class TaskManagerThreadFactory extends ThreadFactory {
+ override def newThread(r: Runnable): Thread = {
+ val thread = new Thread(threadGroup, r)
- private class TaskThread extends Thread(taskThreadGroup, taskThreadName) {
- private[TaskManager] var _currentTask: Runnable = _
- private[TaskManager] var _running = false
+ logger.trace(s"Creating new thread named ${thread.getName}!")
- override def start(): Unit = {
- _running = true
- super.start()
- }
-
- /**
- * Main execution loop of the task thread.
- *
- * Pulls tasks from an internal queue to be processed sequentially.
- */
- override def run(): Unit = {
- while (_running) {
- val element = _taskQueue.poll(1L, TimeUnit.MILLISECONDS)
- if (element != null) {
- _currentTask = element._1
- _currentPromise.set(element._2)
-
- if (_currentTask != null) _currentTask.run()
-
- _currentTask = null
- _currentPromise.set(null)
- }
- }
- }
-
- /**
- * Marks the internal flag for running new tasks to false.
- */
- def cancel(): Unit = {
- _running = false
+ thread
}
}
- /**
- * Represents the internal thread used to process tasks.
- *
- * @return The Thread instance wrapped in an Option, or None if not started
- */
- def thread: Option[Thread] =
- if (_taskThread != null) Some(_taskThread) else None
+ private[utils] class ScalingThreadPoolExecutor extends ThreadPoolExecutor(
+ minimumWorkers,
+ maximumWorkers,
+ keepAliveTime,
+ TimeUnit.MILLISECONDS,
+ taskQueue,
+ taskManagerThreadFactory
+ ) {
+ protected val logger = LoggerFactory.getLogger(this.getClass.getName)
+
+ /** Used to keep track of tasks separately from the task queue. */
+ private val taskCount = new AtomicInteger(0)
+
+ /**
+ * Syncs the core pool size of the executor with the current number of
+ * tasks, using the minimum worker size and maximum worker size as the
+ * bounds.
+ */
+ private def syncPoolLimits(): Unit = {
+ val totalTasks = taskCount.get()
+ val newCorePoolSize =
+ math.max(minimumWorkers, math.min(totalTasks, maximumWorkers))
+
+ logger.trace(Seq(
+ s"Task execution count is $totalTasks!",
+ s"Updating core pool size to $newCorePoolSize!"
+ ).mkString(" "))
+ executor.foreach(_.setCorePoolSize(newCorePoolSize))
+ }
+
+ override def execute(r: Runnable): Unit = {
+ synchronized {
+ if (taskCount.incrementAndGet() > maximumWorkers)
+ logger.warn(s"Exceeded $maximumWorkers workers during processing!")
+
+ syncPoolLimits()
+ }
+
+ super.execute(r)
+ }
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+
+ synchronized {
+ taskCount.decrementAndGet()
+ syncPoolLimits()
+ }
+ }
+ }
+
+ private val taskManagerThreadFactory = new TaskManagerThreadFactory
+ private val taskQueue = new ArrayBlockingQueue[Runnable](maxTasks)
+
+ @volatile
+ private[utils] var executor: Option[ScalingThreadPoolExecutor] = None
/**
* Adds a new task to the list to execute.
@@ -97,189 +108,130 @@
* @return Future representing the return value (or error) from the task
*/
def add[T <: Any](taskFunction: => T): Future[T] = {
+ assert(executor.nonEmpty, "Task manager not started!")
+
val taskPromise = promise[T]()
// Construct runnable that completes the promise
- _taskQueue.add((new Runnable {
- override def run(): Unit =
+ logger.trace(s"Queueing new task to be processed!")
+ executor.foreach(_.execute(new Runnable {
+ override def run(): Unit = {
+ var threadName: String = "???"
try {
+ threadName = Try(Thread.currentThread().getName).getOrElse(threadName)
+ logger.trace(s"(Thread $threadName) Executing task!")
val result = taskFunction
+
+ logger.trace(s"(Thread $threadName) Task finished successfully!")
taskPromise.success(result)
} catch {
- case ex: Throwable => taskPromise.tryFailure(ex)
+ case ex: Throwable =>
+ val exName = ex.getClass.getName
+ val exMessage = Option(ex.getLocalizedMessage).getOrElse("???")
+ logger.trace(
+ s"(Thread $threadName) Task failed: ($exName) = $exMessage")
+ taskPromise.tryFailure(ex)
}
- }, taskPromise))
+ }
+ }))
taskPromise.future
}
/**
- * Returns the count of tasks including the currently-running one.
+ * Returns the count of tasks including the currently-running ones.
*
* @return The count of tasks
*/
- def size: Int = _taskQueue.size()
+ def size: Int = taskQueue.size() + executor.map(_.getActiveCount).getOrElse(0)
/**
* Returns whether or not there is a task in the queue to be processed.
*
* @return True if the internal queue is not empty, otherwise false
*/
- def hasTaskInQueue: Boolean = !_taskQueue.isEmpty
-
- /**
- * Returns the current executing task.
- *
- * @return The current task or None if no task is running
- */
- def currentTask: Option[Runnable] =
- if (_taskThread != null && _taskThread._currentTask != null)
- Some(_taskThread._currentTask)
- else None
-
- /**
- * Returns the sequence of tasks.
- *
- * @return The sequence of tasks as Runnables
- */
- def tasks: Seq[Runnable] = _taskQueue.toArray.map(_.asInstanceOf[(Runnable, _)]._1)
+ def hasTaskInQueue: Boolean = !taskQueue.isEmpty
/**
* Whether or not there is a task being executed currently.
*
* @return True if there is a task being executed, otherwise false
*/
- def isExecutingTask: Boolean = currentTask.nonEmpty
-
- /**
- * Whether or not the task manager is processing new tasks.
- *
- * @return True if the manager is capable of consuming more tasks,
- * otherwise false
- */
- def isRunning: Boolean = _taskThread != null && _taskThread._running
+ def isExecutingTask: Boolean = executor.exists(_.getActiveCount > 0)
/**
* Block execution (by sleeping) until all tasks currently queued up for
* execution are processed.
*/
def await(): Unit =
- while (hasTaskInQueue || isExecutingTask) Thread.sleep(1)
+ while (!taskQueue.isEmpty || isExecutingTask) Thread.sleep(1)
/**
- * Starts the task manager (begins processing tasks). Creates a new thread
+ * Starts the task manager (begins processing tasks). Creates X new threads
* in the process.
*/
- def start(): Unit = startTaskProcessingThread()
+ def start(): Unit = {
+ logger.trace(
+ s"""
+ |Initializing with the following settings:
+ |- $minimumWorkers core worker pool
+ |- $maximumWorkers maximum workers
+ |- $keepAliveTime milliseconds keep alive time
+ """.stripMargin.trim)
+ executor = Some(new ScalingThreadPoolExecutor)
+ }
/**
* Restarts internal processing of tasks (removing current task).
*/
- def restart(): Unit = restartTaskProcessingThread()
+ def restart(): Unit = {
+ stop()
+ start()
+ }
/**
* Stops internal processing of tasks.
- *
- * @param killThread Whether to kill the thread processing tasks if not it
- * is not responding to interrupts
- * @param killTimeout The period of time (in milliseconds) to wait before
- * attempting to kill the thread processing tasks
*/
- def stop(
- killThread: Boolean = true,
- killTimeout: Int = InterruptTimeout
- ): Unit = killTaskProcessingThread(killThread, killTimeout)
-
- /**
- * Creates a new task-processing thread and starts it.
- */
- private def startTaskProcessingThread(): Unit = {
- _taskThread = new TaskThread
- _taskThread.start()
- }
-
- /**
- * Attempts to cancel/interrupt the task manager's internal task-processing
- * thread. If unable to interrupt, will be forcefully stopped after the
- * interrupt timeout threshold is reached.
- *
- * @param killThread Whether to kill the thread processing tasks if not it
- * is not responding to interrupts
- * @param killTimeout The period of time (in milliseconds) to wait before
- * attempting to kill the thread processing tasks
- */
- private def killTaskProcessingThread(
- killThread: Boolean = true,
- killTimeout: Int = InterruptTimeout
- ): Unit = {
- // NOTE: Dirty hack to suppress deprecation warnings
- // See https://issues.scala-lang.org/browse/SI-7934 for discussion
- object Thread {
- @deprecated("", "") class Killer {
- def stop() = try { _taskThread.stop() }
- }
-
- object Killer extends Killer
- }
-
- _taskThread.cancel()
- _taskThread.interrupt()
-
- runIfTimeout(
- killTimeout,
- _taskThread.isAlive && killThread,
- // Still available (JDK8 now calls stop0 directly)
- //
- // Used because of discussion on:
- // https://issues.scala-lang.org/browse/SI-6302
- {
- //_taskThread.stop()
- Thread.Killer.stop()
- val currentPromise = _currentPromise.get()
- if (currentPromise != null)
- currentPromise.tryFailure(new ThreadDeath())
- }
- )
-
- _taskThread = null
- }
-
- /**
- * Shuts down the current thread used to execute tasks and starts a new
- * thread in its place.
- */
- private def restartTaskProcessingThread(): Unit = {
- killTaskProcessingThread()
- startTaskProcessingThread()
- }
-
- /**
- * Execute body if the millisecond timeout is reached and the condition
- * still holds true.
- *
- * @param milliseconds The timeout limit in milliseconds
- * @param condition The condition to test up until executing the body
- * @param body The body to execute
- *
- * @return Whether or not the body was executed
- */
- private def runIfTimeout(
- milliseconds: Int,
- condition: => Boolean,
- body: => Unit
- ): Boolean = {
- val endTime: Long = milliseconds + java.lang.System.currentTimeMillis()
-
- while (java.lang.System.currentTimeMillis() < endTime) {
- // Exit if our condition suddenly goes south
- if (!condition) return false
-
- // Allow check to be interrupted
- if (Thread.interrupted())
- throw new InterruptedException()
- }
-
- if (condition) { body; true } else false
+ def stop(): Unit = {
+ executor.foreach(_.shutdownNow())
+ executor = None
}
}
+/**
+ * Represents constants associated with the task manager.
+ */
+object TaskManager {
+ /** The default thread group to use with all worker threads. */
+ val DefaultThreadGroup = new ThreadGroup(RestrictedGroupName)
+
+ /** The default number of maximum tasks accepted by the task manager. */
+ val DefaultMaxTasks = 200
+
+ /**
+ * The default number of workers to spawn initially and keep alive
+ * even when idle.
+ */
+ val DefaultMinimumWorkers = 1
+
+ /** The default maximum number of workers to spawn. */
+ val DefaultMaximumWorkers = Runtime.getRuntime.availableProcessors()
+
+ /** The default timeout in milliseconds for workers waiting for tasks. */
+ val DefaultKeepAliveTime = 1000
+
+ /**
+ * The default timeout in milliseconds to wait before stopping a thread
+ * if it cannot be interrupted.
+ */
+ val InterruptTimeout = 5000
+
+ /** The maximum time to wait to add a task to the queue in milliseconds. */
+ val MaximumTaskQueueTimeout = 10000
+
+ /**
+ * The maximum time in milliseconds to wait to queue up a thread in the
+ * thread factory.
+ */
+ val MaximumThreadQueueTimeout = 10000
+}
\ No newline at end of file
diff --git a/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala b/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
index e453dcb..a4184d4 100644
--- a/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
+++ b/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
@@ -16,10 +16,10 @@
package com.ibm.spark.utils
-import java.util.concurrent.ExecutionException
+import java.util.concurrent.{RejectedExecutionException, ExecutionException}
import org.scalatest.concurrent.PatienceConfiguration.Timeout
-import org.scalatest.concurrent.{Eventually, ScalaFutures}
+import org.scalatest.concurrent.{Timeouts, Eventually, ScalaFutures}
import org.scalatest.mock.MockitoSugar
import org.scalatest.time.{Milliseconds, Seconds, Span}
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}
@@ -30,7 +30,7 @@
class TaskManagerSpec extends FunSpec with Matchers with MockitoSugar
with BeforeAndAfter with ScalaFutures with UncaughtExceptionSuppression
- with Eventually
+ with Eventually with Timeouts
{
private var taskManager: TaskManager = _
implicit override val patienceConfig = PatienceConfig(
@@ -48,18 +48,37 @@
describe("TaskManager") {
describe("#add") {
- // TODO: How to verify the (Runnable, Promise[_]) stored in private queue?
+ it("should throw an exception if not started") {
+ intercept[AssertionError] {
+ taskManager.add {}
+ }
+ }
+
+ it("should throw an exception if more tasks are added than max task size") {
+ val taskManager = new TaskManager(maximumWorkers = 1, maxTasks = 1)
+
+ taskManager.start()
+
+ // Should fail from having too many tasks added
+ intercept[RejectedExecutionException] {
+ for (i <- 1 to 500) taskManager.add {}
+ }
+ }
it("should return a Future[_] based on task provided") {
+ taskManager.start()
+
// Cannot check inner Future type due to type erasure
taskManager.add { } shouldBe an [Future[_]]
+
+ taskManager.stop()
}
it("should work for a task that returns nothing") {
- val f = taskManager.add { }
-
taskManager.start()
+ val f = taskManager.add { }
+
whenReady(f) { result =>
result shouldBe a [BoxedUnit]
taskManager.stop()
@@ -67,11 +86,11 @@
}
it("should construct a Runnable that invokes a Promise on success") {
+ taskManager.start()
+
val returnValue = 3
val f = taskManager.add { returnValue }
- taskManager.start()
-
whenReady(f) { result =>
result should be (returnValue)
taskManager.stop()
@@ -79,38 +98,26 @@
}
it("should construct a Runnable that invokes a Promise on failure") {
+ taskManager.start()
+
val error = new Throwable("ERROR")
val f = taskManager.add { throw error }
- taskManager.start()
-
whenReady(f.failed) { result =>
result should be (error)
taskManager.stop()
}
}
- }
- describe("#tasks") {
- it("should return a sequence of Runnables not yet executed") {
- // TODO: Investigate how to validate tasks better than just a count
- for (x <- 1 to 50) taskManager.add { }
+ it("should not block when adding more tasks than available threads") {
+ val taskManager = new TaskManager(maximumWorkers = 1)
- taskManager.tasks should have size 50
- }
- }
-
- describe("#thread") {
- it("should return Some(Thread) if running") {
taskManager.start()
- taskManager.thread should not be (None)
-
- taskManager.stop()
- }
-
- it("should return None if not running") {
- taskManager.thread should be (None)
+ failAfter(Span(100, Milliseconds)) {
+ taskManager.add { while (true) { Thread.sleep(1) } }
+ taskManager.add { while (true) { Thread.sleep(1) } }
+ }
}
}
@@ -119,22 +126,27 @@
taskManager.size should be (0)
}
- it("should be one when a new task has been added") {
- taskManager.add {}
+ it("should reflect queued tasks and executing tasks") {
+ val taskManager = new TaskManager(maximumWorkers = 1)
+ taskManager.start()
- taskManager.size should be (1)
- }
-
- it("should be zero when the only task is currently being executed") {
+ // Fill up the task manager and then add another task to the queue
+ taskManager.add { while (true) { Thread.sleep(1000) } }
taskManager.add { while (true) { Thread.sleep(1000) } }
+ taskManager.size should be (2)
+ }
+
+ it("should be one if there is only one executing task and no queued ones") {
taskManager.start()
+ taskManager.add { while (true) { Thread.sleep(1000) } }
+
// Wait until task is being executed to check if the task is still in
// the queue
while (!taskManager.isExecutingTask) Thread.sleep(1)
- taskManager.size should be (0)
+ taskManager.size should be (1)
taskManager.stop()
}
@@ -145,17 +157,22 @@
taskManager.hasTaskInQueue should be (false)
}
- it("should be true when one task has been added but not started") {
- taskManager.add {}
+ it("should be true where there are tasks remaining in the queue") {
+ val taskManager = new TaskManager(maximumWorkers = 1)
+ taskManager.start()
+
+ // Fill up the task manager and then add another task to the queue
+ taskManager.add { while (true) { Thread.sleep(1000) } }
+ taskManager.add { while (true) { Thread.sleep(1000) } }
taskManager.hasTaskInQueue should be (true)
}
it("should be false when the only task is currently being executed") {
- taskManager.add { while (true) { Thread.sleep(1000) } }
-
taskManager.start()
+ taskManager.add { while (true) { Thread.sleep(1000) } }
+
// Wait until task is being executed to check if the task is still in
// the queue
while (!taskManager.isExecutingTask) Thread.sleep(1)
@@ -194,55 +211,16 @@
}
}
- describe("#currentTask") {
- it("should be None when there are no tasks") {
- taskManager.currentTask should be (None)
- }
-
- it("should be None when there are tasks, but none are running") {
- taskManager.add { }
-
- taskManager.currentTask should be (None)
- }
-
- it("should be Some(...) when a task is being executed") {
- taskManager.add { while (true) { Thread.sleep(1000) } }
- taskManager.start()
-
- // Wait until executing task
- while (!taskManager.isExecutingTask) Thread.sleep(1)
-
- taskManager.currentTask should not be (None)
-
- taskManager.stop()
- }
- }
-
- describe("#isRunning") {
- it("should be false when not started") {
- taskManager.isRunning should be (false)
- }
-
- it("should be true after being started") {
- taskManager.start()
- taskManager.isRunning should be (true)
- taskManager.stop()
- }
-
- it("should be false after being stopped") {
- taskManager.start(); taskManager.stop()
- taskManager.isRunning should be (false)
- }
- }
-
describe("#await") {
it("should block until all tasks are completed") {
- // TODO: Need better way to ensure tasks are still running while
- // awaiting their return
- for (x <- 1 to 50) taskManager.add { Thread.sleep(1) }
+ val taskManager = new TaskManager(maximumWorkers = 1, maxTasks = 500)
taskManager.start()
+ // TODO: Need better way to ensure tasks are still running while
+ // awaiting their return
+ for (x <- 1 to 500) taskManager.add { Thread.sleep(1) }
+
assume(taskManager.hasTaskInQueue)
taskManager.await()
@@ -254,10 +232,10 @@
}
describe("#start") {
- it("should create an internal thread and start it") {
+ it("should create an internal thread pool executor") {
taskManager.start()
- taskManager.thread should not be (None)
+ taskManager.executor should not be (None)
taskManager.stop()
}
@@ -267,11 +245,11 @@
it("should stop & erase the old internal thread and create a new one") {
taskManager.start()
- val oldThread = taskManager.thread
+ val oldExecutor = taskManager.executor
taskManager.restart()
- taskManager.thread should not be (oldThread)
+ taskManager.executor should not be (oldExecutor)
taskManager.stop()
}
@@ -279,8 +257,8 @@
describe("#stop") {
it("should attempt to interrupt the currently-running task") {
- val f = taskManager.add { while (true) { Thread.sleep(1000) } }
taskManager.start()
+ val f = taskManager.add { while (true) { Thread.sleep(1000) } }
// Wait for the task to start
while (!taskManager.isExecutingTask) Thread.sleep(1)
@@ -295,15 +273,17 @@
}
}
- it("should kill the thread if interrupts failed and kill enabled") {
- val f = taskManager.add { var x = 0; while (true) { x += 1 } }
+ // TODO: Refactoring task manager to be parallelizable broke this ability
+ // so this will need to be reimplemented or abandoned
+ ignore("should kill the thread if interrupts failed and kill enabled") {
taskManager.start()
+ val f = taskManager.add { var x = 0; while (true) { x += 1 } }
// Wait for the task to start
while (!taskManager.isExecutingTask) Thread.sleep(1)
// Kill the task
- taskManager.stop(true, 0)
+ taskManager.stop()
// Future should return ThreadDeath when killed
whenReady(f.failed) { result =>
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala b/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala
index 6b1b170..0e71a55 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala
@@ -79,6 +79,11 @@
parser.accepts("magic-url", "path to a magic jar")
.withRequiredArg().ofType(classOf[String])
+ private val _max_interpreter_threads = parser.accepts(
+ "max-interpreter-threads",
+ "total number of worker threads to use to execute code"
+ ).withRequiredArg().ofType(classOf[Int])
+
private val options = parser.parse(args: _*)
/*
@@ -130,7 +135,8 @@
.flatMap(list => if (list.isEmpty) None else Some(list)),
"spark_configuration" -> getAll(_spark_configuration)
.map(list => KeyValuePairUtils.keyValuePairSeqToString(list))
- .flatMap(str => if (str.nonEmpty) Some(str) else None)
+ .flatMap(str => if (str.nonEmpty) Some(str) else None),
+ "max_interpreter_threads" -> get(_max_interpreter_threads)
).flatMap(removeEmptyOptions).asInstanceOf[Map[String, AnyRef]].asJava)
commandLineConfig.withFallback(profileConfig).withFallback(ConfigFactory.load)
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala b/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
index 6271459..f3db572 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
@@ -73,7 +73,7 @@
// Initialize components needed elsewhere
val (commStorage, commRegistrar, commManager, interpreter,
- kernelInterpreter, kernel, sparkContext, dependencyDownloader,
+ kernel, sparkContext, dependencyDownloader,
magicLoader, responseMap) =
initializeComponents(
config = config,
@@ -81,7 +81,7 @@
actorLoader = actorLoader
)
this.sparkContext = sparkContext
- this.interpreters ++= Seq(interpreter, kernelInterpreter)
+ this.interpreters ++= Seq(interpreter)
// Initialize our handlers that take care of processing messages
initializeHandlers(
@@ -96,8 +96,7 @@
// Initialize our hooks that handle various JVM events
initializeHooks(
- interpreter = interpreter,
- kernelInterpreter = kernelInterpreter
+ interpreter = interpreter
)
logger.debug("Initializing security manager")
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala b/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
index 7f6c511..9011193 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
@@ -16,7 +16,6 @@
package com.ibm.spark.boot.layer
-import java.io.File
import java.util.concurrent.ConcurrentHashMap
import akka.actor.ActorRef
@@ -31,15 +30,12 @@
import com.ibm.spark.magic.MagicLoader
import com.ibm.spark.magic.builtin.BuiltinLoader
import com.ibm.spark.magic.dependencies.DependencyMap
-import com.ibm.spark.utils.{KeyValuePairUtils, LogLike}
+import com.ibm.spark.utils.{TaskManager, KeyValuePairUtils, LogLike}
import com.typesafe.config.Config
-import joptsimple.util.KeyValuePair
import org.apache.spark.{SparkContext, SparkConf}
import scala.collection.JavaConverters._
-import scala.tools.nsc.Settings
-import scala.tools.nsc.interpreter.JPrintWriter
import scala.util.Try
/**
@@ -56,7 +52,7 @@
*/
def initializeComponents(
config: Config, appName: String, actorLoader: ActorLoader
- ): (CommStorage, CommRegistrar, CommManager, Interpreter, Interpreter,
+ ): (CommStorage, CommRegistrar, CommManager, Interpreter,
Kernel, SparkContext, DependencyDownloader, MagicLoader,
collection.mutable.Map[String, ActorRef])
}
@@ -80,17 +76,16 @@
val (commStorage, commRegistrar, commManager) =
initializeCommObjects(actorLoader)
val interpreter = initializeInterpreter(config)
- val kernelInterpreter = initializeKernelInterpreter(config, interpreter)
val sparkContext = initializeSparkContext(
config, appName, actorLoader, interpreter)
val dependencyDownloader = initializeDependencyDownloader(config)
val magicLoader = initializeMagicLoader(
- config, interpreter, kernelInterpreter, sparkContext, dependencyDownloader)
+ config, interpreter, sparkContext, dependencyDownloader)
val kernel = initializeKernel(
- actorLoader, interpreter, kernelInterpreter, commManager, magicLoader)
+ actorLoader, interpreter, commManager, magicLoader)
val responseMap = initializeResponseMap()
- (commStorage, commRegistrar, commManager, interpreter, kernelInterpreter,
- kernel, sparkContext, dependencyDownloader, magicLoader, responseMap)
+ (commStorage, commRegistrar, commManager, interpreter, kernel,
+ sparkContext, dependencyDownloader, magicLoader, responseMap)
}
private def initializeCommObjects(actorLoader: ActorLoader) = {
@@ -117,13 +112,18 @@
protected def initializeInterpreter(config: Config) = {
val interpreterArgs = config.getStringList("interpreter_args").asScala.toList
+ val maxInterpreterThreads = config.getInt("max_interpreter_threads")
- logger.info("Constructing interpreter with arguments: " +
- interpreterArgs.mkString(" "))
+ logger.info(
+ s"Constructing interpreter with $maxInterpreterThreads threads and " +
+ "with arguments: " + interpreterArgs.mkString(" "))
val interpreter = new ScalaInterpreter(interpreterArgs, Console.out)
with StandardSparkIMainProducer
- with StandardTaskManagerProducer
- with StandardSettingsProducer
+ with TaskManagerProducerLike
+ with StandardSettingsProducer {
+ override def newTaskManager(): TaskManager =
+ new TaskManager(maximumWorkers = maxInterpreterThreads)
+ }
logger.debug("Starting interpreter")
interpreter.start()
@@ -131,30 +131,6 @@
interpreter
}
- private def initializeKernelInterpreter(
- config: Config, interpreter: Interpreter
- ) = {
- val interpreterArgs =
- config.getStringList("interpreter_args").asScala.toList
-
- // TODO: Refactor this construct to a cleaner implementation (for future
- // multi-interpreter design)
- logger.info("Constructing interpreter with arguments: " +
- interpreterArgs.mkString(" "))
- val kernelInterpreter = new ScalaInterpreter(interpreterArgs, Console.out)
- with StandardTaskManagerProducer
- with StandardSettingsProducer
- with SparkIMainProducerLike {
- override def newSparkIMain(settings: Settings, out: JPrintWriter) = {
- interpreter.asInstanceOf[ScalaInterpreter].sparkIMain
- }
- }
- logger.debug("Starting interpreter")
- kernelInterpreter.start()
-
- kernelInterpreter
- }
-
// TODO: Think of a better way to test without exposing this
protected[layer] def initializeSparkContext(
config: Config, appName: String, actorLoader: ActorLoader,
@@ -283,27 +259,24 @@
private def initializeKernel(
actorLoader: ActorLoader,
- interpreterToDoBinding: Interpreter,
- interpreterToBind: Interpreter,
+ interpreter: Interpreter,
commManager: CommManager,
magicLoader: MagicLoader
) = {
- //interpreter.doQuietly {
- val kernel = new Kernel(
- actorLoader, interpreterToBind, commManager, magicLoader
- )
- interpreterToDoBinding.bind(
- "kernel", "com.ibm.spark.kernel.api.Kernel",
- kernel, List( """@transient implicit""")
- )
- //}
+ val kernel = new Kernel(actorLoader, interpreter, commManager, magicLoader)
+ interpreter.doQuietly {
+ interpreter.bind(
+ "kernel", "com.ibm.spark.kernel.api.Kernel",
+ kernel, List( """@transient implicit""")
+ )
+ }
magicLoader.dependencyMap.setKernel(kernel)
kernel
}
private def initializeMagicLoader(
- config: Config, interpreter: Interpreter, kernelInterpreter: Interpreter, sparkContext: SparkContext,
+ config: Config, interpreter: Interpreter, sparkContext: SparkContext,
dependencyDownloader: DependencyDownloader
) = {
logger.debug("Constructing magic loader")
@@ -311,8 +284,8 @@
logger.debug("Building dependency map")
val dependencyMap = new DependencyMap()
.setInterpreter(interpreter)
+ .setKernelInterpreter(interpreter) // This is deprecated
.setSparkContext(sparkContext)
- .setKernelInterpreter(kernelInterpreter)
.setDependencyDownloader(dependencyDownloader)
logger.debug("Creating BuiltinLoader")
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala b/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
index 7e83188..2f23599 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
@@ -29,11 +29,8 @@
* Initializes and registers all hooks.
*
* @param interpreter The main interpreter
- * @param kernelInterpreter The interpreter bound to the kernel instance
*/
- def initializeHooks(
- interpreter: Interpreter, kernelInterpreter: Interpreter
- ): Unit
+ def initializeHooks(interpreter: Interpreter): Unit
}
/**
@@ -46,18 +43,13 @@
* Initializes and registers all hooks.
*
* @param interpreter The main interpreter
- * @param kernelInterpreter The interpreter bound to the kernel instance
*/
- def initializeHooks(
- interpreter: Interpreter, kernelInterpreter: Interpreter
- ): Unit = {
- registerInterruptHook(interpreter, kernelInterpreter)
+ def initializeHooks(interpreter: Interpreter): Unit = {
+ registerInterruptHook(interpreter)
registerShutdownHook()
}
- private def registerInterruptHook(
- interpreter: Interpreter, kernelInterpreter: Interpreter
- ): Unit = {
+ private def registerInterruptHook(interpreter: Interpreter): Unit = {
val self = this
import sun.misc.{Signal, SignalHandler}
@@ -73,7 +65,6 @@
if (currentTime - lastSignalReceived > MaxSignalTime) {
logger.info("Resetting code execution!")
interpreter.interrupt()
- kernelInterpreter.interrupt()
// TODO: Cancel group representing current code execution
//sparkContext.cancelJobGroup()
diff --git a/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala b/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala
index 303037b..62de035 100644
--- a/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala
+++ b/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala
@@ -27,6 +27,19 @@
class CommandLineOptionsSpec extends FunSpec with Matchers {
describe("CommandLineOptions") {
+ describe("when received --max-interpreter-threads=<int>") {
+ it("should set the configuration to the specified value") {
+ val expected = 999
+ val options = new CommandLineOptions(
+ s"--max-interpreter-threads=$expected" :: Nil
+ )
+
+ val actual = options.toConfig.getInt("max_interpreter_threads")
+
+ actual should be (expected)
+ }
+ }
+
describe("when received --help") {
it("should set the help flag to true") {
val options = new CommandLineOptions("--help" :: Nil)
diff --git a/resources/compile/reference.conf b/resources/compile/reference.conf
index a677a96..376c9ed 100644
--- a/resources/compile/reference.conf
+++ b/resources/compile/reference.conf
@@ -48,3 +48,6 @@
spark_configuration = ""
spark_configuration = ${?SPARK_CONFIGURATION}
+
+max_interpreter_threads = 4
+max_interpreter_threads = ${?MAX_INTERPRETER_THREADS}
\ No newline at end of file
diff --git a/resources/test/reference.conf b/resources/test/reference.conf
index 8c47069..75ccbe8 100644
--- a/resources/test/reference.conf
+++ b/resources/test/reference.conf
@@ -49,3 +49,5 @@
spark_configuration = ""
spark_configuration = ${?SPARK_CONFIGURATION}
+max_interpreter_threads = 4
+max_interpreter_threads = ${?MAX_INTERPRETER_THREADS}