| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License |
| */ |
| |
| package org.apache.toree.kernel.interpreter.scala |
| |
| import java.io.{BufferedReader, ByteArrayOutputStream, InputStreamReader, PrintStream} |
| import java.net.{URL, URLClassLoader} |
| import java.nio.charset.Charset |
| import java.util.concurrent.ExecutionException |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.repl.{SparkCommandLine, SparkIMain, SparkJLineCompletion} |
| import org.apache.spark.sql.SQLContext |
| import org.apache.toree.global.StreamState |
| import org.apache.toree.interpreter._ |
| import org.apache.toree.interpreter.imports.printers.{WrapperConsole, WrapperSystem} |
| import org.apache.toree.kernel.api.{KernelLike, KernelOptions} |
| import org.apache.toree.utils.{MultiOutputStream, TaskManager} |
| import org.slf4j.LoggerFactory |
| |
| import scala.annotation.tailrec |
| import scala.concurrent.{Await, Future} |
| import scala.language.reflectiveCalls |
| import scala.tools.nsc.backend.JavaPlatform |
| import scala.tools.nsc.interpreter.{IR, InputStream, JPrintWriter, OutputStream} |
| import scala.tools.nsc.io.AbstractFile |
| import scala.tools.nsc.util.{ClassPath, MergedClassPath} |
| import scala.tools.nsc.{Global, Settings, io} |
| import scala.util.{Try => UtilTry} |
| |
| /** |
| * Provides Scala version-specific features needed for the interpreter. |
| */ |
| trait ScalaInterpreterSpecific { this: ScalaInterpreter => |
| private val ExecutionExceptionName = "lastException" |
| |
| var sparkIMain: SparkIMain = _ |
| protected var jLineCompleter: SparkJLineCompletion = _ |
| |
| val _runtimeClassloader = |
| new URLClassLoader(Array(), _thisClassloader) { |
| def addJar(url: URL) = this.addURL(url) |
| } |
| |
| protected def newSparkIMain( |
| settings: Settings, out: JPrintWriter |
| ): SparkIMain = { |
| val s = new SparkIMain(settings, out) |
| s.initializeSynchronous() |
| s |
| } |
| |
| /** |
| * Adds jars to the runtime and compile time classpaths. Does not work with |
| * directories or expanding star in a path. |
| * @param jars The list of jar locations |
| */ |
| override def addJars(jars: URL*): Unit = { |
| // Enable Scala class support |
| reinitializeSymbols() |
| |
| jars.foreach(_runtimeClassloader.addJar) |
| updateCompilerClassPath(jars : _*) |
| |
| // Refresh all of our variables |
| refreshDefinitions() |
| } |
| |
| // TODO: Need to figure out a better way to compare the representation of |
| // an annotation (contained in AnnotationInfo) with various annotations |
| // like scala.transient |
| protected def convertAnnotationsToModifiers( |
| annotationInfos: List[Global#AnnotationInfo] |
| ) = annotationInfos map { |
| case a if a.toString == "transient" => "@transient" |
| case a => |
| logger.debug(s"Ignoring unknown annotation: $a") |
| "" |
| } filterNot { |
| _.isEmpty |
| } |
| |
| protected def convertScopeToModifiers(scopeSymbol: Global#Symbol) = { |
| (if (scopeSymbol.isImplicit) "implicit" else "") :: |
| Nil |
| } |
| |
| protected def buildModifierList(termNameString: String) = { |
| import scala.language.existentials |
| val termSymbol = sparkIMain.symbolOfTerm(termNameString) |
| |
| |
| convertAnnotationsToModifiers( |
| if (termSymbol.hasAccessorFlag) termSymbol.accessed.annotations |
| else termSymbol.annotations |
| ) ++ convertScopeToModifiers(termSymbol) |
| } |
| |
| protected def refreshDefinitions(): Unit = { |
| sparkIMain.definedTerms.foreach(termName => { |
| val termNameString = termName.toString |
| val termTypeString = sparkIMain.typeOfTerm(termNameString).toLongString |
| sparkIMain.valueOfTerm(termNameString) match { |
| case Some(termValue) => |
| val modifiers = buildModifierList(termNameString) |
| logger.debug(s"Rebinding of $termNameString as " + |
| s"${modifiers.mkString(" ")} $termTypeString") |
| UtilTry(sparkIMain.beSilentDuring { |
| sparkIMain.bind( |
| termNameString, termTypeString, termValue, modifiers |
| ) |
| }) |
| case None => |
| logger.debug(s"Ignoring rebinding of $termNameString") |
| } |
| }) |
| } |
| |
| protected def reinitializeSymbols(): Unit = { |
| val global = sparkIMain.global |
| import global._ |
| new Run // Initializes something needed for Scala classes |
| } |
| |
| protected def updateCompilerClassPath( jars: URL*): Unit = { |
| require(!sparkIMain.global.forMSIL) // Only support JavaPlatform |
| |
| val platform = sparkIMain.global.platform.asInstanceOf[JavaPlatform] |
| |
| val newClassPath = mergeJarsIntoClassPath(platform, jars:_*) |
| logger.debug(s"newClassPath: ${newClassPath}") |
| |
| // TODO: Investigate better way to set this... one thought is to provide |
| // a classpath in the currentClassPath (which is merged) that can be |
| // replaced using updateClasspath, but would that work more than once? |
| val fieldSetter = platform.getClass.getMethods |
| .find(_.getName.endsWith("currentClassPath_$eq")).get |
| fieldSetter.invoke(platform, Some(newClassPath)) |
| |
| // Reload all jars specified into our compiler |
| sparkIMain.global.invalidateClassPathEntries(jars.map(_.getPath): _*) |
| } |
| |
| protected def mergeJarsIntoClassPath(platform: JavaPlatform, jars: URL*): MergedClassPath[AbstractFile] = { |
| // Collect our new jars and add them to the existing set of classpaths |
| val allClassPaths = ( |
| platform.classPath |
| .asInstanceOf[MergedClassPath[AbstractFile]].entries |
| ++ |
| jars.map(url => |
| platform.classPath.context.newClassPath( |
| io.AbstractFile.getFile(url.getPath)) |
| ) |
| ).distinct |
| |
| // Combine all of our classpaths (old and new) into one merged classpath |
| new MergedClassPath( |
| allClassPaths, |
| platform.classPath.context |
| ) |
| } |
| |
| /** |
| * Binds a variable in the interpreter to a value. |
| * @param variableName The name to expose the value in the interpreter |
| * @param typeName The type of the variable, must be the fully qualified class name |
| * @param value The value of the variable binding |
| * @param modifiers Any annotation, scoping modifiers, etc on the variable |
| */ |
| override def bind( |
| variableName: String, |
| typeName: String, |
| value: Any, |
| modifiers: List[String] |
| ): Unit = { |
| require(sparkIMain != null) |
| sparkIMain.bind(variableName, typeName, value, modifiers) |
| } |
| |
| /** |
| * Executes body and will not print anything to the console during the execution |
| * @param body The function to execute |
| * @tparam T The return type of body |
| * @return The return value of body |
| */ |
| override def doQuietly[T](body: => T): T = { |
| require(sparkIMain != null) |
| sparkIMain.beQuietDuring[T](body) |
| } |
| |
| /** |
| * Stops the interpreter, removing any previous internal state. |
| * @return A reference to the interpreter |
| */ |
| override def stop(): Interpreter = { |
| logger.info("Shutting down interpreter") |
| |
| // Shut down the task manager (kills current execution |
| if (taskManager != null) taskManager.stop() |
| taskManager = null |
| |
| // Erase our completer |
| jLineCompleter = null |
| |
| // Close the entire interpreter (loses all state) |
| if (sparkIMain != null) sparkIMain.close() |
| sparkIMain = null |
| |
| this |
| } |
| |
| /** |
| * Returns the name of the variable created from the last execution. |
| * @return Some String name if a variable was created, otherwise None |
| */ |
| override def lastExecutionVariableName: Option[String] = { |
| require(sparkIMain != null) |
| |
| // TODO: Get this API method changed back to public in Apache Spark |
| val lastRequestMethod = classOf[SparkIMain].getDeclaredMethod("lastRequest") |
| lastRequestMethod.setAccessible(true) |
| |
| val request = |
| lastRequestMethod.invoke(sparkIMain).asInstanceOf[SparkIMain#Request] |
| |
| val mostRecentVariableName = sparkIMain.mostRecentVar |
| |
| request.definedNames.map(_.toString).find(_ == mostRecentVariableName) |
| } |
| |
| /** |
| * Mask the Console and System objects with our wrapper implementations |
| * and dump the Console methods into the public namespace (similar to |
| * the Predef approach). |
| * @param in The new input stream |
| * @param out The new output stream |
| * @param err The new error stream |
| */ |
| override def updatePrintStreams( |
| in: InputStream, |
| out: OutputStream, |
| err: OutputStream |
| ): Unit = { |
| val inReader = new BufferedReader(new InputStreamReader(in)) |
| val outPrinter = new PrintStream(out) |
| val errPrinter = new PrintStream(err) |
| |
| sparkIMain.beQuietDuring { |
| sparkIMain.bind( |
| "Console", classOf[WrapperConsole].getName, |
| new WrapperConsole(inReader, outPrinter, errPrinter), |
| List("""@transient""") |
| ) |
| sparkIMain.bind( |
| "System", classOf[WrapperSystem].getName, |
| new WrapperSystem(in, out, err), |
| List("""@transient""") |
| ) |
| sparkIMain.addImports("Console._") |
| } |
| } |
| |
| /** |
| * Retrieves the contents of the variable with the provided name from the |
| * interpreter. |
| * @param variableName The name of the variable whose contents to read |
| * @return An option containing the variable contents or None if the |
| * variable does not exist |
| */ |
| override def read(variableName: String): Option[AnyRef] = { |
| require(sparkIMain != null) |
| val variable = sparkIMain.valueOfTerm(variableName) |
| if (variable == null || variable.isEmpty) None |
| else variable |
| } |
| |
| /** |
| * Starts the interpreter, initializing any internal state. |
| * You must call init before running this function. |
| * |
| * @return A reference to the interpreter |
| */ |
| override def start(): Interpreter = { |
| require(sparkIMain == null && taskManager == null) |
| |
| taskManager = newTaskManager() |
| |
| logger.debug("Initializing task manager") |
| taskManager.start() |
| |
| sparkIMain = |
| newSparkIMain(settings, new JPrintWriter(lastResultOut, true)) |
| |
| |
| //logger.debug("Initializing interpreter") |
| //sparkIMain.initializeSynchronous() |
| |
| logger.debug("Initializing completer") |
| jLineCompleter = new SparkJLineCompletion(sparkIMain) |
| |
| sparkIMain.beQuietDuring { |
| //logger.info("Rerouting Console and System related input and output") |
| //updatePrintStreams(System.in, multiOutputStream, multiOutputStream) |
| |
| // ADD IMPORTS generates too many classes, client is responsible for adding import |
| logger.debug("Adding org.apache.spark.SparkContext._ to imports") |
| sparkIMain.addImports("org.apache.spark.SparkContext._") |
| } |
| |
| this |
| } |
| |
| /** |
| * Attempts to perform code completion via the <TAB> command. |
| * @param code The current cell to complete |
| * @param pos The cursor position |
| * @return The cursor position and list of possible completions |
| */ |
| override def completion(code: String, pos: Int): (Int, List[String]) = { |
| require(jLineCompleter != null) |
| |
| logger.debug(s"Attempting code completion for ${code}") |
| val regex = """[0-9a-zA-Z._]+$""".r |
| val parsedCode = (regex findAllIn code).mkString("") |
| |
| logger.debug(s"Attempting code completion for ${parsedCode}") |
| val result = jLineCompleter.completer().complete(parsedCode, pos) |
| |
| (result.cursor, result.candidates) |
| } |
| |
| protected def newSettings(args: List[String]): Settings = |
| new SparkCommandLine(args).settings |
| |
| protected def interpretAddTask(code: String, silent: Boolean): Future[IR.Result] = { |
| if (sparkIMain == null) throw new IllegalArgumentException("Cannot interpret on a stopped interpreter") |
| |
| taskManager.add { |
| // Add a task using the given state of our streams |
| StreamState.withStreams { |
| if (silent) { |
| sparkIMain.beSilentDuring { |
| sparkIMain.interpret(code) |
| } |
| } else { |
| sparkIMain.interpret(code) |
| } |
| } |
| } |
| } |
| |
| protected def interpretMapToResultAndExecuteInfo( |
| future: Future[(Results.Result, String)] |
| ): Future[(Results.Result, Either[ExecuteOutput, ExecuteFailure])] = { |
| import scala.concurrent.ExecutionContext.Implicits.global |
| future map { |
| case (Results.Success, output) => (Results.Success, Left(output)) |
| case (Results.Incomplete, output) => (Results.Incomplete, Left(output)) |
| case (Results.Aborted, output) => (Results.Aborted, Right(null)) |
| case (Results.Error, output) => |
| val x = sparkIMain.valueOfTerm(ExecutionExceptionName) |
| ( |
| Results.Error, |
| Right( |
| interpretConstructExecuteError( |
| sparkIMain.valueOfTerm(ExecutionExceptionName), |
| output |
| ) |
| ) |
| ) |
| } |
| } |
| |
| protected def interpretConstructExecuteError( |
| value: Option[AnyRef], |
| output: String |
| ) = value match { |
| // Runtime error |
| case Some(e) if e != null => |
| val ex = e.asInstanceOf[Throwable] |
| // Clear runtime error message |
| sparkIMain.directBind( |
| ExecutionExceptionName, |
| classOf[Throwable].getName, |
| null |
| ) |
| ExecuteError( |
| ex.getClass.getName, |
| ex.getLocalizedMessage, |
| ex.getStackTrace.map(_.toString).toList |
| ) |
| // Compile time error, need to check internal reporter |
| case _ => |
| if (sparkIMain.isReportingErrors) |
| // TODO: This wrapper is not needed when just getting compile |
| // error that we are not parsing... maybe have it be purely |
| // output and have the error check this? |
| ExecuteError( |
| "Compile Error", output, List() |
| ) |
| else |
| ExecuteError("Unknown", "Unable to retrieve error!", List()) |
| } |
| } |