Merge pull request #107 from ibm-et/ParallelizeTaskManager

Parallelize task manager
diff --git a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
index e8db370..e91a2f2 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
@@ -53,6 +53,7 @@
    * Sets the Interpreter for this map.
    * @param interpreter The new Interpreter
    */
+  //@deprecated("Use setInterpreter with IncludeInterpreter!", "2015.05.06")
   def setKernelInterpreter(interpreter: Interpreter) = {
     internalMap(typeOf[IncludeKernelInterpreter]) =
       PartialFunction[Magic, Unit](
@@ -123,7 +124,7 @@
 
   /**
    * Sets the MagicLoader for this map.
-   * @param kernel The new Kernel
+   * @param magicLoader The new MagicLoader
    */
   def setMagicLoader(magicLoader: MagicLoader) = {
     internalMap(typeOf[IncludeMagicLoader]) =
diff --git a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
index 60faf50..de19c07 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
@@ -19,6 +19,7 @@
 import com.ibm.spark.interpreter.Interpreter
 import com.ibm.spark.magic.Magic
 
+//@deprecated("Use IncludeInterpreter instead!", "2015.05.06")
 trait IncludeKernelInterpreter {
   this: Magic =>
 
diff --git a/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala b/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
index 55b1b97..c3a43fc 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
@@ -1,93 +1,104 @@
-/*
- * Copyright 2014 IBM Corp.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
 package com.ibm.spark.utils
 
-import java.util.UUID
-import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.atomic.AtomicInteger
 
-import scala.concurrent.{Future, Promise, promise}
+import org.slf4j.LoggerFactory
+
+import scala.concurrent.{promise, Future}
+import java.util.concurrent._
 
 import com.ibm.spark.security.KernelSecurityManager._
+import TaskManager._
+
+import scala.util.Try
 
 /**
- * Represents a generic manager of Runnable tasks that will be executed in a
- * separate thread (created inside the manager). Facilitates running tasks and
- * provides a method to kill
+ * Represents a processor of tasks that has X worker threads dedicated to
+ * executing the tasks.
+ *
+ * @param threadGroup The thread group to use with all worker threads
+ * @param minimumWorkers The number of workers to spawn initially and keep
+ *                       alive even when idle
+ * @param maximumWorkers The max number of worker threads to spawn, defaulting
+ *                   to the number of processors on the machine
+ * @param keepAliveTime The maximum time in milliseconds for workers to remain
+ *                      idle before shutting down
  */
 class TaskManager(
-  taskThreadGroup: ThreadGroup = new ThreadGroup(RestrictedGroupName),
-  taskThreadName: String = "task-" + UUID.randomUUID().toString
+  private val threadGroup: ThreadGroup = DefaultThreadGroup,
+  private val maxTasks: Int = DefaultMaxTasks,
+  private val minimumWorkers: Int = DefaultMinimumWorkers,
+  private val maximumWorkers: Int = DefaultMaximumWorkers,
+  private val keepAliveTime: Long = DefaultKeepAliveTime
 ) {
-  // Maximum time to wait (in milliseconds) before forcefully stopping this
-  // thread when an interrupt fails
-  private val InterruptTimeout = 5 * 1000
-  private val _queueCapacity = 200
-  private val _taskQueue =
-    new ArrayBlockingQueue[(Runnable, Promise[_])](_queueCapacity)
-  private var _taskThread: TaskThread = _
+  protected val logger = LoggerFactory.getLogger(this.getClass.getName)
 
-  private val _currentPromise: AtomicReference[Promise[_]] =
-    new AtomicReference[Promise[_]]()
+  private class TaskManagerThreadFactory extends ThreadFactory {
+    override def newThread(r: Runnable): Thread = {
+      val thread = new Thread(threadGroup, r)
 
-  private class TaskThread extends Thread(taskThreadGroup, taskThreadName) {
-    private[TaskManager] var _currentTask: Runnable = _
-    private[TaskManager] var _running = false
+      logger.trace(s"Creating new thread named ${thread.getName}!")
 
-    override def start(): Unit = {
-      _running = true
-      super.start()
-    }
-
-    /**
-     * Main execution loop of the task thread.
-     *
-     * Pulls tasks from an internal queue to be processed sequentially.
-     */
-    override def run(): Unit = {
-      while (_running) {
-        val element = _taskQueue.poll(1L, TimeUnit.MILLISECONDS)
-        if (element != null) {
-          _currentTask = element._1
-          _currentPromise.set(element._2)
-
-          if (_currentTask != null) _currentTask.run()
-
-          _currentTask = null
-          _currentPromise.set(null)
-        }
-      }
-    }
-
-    /**
-     * Marks the internal flag for running new tasks to false.
-     */
-    def cancel(): Unit = {
-      _running = false
+      thread
     }
   }
 
-  /**
-   * Represents the internal thread used to process tasks.
-   *
-   * @return The Thread instance wrapped in an Option, or None if not started
-   */
-  def thread: Option[Thread] =
-    if (_taskThread != null) Some(_taskThread) else None
+  private[utils] class ScalingThreadPoolExecutor extends ThreadPoolExecutor(
+    minimumWorkers,
+    maximumWorkers,
+    keepAliveTime,
+    TimeUnit.MILLISECONDS,
+    taskQueue,
+    taskManagerThreadFactory
+  ) {
+    protected val logger = LoggerFactory.getLogger(this.getClass.getName)
+
+    /** Used to keep track of tasks separately from the task queue. */
+    private val taskCount = new AtomicInteger(0)
+
+    /**
+     * Syncs the core pool size of the executor with the current number of
+     * tasks, using the minimum worker size and maximum worker size as the
+     * bounds.
+     */
+    private def syncPoolLimits(): Unit = {
+      val totalTasks = taskCount.get()
+      val newCorePoolSize =
+        math.max(minimumWorkers, math.min(totalTasks, maximumWorkers))
+
+      logger.trace(Seq(
+        s"Task execution count is $totalTasks!",
+        s"Updating core pool size to $newCorePoolSize!"
+      ).mkString(" "))
+      executor.foreach(_.setCorePoolSize(newCorePoolSize))
+    }
+
+    override def execute(r: Runnable): Unit = {
+      synchronized {
+        if (taskCount.incrementAndGet() > maximumWorkers)
+          logger.warn(s"Exceeded $maximumWorkers workers during processing!")
+
+        syncPoolLimits()
+      }
+
+      super.execute(r)
+    }
+
+    override def afterExecute(r: Runnable, t: Throwable): Unit = {
+      super.afterExecute(r, t)
+
+      synchronized {
+        taskCount.decrementAndGet()
+        syncPoolLimits()
+      }
+    }
+  }
+
+  private val taskManagerThreadFactory = new TaskManagerThreadFactory
+  private val taskQueue = new ArrayBlockingQueue[Runnable](maxTasks)
+
+  @volatile
+  private[utils] var executor: Option[ScalingThreadPoolExecutor] = None
 
   /**
    * Adds a new task to the list to execute.
@@ -97,189 +108,130 @@
    * @return Future representing the return value (or error) from the task
    */
   def add[T <: Any](taskFunction: => T): Future[T] = {
+    assert(executor.nonEmpty, "Task manager not started!")
+
     val taskPromise = promise[T]()
 
     // Construct runnable that completes the promise
-    _taskQueue.add((new Runnable {
-      override def run(): Unit =
+    logger.trace(s"Queueing new task to be processed!")
+    executor.foreach(_.execute(new Runnable {
+      override def run(): Unit = {
+        var threadName: String = "???"
         try {
+          threadName = Try(Thread.currentThread().getName).getOrElse(threadName)
+          logger.trace(s"(Thread $threadName) Executing task!")
           val result = taskFunction
+
+          logger.trace(s"(Thread $threadName) Task finished successfully!")
           taskPromise.success(result)
         } catch {
-          case ex: Throwable => taskPromise.tryFailure(ex)
+          case ex: Throwable =>
+            val exName = ex.getClass.getName
+            val exMessage = Option(ex.getLocalizedMessage).getOrElse("???")
+            logger.trace(
+              s"(Thread $threadName) Task failed: ($exName) = $exMessage")
+            taskPromise.tryFailure(ex)
         }
-    }, taskPromise))
+      }
+    }))
 
     taskPromise.future
   }
 
   /**
-   * Returns the count of tasks including the currently-running one.
+   * Returns the count of tasks including the currently-running ones.
    *
    * @return The count of tasks
    */
-  def size: Int = _taskQueue.size()
+  def size: Int = taskQueue.size() + executor.map(_.getActiveCount).getOrElse(0)
 
   /**
    * Returns whether or not there is a task in the queue to be processed.
    *
    * @return True if the internal queue is not empty, otherwise false
    */
-  def hasTaskInQueue: Boolean = !_taskQueue.isEmpty
-
-  /**
-   * Returns the current executing task.
-   *
-   * @return The current task or None if no task is running
-   */
-  def currentTask: Option[Runnable] =
-    if (_taskThread != null && _taskThread._currentTask != null)
-      Some(_taskThread._currentTask)
-    else None
-
-  /**
-   * Returns the sequence of tasks.
-   *
-   * @return The sequence of tasks as Runnables
-   */
-  def tasks: Seq[Runnable] = _taskQueue.toArray.map(_.asInstanceOf[(Runnable, _)]._1)
+  def hasTaskInQueue: Boolean = !taskQueue.isEmpty
 
   /**
    * Whether or not there is a task being executed currently.
    *
    * @return True if there is a task being executed, otherwise false
    */
-  def isExecutingTask: Boolean = currentTask.nonEmpty
-
-  /**
-   * Whether or not the task manager is processing new tasks.
-   *
-   * @return True if the manager is capable of consuming more tasks,
-   *         otherwise false
-   */
-  def isRunning: Boolean = _taskThread != null && _taskThread._running
+  def isExecutingTask: Boolean = executor.exists(_.getActiveCount > 0)
 
   /**
    * Block execution (by sleeping) until all tasks currently queued up for
    * execution are processed.
    */
   def await(): Unit =
-    while (hasTaskInQueue || isExecutingTask) Thread.sleep(1)
+    while (!taskQueue.isEmpty || isExecutingTask) Thread.sleep(1)
 
   /**
-   * Starts the task manager (begins processing tasks). Creates a new thread
+   * Starts the task manager (begins processing tasks). Creates X new threads
    * in the process.
    */
-  def start(): Unit = startTaskProcessingThread()
+  def start(): Unit = {
+    logger.trace(
+      s"""
+         |Initializing with the following settings:
+         |- $minimumWorkers core worker pool
+         |- $maximumWorkers maximum workers
+         |- $keepAliveTime milliseconds keep alive time
+       """.stripMargin.trim)
+    executor = Some(new ScalingThreadPoolExecutor)
+  }
 
   /**
    * Restarts internal processing of tasks (removing current task).
    */
-  def restart(): Unit = restartTaskProcessingThread()
+  def restart(): Unit = {
+    stop()
+    start()
+  }
 
   /**
    * Stops internal processing of tasks.
-   *
-   * @param killThread Whether to kill the thread processing tasks if not it
-   *                   is not responding to interrupts
-   * @param killTimeout The period of time (in milliseconds) to wait before
-   *                    attempting to kill the thread processing tasks
    */
-  def stop(
-    killThread: Boolean = true,
-    killTimeout: Int = InterruptTimeout
-  ): Unit = killTaskProcessingThread(killThread, killTimeout)
-
-  /**
-   * Creates a new task-processing thread and starts it.
-   */
-  private def startTaskProcessingThread(): Unit = {
-    _taskThread = new TaskThread
-    _taskThread.start()
-  }
-
-  /**
-   * Attempts to cancel/interrupt the task manager's internal task-processing
-   * thread. If unable to interrupt, will be forcefully stopped after the
-   * interrupt timeout threshold is reached.
-   *
-   * @param killThread Whether to kill the thread processing tasks if not it
-   *                   is not responding to interrupts
-   * @param killTimeout The period of time (in milliseconds) to wait before
-   *                    attempting to kill the thread processing tasks
-   */
-  private def killTaskProcessingThread(
-    killThread: Boolean = true,
-    killTimeout: Int = InterruptTimeout
-  ): Unit = {
-    // NOTE: Dirty hack to suppress deprecation warnings
-    // See https://issues.scala-lang.org/browse/SI-7934 for discussion
-    object Thread {
-      @deprecated("", "") class Killer {
-        def stop() = try { _taskThread.stop() }
-      }
-
-      object Killer extends Killer
-    }
-
-    _taskThread.cancel()
-    _taskThread.interrupt()
-
-    runIfTimeout(
-      killTimeout,
-      _taskThread.isAlive && killThread,
-      // Still available (JDK8 now calls stop0 directly)
-      //
-      // Used because of discussion on:
-      // https://issues.scala-lang.org/browse/SI-6302
-      {
-        //_taskThread.stop()
-        Thread.Killer.stop()
-        val currentPromise = _currentPromise.get()
-        if (currentPromise != null)
-          currentPromise.tryFailure(new ThreadDeath())
-      }
-    )
-
-    _taskThread = null
-  }
-
-  /**
-   * Shuts down the current thread used to execute tasks and starts a new
-   * thread in its place.
-   */
-  private def restartTaskProcessingThread(): Unit = {
-    killTaskProcessingThread()
-    startTaskProcessingThread()
-  }
-
-  /**
-   * Execute body if the millisecond timeout is reached and the condition
-   * still holds true.
-   *
-   * @param milliseconds The timeout limit in milliseconds
-   * @param condition The condition to test up until executing the body
-   * @param body The body to execute
-   *
-   * @return Whether or not the body was executed
-   */
-  private def runIfTimeout(
-    milliseconds: Int,
-    condition: => Boolean,
-    body: => Unit
-  ): Boolean = {
-    val endTime: Long = milliseconds + java.lang.System.currentTimeMillis()
-
-    while (java.lang.System.currentTimeMillis() < endTime) {
-      // Exit if our condition suddenly goes south
-      if (!condition) return false
-
-      // Allow check to be interrupted
-      if (Thread.interrupted())
-        throw new InterruptedException()
-    }
-
-    if (condition) { body; true } else false
+  def stop(): Unit = {
+    executor.foreach(_.shutdownNow())
+    executor = None
   }
 }
 
+/**
+ * Represents constants associated with the task manager.
+ */
+object TaskManager {
+  /** The default thread group to use with all worker threads. */
+  val DefaultThreadGroup = new ThreadGroup(RestrictedGroupName)
+
+  /** The default number of maximum tasks accepted by the task manager. */
+  val DefaultMaxTasks = 200
+
+  /**
+   * The default number of workers to spawn initially and keep alive
+   * even when idle.
+   */
+  val DefaultMinimumWorkers = 1
+
+  /** The default maximum number of workers to spawn. */
+  val DefaultMaximumWorkers = Runtime.getRuntime.availableProcessors()
+
+  /** The default timeout in milliseconds for workers waiting for tasks. */
+  val DefaultKeepAliveTime = 1000
+
+  /**
+   * The default timeout in milliseconds to wait before stopping a thread
+   * if it cannot be interrupted.
+   */
+  val InterruptTimeout = 5000
+
+  /** The maximum time to wait to add a task to the queue in milliseconds. */
+  val MaximumTaskQueueTimeout = 10000
+
+  /**
+   * The maximum time in milliseconds to wait to queue up a thread in the
+   * thread factory.
+   */
+  val MaximumThreadQueueTimeout = 10000
+}
\ No newline at end of file
diff --git a/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala b/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
index e453dcb..a4184d4 100644
--- a/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
+++ b/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
@@ -16,10 +16,10 @@
 
 package com.ibm.spark.utils
 
-import java.util.concurrent.ExecutionException
+import java.util.concurrent.{RejectedExecutionException, ExecutionException}
 
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
-import org.scalatest.concurrent.{Eventually, ScalaFutures}
+import org.scalatest.concurrent.{Timeouts, Eventually, ScalaFutures}
 import org.scalatest.mock.MockitoSugar
 import org.scalatest.time.{Milliseconds, Seconds, Span}
 import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}
@@ -30,7 +30,7 @@
 
 class TaskManagerSpec extends FunSpec with Matchers with MockitoSugar
   with BeforeAndAfter with ScalaFutures with UncaughtExceptionSuppression
-  with Eventually
+  with Eventually with Timeouts
 {
   private var taskManager: TaskManager = _
   implicit override val patienceConfig = PatienceConfig(
@@ -48,18 +48,37 @@
 
   describe("TaskManager") {
     describe("#add") {
-      // TODO: How to verify the (Runnable, Promise[_]) stored in private queue?
+      it("should throw an exception if not started") {
+        intercept[AssertionError] {
+          taskManager.add {}
+        }
+      }
+
+      it("should throw an exception if more tasks are added than max task size") {
+        val taskManager = new TaskManager(maximumWorkers = 1, maxTasks = 1)
+
+        taskManager.start()
+
+        // Should fail from having too many tasks added
+        intercept[RejectedExecutionException] {
+          for (i <- 1 to 500) taskManager.add {}
+        }
+      }
 
       it("should return a Future[_] based on task provided") {
+        taskManager.start()
+
         // Cannot check inner Future type due to type erasure
         taskManager.add { } shouldBe an [Future[_]]
+
+        taskManager.stop()
       }
 
       it("should work for a task that returns nothing") {
-        val f = taskManager.add { }
-
         taskManager.start()
 
+        val f = taskManager.add { }
+
         whenReady(f) { result =>
           result shouldBe a [BoxedUnit]
           taskManager.stop()
@@ -67,11 +86,11 @@
       }
 
       it("should construct a Runnable that invokes a Promise on success") {
+        taskManager.start()
+
         val returnValue = 3
         val f = taskManager.add { returnValue }
 
-        taskManager.start()
-
         whenReady(f) { result =>
           result should be (returnValue)
           taskManager.stop()
@@ -79,38 +98,26 @@
       }
 
       it("should construct a Runnable that invokes a Promise on failure") {
+        taskManager.start()
+
         val error = new Throwable("ERROR")
         val f = taskManager.add { throw error }
 
-        taskManager.start()
-
         whenReady(f.failed) { result =>
           result should be (error)
           taskManager.stop()
         }
       }
-    }
 
-    describe("#tasks") {
-      it("should return a sequence of Runnables not yet executed") {
-        // TODO: Investigate how to validate tasks better than just a count
-        for (x <- 1 to 50) taskManager.add { }
+      it("should not block when adding more tasks than available threads") {
+        val taskManager = new TaskManager(maximumWorkers = 1)
 
-        taskManager.tasks should have size 50
-      }
-    }
-
-    describe("#thread") {
-      it("should return Some(Thread) if running") {
         taskManager.start()
 
-        taskManager.thread should not be (None)
-
-        taskManager.stop()
-      }
-
-      it("should return None if not running") {
-        taskManager.thread should be (None)
+        failAfter(Span(100, Milliseconds)) {
+          taskManager.add { while (true) { Thread.sleep(1) } }
+          taskManager.add { while (true) { Thread.sleep(1) } }
+        }
       }
     }
 
@@ -119,22 +126,27 @@
         taskManager.size should be (0)
       }
 
-      it("should be one when a new task has been added") {
-        taskManager.add {}
+      it("should reflect queued tasks and executing tasks") {
+        val taskManager = new TaskManager(maximumWorkers = 1)
+        taskManager.start()
 
-        taskManager.size should be (1)
-      }
-
-      it("should be zero when the only task is currently being executed") {
+        // Fill up the task manager and then add another task to the queue
+        taskManager.add { while (true) { Thread.sleep(1000) } }
         taskManager.add { while (true) { Thread.sleep(1000) } }
 
+        taskManager.size should be (2)
+      }
+
+      it("should be one if there is only one executing task and no queued ones") {
         taskManager.start()
 
+        taskManager.add { while (true) { Thread.sleep(1000) } }
+
         // Wait until task is being executed to check if the task is still in
         // the queue
         while (!taskManager.isExecutingTask) Thread.sleep(1)
 
-        taskManager.size should be (0)
+        taskManager.size should be (1)
 
         taskManager.stop()
       }
@@ -145,17 +157,22 @@
         taskManager.hasTaskInQueue should be (false)
       }
 
-      it("should be true when one task has been added but not started") {
-        taskManager.add {}
+      it("should be true where there are tasks remaining in the queue") {
+        val taskManager = new TaskManager(maximumWorkers = 1)
+        taskManager.start()
+
+        // Fill up the task manager and then add another task to the queue
+        taskManager.add { while (true) { Thread.sleep(1000) } }
+        taskManager.add { while (true) { Thread.sleep(1000) } }
 
         taskManager.hasTaskInQueue should be (true)
       }
 
       it("should be false when the only task is currently being executed") {
-        taskManager.add { while (true) { Thread.sleep(1000) } }
-
         taskManager.start()
 
+        taskManager.add { while (true) { Thread.sleep(1000) } }
+
         // Wait until task is being executed to check if the task is still in
         // the queue
         while (!taskManager.isExecutingTask) Thread.sleep(1)
@@ -194,55 +211,16 @@
       }
     }
 
-    describe("#currentTask") {
-      it("should be None when there are no tasks") {
-        taskManager.currentTask should be (None)
-      }
-
-      it("should be None when there are tasks, but none are running") {
-        taskManager.add { }
-
-        taskManager.currentTask should be (None)
-      }
-
-      it("should be Some(...) when a task is being executed") {
-        taskManager.add { while (true) { Thread.sleep(1000) } }
-        taskManager.start()
-
-        // Wait until executing task
-        while (!taskManager.isExecutingTask) Thread.sleep(1)
-
-        taskManager.currentTask should not be (None)
-
-        taskManager.stop()
-      }
-    }
-
-    describe("#isRunning") {
-      it("should be false when not started") {
-        taskManager.isRunning should be (false)
-      }
-
-      it("should be true after being started") {
-        taskManager.start()
-        taskManager.isRunning should be (true)
-        taskManager.stop()
-      }
-
-      it("should be false after being stopped") {
-        taskManager.start(); taskManager.stop()
-        taskManager.isRunning should be (false)
-      }
-    }
-
     describe("#await") {
       it("should block until all tasks are completed") {
-        // TODO: Need better way to ensure tasks are still running while
-        // awaiting their return
-        for (x <- 1 to 50) taskManager.add { Thread.sleep(1) }
+        val taskManager = new TaskManager(maximumWorkers = 1, maxTasks = 500)
 
         taskManager.start()
 
+        // TODO: Need better way to ensure tasks are still running while
+        // awaiting their return
+        for (x <- 1 to 500) taskManager.add { Thread.sleep(1) }
+
         assume(taskManager.hasTaskInQueue)
         taskManager.await()
 
@@ -254,10 +232,10 @@
     }
 
     describe("#start") {
-      it("should create an internal thread and start it") {
+      it("should create an internal thread pool executor") {
         taskManager.start()
 
-        taskManager.thread should not be (None)
+        taskManager.executor should not be (None)
 
         taskManager.stop()
       }
@@ -267,11 +245,11 @@
       it("should stop & erase the old internal thread and create a new one") {
         taskManager.start()
 
-        val oldThread = taskManager.thread
+        val oldExecutor = taskManager.executor
 
         taskManager.restart()
 
-        taskManager.thread should not be (oldThread)
+        taskManager.executor should not be (oldExecutor)
 
         taskManager.stop()
       }
@@ -279,8 +257,8 @@
 
     describe("#stop") {
       it("should attempt to interrupt the currently-running task") {
-        val f = taskManager.add { while (true) { Thread.sleep(1000) } }
         taskManager.start()
+        val f = taskManager.add { while (true) { Thread.sleep(1000) } }
 
         // Wait for the task to start
         while (!taskManager.isExecutingTask) Thread.sleep(1)
@@ -295,15 +273,17 @@
         }
       }
 
-      it("should kill the thread if interrupts failed and kill enabled") {
-        val f = taskManager.add { var x = 0; while (true) { x += 1 } }
+      // TODO: Refactoring task manager to be parallelizable broke this ability
+      //       so this will need to be reimplemented or abandoned
+      ignore("should kill the thread if interrupts failed and kill enabled") {
         taskManager.start()
+        val f = taskManager.add { var x = 0; while (true) { x += 1 } }
 
         // Wait for the task to start
         while (!taskManager.isExecutingTask) Thread.sleep(1)
 
         // Kill the task
-        taskManager.stop(true, 0)
+        taskManager.stop()
 
         // Future should return ThreadDeath when killed
         whenReady(f.failed) { result =>
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala b/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala
index 6b1b170..0e71a55 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/CommandLineOptions.scala
@@ -79,6 +79,11 @@
     parser.accepts("magic-url", "path to a magic jar")
       .withRequiredArg().ofType(classOf[String])
 
+  private val _max_interpreter_threads = parser.accepts(
+    "max-interpreter-threads",
+    "total number of worker threads to use to execute code"
+  ).withRequiredArg().ofType(classOf[Int])
+
   private val options = parser.parse(args: _*)
 
   /*
@@ -130,7 +135,8 @@
           .flatMap(list => if (list.isEmpty) None else Some(list)),
         "spark_configuration" -> getAll(_spark_configuration)
           .map(list => KeyValuePairUtils.keyValuePairSeqToString(list))
-          .flatMap(str => if (str.nonEmpty) Some(str) else None)
+          .flatMap(str => if (str.nonEmpty) Some(str) else None),
+        "max_interpreter_threads" -> get(_max_interpreter_threads)
     ).flatMap(removeEmptyOptions).asInstanceOf[Map[String, AnyRef]].asJava)
 
     commandLineConfig.withFallback(profileConfig).withFallback(ConfigFactory.load)
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala b/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
index 6271459..f3db572 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
@@ -73,7 +73,7 @@
 
     // Initialize components needed elsewhere
     val (commStorage, commRegistrar, commManager, interpreter,
-      kernelInterpreter, kernel, sparkContext, dependencyDownloader,
+      kernel, sparkContext, dependencyDownloader,
       magicLoader, responseMap) =
       initializeComponents(
         config      = config,
@@ -81,7 +81,7 @@
         actorLoader = actorLoader
       )
     this.sparkContext = sparkContext
-    this.interpreters ++= Seq(interpreter, kernelInterpreter)
+    this.interpreters ++= Seq(interpreter)
 
     // Initialize our handlers that take care of processing messages
     initializeHandlers(
@@ -96,8 +96,7 @@
 
     // Initialize our hooks that handle various JVM events
     initializeHooks(
-      interpreter = interpreter,
-      kernelInterpreter = kernelInterpreter
+      interpreter = interpreter
     )
 
     logger.debug("Initializing security manager")
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala b/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
index 7f6c511..9011193 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
@@ -16,7 +16,6 @@
 
 package com.ibm.spark.boot.layer
 
-import java.io.File
 import java.util.concurrent.ConcurrentHashMap
 
 import akka.actor.ActorRef
@@ -31,15 +30,12 @@
 import com.ibm.spark.magic.MagicLoader
 import com.ibm.spark.magic.builtin.BuiltinLoader
 import com.ibm.spark.magic.dependencies.DependencyMap
-import com.ibm.spark.utils.{KeyValuePairUtils, LogLike}
+import com.ibm.spark.utils.{TaskManager, KeyValuePairUtils, LogLike}
 import com.typesafe.config.Config
-import joptsimple.util.KeyValuePair
 import org.apache.spark.{SparkContext, SparkConf}
 
 import scala.collection.JavaConverters._
 
-import scala.tools.nsc.Settings
-import scala.tools.nsc.interpreter.JPrintWriter
 import scala.util.Try
 
 /**
@@ -56,7 +52,7 @@
    */
   def initializeComponents(
     config: Config, appName: String, actorLoader: ActorLoader
-  ): (CommStorage, CommRegistrar, CommManager, Interpreter, Interpreter,
+  ): (CommStorage, CommRegistrar, CommManager, Interpreter,
     Kernel, SparkContext, DependencyDownloader, MagicLoader,
     collection.mutable.Map[String, ActorRef])
 }
@@ -80,17 +76,16 @@
     val (commStorage, commRegistrar, commManager) =
       initializeCommObjects(actorLoader)
     val interpreter = initializeInterpreter(config)
-    val kernelInterpreter = initializeKernelInterpreter(config, interpreter)
     val sparkContext = initializeSparkContext(
       config, appName, actorLoader, interpreter)
     val dependencyDownloader = initializeDependencyDownloader(config)
     val magicLoader = initializeMagicLoader(
-      config, interpreter, kernelInterpreter, sparkContext, dependencyDownloader)
+      config, interpreter, sparkContext, dependencyDownloader)
     val kernel = initializeKernel(
-      actorLoader, interpreter, kernelInterpreter, commManager, magicLoader)
+      actorLoader, interpreter, commManager, magicLoader)
     val responseMap = initializeResponseMap()
-    (commStorage, commRegistrar, commManager, interpreter, kernelInterpreter,
-      kernel, sparkContext, dependencyDownloader, magicLoader, responseMap)
+    (commStorage, commRegistrar, commManager, interpreter, kernel,
+      sparkContext, dependencyDownloader, magicLoader, responseMap)
   }
 
   private def initializeCommObjects(actorLoader: ActorLoader) = {
@@ -117,13 +112,18 @@
 
   protected def initializeInterpreter(config: Config) = {
     val interpreterArgs = config.getStringList("interpreter_args").asScala.toList
+    val maxInterpreterThreads = config.getInt("max_interpreter_threads")
 
-    logger.info("Constructing interpreter with arguments: " +
-      interpreterArgs.mkString(" "))
+    logger.info(
+      s"Constructing interpreter with $maxInterpreterThreads threads and " +
+      "with arguments: " + interpreterArgs.mkString(" "))
     val interpreter = new ScalaInterpreter(interpreterArgs, Console.out)
       with StandardSparkIMainProducer
-      with StandardTaskManagerProducer
-      with StandardSettingsProducer
+      with TaskManagerProducerLike
+      with StandardSettingsProducer {
+      override def newTaskManager(): TaskManager =
+        new TaskManager(maximumWorkers = maxInterpreterThreads)
+    }
 
     logger.debug("Starting interpreter")
     interpreter.start()
@@ -131,30 +131,6 @@
     interpreter
   }
 
-  private def initializeKernelInterpreter(
-    config: Config, interpreter: Interpreter
-  ) = {
-    val interpreterArgs =
-      config.getStringList("interpreter_args").asScala.toList
-
-    // TODO: Refactor this construct to a cleaner implementation (for future
-    //       multi-interpreter design)
-    logger.info("Constructing interpreter with arguments: " +
-      interpreterArgs.mkString(" "))
-    val kernelInterpreter = new ScalaInterpreter(interpreterArgs, Console.out)
-      with StandardTaskManagerProducer
-      with StandardSettingsProducer
-      with SparkIMainProducerLike {
-      override def newSparkIMain(settings: Settings, out: JPrintWriter) = {
-        interpreter.asInstanceOf[ScalaInterpreter].sparkIMain
-      }
-    }
-    logger.debug("Starting interpreter")
-    kernelInterpreter.start()
-
-    kernelInterpreter
-  }
-
   // TODO: Think of a better way to test without exposing this
   protected[layer] def initializeSparkContext(
     config: Config, appName: String, actorLoader: ActorLoader,
@@ -283,27 +259,24 @@
 
   private def initializeKernel(
     actorLoader: ActorLoader,
-    interpreterToDoBinding: Interpreter,
-    interpreterToBind: Interpreter,
+    interpreter: Interpreter,
     commManager: CommManager,
     magicLoader: MagicLoader
   ) = {
-    //interpreter.doQuietly {
-    val kernel = new Kernel(
-      actorLoader, interpreterToBind, commManager, magicLoader
-    )
-    interpreterToDoBinding.bind(
-      "kernel", "com.ibm.spark.kernel.api.Kernel",
-      kernel, List( """@transient implicit""")
-    )
-    //}
+    val kernel = new Kernel(actorLoader, interpreter, commManager, magicLoader)
+    interpreter.doQuietly {
+      interpreter.bind(
+        "kernel", "com.ibm.spark.kernel.api.Kernel",
+        kernel, List( """@transient implicit""")
+      )
+    }
     magicLoader.dependencyMap.setKernel(kernel)
 
     kernel
   }
 
   private def initializeMagicLoader(
-    config: Config, interpreter: Interpreter, kernelInterpreter: Interpreter, sparkContext: SparkContext,
+    config: Config, interpreter: Interpreter, sparkContext: SparkContext,
     dependencyDownloader: DependencyDownloader
   ) = {
     logger.debug("Constructing magic loader")
@@ -311,8 +284,8 @@
     logger.debug("Building dependency map")
     val dependencyMap = new DependencyMap()
       .setInterpreter(interpreter)
+      .setKernelInterpreter(interpreter) // This is deprecated
       .setSparkContext(sparkContext)
-      .setKernelInterpreter(kernelInterpreter)
       .setDependencyDownloader(dependencyDownloader)
 
     logger.debug("Creating BuiltinLoader")
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala b/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
index 7e83188..2f23599 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
@@ -29,11 +29,8 @@
    * Initializes and registers all hooks.
    *
    * @param interpreter The main interpreter
-   * @param kernelInterpreter The interpreter bound to the kernel instance
    */
-  def initializeHooks(
-    interpreter: Interpreter, kernelInterpreter: Interpreter
-  ): Unit
+  def initializeHooks(interpreter: Interpreter): Unit
 }
 
 /**
@@ -46,18 +43,13 @@
    * Initializes and registers all hooks.
    *
    * @param interpreter The main interpreter
-   * @param kernelInterpreter The interpreter bound to the kernel instance
    */
-  def initializeHooks(
-    interpreter: Interpreter, kernelInterpreter: Interpreter
-  ): Unit = {
-    registerInterruptHook(interpreter, kernelInterpreter)
+  def initializeHooks(interpreter: Interpreter): Unit = {
+    registerInterruptHook(interpreter)
     registerShutdownHook()
   }
 
-  private def registerInterruptHook(
-    interpreter: Interpreter, kernelInterpreter: Interpreter
-  ): Unit = {
+  private def registerInterruptHook(interpreter: Interpreter): Unit = {
     val self = this
 
     import sun.misc.{Signal, SignalHandler}
@@ -73,7 +65,6 @@
         if (currentTime - lastSignalReceived > MaxSignalTime) {
           logger.info("Resetting code execution!")
           interpreter.interrupt()
-          kernelInterpreter.interrupt()
 
           // TODO: Cancel group representing current code execution
           //sparkContext.cancelJobGroup()
diff --git a/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala b/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala
index 303037b..62de035 100644
--- a/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala
+++ b/kernel/src/test/scala/com/ibm/spark/boot/CommandLineOptionsSpec.scala
@@ -27,6 +27,19 @@
 class CommandLineOptionsSpec extends FunSpec with Matchers {
 
   describe("CommandLineOptions") {
+    describe("when received --max-interpreter-threads=<int>") {
+      it("should set the configuration to the specified value") {
+        val expected = 999
+        val options = new CommandLineOptions(
+          s"--max-interpreter-threads=$expected" :: Nil
+        )
+
+        val actual = options.toConfig.getInt("max_interpreter_threads")
+
+        actual should be (expected)
+      }
+    }
+
     describe("when received --help") {
       it("should set the help flag to true") {
         val options = new CommandLineOptions("--help" :: Nil)
diff --git a/resources/compile/reference.conf b/resources/compile/reference.conf
index a677a96..376c9ed 100644
--- a/resources/compile/reference.conf
+++ b/resources/compile/reference.conf
@@ -48,3 +48,6 @@
 
 spark_configuration = ""
 spark_configuration = ${?SPARK_CONFIGURATION}
+
+max_interpreter_threads = 4
+max_interpreter_threads = ${?MAX_INTERPRETER_THREADS}
\ No newline at end of file
diff --git a/resources/test/reference.conf b/resources/test/reference.conf
index 8c47069..75ccbe8 100644
--- a/resources/test/reference.conf
+++ b/resources/test/reference.conf
@@ -49,3 +49,5 @@
 spark_configuration = ""
 spark_configuration = ${?SPARK_CONFIGURATION}
 
+max_interpreter_threads = 4
+max_interpreter_threads = ${?MAX_INTERPRETER_THREADS}