[TOREE-462] Update dataframe magic to support showing nulls and arrays
Closes #148
diff --git a/kernel/src/main/scala/org/apache/toree/utils/DataFrameConverter.scala b/kernel/src/main/scala/org/apache/toree/utils/DataFrameConverter.scala
index ba9bade..ba75daa 100644
--- a/kernel/src/main/scala/org/apache/toree/utils/DataFrameConverter.scala
+++ b/kernel/src/main/scala/org/apache/toree/utils/DataFrameConverter.scala
@@ -21,17 +21,17 @@
import org.apache.toree.plugins.Plugin
import play.api.libs.json.{JsObject, Json}
-import scala.util.{Failure, Try}
+import scala.util.Try
import org.apache.toree.plugins.annotations.Init
+import DataFrameConverter._
+
class DataFrameConverter extends Plugin with LogLike {
@Init def init() = {
register(this)
}
- def convert(
- df: Dataset[Row], outputType: String, limit: Int = 10
- ): Try[String] = {
+ def convert(df: Dataset[Row], outputType: String, limit: Int = 10): Try[String] = {
Try(
outputType.toLowerCase() match {
case "html" =>
@@ -45,14 +45,13 @@
}
private def convertToHtml(df: Dataset[Row], limit: Int = 10): String = {
- import df.sqlContext.implicits._
val columnFields = df.schema.fieldNames.map(columnName => {
s"<th>${columnName}</th>"
}).reduce(_ + _)
val columns = s"<tr>${columnFields}</tr>"
val rows = df.rdd.map(row => {
val fieldValues = row.toSeq.map(field => {
- s"<td>${field.toString}</td>"
+ s"<td>${fieldToString(field)}</td>"
}).reduce(_ + _)
s"<tr>${fieldValues}</tr>"
}).take(limit).reduce(_ + _)
@@ -60,10 +59,9 @@
}
private def convertToJson(df: Dataset[Row], limit: Int = 10): String = {
- import df.sqlContext.implicits._
val schema = Json.toJson(df.schema.fieldNames)
val transformed = df.rdd.map(row =>
- row.toSeq.map(_.toString).toArray)
+ row.toSeq.map(fieldToString).toArray)
val rows = transformed.take(limit)
JsObject(Seq(
"columns" -> schema,
@@ -72,11 +70,22 @@
}
private def convertToCsv(df: Dataset[Row], limit: Int = 10): String = {
- import df.sqlContext.implicits._
val headers = df.schema.fieldNames.reduce(_ + "," + _)
val rows = df.rdd.map(row => {
- row.toSeq.map(field => field.toString).reduce(_ + "," + _)
+ row.toSeq.map(fieldToString).reduce(_ + "," + _)
}).take(limit).reduce(_ + "\n" + _)
s"${headers}\n${rows}"
}
-}
\ No newline at end of file
+
+}
+
+object DataFrameConverter {
+
+ def fieldToString(any: Any): String =
+ any match {
+ case null => "null"
+ case seq: Seq[_] => seq.mkString("[", ", ", "]")
+ case _ => any.toString
+ }
+
+}
diff --git a/kernel/src/test/scala/org/apache/toree/utils/DataFrameConverterSpec.scala b/kernel/src/test/scala/org/apache/toree/utils/DataFrameConverterSpec.scala
index d481a92..601a31c 100644
--- a/kernel/src/test/scala/org/apache/toree/utils/DataFrameConverterSpec.scala
+++ b/kernel/src/test/scala/org/apache/toree/utils/DataFrameConverterSpec.scala
@@ -17,30 +17,34 @@
package org.apache.toree.utils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
-import org.mockito.Matchers._
+import org.apache.spark.sql.{DataFrame, Row}
import org.mockito.Mockito._
import org.scalatest.mock.MockitoSugar
-import org.scalatest.{FunSpec, Matchers}
+import org.scalatest.{BeforeAndAfterAll, FunSpec, Matchers}
import play.api.libs.json.{JsArray, JsString, Json}
+import test.utils.SparkContextProvider
-class DataFrameConverterSpec extends FunSpec with MockitoSugar with Matchers {
+import scala.collection.mutable
+
+class DataFrameConverterSpec extends FunSpec with MockitoSugar with Matchers with BeforeAndAfterAll {
+
+ lazy val spark = SparkContextProvider.sparkContext
+
+ override protected def afterAll(): Unit = {
+ spark.stop()
+ super.afterAll()
+ }
+
val dataFrameConverter: DataFrameConverter = new DataFrameConverter
val mockDataFrame = mock[DataFrame]
- val mockRdd = mock[RDD[Any]]
+ val mockRdd = spark.parallelize(Seq(Row(new mutable.WrappedArray.ofRef(Array("test1", "test2")), 2, null)))
val mockStruct = mock[StructType]
val columns = Seq("foo", "bar").toArray
- val rowsOfArrays = Array( Array("a", "b"), Array("c", "d") )
- val rowsOfStrings = Array("test1","test2")
- val rowsOfString = Array("test1")
doReturn(mockStruct).when(mockDataFrame).schema
doReturn(columns).when(mockStruct).fieldNames
doReturn(mockRdd).when(mockDataFrame).rdd
- doReturn(mockRdd).when(mockRdd).map(any())(any())
- doReturn(rowsOfArrays).when(mockRdd).take(anyInt())
describe("DataFrameConverter") {
describe("#convert") {
@@ -49,24 +53,24 @@
val jsValue = Json.parse(someJson.get)
jsValue \ "columns" should be (JsArray(Seq(JsString("foo"), JsString("bar"))))
jsValue \ "rows" should be (JsArray(Seq(
- JsArray(Seq(JsString("a"), JsString("b"))),
- JsArray(Seq(JsString("c"), JsString("d")))))
- )
+ JsArray(Seq(JsString("[test1, test2]"), JsString("2"), JsString("null")))
+ )))
}
it("should convert to csv") {
- doReturn(rowsOfStrings).when(mockRdd).take(anyInt())
val csv = dataFrameConverter.convert(mockDataFrame, "csv").get
- val values = csv.split("\n").map(_.split(","))
- values(0) should contain allOf ("foo","bar")
+ val values = csv.split("\n")
+ values(0) shouldBe "foo,bar"
+ values(1) shouldBe "[test1, test2],2,null"
}
it("should convert to html") {
- doReturn(rowsOfStrings).when(mockRdd).take(anyInt())
val html = dataFrameConverter.convert(mockDataFrame, "html").get
html.contains("<th>foo</th>") should be(true)
html.contains("<th>bar</th>") should be(true)
+ html.contains("<td>[test1, test2]</td>") should be(true)
+ html.contains("<td>2</td>") should be(true)
+ html.contains("<td>null</td>") should be(true)
}
it("should convert limit the selection") {
- doReturn(rowsOfString).when(mockRdd).take(1)
val someLimited = dataFrameConverter.convert(mockDataFrame, "csv", 1)
val limitedLines = someLimited.get.split("\n")
limitedLines.length should be(2)