[TOREE-466] Properly recognize higher order functions
Closes #153
diff --git a/scala-interpreter/src/main/scala-2.11/org/apache/toree/kernel/interpreter/scala/ScalaInterpreterSpecific.scala b/scala-interpreter/src/main/scala-2.11/org/apache/toree/kernel/interpreter/scala/ScalaInterpreterSpecific.scala
index 1bde367..902bb1d 100644
--- a/scala-interpreter/src/main/scala-2.11/org/apache/toree/kernel/interpreter/scala/ScalaInterpreterSpecific.scala
+++ b/scala-interpreter/src/main/scala-2.11/org/apache/toree/kernel/interpreter/scala/ScalaInterpreterSpecific.scala
@@ -247,10 +247,19 @@
override def read(variableName: String): Option[AnyRef] = {
require(iMain != null)
- iMain.eval(variableName) match {
- case null => None
- case str: String if str.isEmpty => None
- case res => Some(res)
+ try {
+ iMain.eval(variableName) match {
+ case null => None
+ case str: String if str.isEmpty => None
+ case res => Some(res)
+ }
+ } catch {
+ // if any error returns None
+ case e: Throwable => {
+ logger.debug(s"Error reading variable name: ${variableName}", e)
+ clearLastException()
+ None
+ }
}
}
diff --git a/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaInterpreter.scala b/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaInterpreter.scala
index d8a937a..cf763d8 100644
--- a/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaInterpreter.scala
+++ b/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaInterpreter.scala
@@ -204,6 +204,11 @@
val text = new StringBuilder
interpreterOutput.split("\n").foreach {
+
+ case HigherOrderFunction(name, func, funcType) =>
+
+ definitions.append(s"$name: $func$funcType").append("\n")
+
case NamedResult(name, vtype, value) if read(name).nonEmpty =>
val result = read(name)
@@ -409,6 +414,7 @@
object ScalaInterpreter {
+ val HigherOrderFunction: Regex = """(\w+):\s+(\(\s*.*=>\s*\w+\))(\w+)\s*.*""".r
val NamedResult: Regex = """(\w+):\s+([^=]+)\s+=\s*(.*)""".r
val Definition: Regex = """defined\s+(\w+)\s+(.+)""".r
val Import: Regex = """import\s+([\w\.,\{\}\s]+)""".r
diff --git a/scala-interpreter/src/test/scala-2.11/scala/ScalaInterpreterSpec.scala b/scala-interpreter/src/test/scala-2.11/scala/ScalaInterpreterSpec.scala
index fd0cf20..9d024f1 100644
--- a/scala-interpreter/src/test/scala-2.11/scala/ScalaInterpreterSpec.scala
+++ b/scala-interpreter/src/test/scala-2.11/scala/ScalaInterpreterSpec.scala
@@ -438,6 +438,21 @@
interpreter.stop()
}
+ it("should properly handle higher order functions") {
+ interpreter.start()
+ doReturn("myFunction: (x: Int, foo: Int => Int)Int").when(mockSparkIMain).eval("myFunction")
+
+ // Results that match
+ interpreter.prepareResult("myFunction: (x: Int, foo: Int => Int)Int") should be(
+ (None,
+ Some("myFunction: (x: Int, foo: Int => Int)Int\n"),
+ None))
+
+
+ interpreter.stop()
+
+ }
+
it("should truncate res results that have tuple values") {
//val t: (String, Int) = ("hello",1) ==> t: (String, Int) = (hello,1)
interpreter.start()