blob: fab8d95cf8cad0eadbe1d76441c0d8b10e99af61 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.repl
import java.io.ByteArrayOutputStream
import scala.tools.nsc.interpreter.Results
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.json4s.DefaultFormats
import org.json4s.Extraction
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.apache.livy.Logging
import org.apache.livy.rsc.driver.SparkEntries
object AbstractSparkInterpreter {
private[repl] val KEEP_NEWLINE_REGEX = """(?<=\n)""".r
private val MAGIC_REGEX = "^%(\\w+)\\W*(.*)".r
}
abstract class AbstractSparkInterpreter extends Interpreter with Logging {
import AbstractSparkInterpreter._
private implicit def formats = DefaultFormats
protected val outputStream = new ByteArrayOutputStream()
protected var entries: SparkEntries = _
def sparkEntries(): SparkEntries = entries
final def kind: String = "spark"
protected def isStarted(): Boolean
protected def interpret(code: String): Results.Result
protected def completeCandidates(code: String, cursor: Int) : Array[String] = Array()
protected def valueOfTerm(name: String): Option[Any]
protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit
protected def conf: SparkConf
protected def postStart(): Unit = {
entries = new SparkEntries(conf)
if (isSparkSessionPresent()) {
bind("spark",
sparkEntries.sparkSession().getClass.getCanonicalName,
sparkEntries.sparkSession(),
List("""@transient"""))
bind("sc", "org.apache.spark.SparkContext", sparkEntries.sc().sc, List("""@transient"""))
execute("import org.apache.spark.SparkContext._")
execute("import spark.implicits._")
execute("import spark.sql")
execute("import org.apache.spark.sql.functions._")
} else {
bind("sc", "org.apache.spark.SparkContext", sparkEntries.sc().sc, List("""@transient"""))
val sqlContext = Option(sparkEntries.hivectx()).getOrElse(sparkEntries.sqlctx())
bind("sqlContext", sqlContext.getClass.getCanonicalName, sqlContext, List("""@transient"""))
execute("import org.apache.spark.SparkContext._")
execute("import sqlContext.implicits._")
execute("import sqlContext.sql")
execute("import org.apache.spark.sql.functions._")
}
}
override def close(): Unit = {
if (entries != null) {
entries.stop()
entries = null
}
}
private def isSparkSessionPresent(): Boolean = {
try {
Class.forName("org.apache.spark.sql.SparkSession")
true
} catch {
case _: ClassNotFoundException | _: NoClassDefFoundError => false
}
}
override protected[repl] def execute(code: String): Interpreter.ExecuteResponse =
restoreContextClassLoader {
require(isStarted())
executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject(
(TEXT_PLAIN, JString(""))
)))
}
override protected[repl] def complete(code: String, cursor: Int): Array[String] = {
completeCandidates(code, cursor)
}
private def executeMagic(magic: String, rest: String): Interpreter.ExecuteResponse = {
magic match {
case "json" => executeJsonMagic(rest)
case "table" => executeTableMagic(rest)
case _ =>
Interpreter.ExecuteError("UnknownMagic", f"Unknown magic command $magic")
}
}
private def executeJsonMagic(name: String): Interpreter.ExecuteResponse = {
try {
val value = valueOfTerm(name) match {
case Some(obj: RDD[_]) => obj.asInstanceOf[RDD[_]].take(10)
case Some(obj) => obj
case None => return Interpreter.ExecuteError("NameError", f"Value $name does not exist")
}
Interpreter.ExecuteSuccess(JObject(
(APPLICATION_JSON, Extraction.decompose(value))
))
} catch {
case _: Throwable =>
Interpreter.ExecuteError("ValueError", "Failed to convert value into a JSON value")
}
}
private class TypesDoNotMatch extends Exception
private def convertTableType(value: JValue): String = {
value match {
case (JNothing | JNull) => "NULL_TYPE"
case JBool(_) => "BOOLEAN_TYPE"
case JString(_) => "STRING_TYPE"
case JInt(_) => "BIGINT_TYPE"
case JDouble(_) => "DOUBLE_TYPE"
case JDecimal(_) => "DECIMAL_TYPE"
case JArray(arr) =>
if (allSameType(arr.iterator)) {
"ARRAY_TYPE"
} else {
throw new TypesDoNotMatch
}
case JObject(obj) =>
if (allSameType(obj.iterator.map(_._2))) {
"MAP_TYPE"
} else {
throw new TypesDoNotMatch
}
}
}
private def allSameType(values: Iterator[JValue]): Boolean = {
if (values.hasNext) {
val type_name = convertTableType(values.next())
values.forall { case value => type_name.equals(convertTableType(value)) }
} else {
true
}
}
private def executeTableMagic(name: String): Interpreter.ExecuteResponse = {
val value = valueOfTerm(name) match {
case Some(obj: RDD[_]) => obj.asInstanceOf[RDD[_]].take(10)
case Some(obj) => obj
case None => return Interpreter.ExecuteError("NameError", f"Value $name does not exist")
}
extractTableFromJValue(Extraction.decompose(value))
}
private def extractTableFromJValue(value: JValue): Interpreter.ExecuteResponse = {
// Convert the value into JSON and map it to a table.
val rows: List[JValue] = value match {
case JArray(arr) => arr
case _ => List(value)
}
try {
val headers = scala.collection.mutable.Map[String, Map[String, String]]()
val data = rows.map { case row =>
val cols: List[JField] = row match {
case JArray(arr: List[JValue]) =>
arr.zipWithIndex.map { case (v, index) => JField(index.toString, v) }
case JObject(obj) => obj.sortBy(_._1)
case value: JValue => List(JField("0", value))
}
cols.map { case (k, v) =>
val typeName = convertTableType(v)
headers.get(k) match {
case Some(header) =>
if (header.get("type").get != typeName) {
throw new TypesDoNotMatch
}
case None =>
headers.put(k, Map(
"type" -> typeName,
"name" -> k
))
}
v
}
}
Interpreter.ExecuteSuccess(
APPLICATION_LIVY_TABLE_JSON -> (
("headers" -> headers.toSeq.sortBy(_._1).map(_._2)) ~ ("data" -> data)
))
} catch {
case _: TypesDoNotMatch =>
Interpreter.ExecuteError("TypeError", "table rows have different types")
}
}
private def executeLines(
lines: List[String],
resultFromLastLine: Interpreter.ExecuteResponse): Interpreter.ExecuteResponse = {
lines match {
case Nil => resultFromLastLine
case head :: tail =>
val result = executeLine(head)
result match {
case Interpreter.ExecuteIncomplete() =>
tail match {
case Nil =>
// ExecuteIncomplete could be caused by an actual incomplete statements (e.g. "sc.")
// or statements with just comments.
// To distinguish them, reissue the same statement wrapped in { }.
// If it is an actual incomplete statement, the interpreter will return an error.
// If it is some comment, the interpreter will return success.
executeLine(s"{\n$head\n}") match {
case Interpreter.ExecuteIncomplete() | Interpreter.ExecuteError(_, _, _) =>
// Return the original error so users won't get confusing error message.
result
case _ => resultFromLastLine
}
case next :: nextTail =>
executeLines(head + "\n" + next :: nextTail, resultFromLastLine)
}
case Interpreter.ExecuteError(_, _, _) =>
result
case _ =>
executeLines(tail, result)
}
}
}
private def executeLine(code: String): Interpreter.ExecuteResponse = {
code match {
case MAGIC_REGEX(magic, rest) =>
executeMagic(magic, rest)
case _ =>
scala.Console.withOut(outputStream) {
interpret(code) match {
case Results.Success =>
Interpreter.ExecuteSuccess(
TEXT_PLAIN -> readStdout()
)
case Results.Incomplete => Interpreter.ExecuteIncomplete()
case Results.Error =>
val (ename, traceback) = parseError(readStdout())
Interpreter.ExecuteError("Error", ename, traceback)
}
}
}
}
protected[repl] def parseError(stdout: String): (String, Seq[String]) = {
// An example of Scala compile error message:
// <console>:27: error: type mismatch;
// found : Int
// required: Boolean
// An example of Scala runtime exception error message:
// java.lang.RuntimeException: message
// at .error(<console>:11)
// ... 32 elided
// Return the first line as ename. Lines following as traceback.
val lines = KEEP_NEWLINE_REGEX.split(stdout)
val ename = lines.headOption.map(_.trim).getOrElse("unknown error")
val traceback = lines.tail
(ename, traceback)
}
protected def restoreContextClassLoader[T](fn: => T): T = {
val currentClassLoader = Thread.currentThread().getContextClassLoader()
try {
fn
} finally {
Thread.currentThread().setContextClassLoader(currentClassLoader)
}
}
private def readStdout() = {
val output = outputStream.toString("UTF-8").trim
outputStream.reset()
output
}
}