Refactored KernelInterpreter usage to Interpreter, fixed task manager executor usage
diff --git a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
index e8db370..e91a2f2 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/DependencyMap.scala
@@ -53,6 +53,7 @@
* Sets the Interpreter for this map.
* @param interpreter The new Interpreter
*/
+ //@deprecated("Use setInterpreter with IncludeInterpreter!", "2015.05.06")
def setKernelInterpreter(interpreter: Interpreter) = {
internalMap(typeOf[IncludeKernelInterpreter]) =
PartialFunction[Magic, Unit](
@@ -123,7 +124,7 @@
/**
* Sets the MagicLoader for this map.
- * @param kernel The new Kernel
+ * @param magicLoader The new MagicLoader
*/
def setMagicLoader(magicLoader: MagicLoader) = {
internalMap(typeOf[IncludeMagicLoader]) =
diff --git a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
index 60faf50..de19c07 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/magic/dependencies/IncludeKernelInterpreter.scala
@@ -19,6 +19,7 @@
import com.ibm.spark.interpreter.Interpreter
import com.ibm.spark.magic.Magic
+//@deprecated("Use IncludeInterpreter instead!", "2015.05.06")
trait IncludeKernelInterpreter {
this: Magic =>
diff --git a/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala b/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
index 9d29227..c3a43fc 100644
--- a/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
+++ b/kernel-api/src/main/scala/com/ibm/spark/utils/TaskManager.scala
@@ -1,18 +1,25 @@
package com.ibm.spark.utils
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.slf4j.LoggerFactory
+
import scala.concurrent.{promise, Future}
import java.util.concurrent._
import com.ibm.spark.security.KernelSecurityManager._
import TaskManager._
+import scala.util.Try
+
/**
* Represents a processor of tasks that has X worker threads dedicated to
* executing the tasks.
*
* @param threadGroup The thread group to use with all worker threads
- * @param initialWorkers The number of workers to spawn initially
- * @param maxWorkers The max number of worker threads to spawn, defaulting
+ * @param minimumWorkers The number of workers to spawn initially and keep
+ * alive even when idle
+ * @param maximumWorkers The max number of worker threads to spawn, defaulting
* to the number of processors on the machine
* @param keepAliveTime The maximum time in milliseconds for workers to remain
* idle before shutting down
@@ -20,20 +27,78 @@
class TaskManager(
private val threadGroup: ThreadGroup = DefaultThreadGroup,
private val maxTasks: Int = DefaultMaxTasks,
- private val initialWorkers: Int = DefaultInitialWorkers,
- private val maxWorkers: Int = DefaultMaxWorkers,
+ private val minimumWorkers: Int = DefaultMinimumWorkers,
+ private val maximumWorkers: Int = DefaultMaximumWorkers,
private val keepAliveTime: Long = DefaultKeepAliveTime
) {
+ protected val logger = LoggerFactory.getLogger(this.getClass.getName)
+
private class TaskManagerThreadFactory extends ThreadFactory {
override def newThread(r: Runnable): Thread = {
- new Thread(threadGroup, r)
+ val thread = new Thread(threadGroup, r)
+
+ logger.trace(s"Creating new thread named ${thread.getName}!")
+
+ thread
+ }
+ }
+
+ private[utils] class ScalingThreadPoolExecutor extends ThreadPoolExecutor(
+ minimumWorkers,
+ maximumWorkers,
+ keepAliveTime,
+ TimeUnit.MILLISECONDS,
+ taskQueue,
+ taskManagerThreadFactory
+ ) {
+ protected val logger = LoggerFactory.getLogger(this.getClass.getName)
+
+ /** Used to keep track of tasks separately from the task queue. */
+ private val taskCount = new AtomicInteger(0)
+
+ /**
+ * Syncs the core pool size of the executor with the current number of
+ * tasks, using the minimum worker size and maximum worker size as the
+ * bounds.
+ */
+ private def syncPoolLimits(): Unit = {
+ val totalTasks = taskCount.get()
+ val newCorePoolSize =
+ math.max(minimumWorkers, math.min(totalTasks, maximumWorkers))
+
+ logger.trace(Seq(
+ s"Task execution count is $totalTasks!",
+ s"Updating core pool size to $newCorePoolSize!"
+ ).mkString(" "))
+ executor.foreach(_.setCorePoolSize(newCorePoolSize))
+ }
+
+ override def execute(r: Runnable): Unit = {
+ synchronized {
+ if (taskCount.incrementAndGet() > maximumWorkers)
+ logger.warn(s"Exceeded $maximumWorkers workers during processing!")
+
+ syncPoolLimits()
+ }
+
+ super.execute(r)
+ }
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+
+ synchronized {
+ taskCount.decrementAndGet()
+ syncPoolLimits()
+ }
}
}
private val taskManagerThreadFactory = new TaskManagerThreadFactory
private val taskQueue = new ArrayBlockingQueue[Runnable](maxTasks)
- @volatile private[utils] var executor: Option[ThreadPoolExecutor] = None
+ @volatile
+ private[utils] var executor: Option[ScalingThreadPoolExecutor] = None
/**
* Adds a new task to the list to execute.
@@ -48,25 +113,37 @@
val taskPromise = promise[T]()
// Construct runnable that completes the promise
+ logger.trace(s"Queueing new task to be processed!")
executor.foreach(_.execute(new Runnable {
- override def run(): Unit =
+ override def run(): Unit = {
+ var threadName: String = "???"
try {
+ threadName = Try(Thread.currentThread().getName).getOrElse(threadName)
+ logger.trace(s"(Thread $threadName) Executing task!")
val result = taskFunction
+
+ logger.trace(s"(Thread $threadName) Task finished successfully!")
taskPromise.success(result)
} catch {
- case ex: Throwable => taskPromise.tryFailure(ex)
+ case ex: Throwable =>
+ val exName = ex.getClass.getName
+ val exMessage = Option(ex.getLocalizedMessage).getOrElse("???")
+ logger.trace(
+ s"(Thread $threadName) Task failed: ($exName) = $exMessage")
+ taskPromise.tryFailure(ex)
}
+ }
}))
taskPromise.future
}
/**
- * Returns the count of tasks including the currently-running one.
+ * Returns the count of tasks including the currently-running ones.
*
* @return The count of tasks
*/
- def size: Int = taskQueue.size()
+ def size: Int = taskQueue.size() + executor.map(_.getActiveCount).getOrElse(0)
/**
* Returns whether or not there is a task in the queue to be processed.
@@ -93,14 +170,16 @@
* Starts the task manager (begins processing tasks). Creates X new threads
* in the process.
*/
- def start(): Unit = executor = Some(new ThreadPoolExecutor(
- initialWorkers,
- maxWorkers,
- keepAliveTime,
- TimeUnit.MILLISECONDS,
- taskQueue,
- taskManagerThreadFactory
- ))
+ def start(): Unit = {
+ logger.trace(
+ s"""
+ |Initializing with the following settings:
+ |- $minimumWorkers core worker pool
+ |- $maximumWorkers maximum workers
+ |- $keepAliveTime milliseconds keep alive time
+ """.stripMargin.trim)
+ executor = Some(new ScalingThreadPoolExecutor)
+ }
/**
* Restarts internal processing of tasks (removing current task).
@@ -129,11 +208,14 @@
/** The default number of maximum tasks accepted by the task manager. */
val DefaultMaxTasks = 200
- /** The default number of workers to spawn initially. */
- val DefaultInitialWorkers = 1
+ /**
+ * The default number of workers to spawn initially and keep alive
+ * even when idle.
+ */
+ val DefaultMinimumWorkers = 1
/** The default maximum number of workers to spawn. */
- val DefaultMaxWorkers = Runtime.getRuntime.availableProcessors()
+ val DefaultMaximumWorkers = Runtime.getRuntime.availableProcessors()
/** The default timeout in milliseconds for workers waiting for tasks. */
val DefaultKeepAliveTime = 1000
@@ -145,5 +227,11 @@
val InterruptTimeout = 5000
/** The maximum time to wait to add a task to the queue in milliseconds. */
- val MaxQueueTimeout = 10000
+ val MaximumTaskQueueTimeout = 10000
+
+ /**
+ * The maximum time in milliseconds to wait to queue up a thread in the
+ * thread factory.
+ */
+ val MaximumThreadQueueTimeout = 10000
}
\ No newline at end of file
diff --git a/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala b/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
index 17e42fa..a4184d4 100644
--- a/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
+++ b/kernel-api/src/test/scala/com/ibm/spark/utils/TaskManagerSpec.scala
@@ -16,10 +16,10 @@
package com.ibm.spark.utils
-import java.util.concurrent.ExecutionException
+import java.util.concurrent.{RejectedExecutionException, ExecutionException}
import org.scalatest.concurrent.PatienceConfiguration.Timeout
-import org.scalatest.concurrent.{Eventually, ScalaFutures}
+import org.scalatest.concurrent.{Timeouts, Eventually, ScalaFutures}
import org.scalatest.mock.MockitoSugar
import org.scalatest.time.{Milliseconds, Seconds, Span}
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}
@@ -30,7 +30,7 @@
class TaskManagerSpec extends FunSpec with Matchers with MockitoSugar
with BeforeAndAfter with ScalaFutures with UncaughtExceptionSuppression
- with Eventually
+ with Eventually with Timeouts
{
private var taskManager: TaskManager = _
implicit override val patienceConfig = PatienceConfig(
@@ -48,14 +48,23 @@
describe("TaskManager") {
describe("#add") {
- // TODO: How to verify the (Runnable, Promise[_]) stored in private queue?
-
it("should throw an exception if not started") {
intercept[AssertionError] {
taskManager.add {}
}
}
+ it("should throw an exception if more tasks are added than max task size") {
+ val taskManager = new TaskManager(maximumWorkers = 1, maxTasks = 1)
+
+ taskManager.start()
+
+ // Should fail from having too many tasks added
+ intercept[RejectedExecutionException] {
+ for (i <- 1 to 500) taskManager.add {}
+ }
+ }
+
it("should return a Future[_] based on task provided") {
taskManager.start()
@@ -99,6 +108,17 @@
taskManager.stop()
}
}
+
+ it("should not block when adding more tasks than available threads") {
+ val taskManager = new TaskManager(maximumWorkers = 1)
+
+ taskManager.start()
+
+ failAfter(Span(100, Milliseconds)) {
+ taskManager.add { while (true) { Thread.sleep(1) } }
+ taskManager.add { while (true) { Thread.sleep(1) } }
+ }
+ }
}
describe("#size") {
@@ -106,18 +126,18 @@
taskManager.size should be (0)
}
- it("should be one when a there is one task in the queue") {
- val taskManager = new TaskManager(maxWorkers = 1)
+ it("should reflect queued tasks and executing tasks") {
+ val taskManager = new TaskManager(maximumWorkers = 1)
taskManager.start()
// Fill up the task manager and then add another task to the queue
taskManager.add { while (true) { Thread.sleep(1000) } }
taskManager.add { while (true) { Thread.sleep(1000) } }
- taskManager.size should be (1)
+ taskManager.size should be (2)
}
- it("should be zero when the only task is currently being executed") {
+ it("should be one if there is only one executing task and no queued ones") {
taskManager.start()
taskManager.add { while (true) { Thread.sleep(1000) } }
@@ -126,7 +146,7 @@
// the queue
while (!taskManager.isExecutingTask) Thread.sleep(1)
- taskManager.size should be (0)
+ taskManager.size should be (1)
taskManager.stop()
}
@@ -138,7 +158,7 @@
}
it("should be true where there are tasks remaining in the queue") {
- val taskManager = new TaskManager(maxWorkers = 1)
+ val taskManager = new TaskManager(maximumWorkers = 1)
taskManager.start()
// Fill up the task manager and then add another task to the queue
@@ -193,11 +213,13 @@
describe("#await") {
it("should block until all tasks are completed") {
+ val taskManager = new TaskManager(maximumWorkers = 1, maxTasks = 500)
+
taskManager.start()
// TODO: Need better way to ensure tasks are still running while
// awaiting their return
- for (x <- 1 to 50) taskManager.add { Thread.sleep(1) }
+ for (x <- 1 to 500) taskManager.add { Thread.sleep(1) }
assume(taskManager.hasTaskInQueue)
taskManager.await()
@@ -251,9 +273,11 @@
}
}
+ // TODO: Refactoring task manager to be parallelizable broke this ability
+ // so this will need to be reimplemented or abandoned
ignore("should kill the thread if interrupts failed and kill enabled") {
- val f = taskManager.add { var x = 0; while (true) { x += 1 } }
taskManager.start()
+ val f = taskManager.add { var x = 0; while (true) { x += 1 } }
// Wait for the task to start
while (!taskManager.isExecutingTask) Thread.sleep(1)
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala b/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
index 6271459..f3db572 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/KernelBootstrap.scala
@@ -73,7 +73,7 @@
// Initialize components needed elsewhere
val (commStorage, commRegistrar, commManager, interpreter,
- kernelInterpreter, kernel, sparkContext, dependencyDownloader,
+ kernel, sparkContext, dependencyDownloader,
magicLoader, responseMap) =
initializeComponents(
config = config,
@@ -81,7 +81,7 @@
actorLoader = actorLoader
)
this.sparkContext = sparkContext
- this.interpreters ++= Seq(interpreter, kernelInterpreter)
+ this.interpreters ++= Seq(interpreter)
// Initialize our handlers that take care of processing messages
initializeHandlers(
@@ -96,8 +96,7 @@
// Initialize our hooks that handle various JVM events
initializeHooks(
- interpreter = interpreter,
- kernelInterpreter = kernelInterpreter
+ interpreter = interpreter
)
logger.debug("Initializing security manager")
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala b/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
index f1fbf71..9011193 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/layer/ComponentInitialization.scala
@@ -16,7 +16,6 @@
package com.ibm.spark.boot.layer
-import java.io.File
import java.util.concurrent.ConcurrentHashMap
import akka.actor.ActorRef
@@ -33,13 +32,10 @@
import com.ibm.spark.magic.dependencies.DependencyMap
import com.ibm.spark.utils.{TaskManager, KeyValuePairUtils, LogLike}
import com.typesafe.config.Config
-import joptsimple.util.KeyValuePair
import org.apache.spark.{SparkContext, SparkConf}
import scala.collection.JavaConverters._
-import scala.tools.nsc.Settings
-import scala.tools.nsc.interpreter.JPrintWriter
import scala.util.Try
/**
@@ -56,7 +52,7 @@
*/
def initializeComponents(
config: Config, appName: String, actorLoader: ActorLoader
- ): (CommStorage, CommRegistrar, CommManager, Interpreter, Interpreter,
+ ): (CommStorage, CommRegistrar, CommManager, Interpreter,
Kernel, SparkContext, DependencyDownloader, MagicLoader,
collection.mutable.Map[String, ActorRef])
}
@@ -80,17 +76,16 @@
val (commStorage, commRegistrar, commManager) =
initializeCommObjects(actorLoader)
val interpreter = initializeInterpreter(config)
- val kernelInterpreter = initializeKernelInterpreter(config, interpreter)
val sparkContext = initializeSparkContext(
config, appName, actorLoader, interpreter)
val dependencyDownloader = initializeDependencyDownloader(config)
val magicLoader = initializeMagicLoader(
- config, interpreter, kernelInterpreter, sparkContext, dependencyDownloader)
+ config, interpreter, sparkContext, dependencyDownloader)
val kernel = initializeKernel(
- actorLoader, interpreter, kernelInterpreter, commManager, magicLoader)
+ actorLoader, interpreter, commManager, magicLoader)
val responseMap = initializeResponseMap()
- (commStorage, commRegistrar, commManager, interpreter, kernelInterpreter,
- kernel, sparkContext, dependencyDownloader, magicLoader, responseMap)
+ (commStorage, commRegistrar, commManager, interpreter, kernel,
+ sparkContext, dependencyDownloader, magicLoader, responseMap)
}
private def initializeCommObjects(actorLoader: ActorLoader) = {
@@ -127,7 +122,7 @@
with TaskManagerProducerLike
with StandardSettingsProducer {
override def newTaskManager(): TaskManager =
- new TaskManager(maxWorkers = maxInterpreterThreads)
+ new TaskManager(maximumWorkers = maxInterpreterThreads)
}
logger.debug("Starting interpreter")
@@ -136,35 +131,6 @@
interpreter
}
- private def initializeKernelInterpreter(
- config: Config, interpreter: Interpreter
- ) = {
- val interpreterArgs =
- config.getStringList("interpreter_args").asScala.toList
- val maxInterpreterThreads = config.getInt("max_interpreter_threads")
-
- // TODO: Refactor this construct to a cleaner implementation (for future
- // multi-interpreter design)
- logger.info(
- s"Constructing interpreter with $maxInterpreterThreads threads and " +
- "with arguments: " + interpreterArgs.mkString(" "))
- val kernelInterpreter = new ScalaInterpreter(interpreterArgs, Console.out)
- with TaskManagerProducerLike
- with StandardSettingsProducer
- with SparkIMainProducerLike {
- override def newTaskManager(): TaskManager =
- new TaskManager(maxWorkers = maxInterpreterThreads)
-
- override def newSparkIMain(settings: Settings, out: JPrintWriter) = {
- interpreter.asInstanceOf[ScalaInterpreter].sparkIMain
- }
- }
- logger.debug("Starting interpreter")
- kernelInterpreter.start()
-
- kernelInterpreter
- }
-
// TODO: Think of a better way to test without exposing this
protected[layer] def initializeSparkContext(
config: Config, appName: String, actorLoader: ActorLoader,
@@ -293,27 +259,24 @@
private def initializeKernel(
actorLoader: ActorLoader,
- interpreterToDoBinding: Interpreter,
- interpreterToBind: Interpreter,
+ interpreter: Interpreter,
commManager: CommManager,
magicLoader: MagicLoader
) = {
- //interpreter.doQuietly {
- val kernel = new Kernel(
- actorLoader, interpreterToBind, commManager, magicLoader
- )
- interpreterToDoBinding.bind(
- "kernel", "com.ibm.spark.kernel.api.Kernel",
- kernel, List( """@transient implicit""")
- )
- //}
+ val kernel = new Kernel(actorLoader, interpreter, commManager, magicLoader)
+ interpreter.doQuietly {
+ interpreter.bind(
+ "kernel", "com.ibm.spark.kernel.api.Kernel",
+ kernel, List( """@transient implicit""")
+ )
+ }
magicLoader.dependencyMap.setKernel(kernel)
kernel
}
private def initializeMagicLoader(
- config: Config, interpreter: Interpreter, kernelInterpreter: Interpreter, sparkContext: SparkContext,
+ config: Config, interpreter: Interpreter, sparkContext: SparkContext,
dependencyDownloader: DependencyDownloader
) = {
logger.debug("Constructing magic loader")
@@ -321,8 +284,8 @@
logger.debug("Building dependency map")
val dependencyMap = new DependencyMap()
.setInterpreter(interpreter)
+ .setKernelInterpreter(interpreter) // This is deprecated
.setSparkContext(sparkContext)
- .setKernelInterpreter(kernelInterpreter)
.setDependencyDownloader(dependencyDownloader)
logger.debug("Creating BuiltinLoader")
diff --git a/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala b/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
index 7e83188..2f23599 100644
--- a/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
+++ b/kernel/src/main/scala/com/ibm/spark/boot/layer/HookInitialization.scala
@@ -29,11 +29,8 @@
* Initializes and registers all hooks.
*
* @param interpreter The main interpreter
- * @param kernelInterpreter The interpreter bound to the kernel instance
*/
- def initializeHooks(
- interpreter: Interpreter, kernelInterpreter: Interpreter
- ): Unit
+ def initializeHooks(interpreter: Interpreter): Unit
}
/**
@@ -46,18 +43,13 @@
* Initializes and registers all hooks.
*
* @param interpreter The main interpreter
- * @param kernelInterpreter The interpreter bound to the kernel instance
*/
- def initializeHooks(
- interpreter: Interpreter, kernelInterpreter: Interpreter
- ): Unit = {
- registerInterruptHook(interpreter, kernelInterpreter)
+ def initializeHooks(interpreter: Interpreter): Unit = {
+ registerInterruptHook(interpreter)
registerShutdownHook()
}
- private def registerInterruptHook(
- interpreter: Interpreter, kernelInterpreter: Interpreter
- ): Unit = {
+ private def registerInterruptHook(interpreter: Interpreter): Unit = {
val self = this
import sun.misc.{Signal, SignalHandler}
@@ -73,7 +65,6 @@
if (currentTime - lastSignalReceived > MaxSignalTime) {
logger.info("Resetting code execution!")
interpreter.interrupt()
- kernelInterpreter.interrupt()
// TODO: Cancel group representing current code execution
//sparkContext.cancelJobGroup()