DATAFU-171 - Improving DedupWithCombiner to support multiple orderBy / groupBy keys. add tests
diff --git a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
index adf4784..8cd1e3e 100644
--- a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
+++ b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
@@ -67,7 +67,9 @@
 *                          those columns in the result
     :return: DataFrame representing the data after the operation
     """
-    jdf = _get_utils(df).dedupWithCombiner(df._jdf, group_col._jc, order_by_col._jc, desc, columns_filter, columns_filter_keep)
+    group_col = group_col if isinstance(group_col, list) else [group_col]
+    order_by_col = order_by_col if isinstance(order_by_col, list) else [order_by_col]
+    jdf = _get_utils(df).dedupWithCombiner(df._jdf, _cols_to_java_cols(group_col), _cols_to_java_cols(order_by_col), desc, columns_filter, columns_filter_keep)
     return DataFrame(jdf, df.sql_ctx)
 
 
diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
index 64c1f8b..daad469 100644
--- a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
@@ -19,6 +19,7 @@
 package datafu.spark
 
 import org.apache.spark.sql.{Column, DataFrame}
+import scala.language.implicitConversions
 
 /**
  * implicit class to enable easier usage e.g:
@@ -32,6 +33,7 @@
  */
 object DataFrameOps {
 
+  implicit def columnToColumns(c: Column): Seq[Column] = Seq(c)
   implicit class someDataFrameUtils(df: DataFrame) {
 
     def dedupWithOrder(groupCol: Column, orderCols: Column*): DataFrame =
@@ -40,8 +42,8 @@
     def dedupTopN(n: Int, groupCol: Column, orderCols: Column*): DataFrame =
       SparkDFUtils.dedupTopN(df, n, groupCol, orderCols: _*)
 
-    def dedupWithCombiner(groupCol: Column,
-                          orderByCol: Column,
+    def dedupWithCombiner(groupCol: Seq[Column],
+                          orderByCol: Seq[Column],
                           desc: Boolean = true,
                           moreAggFunctions: Seq[Column] = Nil,
                           columnsFilter: Seq[String] = Nil,
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
index 240bcbc..2732460 100644
--- a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
@@ -25,6 +25,8 @@
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{LongType, StructType}
+import scala.language.implicitConversions
+import DataFrameOps.columnToColumns
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -51,16 +53,18 @@
   }
 
   def dedupWithCombiner(df: DataFrame,
-             groupCol: Column,
-             orderByCol: Column,
+             groupCol: JavaList[Column],
+             orderByCol: JavaList[Column],
              desc: Boolean,
              columnsFilter: JavaList[String],
              columnsFilterKeep: Boolean): DataFrame = {
     val columnsFilter_converted = convertJavaListToSeq(columnsFilter)
+    val groupCol_converted = convertJavaListToSeq(groupCol)
+    val orderByCol_converted = convertJavaListToSeq(orderByCol)
     SparkDFUtils.dedupWithCombiner(
       df = df,
-      groupCol = groupCol,
-      orderByCol = orderByCol,
+      groupCol = groupCol_converted,
+      orderByCol = orderByCol_converted,
       desc = desc,
       moreAggFunctions = Nil,
       columnsFilter = columnsFilter_converted,
@@ -200,25 +204,25 @@
     * @return DataFrame representing the data after the operation
     */
   def dedupWithCombiner(df: DataFrame,
-                        groupCol: Column,
-                        orderByCol: Column,
+                        groupCol: Seq[Column],
+                        orderByCol: Seq[Column],
                         desc: Boolean = true,
                         moreAggFunctions: Seq[Column] = Nil,
                         columnsFilter: Seq[String] = Nil,
                         columnsFilterKeep: Boolean = true): DataFrame = {
     val newDF =
       if (columnsFilter == Nil) {
-        df.withColumn("sort_by_column", orderByCol)
+        df.withColumn("sort_by_column", struct(orderByCol: _*))
       } else {
         if (columnsFilterKeep) {
-          df.withColumn("sort_by_column", orderByCol)
+          df.withColumn("sort_by_column", struct(orderByCol: _*))
             .select("sort_by_column", columnsFilter: _*)
         } else {
           df.select(
             df.columns
               .filter(colName => !columnsFilter.contains(colName))
               .map(colName => new Column(colName)): _*)
-            .withColumn("sort_by_column", orderByCol)
+            .withColumn("sort_by_column", struct(orderByCol: _*))
         }
       }
 
@@ -227,15 +231,18 @@
       else SparkOverwriteUDAFs.minValueByKey(_: Column, _: Column)
 
     val df2 = newDF
-      .groupBy(groupCol.as("group_by_col"))
+      .groupBy(groupCol:_*)
       .agg(aggFunc(expr("sort_by_column"), expr("struct(sort_by_column, *)"))
-             .as("h1"),
-           struct(lit(1).as("lit_placeholder_col") +: moreAggFunctions: _*)
-             .as("h2"))
-      .selectExpr("h2.*", "h1.*")
+        .as("h1"),
+        struct(lit(1).as("lit_placeholder_col") +: moreAggFunctions: _*)
+          .as("h2"))
+      .selectExpr("h1.*", "h2.*")
       .drop("lit_placeholder_col")
       .drop("sort_by_column")
-    df2
+    val ns = StructType((df.schema++df2.schema.filter(s2 => !df.schema.map(_.name).contains(s2.name)))
+      .filter(s2 => columnsFilter == Nil || (columnsFilterKeep && columnsFilter.contains(s2.name)) || (!columnsFilterKeep && !columnsFilter.contains(s2.name))).toList)
+
+    df2.sparkSession.createDataFrame(df2.rdd,ns)
   }
 
   /**
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
index fa4060c..848c28e 100644
--- a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
@@ -67,16 +67,16 @@
                             .select($"col_grp", $"col_ord"))
   }
 
-  case class dedupExp(col2: String,
-                       col_grp: String,
-                       col_ord: Option[Int],
-                       col_str: String)
+  case class dedupExp( col_grp: String,
+                       col_ord: Int,
+                       col_str: String,
+                       col2: String)
 
   test("dedup2_by_int") {
 
     val expectedByIntDf: DataFrame = sqlContext.createDataFrame(
-      List(dedupExp("asd4", "b", Option(1), "asd4"),
-        dedupExp("asd1", "a", Option(3), "asd3")))
+      List(dedupExp("b", 1, "asd4", "asd4"),
+        dedupExp("a", 3, "asd3", "asd1")))
 
     val actual = inputDataFrame.dedupWithCombiner($"col_grp",
                                        $"col_ord",
@@ -85,19 +85,57 @@
     assertDataFrameEquals(expectedByIntDf, actual)
   }
 
-  case class dedupExp2(col_grp: String, col_ord: Option[Int], col_str: String)
+  case class dedupExp2(col_grp: String, col_ord: Int, col_str: String)
 
   test("dedup2_by_string_asc") {
 
     val actual = inputDataFrame.dedupWithCombiner($"col_grp", $"col_str", desc = false)
 
     val expectedByStringDf: DataFrame = sqlContext.createDataFrame(
-      List(dedupExp2("b", Option(1), "asd4"),
-        dedupExp2("a", Option(1), "asd1")))
+      List(dedupExp2("b", 1, "asd4"),
+        dedupExp2("a", 1, "asd1")))
 
     assertDataFrameEquals(expectedByStringDf, actual)
   }
 
+  test("dedup2_with_filter") {
+
+    val df = sqlContext.createDataFrame(
+      Seq(("a", 2, "aa12", "a"),
+        ("a", 1, "aa11", "a"),
+        ("b", 2, "ab32", "a"),
+        ("b", 1, "ba11", "a"))
+    ).toDF("col_grp", "col_ord", "col_str", "filter_col")
+
+    // Test case 1 - filter keep false
+    val actual1 = df.dedupWithCombiner($"col_grp",
+      $"col_ord",
+      desc = false,
+      columnsFilter = List("filter_col"),
+      columnsFilterKeep = false)
+
+    val expectedFilter1: DataFrame = sqlContext.createDataFrame(
+      Seq(("a", 1, "aa11"),
+        ("b", 1, "ba11"))
+    ).toDF("col_grp", "col_ord", "col_str")
+
+    assertDataFrameNoOrderEquals(expectedFilter1, actual1)
+
+    // Test case 2 - filter keep true
+    val actual2 = df.dedupWithCombiner($"col_grp",
+      $"col_ord",
+      desc = false,
+      columnsFilter = List("col_grp", "col_ord", "filter_col"),
+      columnsFilterKeep = true)
+
+    val expectedFilter2: DataFrame = sqlContext.createDataFrame(
+      Seq(("a", 1, "a"),
+        ("b", 1, "a"))
+    ).toDF("col_grp", "col_ord", "filter_col")
+
+    assertDataFrameNoOrderEquals(expectedFilter2, actual2)
+  }
+
   test("test_dedup2_by_complex_column") {
 
     val actual = inputDataFrame.dedupWithCombiner($"col_grp",
@@ -105,17 +143,43 @@
                                        desc = false)
 
     val expectedComplex: DataFrame = sqlContext.createDataFrame(
-      List(dedupExp2("b", Option(1), "asd4"),
-        dedupExp2("a", Option(3), "asd3")))
+      List(dedupExp2("b", 1, "asd4"),
+        dedupExp2("a", 3, "asd3")))
 
     assertDataFrameEquals(expectedComplex, actual)
   }
 
+  test("test_dedup2_by_multi_column") {
+
+    val df = sqlContext.createDataFrame(
+          Seq(("a", "a", 1, 2, "aa12", "a"),
+            ("a", "a", 1, 1, "aa11", "a"),
+            ("a", "a", 2, 1, "aa21", "a"),
+            ("a", "b", 3, 2, "ab32", "a"),
+            ("b", "a", 1, 1, "ba11", "a"))
+        ).toDF("col_grp1", "col_grp2", "col_ord1", "col_ord2", "col_str", "col_to_ignore")
+
+    val actual = df.dedupWithCombiner(List($"col_grp1", $"col_grp2"),
+                                      List($"col_ord1", $"col_ord2"),
+                                      desc = false,
+                                      columnsFilter = List("col_to_ignore"),
+                                      columnsFilterKeep = false)
+
+    val expectedMulti: DataFrame = sqlContext.createDataFrame(
+      Seq(("a", "a", 1, 1, "aa11"),
+        ("a", "b", 3, 2, "ab32"),
+        ("b", "a", 1, 1, "ba11"))
+      ).toDF("col_grp1", "col_grp2", "col_ord1", "col_ord2", "col_str")
+
+    assertDataFrameNoOrderEquals(expectedMulti, actual)
+
+  }
+
   case class Inner(col_grp: String, col_ord: Int)
 
   case class expComplex(
                          col_grp: String,
-                         col_ord: Option[Int],
+                         col_ord: Int,
                          col_str: String,
                          arr_col: Array[String],
                          struct_col: Inner,
@@ -124,29 +188,33 @@
 
   test("test_dedup2_with_other_complex_column") {
 
-    val actual = inputDataFrame
+    val df = inputDataFrame
       .withColumn("arr_col", expr("array(col_grp, col_ord)"))
       .withColumn("struct_col", expr("struct(col_grp, col_ord)"))
       .withColumn("map_col", expr("map(col_grp, col_ord)"))
       .withColumn("map_col_blah", expr("map(col_grp, col_ord)"))
-      .dedupWithCombiner($"col_grp", expr("cast(concat('-',col_ord) as int)"))
+
+    val schema = df.drop("map_col_blah").schema
+
+    val actual = df.dedupWithCombiner($"col_grp", expr("cast(concat('-',col_ord) as int)"))
       .drop("map_col_blah")
 
     val expected: DataFrame = sqlContext.createDataFrame(
-      List(
+      sqlContext.createDataFrame(
+      Seq(
         expComplex("b",
-                   Option(1),
+                   1,
                    "asd4",
                    Array("b", "1"),
                    Inner("b", 1),
                    Map("b" -> 1)),
         expComplex("a",
-                   Option(1),
+                   1,
                    "asd1",
                    Array("a", "1"),
                    Inner("a", 1),
                    Map("a" -> 1))
-      ))
+      )).rdd, schema)
 
     assertDataFrameEquals(expected, actual)
   }
@@ -218,7 +286,7 @@
 
   val expectedSchemaRangedJoinWithDedup = List(
     StructField("col_grp", StringType, true),
-    StructField("col_ord", IntegerType, true),
+    StructField("col_ord", IntegerType, false),
     StructField("col_str", StringType, true),
     StructField("start", IntegerType, true),
     StructField("end", IntegerType, true),