blob: 848c28ee1e8102840e655cd4553f12fe6bb57032 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package datafu.spark
import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@RunWith(classOf[JUnitRunner])
class DataFrameOpsTests extends FunSuite with DataFrameSuiteBase {
import DataFrameOps._
import spark.implicits._
val inputSchema = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false),
StructField("col_str", StringType, true)
)
val dedupSchema = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false)
)
lazy val inputRDD = sc.parallelize(
Seq(Row("a", 1, "asd1"),
Row("a", 2, "asd2"),
Row("a", 3, "asd3"),
Row("b", 1, "asd4")))
lazy val inputDataFrame =
sqlContext.createDataFrame(inputRDD, StructType(inputSchema)).cache
test("dedup") {
val expected: DataFrame =
sqlContext.createDataFrame(sc.parallelize(Seq(Row("b", 1), Row("a", 3))),
StructType(dedupSchema))
assertDataFrameEquals(expected,
inputDataFrame
.dedupWithOrder($"col_grp", $"col_ord".desc)
.select($"col_grp", $"col_ord"))
}
case class dedupExp( col_grp: String,
col_ord: Int,
col_str: String,
col2: String)
test("dedup2_by_int") {
val expectedByIntDf: DataFrame = sqlContext.createDataFrame(
List(dedupExp("b", 1, "asd4", "asd4"),
dedupExp("a", 3, "asd3", "asd1")))
val actual = inputDataFrame.dedupWithCombiner($"col_grp",
$"col_ord",
moreAggFunctions = Seq(min($"col_str")))
assertDataFrameEquals(expectedByIntDf, actual)
}
case class dedupExp2(col_grp: String, col_ord: Int, col_str: String)
test("dedup2_by_string_asc") {
val actual = inputDataFrame.dedupWithCombiner($"col_grp", $"col_str", desc = false)
val expectedByStringDf: DataFrame = sqlContext.createDataFrame(
List(dedupExp2("b", 1, "asd4"),
dedupExp2("a", 1, "asd1")))
assertDataFrameEquals(expectedByStringDf, actual)
}
test("dedup2_with_filter") {
val df = sqlContext.createDataFrame(
Seq(("a", 2, "aa12", "a"),
("a", 1, "aa11", "a"),
("b", 2, "ab32", "a"),
("b", 1, "ba11", "a"))
).toDF("col_grp", "col_ord", "col_str", "filter_col")
// Test case 1 - filter keep false
val actual1 = df.dedupWithCombiner($"col_grp",
$"col_ord",
desc = false,
columnsFilter = List("filter_col"),
columnsFilterKeep = false)
val expectedFilter1: DataFrame = sqlContext.createDataFrame(
Seq(("a", 1, "aa11"),
("b", 1, "ba11"))
).toDF("col_grp", "col_ord", "col_str")
assertDataFrameNoOrderEquals(expectedFilter1, actual1)
// Test case 2 - filter keep true
val actual2 = df.dedupWithCombiner($"col_grp",
$"col_ord",
desc = false,
columnsFilter = List("col_grp", "col_ord", "filter_col"),
columnsFilterKeep = true)
val expectedFilter2: DataFrame = sqlContext.createDataFrame(
Seq(("a", 1, "a"),
("b", 1, "a"))
).toDF("col_grp", "col_ord", "filter_col")
assertDataFrameNoOrderEquals(expectedFilter2, actual2)
}
test("test_dedup2_by_complex_column") {
val actual = inputDataFrame.dedupWithCombiner($"col_grp",
expr("cast(concat('-',col_ord) as int)"),
desc = false)
val expectedComplex: DataFrame = sqlContext.createDataFrame(
List(dedupExp2("b", 1, "asd4"),
dedupExp2("a", 3, "asd3")))
assertDataFrameEquals(expectedComplex, actual)
}
test("test_dedup2_by_multi_column") {
val df = sqlContext.createDataFrame(
Seq(("a", "a", 1, 2, "aa12", "a"),
("a", "a", 1, 1, "aa11", "a"),
("a", "a", 2, 1, "aa21", "a"),
("a", "b", 3, 2, "ab32", "a"),
("b", "a", 1, 1, "ba11", "a"))
).toDF("col_grp1", "col_grp2", "col_ord1", "col_ord2", "col_str", "col_to_ignore")
val actual = df.dedupWithCombiner(List($"col_grp1", $"col_grp2"),
List($"col_ord1", $"col_ord2"),
desc = false,
columnsFilter = List("col_to_ignore"),
columnsFilterKeep = false)
val expectedMulti: DataFrame = sqlContext.createDataFrame(
Seq(("a", "a", 1, 1, "aa11"),
("a", "b", 3, 2, "ab32"),
("b", "a", 1, 1, "ba11"))
).toDF("col_grp1", "col_grp2", "col_ord1", "col_ord2", "col_str")
assertDataFrameNoOrderEquals(expectedMulti, actual)
}
case class Inner(col_grp: String, col_ord: Int)
case class expComplex(
col_grp: String,
col_ord: Int,
col_str: String,
arr_col: Array[String],
struct_col: Inner,
map_col: Map[String, Int]
)
test("test_dedup2_with_other_complex_column") {
val df = inputDataFrame
.withColumn("arr_col", expr("array(col_grp, col_ord)"))
.withColumn("struct_col", expr("struct(col_grp, col_ord)"))
.withColumn("map_col", expr("map(col_grp, col_ord)"))
.withColumn("map_col_blah", expr("map(col_grp, col_ord)"))
val schema = df.drop("map_col_blah").schema
val actual = df.dedupWithCombiner($"col_grp", expr("cast(concat('-',col_ord) as int)"))
.drop("map_col_blah")
val expected: DataFrame = sqlContext.createDataFrame(
sqlContext.createDataFrame(
Seq(
expComplex("b",
1,
"asd4",
Array("b", "1"),
Inner("b", 1),
Map("b" -> 1)),
expComplex("a",
1,
"asd1",
Array("a", "1"),
Inner("a", 1),
Map("a" -> 1))
)).rdd, schema)
assertDataFrameEquals(expected, actual)
}
val dedupTopNExpectedSchema = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false)
)
test("test_dedup_top_n") {
val actual = inputDataFrame
.dedupTopN(2, $"col_grp", $"col_ord".desc)
.select($"col_grp", $"col_ord")
val expected = sqlContext.createDataFrame(
sc.parallelize(Seq(Row("b", 1), Row("a", 3), Row("a", 2))),
StructType(dedupTopNExpectedSchema))
assertDataFrameEquals(expected, actual)
}
val schema2 = List(
StructField("start", IntegerType, false),
StructField("end", IntegerType, false),
StructField("desc", StringType, true)
)
val expectedSchemaRangedJoin = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false),
StructField("col_str", StringType, true),
StructField("start", IntegerType, true),
StructField("end", IntegerType, true),
StructField("desc", StringType, true)
)
test("join_with_range") {
val joinWithRangeDataFrame =
sqlContext.createDataFrame(sc.parallelize(
Seq(Row(1, 2, "asd1"),
Row(1, 4, "asd2"),
Row(3, 5, "asd3"),
Row(3, 10, "asd4"))),
StructType(schema2))
val expected = sqlContext.createDataFrame(
sc.parallelize(
Seq(
Row("b", 1, "asd4", 1, 2, "asd1"),
Row("a", 2, "asd2", 1, 2, "asd1"),
Row("a", 1, "asd1", 1, 2, "asd1"),
Row("b", 1, "asd4", 1, 4, "asd2"),
Row("a", 3, "asd3", 1, 4, "asd2"),
Row("a", 2, "asd2", 1, 4, "asd2"),
Row("a", 1, "asd1", 1, 4, "asd2"),
Row("a", 3, "asd3", 3, 5, "asd3"),
Row("a", 3, "asd3", 3, 10, "asd4")
)),
StructType(expectedSchemaRangedJoin)
)
val actual = inputDataFrame.joinWithRange("col_ord",
joinWithRangeDataFrame,
"start",
"end")
assertDataFrameEquals(expected, actual)
}
val expectedSchemaRangedJoinWithDedup = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false),
StructField("col_str", StringType, true),
StructField("start", IntegerType, true),
StructField("end", IntegerType, true),
StructField("desc", StringType, true)
)
test("join_with_range_and_dedup") {
val df = sc
.parallelize(
List(("a", 1, "asd1"),
("a", 2, "asd2"),
("a", 3, "asd3"),
("b", 1, "asd4")))
.toDF("col_grp", "col_ord", "col_str")
val dfr = sc
.parallelize(
List((1, 2, "asd1"), (1, 4, "asd2"), (3, 5, "asd3"), (3, 10, "asd4")))
.toDF("start", "end", "desc")
val expected = sqlContext.createDataFrame(
sc.parallelize(
Seq(
Row("b", 1, "asd4", 1, 2, "asd1"),
Row("a", 3, "asd3", 3, 5, "asd3"),
Row("a", 2, "asd2", 1, 2, "asd1")
)),
StructType(expectedSchemaRangedJoinWithDedup)
)
val actual = df.joinWithRangeAndDedup("col_ord", dfr, "start", "end")
assertDataFrameEquals(expected, actual)
}
test("randomJoinSkewedTests") {
def makeSkew(i: Int): Int = {
if (i < 200) 10 else 50
}
val skewed = sqlContext.createDataFrame((1 to 500).map(i => ((Math.random * makeSkew(i)).toInt, s"str$i")))
.toDF("key", "val_skewed")
val notSkewed = sqlContext
.createDataFrame((1 to 500).map(i => ((Math.random * 50).toInt, s"str$i")))
.toDF("key", "val")
val expected = notSkewed.join(skewed, Seq("key")).sort($"key", $"val", $"val_skewed")
val actual1 = notSkewed.broadcastJoinSkewed(skewed, "key", 1).sort($"key", $"val", $"val_skewed")
assertDataFrameEquals(expected, actual1)
val leftExpected = notSkewed.join(skewed, Seq("key"), "left").sort($"key", $"val", $"val_skewed")
val actual2 = notSkewed.broadcastJoinSkewed(skewed, "key", 1, joinType = "left").sort($"key", $"val", $"val_skewed")
assertDataFrameEquals(leftExpected, actual2)
val rightExpected = notSkewed.join(skewed, Seq("key"), "right").sort($"key", $"val", $"val_skewed")
val actual3 = notSkewed.broadcastJoinSkewed(skewed, "key", 2, joinType = "right").sort($"key", $"val", $"val_skewed")
assertDataFrameEquals(rightExpected, actual3)
}
// because of nulls in expected data, an actual schema needs to be used
case class expJoinSkewed(str1: String,
str2: String,
str3: String,
str4: String)
test("joinSkewed") {
val skewedList = List(("1", "a"),
("1", "b"),
("1", "c"),
("1", "d"),
("1", "e"),
("2", "k"),
("0", "k"))
val skewed =
sqlContext.createDataFrame(skewedList).toDF("key", "val_skewed")
val notSkewed = sqlContext
.createDataFrame((1 to 10).map(i => (i.toString, s"str$i")))
.toDF("key", "val")
val actual1 =
skewed.as("a").joinSkewed(notSkewed.as("b"), expr("a.key = b.key"), 3)
val expected1 = sqlContext
.createDataFrame(
List(
("1", "a", "1", "str1"),
("1", "b", "1", "str1"),
("1", "c", "1", "str1"),
("1", "d", "1", "str1"),
("1", "e", "1", "str1"),
("2", "k", "2", "str2")
))
.toDF("key", "val_skewed", "key", "val")
// assertDataFrameEquals cares about order but we don't
assertDataFrameEquals(expected1, actual1.sort($"val_skewed"))
val actual2 = skewed
.as("a")
.joinSkewed(notSkewed.as("b"), expr("a.key = b.key"), 3, "left_outer")
val expected2 = sqlContext
.createDataFrame(
List(
expJoinSkewed("1", "a", "1", "str1"),
expJoinSkewed("1", "b", "1", "str1"),
expJoinSkewed("1", "c", "1", "str1"),
expJoinSkewed("1", "d", "1", "str1"),
expJoinSkewed("1", "e", "1", "str1"),
expJoinSkewed("2", "k", "2", "str2"),
expJoinSkewed("0", "k", null, null)
))
.toDF("key", "val_skewed", "key", "val")
// assertDataFrameEquals cares about order but we don't
assertDataFrameEquals(expected2, actual2.sort($"val_skewed"))
}
val changedSchema = List(
StructField("fld1", StringType, true),
StructField("fld2", IntegerType, false),
StructField("fld3", StringType, true)
)
test("test_changeSchema") {
val actual = inputDataFrame.changeSchema("fld1", "fld2", "fld3")
val expected =
sqlContext.createDataFrame(inputRDD, StructType(changedSchema))
assertDataFrameEquals(expected, actual)
}
test("test_flatten") {
val input = inputDataFrame
.withColumn("struct_col", expr("struct(col_grp, col_ord)"))
.select("struct_col")
val expected: DataFrame = inputDataFrame.select("col_grp", "col_ord")
val actual = input.flatten("struct_col")
assertDataFrameEquals(expected, actual)
}
test("test_explode_array") {
val input = spark.createDataFrame(Seq(
(0.0, Seq("Hi", "I heard", "about", "Spark")),
(0.0, Seq("I wish", "Java", "could use", "case", "classes")),
(1.0, Seq("Logistic", "regression", "models", "are neat")),
(0.0, Seq()),
(1.0, null)
)).toDF("label", "sentence_arr")
val actual = input.explodeArray($"sentence_arr", "token")
val expected = spark.createDataFrame(Seq(
(0.0, Seq("Hi", "I heard", "about", "Spark"),"Hi", "I heard", "about", "Spark",null),
(0.0, Seq("I wish", "Java", "could use", "case", "classes"),"I wish", "Java", "could use", "case", "classes"),
(1.0, Seq("Logistic", "regression", "models", "are neat"),"Logistic", "regression", "models", "are neat",null),
(0.0, Seq(),null,null,null,null,null),
(1.0, null,null,null,null,null,null)
)).toDF("label", "sentence_arr","token0","token1","token2","token3","token4")
assertDataFrameEquals(expected, actual)
}
}