blob: aadd05958587ecf0ed2b3b3f03076c3e56071b02 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package datafu.spark
import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.junit.Assert
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.slf4j.LoggerFactory
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.datafu.types.SparkOverwriteUDAFs
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types._
@RunWith(classOf[JUnitRunner])
class UdafTests extends FunSuite with DataFrameSuiteBase {
import spark.implicits._
/**
* taken from https://github.com/holdenk/spark-testing-base/issues/234#issuecomment-390150835
*
* Solves problem with Hive in Spark 2.3.0 in spark-testing-base
*/
override def conf: SparkConf =
super.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
val logger = LoggerFactory.getLogger(this.getClass)
val inputSchema = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false),
StructField("col_str", StringType, true)
)
lazy val inputRDD = sc.parallelize(
Seq(Row("a", 1, "asd1"),
Row("a", 2, "asd2"),
Row("a", 3, "asd3"),
Row("b", 1, "asd4")))
lazy val df =
sqlContext.createDataFrame(inputRDD, StructType(inputSchema)).cache
case class mapExp(map_col: Map[String, Int])
case class mapArrExp(map_col: Map[String, Array[String]])
test("test multiset simple") {
val ms = new SparkUDAFs.MultiSet()
val expected: DataFrame =
sqlContext.createDataFrame(List(mapExp(Map("b" -> 1, "a" -> 3))))
assertDataFrameEquals(expected, df.agg(ms($"col_grp").as("map_col")))
}
val mas = new SparkUDAFs.MultiArraySet[String]()
test("test multiarrayset simple") {
assertDataFrameEquals(
sqlContext.createDataFrame(List(mapExp(Map("tre" -> 1, "asd" -> 2)))),
spark
.sql("select array('asd','tre','asd') arr")
.groupBy()
.agg(mas($"arr").as("map_col"))
)
}
test("test multiarrayset all nulls") {
// end case
spark.sql("drop table if exists mas_table")
spark.sql("create table mas_table (arr array<string>)")
spark.sql(
"insert overwrite table mas_table select case when 1=2 then array('asd') end " +
"from (select 1)z")
spark.sql(
"insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
spark.sql(
"insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
spark.sql(
"insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
spark.sql(
"insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
val expected = sqlContext.createDataFrame(List(mapExp(Map())))
val actual =
spark.table("mas_table").groupBy().agg(mas($"arr").as("map_col"))
assertDataFrameEquals(expected, actual)
}
test("test multiarrayset max keys") {
// max keys case
spark.sql("drop table if exists mas_table2")
spark.sql("create table mas_table2 (arr array<string>)")
spark.sql(
"insert overwrite table mas_table2 select array('asd','dsa') from (select 1)z")
spark.sql(
"insert into table mas_table2 select array('asd','abc') from (select 1)z")
spark.sql(
"insert into table mas_table2 select array('asd') from (select 1)z")
spark.sql(
"insert into table mas_table2 select array('asd') from (select 1)z")
spark.sql(
"insert into table mas_table2 select array('asd') from (select 1)z")
spark.sql(
"insert into table mas_table2 select array('asd2') from (select 1)z")
val mas2 = new SparkUDAFs.MultiArraySet[String](maxKeys = 2)
assertDataFrameEquals(
sqlContext.createDataFrame(List(mapExp(Map("dsa" -> 1, "asd" -> 5)))),
spark.table("mas_table2").groupBy().agg(mas2($"arr").as("map_col")))
val mas1 = new SparkUDAFs.MultiArraySet[String](maxKeys = 1)
assertDataFrameEquals(
sqlContext.createDataFrame(List(mapExp(Map("asd" -> 5)))),
spark.table("mas_table2").groupBy().agg(mas1($"arr").as("map_col")))
}
test("test multiarrayset big input") {
val N = 100000
val blah = spark.sparkContext
.parallelize(1 to N, 20)
.toDF("num")
.selectExpr("array('asd',concat('dsa',num)) as arr")
val mas = new SparkUDAFs.MultiArraySet[String](maxKeys = 3)
val time1 = System.currentTimeMillis()
val mp = blah
.groupBy()
.agg(mas($"arr"))
.collect()
.map(_.getMap[String, Int](0))
.head
Assert.assertEquals(3, mp.size)
Assert.assertEquals("asd", mp.maxBy(_._2)._1)
Assert.assertEquals(N, mp.maxBy(_._2)._2)
val time2 = System.currentTimeMillis()
logger.info("time took: " + (time2 - time1) / 1000 + " secs")
}
test("test mapmerge") {
val mapMerge = new SparkUDAFs.MapSetMerge()
spark.sql("drop table if exists mapmerge_table")
spark.sql("create table mapmerge_table (c map<string, array<string>>)")
spark.sql(
"insert overwrite table mapmerge_table select map('k1', array('v1')) from (select 1) z")
spark.sql(
"insert into table mapmerge_table select map('k1', array('v1')) from (select 1) z")
spark.sql(
"insert into table mapmerge_table select map('k2', array('v3')) from (select 1) z")
assertDataFrameEquals(
sqlContext.createDataFrame(
List(mapArrExp(Map("k1" -> Array("v1"), "k2" -> Array("v3"))))),
spark.table("mapmerge_table").groupBy().agg(mapMerge($"c").as("map_col"))
)
}
test("minKeyValue") {
assertDataFrameEquals(
sqlContext.createDataFrame(List(("b", "asd4"), ("a", "asd1"))),
df.groupBy($"col_grp".as("_1"))
.agg(SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").as("_2"))
)
}
case class Exp4(colGrp: String, colOrd: Int, colStr: String, asd: String)
val minKeyValueWindowExpectedSchema = List(
StructField("col_grp", StringType, true),
StructField("col_ord", IntegerType, false),
StructField("col_str", StringType, true),
StructField("asd", StringType, true)
)
test("minKeyValue window") {
assertDataFrameEquals(
sqlContext.createDataFrame(
sc.parallelize(
Seq(
Row("b", 1, "asd4", "asd4"),
Row("a", 1, "asd1", "asd1"),
Row("a", 2, "asd2", "asd1"),
Row("a", 3, "asd3", "asd1")
)),
StructType(minKeyValueWindowExpectedSchema)
),
df.withColumn("asd",
SparkOverwriteUDAFs
.minValueByKey($"col_ord", $"col_str")
.over(Window.partitionBy("col_grp")))
)
}
case class Exp5(col_grp: String, col_ord: Option[Int])
test("countDistinctUpTo") {
import datafu.spark.SparkUDAFs.CountDistinctUpTo
val countDistinctUpTo3 = new CountDistinctUpTo(3)
val countDistinctUpTo6 = new CountDistinctUpTo(6)
val inputDF = sqlContext.createDataFrame(
List(
Exp5("b", Option(1)),
Exp5("a", Option(1)),
Exp5("a", Option(2)),
Exp5("a", Option(3)),
Exp5("a", Option(4))
))
val results3DF = sqlContext.createDataFrame(
List(
Exp5("b", Option(1)),
Exp5("a", Option(3))
))
val results6DF = sqlContext.createDataFrame(
List(
Exp5("b", Option(1)),
Exp5("a", Option(4))
))
inputDF
.groupBy("col_grp")
.agg(countDistinctUpTo3($"col_ord").as("col_ord"))
.show
assertDataFrameEquals(results3DF,
inputDF
.groupBy("col_grp")
.agg(countDistinctUpTo3($"col_ord").as("col_ord")))
assertDataFrameEquals(results6DF,
inputDF
.groupBy("col_grp")
.agg(countDistinctUpTo6($"col_ord").as("col_ord")))
}
test("test_limited_collect_list") {
val maxSize = 10
val rows = (1 to 30).flatMap(x => (1 to x).map(n => (x, n, "some-string " + n))).toDF("num1", "num2", "str")
rows.show(10, false)
import org.apache.spark.sql.functions._
val result = rows.groupBy("num1").agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list"))
.withColumn("list_size", expr("size(list)"))
result.show(10, false)
SparkDFUtils.dedupRandomN(rows,$"num1",10).show(10,false)
val rows_different = result.filter(s"case when num1 > $maxSize then $maxSize else num1 end != list_size")
Assert.assertEquals(0, rows_different.count())
}
}