blob: ec7e645c2d2541cb7a80871e62802d10d783658b [file] [log] [blame]
package org.apache.s2graph.s2jobs.wal.process
import org.apache.s2graph.s2jobs.task.TaskConf
import org.apache.s2graph.s2jobs.wal.WalLogAgg
import org.apache.s2graph.s2jobs.wal.transformer.{DefaultTransformer, Transformer}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import play.api.libs.json.{JsObject, Json}
object FilterTopFeaturesProcess {
private var validFeatureHashKeys: Set[Long] = null
def getValidFeatureHashKeys(validFeatureHashKeysBCast: Broadcast[Array[Long]]): Set[Long] = {
if (validFeatureHashKeys == null) {
validFeatureHashKeys = validFeatureHashKeysBCast.value.toSet
}
validFeatureHashKeys
}
def collectDistinctFeatureHashes(ss: SparkSession,
filteredDict: DataFrame): Array[Long] = {
import ss.implicits._
val featureHashUDF = udf((dim: String, value: String) => WalLogAgg.toFeatureHash(dim, value))
filteredDict.withColumn("featureHash", featureHashUDF(col("dim"), col("value")))
.select("featureHash")
.distinct().as[Long].collect()
}
def filterTopKsPerDim(dict: DataFrame,
maxRankPerDim: Broadcast[Map[String, Int]],
defaultMaxRank: Int): DataFrame = {
val filterUDF = udf((dim: String, rank: Long) => {
rank < maxRankPerDim.value.getOrElse(dim, defaultMaxRank)
})
dict.filter(filterUDF(col("dim"), col("rank")))
}
def filterWalLogAgg(ss: SparkSession,
walLogAgg: Dataset[WalLogAgg],
transformers: Seq[Transformer],
validFeatureHashKeysBCast: Broadcast[Array[Long]]) = {
import ss.implicits._
walLogAgg.mapPartitions { iter =>
val validFeatureHashKeys = getValidFeatureHashKeys(validFeatureHashKeysBCast)
iter.map { walLogAgg =>
WalLogAgg.filterProps(walLogAgg, transformers, validFeatureHashKeys)
}
}
}
}
class FilterTopFeaturesProcess(taskConf: TaskConf) extends org.apache.s2graph.s2jobs.task.Process(taskConf) {
import FilterTopFeaturesProcess._
/*
filter topKs per dim, then build valid dimValLs.
then broadcast valid dimValLs to original dataframe, and filter out not valid dimVal.
*/
override def execute(ss: SparkSession, inputMap: Map[String, DataFrame]): DataFrame = {
import ss.implicits._
val maxRankPerDim = taskConf.options.get("maxRankPerDim").map { s =>
Json.parse(s).as[JsObject].fields.map { case (k, jsValue) =>
k -> jsValue.as[Int]
}.toMap
}
val maxRankPerDimBCast = ss.sparkContext.broadcast(maxRankPerDim.getOrElse(Map.empty))
val defaultMaxRank = taskConf.options.get("defaultMaxRank").map(_.toInt)
val featureDict = inputMap(taskConf.options("featureDict"))
val walLogAgg = inputMap(taskConf.options("walLogAgg")).as[WalLogAgg]
val transformers = TaskConf.parseTransformers(taskConf)
val filteredDict = filterTopKsPerDim(featureDict, maxRankPerDimBCast, defaultMaxRank.getOrElse(Int.MaxValue))
val validFeatureHashKeys = collectDistinctFeatureHashes(ss, filteredDict)
val validFeatureHashKeysBCast = ss.sparkContext.broadcast(validFeatureHashKeys)
filterWalLogAgg(ss, walLogAgg, transformers, validFeatureHashKeysBCast).toDF()
}
override def mandatoryOptions: Set[String] = Set("featureDict", "walLogAgg")
}