DATAFU-171 - Improving DedupWithCombiner to support multiple orderBy / groupBy keys. add tests
diff --git a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
index adf4784..8cd1e3e 100644
--- a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
+++ b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
@@ -67,7 +67,9 @@
* those columns in the result
:return: DataFrame representing the data after the operation
"""
- jdf = _get_utils(df).dedupWithCombiner(df._jdf, group_col._jc, order_by_col._jc, desc, columns_filter, columns_filter_keep)
+ group_col = group_col if isinstance(group_col, list) else [group_col]
+ order_by_col = order_by_col if isinstance(order_by_col, list) else [order_by_col]
+ jdf = _get_utils(df).dedupWithCombiner(df._jdf, _cols_to_java_cols(group_col), _cols_to_java_cols(order_by_col), desc, columns_filter, columns_filter_keep)
return DataFrame(jdf, df.sql_ctx)
diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
index 64c1f8b..daad469 100644
--- a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
@@ -19,6 +19,7 @@
package datafu.spark
import org.apache.spark.sql.{Column, DataFrame}
+import scala.language.implicitConversions
/**
* implicit class to enable easier usage e.g:
@@ -32,6 +33,7 @@
*/
object DataFrameOps {
+ implicit def columnToColumns(c: Column): Seq[Column] = Seq(c)
implicit class someDataFrameUtils(df: DataFrame) {
def dedupWithOrder(groupCol: Column, orderCols: Column*): DataFrame =
@@ -40,8 +42,8 @@
def dedupTopN(n: Int, groupCol: Column, orderCols: Column*): DataFrame =
SparkDFUtils.dedupTopN(df, n, groupCol, orderCols: _*)
- def dedupWithCombiner(groupCol: Column,
- orderByCol: Column,
+ def dedupWithCombiner(groupCol: Seq[Column],
+ orderByCol: Seq[Column],
desc: Boolean = true,
moreAggFunctions: Seq[Column] = Nil,
columnsFilter: Seq[String] = Nil,
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
index 240bcbc..2732460 100644
--- a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
@@ -25,6 +25,8 @@
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{LongType, StructType}
+import scala.language.implicitConversions
+import DataFrameOps.columnToColumns
import org.apache.spark.storage.StorageLevel
/**
@@ -51,16 +53,18 @@
}
def dedupWithCombiner(df: DataFrame,
- groupCol: Column,
- orderByCol: Column,
+ groupCol: JavaList[Column],
+ orderByCol: JavaList[Column],
desc: Boolean,
columnsFilter: JavaList[String],
columnsFilterKeep: Boolean): DataFrame = {
val columnsFilter_converted = convertJavaListToSeq(columnsFilter)
+ val groupCol_converted = convertJavaListToSeq(groupCol)
+ val orderByCol_converted = convertJavaListToSeq(orderByCol)
SparkDFUtils.dedupWithCombiner(
df = df,
- groupCol = groupCol,
- orderByCol = orderByCol,
+ groupCol = groupCol_converted,
+ orderByCol = orderByCol_converted,
desc = desc,
moreAggFunctions = Nil,
columnsFilter = columnsFilter_converted,
@@ -200,25 +204,25 @@
* @return DataFrame representing the data after the operation
*/
def dedupWithCombiner(df: DataFrame,
- groupCol: Column,
- orderByCol: Column,
+ groupCol: Seq[Column],
+ orderByCol: Seq[Column],
desc: Boolean = true,
moreAggFunctions: Seq[Column] = Nil,
columnsFilter: Seq[String] = Nil,
columnsFilterKeep: Boolean = true): DataFrame = {
val newDF =
if (columnsFilter == Nil) {
- df.withColumn("sort_by_column", orderByCol)
+ df.withColumn("sort_by_column", struct(orderByCol: _*))
} else {
if (columnsFilterKeep) {
- df.withColumn("sort_by_column", orderByCol)
+ df.withColumn("sort_by_column", struct(orderByCol: _*))
.select("sort_by_column", columnsFilter: _*)
} else {
df.select(
df.columns
.filter(colName => !columnsFilter.contains(colName))
.map(colName => new Column(colName)): _*)
- .withColumn("sort_by_column", orderByCol)
+ .withColumn("sort_by_column", struct(orderByCol: _*))
}
}
@@ -227,15 +231,18 @@
else SparkOverwriteUDAFs.minValueByKey(_: Column, _: Column)
val df2 = newDF
- .groupBy(groupCol.as("group_by_col"))
+ .groupBy(groupCol:_*)
.agg(aggFunc(expr("sort_by_column"), expr("struct(sort_by_column, *)"))
- .as("h1"),
- struct(lit(1).as("lit_placeholder_col") +: moreAggFunctions: _*)
- .as("h2"))
- .selectExpr("h2.*", "h1.*")
+ .as("h1"),
+ struct(lit(1).as("lit_placeholder_col") +: moreAggFunctions: _*)
+ .as("h2"))
+ .selectExpr("h1.*", "h2.*")
.drop("lit_placeholder_col")
.drop("sort_by_column")
- df2
+ val ns = StructType((df.schema++df2.schema.filter(s2 => !df.schema.map(_.name).contains(s2.name)))
+ .filter(s2 => columnsFilter == Nil || (columnsFilterKeep && columnsFilter.contains(s2.name)) || (!columnsFilterKeep && !columnsFilter.contains(s2.name))).toList)
+
+ df2.sparkSession.createDataFrame(df2.rdd,ns)
}
/**
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
index fa4060c..848c28e 100644
--- a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
@@ -67,16 +67,16 @@
.select($"col_grp", $"col_ord"))
}
- case class dedupExp(col2: String,
- col_grp: String,
- col_ord: Option[Int],
- col_str: String)
+ case class dedupExp( col_grp: String,
+ col_ord: Int,
+ col_str: String,
+ col2: String)
test("dedup2_by_int") {
val expectedByIntDf: DataFrame = sqlContext.createDataFrame(
- List(dedupExp("asd4", "b", Option(1), "asd4"),
- dedupExp("asd1", "a", Option(3), "asd3")))
+ List(dedupExp("b", 1, "asd4", "asd4"),
+ dedupExp("a", 3, "asd3", "asd1")))
val actual = inputDataFrame.dedupWithCombiner($"col_grp",
$"col_ord",
@@ -85,19 +85,57 @@
assertDataFrameEquals(expectedByIntDf, actual)
}
- case class dedupExp2(col_grp: String, col_ord: Option[Int], col_str: String)
+ case class dedupExp2(col_grp: String, col_ord: Int, col_str: String)
test("dedup2_by_string_asc") {
val actual = inputDataFrame.dedupWithCombiner($"col_grp", $"col_str", desc = false)
val expectedByStringDf: DataFrame = sqlContext.createDataFrame(
- List(dedupExp2("b", Option(1), "asd4"),
- dedupExp2("a", Option(1), "asd1")))
+ List(dedupExp2("b", 1, "asd4"),
+ dedupExp2("a", 1, "asd1")))
assertDataFrameEquals(expectedByStringDf, actual)
}
+ test("dedup2_with_filter") {
+
+ val df = sqlContext.createDataFrame(
+ Seq(("a", 2, "aa12", "a"),
+ ("a", 1, "aa11", "a"),
+ ("b", 2, "ab32", "a"),
+ ("b", 1, "ba11", "a"))
+ ).toDF("col_grp", "col_ord", "col_str", "filter_col")
+
+ // Test case 1 - filter keep false
+ val actual1 = df.dedupWithCombiner($"col_grp",
+ $"col_ord",
+ desc = false,
+ columnsFilter = List("filter_col"),
+ columnsFilterKeep = false)
+
+ val expectedFilter1: DataFrame = sqlContext.createDataFrame(
+ Seq(("a", 1, "aa11"),
+ ("b", 1, "ba11"))
+ ).toDF("col_grp", "col_ord", "col_str")
+
+ assertDataFrameNoOrderEquals(expectedFilter1, actual1)
+
+ // Test case 2 - filter keep true
+ val actual2 = df.dedupWithCombiner($"col_grp",
+ $"col_ord",
+ desc = false,
+ columnsFilter = List("col_grp", "col_ord", "filter_col"),
+ columnsFilterKeep = true)
+
+ val expectedFilter2: DataFrame = sqlContext.createDataFrame(
+ Seq(("a", 1, "a"),
+ ("b", 1, "a"))
+ ).toDF("col_grp", "col_ord", "filter_col")
+
+ assertDataFrameNoOrderEquals(expectedFilter2, actual2)
+ }
+
test("test_dedup2_by_complex_column") {
val actual = inputDataFrame.dedupWithCombiner($"col_grp",
@@ -105,17 +143,43 @@
desc = false)
val expectedComplex: DataFrame = sqlContext.createDataFrame(
- List(dedupExp2("b", Option(1), "asd4"),
- dedupExp2("a", Option(3), "asd3")))
+ List(dedupExp2("b", 1, "asd4"),
+ dedupExp2("a", 3, "asd3")))
assertDataFrameEquals(expectedComplex, actual)
}
+ test("test_dedup2_by_multi_column") {
+
+ val df = sqlContext.createDataFrame(
+ Seq(("a", "a", 1, 2, "aa12", "a"),
+ ("a", "a", 1, 1, "aa11", "a"),
+ ("a", "a", 2, 1, "aa21", "a"),
+ ("a", "b", 3, 2, "ab32", "a"),
+ ("b", "a", 1, 1, "ba11", "a"))
+ ).toDF("col_grp1", "col_grp2", "col_ord1", "col_ord2", "col_str", "col_to_ignore")
+
+ val actual = df.dedupWithCombiner(List($"col_grp1", $"col_grp2"),
+ List($"col_ord1", $"col_ord2"),
+ desc = false,
+ columnsFilter = List("col_to_ignore"),
+ columnsFilterKeep = false)
+
+ val expectedMulti: DataFrame = sqlContext.createDataFrame(
+ Seq(("a", "a", 1, 1, "aa11"),
+ ("a", "b", 3, 2, "ab32"),
+ ("b", "a", 1, 1, "ba11"))
+ ).toDF("col_grp1", "col_grp2", "col_ord1", "col_ord2", "col_str")
+
+ assertDataFrameNoOrderEquals(expectedMulti, actual)
+
+ }
+
case class Inner(col_grp: String, col_ord: Int)
case class expComplex(
col_grp: String,
- col_ord: Option[Int],
+ col_ord: Int,
col_str: String,
arr_col: Array[String],
struct_col: Inner,
@@ -124,29 +188,33 @@
test("test_dedup2_with_other_complex_column") {
- val actual = inputDataFrame
+ val df = inputDataFrame
.withColumn("arr_col", expr("array(col_grp, col_ord)"))
.withColumn("struct_col", expr("struct(col_grp, col_ord)"))
.withColumn("map_col", expr("map(col_grp, col_ord)"))
.withColumn("map_col_blah", expr("map(col_grp, col_ord)"))
- .dedupWithCombiner($"col_grp", expr("cast(concat('-',col_ord) as int)"))
+
+ val schema = df.drop("map_col_blah").schema
+
+ val actual = df.dedupWithCombiner($"col_grp", expr("cast(concat('-',col_ord) as int)"))
.drop("map_col_blah")
val expected: DataFrame = sqlContext.createDataFrame(
- List(
+ sqlContext.createDataFrame(
+ Seq(
expComplex("b",
- Option(1),
+ 1,
"asd4",
Array("b", "1"),
Inner("b", 1),
Map("b" -> 1)),
expComplex("a",
- Option(1),
+ 1,
"asd1",
Array("a", "1"),
Inner("a", 1),
Map("a" -> 1))
- ))
+ )).rdd, schema)
assertDataFrameEquals(expected, actual)
}
@@ -218,7 +286,7 @@
val expectedSchemaRangedJoinWithDedup = List(
StructField("col_grp", StringType, true),
- StructField("col_ord", IntegerType, true),
+ StructField("col_ord", IntegerType, false),
StructField("col_str", StringType, true),
StructField("start", IntegerType, true),
StructField("end", IntegerType, true),