DATAFU-165 Added dedupRandomN method and collectLimitedList UDAF functionality
Signed-off-by: Eyal Allweil <eyal@apache.org>
diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
index 04ae90a..5b4f42d 100644
--- a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
@@ -103,6 +103,9 @@
def explodeArray(arrayCol: Column,
alias: String) =
- SparkDFUtils.explodeArray(df, arrayCol, alias)
+ SparkDFUtils.explodeArray(df, arrayCol, alias)
+
+ def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame =
+ SparkDFUtils.dedupRandomN(df, groupCol, maxSize)
}
}
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
index 79b51eb..4fd068f 100644
--- a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
@@ -129,6 +129,10 @@
)
}
+ def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = {
+ SparkDFUtils.dedupRandomN(df, groupCol, maxSize)
+ }
+
private def convertJavaListToSeq[T](list: JavaList[T]): Seq[T] = {
scala.collection.JavaConverters
.asScalaIteratorConverter(list.iterator())
@@ -550,4 +554,18 @@
val exprs = (0 until arrSize).map(i => arrayCol.getItem(i).alias(s"$alias$i"))
df.select((col("*") +: exprs):_*)
}
+
+ /**
+ * Used get the random n records in each group. Uses an efficient implementation
+ * that doesn't order the data so it can handle large amounts of data.
+ *
+ * @param df DataFrame to operate on
+ * @param groupCol column to group by the records
+ * @param maxSize The maximal number of rows per group
+ * @return DataFrame representing the data after the operation
+ */
+ def dedupRandomN(df: DataFrame, groupCol: Column, maxSize: Int): DataFrame = {
+ df.groupBy(groupCol).agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list"))
+ .select(groupCol,expr("explode(list)"))
+ }
}
diff --git a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
index 04d68d6..2dcb5be 100644
--- a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
+++ b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
@@ -19,17 +19,23 @@
package org.apache.spark.sql.datafu.types
import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, DeclarativeAggregate, ImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
+import scala.collection.generic.Growable
+import scala.collection.mutable
+
object SparkOverwriteUDAFs {
def minValueByKey(key: Column, value: Column): Column =
Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false))
def maxValueByKey(key: Column, value: Column): Column =
Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false))
+ def collectLimitedList(e: Column, maxSize: Int): Column =
+ Column(CollectLimitedList(e.expr, howMuchToTake = maxSize).toAggregateExpression(false))
}
case class MinValueByKey(child1: Expression, child2: Expression)
@@ -88,3 +94,54 @@
override lazy val evaluateExpression: AttributeReference = data
}
+
+/** *
+ *
+ * This code is copied from CollectList, just modified the method it extends
+ * Copied originally from https://github.com/apache/spark/blob/branch-2.3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+ *
+ */
+case class CollectLimitedList(child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0,
+ howMuchToTake: Int = 10) extends LimitedCollect[mutable.ArrayBuffer[Any]](howMuchToTake) {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+
+ override def prettyName: String = "collect_limited_list"
+
+}
+
+/** *
+ *
+ * This modifies the collect list / set to keep only howMuchToTake random elements
+ *
+ */
+abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTake: Int) extends Collect[T] with Serializable {
+
+ override def update(buffer: T, input: InternalRow): T = {
+ if (buffer.size < howMuchToTake)
+ super.update(buffer, input)
+ else
+ buffer
+ }
+
+ override def merge(buffer: T, other: T): T = {
+ if (buffer.size == howMuchToTake)
+ buffer
+ else if (other.size == howMuchToTake)
+ other
+ else {
+ val howMuchToTakeFromOtherBuffer = howMuchToTake - buffer.size
+ buffer ++= other.take(howMuchToTakeFromOtherBuffer)
+ }
+ }
+}
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
index e80e71f..aadd059 100644
--- a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
@@ -258,4 +258,26 @@
.agg(countDistinctUpTo6($"col_ord").as("col_ord")))
}
+ test("test_limited_collect_list") {
+
+ val maxSize = 10
+
+ val rows = (1 to 30).flatMap(x => (1 to x).map(n => (x, n, "some-string " + n))).toDF("num1", "num2", "str")
+
+ rows.show(10, false)
+
+ import org.apache.spark.sql.functions._
+
+ val result = rows.groupBy("num1").agg(SparkOverwriteUDAFs.collectLimitedList(expr("struct(*)"), maxSize).as("list"))
+ .withColumn("list_size", expr("size(list)"))
+
+ result.show(10, false)
+
+ SparkDFUtils.dedupRandomN(rows,$"num1",10).show(10,false)
+
+ val rows_different = result.filter(s"case when num1 > $maxSize then $maxSize else num1 end != list_size")
+
+ Assert.assertEquals(0, rows_different.count())
+
+ }
}