blob: 3e0810b6057a122936f012d41685d1897271af27 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package datafu.spark
import scala.collection.{mutable, Map}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{ArrayType, _}
// @TODO: add documentation and tests to all the functions. maybe also expose in python.
object SparkUDAFs {
/**
* Like Google's MultiSets.
* Aggregate function that creates a map of key to its count.
*/
class MultiSet() extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType().add("key", StringType)
def bufferSchema: StructType =
new StructType().add("mp", MapType(StringType, IntegerType))
def dataType: DataType = MapType(StringType, IntegerType, false)
def deterministic: Boolean = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = mutable.Map()
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val key = input.getString(0)
if (key != null) {
buffer(0) = buffer.getMap(0) + (key -> (buffer
.getMap(0)
.getOrElse(key, 0) + 1))
}
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val mp = mutable.Map[String, Int]() ++= buffer1.getMap(0)
buffer2
.getMap(0)
.keys
.foreach((key: String) =>
if (key != null) {
mp.put(key,
mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
})
buffer1(0) = mp
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row): Any = {
buffer(0)
}
}
/**
* Essentially the same as MultiSet, but gets an Array for input.
* There is an extra option to limit the number of keys (like CountDistinctUpTo)
*/
class MultiArraySet[T: Ordering](dt: DataType = StringType, maxKeys: Int = -1)
extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType().add("key", ArrayType(dt))
def bufferSchema: StructType = new StructType().add("mp", dataType)
def dataType: DataType = MapType(dt, IntegerType, false)
def deterministic: Boolean = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = mutable.Map()
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val mp = mutable.Map[T, Int]() ++= buffer.getMap(0)
val keyArr: Seq[T] = Option(input.getAs[Seq[T]](0)).getOrElse(Nil)
for (key <- keyArr; if key != null)
mp.put(key, mp.getOrElse(key, 0) + 1)
buffer(0) = limitKeys(mp, 3)
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val mp = mutable.Map[T, Int]() ++= buffer1.getMap(0)
buffer2
.getMap(0)
.keys
.foreach((key: T) =>
if (key != null) {
mp.put(key,
mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
})
buffer1(0) = limitKeys(mp, 3)
}
private def limitKeys(mp: Map[T, Int], factor: Int = 1): Map[T, Int] = {
if (maxKeys > 0 && maxKeys * factor < mp.size) {
val k = mp.toList.map(_.swap).sorted.reverse(maxKeys - 1)._1
var mp2 = mutable.Map[T, Int]() ++= mp.filter((t: (T, Int)) =>
t._2 >= k)
var toRemove = mp2.size - maxKeys
if (toRemove > 0) {
mp2 = mp2.filter((t: (T, Int)) => {
if (t._2 > k) {
true
} else {
if (toRemove >= 0) {
toRemove = toRemove - 1
}
toRemove < 0
}
})
}
mp2
} else {
mp
}
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row): Map[T, Int] = {
limitKeys(buffer.getMap(0).asInstanceOf[Map[T, Int]])
}
}
/**
* Merge maps of kind string -> set<string>
*/
class MapSetMerge extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType().add("key", dataType)
def bufferSchema: StructType = inputSchema
def dataType: DataType = MapType(StringType, ArrayType(StringType))
def deterministic: Boolean = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = mutable.Map()
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val mp0 = input.getMap(0)
if (mp0 != null) {
val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= input
.getMap(0)
buffer(0) =
merge(mp, buffer.getMap[String, mutable.WrappedArray[String]](0))
}
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= buffer1
.getMap(0)
buffer1(0) =
merge(mp, buffer2.getMap[String, mutable.WrappedArray[String]](0))
}
def merge(mpBuffer: mutable.Map[String, mutable.WrappedArray[String]],
mp: Map[String, mutable.WrappedArray[String]])
: mutable.Map[String, mutable.WrappedArray[String]] = {
if (mp != null) {
mp.keys.foreach((key: String) => {
val blah1: mutable.WrappedArray[String] =
mpBuffer.getOrElse(key, mutable.WrappedArray.empty)
val blah2: mutable.WrappedArray[String] =
mp.getOrElse(key, mutable.WrappedArray.empty)
mpBuffer.put(
key,
mutable.WrappedArray.make(
(Option(blah1).getOrElse(mutable.WrappedArray.empty) ++ Option(
blah2).getOrElse(mutable.WrappedArray.empty)).toSet.toArray)
)
})
}
mpBuffer
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row): Any = {
buffer(0)
}
}
/**
* Counts number of distinct records, but only up to a preset amount -
* more efficient than an unbounded count
*/
class CountDistinctUpTo(maxItems: Int = -1)
extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType().add("key", StringType)
def bufferSchema: StructType =
new StructType().add("mp", MapType(StringType, BooleanType))
def dataType: DataType = IntegerType
def deterministic: Boolean = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = mutable.Map()
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (buffer.getMap(0).size < maxItems) {
val key = input.getString(0)
if (key != null) {
buffer(0) = buffer.getMap(0) + (key -> true)
}
}
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (buffer1.getMap(0).size < maxItems) {
val mp = mutable.Map[String, Boolean]() ++= buffer1.getMap(0)
buffer2
.getMap(0)
.keys
.foreach((key: String) =>
if (key != null) {
mp.put(key, true)
})
buffer1(0) = mp
}
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row): Int = {
math.min(buffer.getMap(0).size, maxItems)
}
}
}