blob: 616c68fb0ee48e90395aeb5136e591f7fd7cd487 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
package org.apache.toree.kernel.interpreter.scala
import java.io.{File, InputStream, OutputStream}
import java.net.{URLClassLoader, URL}
import org.apache.toree.interpreter.Results.Result
import org.apache.toree.interpreter._
import org.apache.toree.utils.TaskManager
import org.apache.spark.SparkConf
import org.apache.toree.kernel.interpreter.scala._
import org.apache.spark.repl.SparkIMain
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}
import scala.concurrent.Future
import scala.tools.nsc.Settings
import scala.tools.nsc.interpreter.{JPrintWriter, IR}
import scala.tools.nsc.util.ClassPath
class ScalaInterpreterSpec extends FunSpec
with Matchers with MockitoSugar with BeforeAndAfter
{
private var interpreter: ScalaInterpreter = _
private var interpreterNoPrintStreams: ScalaInterpreter = _
private var mockSparkIMain: SparkIMain = _
private var mockTaskManager: TaskManager = _
private var mockSettings: Settings = _
trait StubbedUpdatePrintStreams extends Interpreter {
override def updatePrintStreams(
in: InputStream,
out: OutputStream,
err: OutputStream
): Unit = {}
}
trait SingleLineInterpretLineRec extends StubbedStartInterpreter {
override protected def interpretRec(lines: List[String], silent: Boolean, results: (Result, Either[ExecuteOutput, ExecuteFailure])): (Result, Either[ExecuteOutput, ExecuteFailure]) =
interpretLine(lines.mkString("\n"))
}
trait StubbedInterpretAddTask extends StubbedStartInterpreter {
override protected def interpretAddTask(code: String, silent: Boolean) =
mock[Future[IR.Result]]
}
trait StubbedInterpretMapToCustomResult extends StubbedStartInterpreter {
override protected def interpretMapToCustomResult(future: Future[IR.Result]) =
mock[Future[Results.Result with Product with Serializable]]
}
trait StubbedInterpretMapToResultAndOutput extends StubbedStartInterpreter {
override protected def interpretMapToResultAndOutput(future: Future[Results.Result]) =
mock[Future[(Results.Result, String)]]
}
trait StubbedInterpretMapToResultAndExecuteInfo extends StubbedStartInterpreter {
override protected def interpretMapToResultAndExecuteInfo(future: Future[(Results.Result, String)]) =
mock[Future[(
Results.Result with Product with Serializable,
Either[ExecuteOutput, ExecuteFailure] with Product with Serializable
)]]
}
trait StubbedInterpretConstructExecuteError extends StubbedStartInterpreter {
override protected def interpretConstructExecuteError(value: Option[AnyRef], output: String) =
mock[ExecuteError]
}
class StubbedStartInterpreter
extends ScalaInterpreter
{
override def newSparkIMain(settings: Settings, out: JPrintWriter): SparkIMain = mockSparkIMain
override def newTaskManager(): TaskManager = mockTaskManager
override def newSettings(args: List[String]): Settings = mockSettings
// Stubbed out (not testing this)
override protected def updateCompilerClassPath(jars: URL*): Unit = {}
override protected def reinitializeSymbols(): Unit = {}
override protected def refreshDefinitions(): Unit = {}
}
before {
mockSparkIMain = mock[SparkIMain]
mockTaskManager = mock[TaskManager]
val mockSettingsClasspath = mock[Settings#PathSetting]
doNothing().when(mockSettingsClasspath).value_=(any[Settings#PathSetting#T])
mockSettings = mock[Settings]
doReturn(mockSettingsClasspath).when(mockSettings).classpath
doNothing().when(mockSettings).embeddedDefaults(any[ClassLoader])
interpreter = new StubbedStartInterpreter
interpreterNoPrintStreams =
new StubbedStartInterpreter with StubbedUpdatePrintStreams
}
after {
mockSparkIMain = null
mockTaskManager = null
mockSettings = null
interpreter = null
}
describe("ScalaInterpreter") {
describe("#addJars") {
it("should add each jar URL to the runtime classloader") {
// Needed to access runtimeClassloader method
import scala.language.reflectiveCalls
// Create a new interpreter exposing the internal runtime classloader
val itInterpreter = new StubbedStartInterpreter {
// Expose the runtime classloader
def runtimeClassloader = _runtimeClassloader
}
val url = new URL("file://expected")
itInterpreter.addJars(url)
itInterpreter.runtimeClassloader.getURLs should contain (url)
}
it("should add each jar URL to the interpreter classpath") {
val url = new URL("file://expected")
interpreter.addJars(url)
}
}
describe("#buildClasspath") {
it("should return classpath based on classloader hierarchy") {
// Needed to access runtimeClassloader method
import scala.language.reflectiveCalls
// Create a new interpreter exposing the internal runtime classloader
val itInterpreter = new StubbedStartInterpreter
val parentUrls = Array(
new URL("file:/some/dir/a.jar"),
new URL("file:/some/dir/b.jar"),
new URL("file:/some/dir/c.jar")
)
val theParentClassloader = new URLClassLoader(parentUrls, null)
val urls = Array(
new URL("file:/some/dir/1.jar"),
new URL("file:/some/dir/2.jar"),
new URL("file:/some/dir/3.jar")
)
val theClassloader = new URLClassLoader(urls, theParentClassloader)
val expected = ClassPath.join((parentUrls ++ urls).map(_.toString) :_*)
itInterpreter.buildClasspath(theClassloader) should be(expected)
}
}
describe("#interrupt") {
it("should fail a require if the interpreter is not started") {
intercept[IllegalArgumentException] {
interpreter.interrupt()
}
}
it("should call restart() on the task manager") {
interpreterNoPrintStreams.start()
interpreterNoPrintStreams.interrupt()
verify(mockTaskManager).restart()
}
}
// TODO: Provide testing for the helper functions that return various
// mapped futures -- this was too difficult for me to figure out
// in a short amount of time
describe("#interpret") {
it("should fail if not started") {
intercept[IllegalArgumentException] {
interpreter.interpret("val x = 3")
}
}
it("should add a new task to the task manager") {
var taskManagerAddCalled = false
val itInterpreter =
new StubbedStartInterpreter
with SingleLineInterpretLineRec
with StubbedUpdatePrintStreams
//with StubbedInterpretAddTask
with StubbedInterpretMapToCustomResult
with StubbedInterpretMapToResultAndOutput
with StubbedInterpretMapToResultAndExecuteInfo
with StubbedInterpretConstructExecuteError
with TaskManagerProducerLike
{
// Must override this way since cannot figure out the signature
// to verify this as a mock
override def newTaskManager(): TaskManager = new TaskManager {
override def add[T](taskFunction: => T): Future[T] = {
taskManagerAddCalled = true
mock[TaskManager].add(taskFunction)
}
}
}
itInterpreter.start()
itInterpreter.interpret("val x = 3")
taskManagerAddCalled should be (true)
}
}
describe("#start") {
it("should initialize the task manager") {
interpreterNoPrintStreams.start()
verify(mockTaskManager).start()
}
// TODO: Figure out how to trigger sparkIMain.beQuietDuring { ... }
/*it("should add an import for SparkContext._") {
interpreterNoPrintStreams.start()
verify(mockSparkIMain).addImports("org.apache.spark.SparkContext._")
}*/
}
describe("#stop") {
describe("when interpreter already started") {
it("should stop the task manager") {
interpreterNoPrintStreams.start()
interpreterNoPrintStreams.stop()
verify(mockTaskManager).stop()
}
it("should stop the SparkIMain") {
interpreterNoPrintStreams.start()
interpreterNoPrintStreams.stop()
verify(mockSparkIMain).close()
}
}
}
describe("#updatePrintStreams") {
// TODO: Figure out how to trigger sparkIMain.beQuietDuring { ... }
}
// describe("#classServerUri") {
// it("should fail a require if the interpreter is not started") {
// intercept[IllegalArgumentException] {
// interpreter.classServerURI
// }
// }
// TODO: Find better way to test this
// it("should invoke the underlying SparkIMain implementation") {
// Using hack to access private class
// val securityManagerClass =
// java.lang.Class.forName("org.apache.spark.SecurityManager")
// val httpServerClass =
// java.lang.Class.forName("org.apache.spark.HttpServer")
// val httpServerConstructor = httpServerClass.getDeclaredConstructor(
// classOf[SparkConf], classOf[File], securityManagerClass, classOf[Int],
// classOf[String])
// val httpServer = httpServerConstructor.newInstance(
// null, null, null, 0: java.lang.Integer, "")
//
// // Return the server instance (cannot mock a private class)
// // NOTE: Can mock the class through reflection, but cannot verify
// // a method was called on it since treated as type Any
// //val mockHttpServer = org.mockito.Mockito.mock(httpServerClass)
// doAnswer(new Answer[String] {
// override def answer(invocation: InvocationOnMock): String = {
// val exceptionClass =
// java.lang.Class.forName("org.apache.spark.ServerStateException")
// val exception = exceptionClass
// .getConstructor(classOf[String])
// .newInstance("")
// .asInstanceOf[Exception]
// throw exception
// }
// }
// ).when(mockSparkIMain)
// interpreterNoPrintStreams.start()
// Not going to dig so deeply that we actually start a web server for
// this to work... just throwing this specific exception proves that
// we have called the uri method of the server
// try {
// interpreterNoPrintStreams.classServerURI
// fail()
// } catch {
// // Have to catch this way because... of course... the exception is
// // also private
// case ex: Throwable =>
// ex.getClass.getName should be ("org.apache.spark.ServerStateException")
// }
// }
// }
describe("#read") {
it("should fail a require if the interpreter is not started") {
intercept[IllegalArgumentException] {
interpreter.read("someVariable")
}
}
it("should execute the underlying valueOfTerm method") {
interpreter.start()
interpreter.read("someVariable")
verify(mockSparkIMain).valueOfTerm(anyString())
}
}
describe("#doQuietly") {
it("should fail a require if the interpreter is not started") {
intercept[IllegalArgumentException] {
interpreter.doQuietly {}
}
}
// TODO: Figure out how to verify sparkIMain.beQuietDuring { ... }
/*it("should invoke the underlying SparkIMain implementation") {
interpreterNoPrintStreams.start()
interpreterNoPrintStreams.doQuietly {}
verify(mockSparkIMain).beQuietDuring(any[IR.Result])
}*/
}
describe("#bind") {
it("should fail a require if the interpreter is not started") {
intercept[IllegalArgumentException] {
interpreter.bind("", "", null, null)
}
}
it("should invoke the underlying SparkIMain implementation") {
interpreterNoPrintStreams.start()
interpreterNoPrintStreams.bind("", "", null, null)
verify(mockSparkIMain).bind(
anyString(), anyString(), any[Any], any[List[String]])
}
}
describe("#truncateResult") {
it("should truncate result of res result") {
// Results that match
interpreter.truncateResult("res7: Int = 38") should be("38")
interpreter.truncateResult("res7: Int = 38",true) should be("Int = 38")
interpreter.truncateResult("res4: String = \nVector(1\n, 2\n)") should be ("Vector(1\n, 2\n)")
interpreter.truncateResult("res4: String = \nVector(1\n, 2\n)",true) should be ("String = Vector(1\n, 2\n)")
interpreter.truncateResult("res123") should be("")
interpreter.truncateResult("res1") should be("")
// Results that don't match
interpreter.truncateResult("resabc: Int = 38") should be("")
}
it("should truncate res results that have tuple values") {
interpreter.truncateResult("res0: (String, Int) = (hello,1)") should
be("(hello,1)")
}
it("should truncate res results that have parameterized types") {
interpreter.truncateResult(
"res0: Class[_ <: (String, Int)] = class scala.Tuple2"
) should be("class scala.Tuple2")
}
}
}
}