[LIVY-7] added autocompletion api and implementation for scala
I started an implementation of the very old feature request (LIVY-7) for code autocompletion.
This implementation works with scala 2.11 and scala 2.10.
I'd be happy if somebody could review and comment it.
As for the API: I chose a synchronous call because resolving the code options shouldn't be a very long process (and if it were it wouldn't make sense anyway).
Author: Pascal Pellmont <github@ppo2.ch>
Closes #51 from pellmont/master.
diff --git a/core/src/main/scala/org/apache/livy/msgs.scala b/core/src/main/scala/org/apache/livy/msgs.scala
index 048aa7f..bb89a29 100644
--- a/core/src/main/scala/org/apache/livy/msgs.scala
+++ b/core/src/main/scala/org/apache/livy/msgs.scala
@@ -60,4 +60,8 @@
case class ExecuteResponse(id: Int, input: Seq[String], output: Seq[String])
+case class CompletionRequest(code: String, kind: String, cursor: Int) extends Content
+
+case class CompletionResponse(candidates: List[String]) extends Content
+
case class ShutdownRequest() extends Content
diff --git a/docs/rest-api.md b/docs/rest-api.md
index 3bd8eab..f76ce09 100644
--- a/docs/rest-api.md
+++ b/docs/rest-api.md
@@ -310,6 +310,42 @@
</tr>
</table>
+### POST /sessions/{sessionId}/completion
+
+Runs a statement in a session.
+
+#### Request Body
+
+<table class="table">
+ <tr><th>Name</th><th>Description</th><th>Type</th></tr>
+ <tr>
+ <td>code</td>
+ <td>The code for which completion proposals are requested</td>
+ <td>string</td>
+ </tr>
+ <tr>
+ <td>kind</td>
+ <td>The kind of code to execute<sup><a href="#footnote2">[2]</a></sup></td>
+ <td><a href="#session-kind">code kind</a></td>
+ </tr>
+ <tr>
+ <td>cursor</td>
+ <td>cursor position to get proposals</td>
+ <td>string</td>
+ </tr>
+</table>
+
+#### Response Body
+
+<table class="table">
+ <tr><th>Name</th><th>Description</th><th>Type</th></tr>
+ <tr>
+ <td>candidates</td>
+ <td>Code completions proposals</td>
+ <td>array[string]</td>
+ </tr>
+</table>
+
### GET /batches
Returns all the active batch sessions.
diff --git a/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala b/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala
index 6d319c7..eaa023a 100644
--- a/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala
+++ b/integration-test/src/main/scala/org/apache/livy/test/framework/LivyRestClient.scala
@@ -46,6 +46,8 @@
@JsonIgnoreProperties(ignoreUnknown = true)
private case class StatementResult(id: Int, state: String, output: Map[String, Any])
+ private case class CompletionResult(candidates: Seq[String])
+
@JsonIgnoreProperties(ignoreUnknown = true)
case class StatementError(ename: String, evalue: String, stackTrace: Seq[String])
@@ -188,8 +190,36 @@
}
}
+ class Completion(code: String, kind: String, cursor: Int) {
+ val completions = {
+ val requestBody = Map("code" -> code, "cursor" -> cursor, "kind" -> kind)
+ val r = httpClient.preparePost(s"$url/completion")
+ .setBody(mapper.writeValueAsString(requestBody))
+ .execute()
+ .get()
+ assertStatusCode(r, HttpServletResponse.SC_OK)
+
+ val res = mapper.readValue(r.getResponseBodyAsStream, classOf[CompletionResult])
+ res.candidates
+ }
+
+ final def result(): Seq[String] = completions
+
+ def verifyContaining(expected: List[String]): Unit = {
+ assert(result().toSet.forall(x => expected.contains(x)))
+ }
+
+ def verifyNone(): Unit = {
+ assert(result() == List(), s"Expected no completion proposals but found $completions")
+ }
+ }
+
def run(code: String): Statement = { new Statement(code) }
+ def complete(code: String, kind: String, cursor: Int): Completion = {
+ new Completion(code, kind, cursor)
+ }
+
def runFatalStatement(code: String): Unit = {
val requestBody = Map("code" -> code)
val r = httpClient.preparePost(s"$url/statements")
diff --git a/integration-test/src/test/scala/org/apache/livy/test/InteractiveIT.scala b/integration-test/src/test/scala/org/apache/livy/test/InteractiveIT.scala
index 728fd1c..0b341e2 100644
--- a/integration-test/src/test/scala/org/apache/livy/test/InteractiveIT.scala
+++ b/integration-test/src/test/scala/org/apache/livy/test/InteractiveIT.scala
@@ -48,6 +48,9 @@
s.run("""sc.getConf.getAll.exists(_._1.startsWith("spark.__livy__."))""")
.verifyResult(".*false")
s.run("""sys.props.exists(_._1.startsWith("spark.__livy__."))""").verifyResult(".*false")
+ s.run("""val str = "str"""")
+ s.complete("str.", "scala", 4).verifyContaining(List("compare", "contains"))
+ s.complete("str2.", "scala", 5).verifyNone()
// Make sure appInfo is reported correctly.
val state = s.snapshot()
diff --git a/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala b/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala
index 39009b2..e86d47d 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala
@@ -28,6 +28,7 @@
import org.apache.spark.SparkConf
import org.apache.spark.repl.SparkIMain
+import org.apache.spark.repl.SparkJLineCompletion
import org.apache.livy.rsc.driver.SparkEntries
@@ -133,6 +134,11 @@
sparkIMain.interpret(code)
}
+ override protected def completeCandidates(code: String, cursor: Int) : Array[String] = {
+ val completer = new SparkJLineCompletion(sparkIMain)
+ completer.completer().complete(code, cursor).candidates.toArray
+ }
+
override protected[repl] def parseError(stdout: String): (String, Seq[String]) = {
// An example of Scala 2.10 runtime exception error message:
// java.lang.Exception: message
diff --git a/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala
index 9d19ef3..884df5c 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala
@@ -22,6 +22,9 @@
import java.nio.file.{Files, Paths}
import scala.tools.nsc.Settings
+import scala.tools.nsc.interpreter.Completion.ScalaCompleter
+import scala.tools.nsc.interpreter.IMain
+import scala.tools.nsc.interpreter.JLineCompletion
import scala.tools.nsc.interpreter.JPrintWriter
import scala.tools.nsc.interpreter.Results.Result
import scala.util.control.NonFatal
@@ -117,6 +120,19 @@
sparkILoop.interpret(code)
}
+ override protected def completeCandidates(code: String, cursor: Int) : Array[String] = {
+ val completer : ScalaCompleter = {
+ try {
+ val cls = Class.forName("scala.tools.nsc.interpreter.PresentationCompilerCompleter")
+ cls.getDeclaredConstructor(classOf[IMain]).newInstance(sparkILoop.intp)
+ .asInstanceOf[ScalaCompleter]
+ } catch {
+ case e : ClassNotFoundException => new JLineCompletion(sparkILoop.intp).completer
+ }
+ }
+ completer.complete(code, cursor).candidates.toArray
+ }
+
override protected def valueOfTerm(name: String): Option[Any] = {
// IMain#valueOfTerm will always return None, so use other way instead.
Option(sparkILoop.lastRequest.lineRep.call("$result"))
diff --git a/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala
index b058d00..fab8d95 100644
--- a/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala
+++ b/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala
@@ -53,6 +53,8 @@
protected def interpret(code: String): Results.Result
+ protected def completeCandidates(code: String, cursor: Int) : Array[String] = Array()
+
protected def valueOfTerm(name: String): Option[Any]
protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit
@@ -110,6 +112,10 @@
)))
}
+ override protected[repl] def complete(code: String, cursor: Int): Array[String] = {
+ completeCandidates(code, cursor)
+ }
+
private def executeMagic(magic: String, rest: String): Interpreter.ExecuteResponse = {
magic match {
case "json" => executeJsonMagic(rest)
diff --git a/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala b/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala
index 860513e..440dcd0 100644
--- a/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala
+++ b/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala
@@ -17,6 +17,7 @@
package org.apache.livy.repl
+import org.json4s.JArray
import org.json4s.JObject
object Interpreter {
@@ -46,6 +47,9 @@
*/
protected[repl] def execute(code: String): ExecuteResponse
+ protected[repl] def complete(code: String, cursor: Int): Array[String]
+ = Array()
+
/** Shut down the interpreter. */
def close(): Unit
}
diff --git a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala
index 7f35982..b90076d 100644
--- a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala
+++ b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala
@@ -62,6 +62,10 @@
session.cancel(msg.id)
}
+ def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplCompleteRequest): Array[String] = {
+ session.complete(msg.code, msg.codeType, msg.cursor)
+ }
+
/**
* Return statement results. Results are sorted by statement id.
*/
diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala
index f245193..7d32fc5 100644
--- a/repl/src/main/scala/org/apache/livy/repl/Session.scala
+++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala
@@ -169,6 +169,12 @@
statementId
}
+ def complete(code: String, codeType: String, cursor: Int): Array[String] = {
+ val tpe = Kind(codeType)
+ val interp = interpreter(tpe)
+ interp.complete(code, cursor)
+ }
+
def cancel(statementId: Int): Unit = {
val statementOpt = _statements.synchronized { _statements.get(statementId) }
if (statementOpt.isEmpty) {
diff --git a/repl/src/test/scala/org/apache/livy/repl/ScalaInterpreterSpec.scala b/repl/src/test/scala/org/apache/livy/repl/ScalaInterpreterSpec.scala
index 3e9ee82..5715e4d 100644
--- a/repl/src/test/scala/org/apache/livy/repl/ScalaInterpreterSpec.scala
+++ b/repl/src/test/scala/org/apache/livy/repl/ScalaInterpreterSpec.scala
@@ -203,4 +203,14 @@
Interpreter.ExecuteSuccess(TEXT_PLAIN -> s"r: String =\n$stringWithComment"))
}
}
+
+ it should "return code completion candidates" in withInterpreter { interpreter =>
+ val code =
+ """"a".""".stripMargin
+ val actual = interpreter.complete(code, code.length)
+ actual should contain ("+")
+ actual should contain ("charAt")
+ actual should contain ("compareTo")
+ }
+
}
diff --git a/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java b/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java
index 823e71b..6b7bab1 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java
@@ -202,6 +202,22 @@
}
}
+ public static class ReplCompleteRequest {
+ public final String code;
+ public final String codeType;
+ public final int cursor;
+
+ public ReplCompleteRequest(String code, String codeType, int cursor) {
+ this.code = code;
+ this.codeType = codeType;
+ this.cursor = cursor;
+ }
+
+ public ReplCompleteRequest() {
+ this(null, null, 0);
+ }
+ }
+
protected static class ReplState {
public final String state;
diff --git a/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java b/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java
index 3fc3348..77d45c7 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java
@@ -302,6 +302,12 @@
return deferredCall(new BaseProtocol.GetReplJobResults(), ReplJobResults.class);
}
+ public Future<String[]> completeReplCode(String code, String codeType, int cursor)
+ throws Exception {
+ return deferredCall(new BaseProtocol.ReplCompleteRequest(code, codeType, cursor),
+ String[].class);
+ }
+
/**
* @return Return the repl state. If this's not connected to a repl session, it will return null.
*/
diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala
index 0462e80..fd7a87a 100644
--- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala
+++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala
@@ -501,6 +501,15 @@
client.get.cancelReplCode(statementId)
}
+ def completion(content: CompletionRequest): CompletionResponse = {
+ ensureRunning()
+ recordActivity()
+
+ val proposals = client.get.completeReplCode(content.code, content.kind,
+ content.cursor).get
+ CompletionResponse(proposals.toList)
+ }
+
def runJob(job: Array[Byte], jobType: String): Long = {
performOperation(job, jobType, true)
}
diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala
index 3800856..54046a1 100644
--- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala
+++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala
@@ -28,7 +28,7 @@
import org.scalatra._
import org.scalatra.servlet.FileUploadSupport
-import org.apache.livy.{ExecuteRequest, JobHandle, LivyConf, Logging}
+import org.apache.livy.{CompletionRequest, ExecuteRequest, JobHandle, LivyConf, Logging}
import org.apache.livy.client.common.HttpMessages
import org.apache.livy.client.common.HttpMessages._
import org.apache.livy.server.{AccessManager, SessionServlet}
@@ -131,6 +131,13 @@
}
}
+ jpost[CompletionRequest]("/:id/completion") { req =>
+ withModifyAccessSession { session =>
+ val compl = session.completion(req)
+ Ok(Map("candidates" -> compl.candidates))
+ }
+ }
+
post("/:id/statements/:statementId/cancel") {
withModifyAccessSession { session =>
val statementId = params("statementId")