DATAFU-165 Added dedupRandomN method and collectLimitedList UDAF functionality

Signed-off-by: Eyal Allweil <eyal@apache.org>
diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
index 04ae90a..5b4f42d 100644
--- a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
@@ -103,6 +103,9 @@
   
     def explodeArray(arrayCol: Column,
                      alias: String) =
-      SparkDFUtils.explodeArray(df, arrayCol, alias)  
+      SparkDFUtils.explodeArray(df, arrayCol, alias)
+
+    def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame =
+      SparkDFUtils.dedupRandomN(df, groupCol, maxSize)
   }
 }
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
index 79b51eb..4fd068f 100644
--- a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
@@ -129,6 +129,10 @@
     )
   }
 
+  def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = {
+    SparkDFUtils.dedupRandomN(df, groupCol, maxSize)
+  }
+
   private def convertJavaListToSeq[T](list: JavaList[T]): Seq[T] = {
     scala.collection.JavaConverters
       .asScalaIteratorConverter(list.iterator())
@@ -550,4 +554,18 @@
     val exprs = (0 until arrSize).map(i => arrayCol.getItem(i).alias(s"$alias$i"))
     df.select((col("*") +: exprs):_*)
   }
+
+  /**
+    * Used get the random n records in each group. Uses an efficient implementation
+    * that doesn't order the data so it can handle large amounts of data.
+    *
+    * @param df        DataFrame to operate on
+    * @param groupCol  column to group by the records
+    * @param maxSize   The maximal number of rows per group
+    * @return DataFrame representing the data after the operation
+    */
+  def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = {
+    df.groupBy(groupCol).agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list"))
+      .select(groupCol,expr("explode(list)"))
+  }
 }
diff --git a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
index 04d68d6..2dcb5be 100644
--- a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
+++ b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
@@ -19,17 +19,23 @@
 package org.apache.spark.sql.datafu.types
 
 import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, DeclarativeAggregate, ImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal}
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
 
+import scala.collection.generic.Growable
+import scala.collection.mutable
+
 object SparkOverwriteUDAFs {
   def minValueByKey(key: Column, value: Column): Column =
     Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false))
   def maxValueByKey(key: Column, value: Column): Column =
     Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false))
+  def collectLimitedList(e: Column, maxSize: Int): Column =
+    Column(CollectLimitedList(e.expr, howMuchToTake = maxSize).toAggregateExpression(false))
 }
 
 case class MinValueByKey(child1: Expression, child2: Expression)
@@ -88,3 +94,54 @@
 
   override lazy val evaluateExpression: AttributeReference = data
 }
+
+/** *
+ *
+ * This code is copied from CollectList, just modified the method it extends
+ * Copied originally from https://github.com/apache/spark/blob/branch-2.3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+ *
+ */
+case class CollectLimitedList(child: Expression,
+                              mutableAggBufferOffset: Int = 0,
+                              inputAggBufferOffset: Int = 0,
+                              howMuchToTake: Int = 10) extends LimitedCollect[mutable.ArrayBuffer[Any]](howMuchToTake) {
+
+  def this(child: Expression) = this(child, 0, 0)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+
+  override def prettyName: String = "collect_limited_list"
+
+}
+
+/** *
+ *
+ * This modifies the collect list / set to keep only howMuchToTake random elements
+ *
+ */
+abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTake: Int) extends Collect[T] with Serializable {
+
+  override def update(buffer: T, input: InternalRow): T = {
+    if (buffer.size < howMuchToTake)
+      super.update(buffer, input)
+    else
+      buffer
+  }
+
+  override def merge(buffer: T, other: T): T = {
+    if (buffer.size == howMuchToTake)
+      buffer
+    else if (other.size == howMuchToTake)
+      other
+    else {
+      val howMuchToTakeFromOtherBuffer = howMuchToTake - buffer.size
+      buffer ++= other.take(howMuchToTakeFromOtherBuffer)
+    }
+  }
+}
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
index e80e71f..aadd059 100644
--- a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
@@ -258,4 +258,26 @@
                             .agg(countDistinctUpTo6($"col_ord").as("col_ord")))
   }
 
+  test("test_limited_collect_list") {
+
+    val maxSize = 10
+
+    val rows = (1 to 30).flatMap(x => (1 to x).map(n => (x, n, "some-string " + n))).toDF("num1", "num2", "str")
+
+    rows.show(10, false)
+
+    import org.apache.spark.sql.functions._
+
+    val result = rows.groupBy("num1").agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list"))
+      .withColumn("list_size", expr("size(list)"))
+
+    result.show(10, false)
+
+    SparkDFUtils.dedupRandomN(rows,$"num1",10).show(10,false)
+
+    val rows_different = result.filter(s"case when num1 > $maxSize then $maxSize else num1 end != list_size")
+
+    Assert.assertEquals(0, rows_different.count())
+
+  }
 }