Merge branch 'master' into S2GRAPH-226
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/Schema.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/Schema.scala
index 58d3368..1c47f13 100644
--- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/Schema.scala
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/Schema.scala
@@ -22,10 +22,19 @@
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
object Schema {
- val BulkLoadSchema = StructType(Seq(
- StructField("timestamp", LongType, false),
- StructField("operation", StringType, false),
- StructField("elem", StringType, false),
+ /**
+ * root
+ * |-- timestamp: long (nullable = false)
+ * |-- operation: string (nullable = false)
+ * |-- elem: string (nullable = false)
+ */
+ val CommonFields = Seq(
+ StructField("timestamp", LongType, nullable = false),
+ StructField("operation", StringType, nullable = false),
+ StructField("elem", StringType, nullable = false)
+ )
+
+ val BulkLoadSchema = StructType(CommonFields ++ Seq(
StructField("from", StringType, false),
StructField("to", StringType, false),
StructField("label", StringType, false),
@@ -33,24 +42,63 @@
StructField("direction", StringType, true)
))
- val VertexSchema = StructType(Seq(
- StructField("timestamp", LongType, false),
- StructField("operation", StringType, false),
- StructField("elem", StringType, false),
+ /**
+ * root
+ * |-- timestamp: long (nullable = true)
+ * |-- operation: string (nullable = true)
+ * |-- elem: string (nullable = true)
+ * |-- id: string (nullable = true)
+ * |-- service: string (nullable = true)
+ * |-- column: string (nullable = true)
+ * |-- props: string (nullable = true)
+ */
+ val VertexSchema = StructType(CommonFields ++ Seq(
StructField("id", StringType, false),
StructField("service", StringType, false),
StructField("column", StringType, false),
StructField("props", StringType, false)
))
- val EdgeSchema = StructType(Seq(
- StructField("timestamp", LongType, false),
- StructField("operation", StringType, false),
- StructField("elem", StringType, false),
+
+ /**
+ * root
+ * |-- timestamp: long (nullable = true)
+ * |-- operation: string (nullable = true)
+ * |-- elem: string (nullable = true)
+ * |-- from: string (nullable = true)
+ * |-- to: string (nullable = true)
+ * |-- label: string (nullable = true)
+ * |-- props: string (nullable = true)
+ * |-- direction: string (nullable = true)
+ */
+ val EdgeSchema = StructType(CommonFields ++ Seq(
StructField("from", StringType, false),
StructField("to", StringType, false),
StructField("label", StringType, false),
StructField("props", StringType, false),
StructField("direction", StringType, true)
))
+
+ /**
+ * root
+ * |-- timestamp: long (nullable = false)
+ * |-- operation: string (nullable = false)
+ * |-- elem: string (nullable = false)
+ * |-- id: string (nullable = true)
+ * |-- service: string (nullable = true)
+ * |-- column: string (nullable = true)
+ * |-- from: string (nullable = true)
+ * |-- to: string (nullable = true)
+ * |-- label: string (nullable = true)
+ * |-- props: string (nullable = true)
+ */
+ val GraphElementSchema = StructType(CommonFields ++ Seq(
+ StructField("id", StringType, nullable = true),
+ StructField("service", StringType, nullable = true),
+ StructField("column", StringType, nullable = true),
+ StructField("from", StringType, nullable = true),
+ StructField("to", StringType, nullable = true),
+ StructField("label", StringType, nullable = true),
+ StructField("props", StringType, nullable = true)
+ ))
}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
index bfac62b..8e4e234 100644
--- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
@@ -19,12 +19,13 @@
package org.apache.s2graph.s2jobs.task
-import org.apache.s2graph.core.Management
+import org.apache.s2graph.core.{JSONParser, Management}
import org.apache.s2graph.s2jobs.Schema
import org.apache.s2graph.s2jobs.loader.{HFileGenerator, SparkBulkLoaderTransformer}
import org.apache.s2graph.s2jobs.serde.reader.S2GraphCellReader
import org.apache.s2graph.s2jobs.serde.writer.RowDataFrameWriter
import org.apache.spark.sql.{DataFrame, SparkSession}
+import play.api.libs.json.{JsObject, Json}
/**
@@ -98,13 +99,21 @@
val paths = conf.options("paths").split(",")
val format = conf.options.getOrElse("format", DEFAULT_FORMAT)
val columnsOpt = conf.options.get("columns")
+ val readOptions = conf.options.get("read").map { s =>
+ Json.parse(s).as[JsObject].fields.map { case (k, jsValue) =>
+ k -> JSONParser.jsValueToString(jsValue)
+ }.toMap
+ }.getOrElse(Map.empty)
format match {
case "edgeLog" =>
ss.read.format("com.databricks.spark.csv").option("delimiter", "\t")
.schema(BulkLoadSchema).load(paths: _*)
- case _ => ss.read.format(format).load(paths: _*)
- val df = ss.read.format(format).load(paths: _*)
+ case _ =>
+ val df =
+ if (readOptions.isEmpty) ss.read.format(format).load(paths: _*)
+ else ss.read.options(readOptions).format(format).load(paths: _*)
+
if (columnsOpt.isDefined) df.toDF(columnsOpt.get.split(",").map(_.trim): _*) else df
}
}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Task.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Task.scala
index ab02900..62081df 100644
--- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Task.scala
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Task.scala
@@ -21,9 +21,13 @@
import org.apache.s2graph.core.S2GraphConfigs
import org.apache.s2graph.s2jobs.Logger
+import org.apache.s2graph.s2jobs.wal.transformer.Transformer
+import play.api.libs.json.Json
//import org.apache.s2graph.s2jobs.loader.GraphFileOptions
object TaskConf {
+ val Empty = new TaskConf(name = "empty", `type` = "empty", inputs = Nil, options = Map.empty)
+
// def toGraphFileOptions(taskConf: TaskConf): GraphFileOptions = {
// val args = taskConf.options.filterKeys(GraphFileOptions.OptionKeys)
// .flatMap(kv => Seq(kv._1, kv._2)).toSeq.toArray
@@ -42,6 +46,18 @@
def parseLocalCacheConfigs(taskConf: TaskConf): Map[String, Any] = {
taskConf.options.filterKeys(S2GraphConfigs.CacheConfigs.DEFAULTS.keySet).mapValues(_.toInt)
}
+
+ def parseTransformers(taskConf: TaskConf): Seq[Transformer] = {
+ val classes = Json.parse(taskConf.options.getOrElse("transformClasses",
+ """["org.apache.s2graph.s2jobs.wal.transformer.DefaultTransformer"]""")).as[Seq[String]]
+
+ classes.map { className =>
+ Class.forName(className)
+ .getConstructor(classOf[TaskConf])
+ .newInstance(taskConf)
+ .asInstanceOf[Transformer]
+ }
+ }
}
case class TaskConf(name:String, `type`:String, inputs:Seq[String] = Nil, options:Map[String, String] = Map.empty, cache:Option[Boolean]=None)
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/WalLog.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/WalLog.scala
new file mode 100644
index 0000000..ad696a9
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/WalLog.scala
@@ -0,0 +1,208 @@
+package org.apache.s2graph.s2jobs.wal
+
+import com.google.common.hash.Hashing
+import org.apache.s2graph.core.{GraphUtil, JSONParser}
+import org.apache.s2graph.s2jobs.wal.process.params.AggregateParam
+import org.apache.s2graph.s2jobs.wal.transformer.Transformer
+import org.apache.s2graph.s2jobs.wal.utils.BoundedPriorityQueue
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
+import play.api.libs.json.{JsObject, Json}
+
+import scala.util.Try
+
+object WalLogAgg {
+ val outputColumns = Seq("from", "vertices", "edges")
+
+ def isEdge(walLog: WalLog): Boolean = {
+ walLog.elem == "edge" || walLog.elem == "e"
+ }
+
+ def apply(walLog: WalLog): WalLogAgg = {
+ val (vertices, edges) =
+ if (isEdge(walLog)) (Nil, Seq(walLog))
+ else (Seq(walLog), Nil)
+
+ new WalLogAgg(walLog.from, vertices, edges)
+ }
+
+ def toFeatureHash(dimVal: DimVal): Long = toFeatureHash(dimVal.dim, dimVal.value)
+
+ def toFeatureHash(dim: String, value: String): Long = {
+ Hashing.murmur3_128().hashBytes(s"$dim:$value".getBytes("UTF-8")).asLong()
+ }
+
+ private def addToHeap(walLog: WalLog,
+ heap: BoundedPriorityQueue[WalLog],
+ now: Long,
+ validTimestampDuration: Option[Long]): Unit = {
+ val ts = walLog.timestamp
+ val isValid = validTimestampDuration.map(d => now - ts < d).getOrElse(true)
+
+ if (isValid) {
+ heap += walLog
+ }
+ }
+
+ private def addToHeap(iter: Seq[WalLog],
+ heap: BoundedPriorityQueue[WalLog],
+ now: Long,
+ validTimestampDuration: Option[Long]): Unit = {
+ iter.foreach(walLog => addToHeap(walLog, heap, now, validTimestampDuration))
+ }
+
+ private def toWalLogAgg(edgeHeap: BoundedPriorityQueue[WalLog],
+ vertexHeap: BoundedPriorityQueue[WalLog],
+ sortTopItems: Boolean): Option[WalLogAgg] = {
+ val topVertices = if (sortTopItems) vertexHeap.toArray.sortBy(-_.timestamp) else vertexHeap.toArray
+ val topEdges = if (sortTopItems) edgeHeap.toArray.sortBy(-_.timestamp) else edgeHeap.toArray
+
+ topEdges.headOption.map(head => WalLogAgg(head.from, topVertices, topEdges))
+ }
+
+ def mergeWalLogs(iter: Iterator[WalLog],
+ heapSize: Int,
+ now: Long,
+ validTimestampDuration: Option[Long],
+ sortTopItems: Boolean)(implicit ord: Ordering[WalLog]): Option[WalLogAgg] = {
+ val edgeHeap = new BoundedPriorityQueue[WalLog](heapSize)
+ val vertexHeap = new BoundedPriorityQueue[WalLog](heapSize)
+
+ iter.foreach { walLog =>
+ if (walLog.isVertex) addToHeap(walLog, vertexHeap, now, validTimestampDuration)
+ else addToHeap(walLog, edgeHeap, now, validTimestampDuration)
+ }
+
+ toWalLogAgg(edgeHeap, vertexHeap, sortTopItems)
+ }
+
+ def merge(iter: Iterator[WalLogAgg],
+ heapSize: Int,
+ now: Long,
+ validTimestampDuration: Option[Long],
+ sortTopItems: Boolean)(implicit ord: Ordering[WalLog]): Option[WalLogAgg] = {
+ val edgeHeap = new BoundedPriorityQueue[WalLog](heapSize)
+ val vertexHeap = new BoundedPriorityQueue[WalLog](heapSize)
+
+ iter.foreach { walLogAgg =>
+ addToHeap(walLogAgg.vertices, vertexHeap, now, validTimestampDuration)
+ addToHeap(walLogAgg.edges, edgeHeap, now, validTimestampDuration)
+ }
+
+ toWalLogAgg(edgeHeap, vertexHeap, sortTopItems)
+ }
+
+ def mergeWalLogs(iter: Iterator[WalLog],
+ param: AggregateParam)(implicit ord: Ordering[WalLog]): Option[WalLogAgg] = {
+ mergeWalLogs(iter, param.heapSize, param.now, param.validTimestampDuration, param.sortTopItems)
+ }
+
+ def merge(iter: Iterator[WalLogAgg],
+ param: AggregateParam)(implicit ord: Ordering[WalLog]): Option[WalLogAgg] = {
+ merge(iter, param.heapSize, param.now, param.validTimestampDuration, param.sortTopItems)
+ }
+
+
+ private def filterPropsInner(walLogs: Seq[WalLog],
+ transformers: Seq[Transformer],
+ validFeatureHashKeys: Set[Long]): Seq[WalLog] = {
+ walLogs.map { walLog =>
+ val fields = walLog.propsJson.fields.filter { case (propKey, propValue) =>
+ val filtered = transformers.flatMap { transformer =>
+ transformer.toDimValLs(walLog, propKey, JSONParser.jsValueToString(propValue)).filter(dimVal => validFeatureHashKeys(toFeatureHash(dimVal)))
+ }
+ filtered.nonEmpty
+ }
+
+ walLog.copy(props = Json.toJson(fields.toMap).as[JsObject].toString)
+ }
+ }
+
+ def filterProps(walLogAgg: WalLogAgg,
+ transformers: Seq[Transformer],
+ validFeatureHashKeys: Set[Long]) = {
+ val filteredVertices = filterPropsInner(walLogAgg.vertices, transformers, validFeatureHashKeys)
+ val filteredEdges = filterPropsInner(walLogAgg.edges, transformers, validFeatureHashKeys)
+
+ walLogAgg.copy(vertices = filteredVertices, edges = filteredEdges)
+ }
+}
+
+object DimValCountRank {
+ def fromRow(row: Row): DimValCountRank = {
+ val dim = row.getAs[String]("dim")
+ val value = row.getAs[String]("value")
+ val count = row.getAs[Long]("count")
+ val rank = row.getAs[Long]("rank")
+
+ new DimValCountRank(DimVal(dim, value), count, rank)
+ }
+}
+
+case class DimValCountRank(dimVal: DimVal, count: Long, rank: Long)
+
+case class DimValCount(dimVal: DimVal, count: Long)
+
+object DimVal {
+ def fromRow(row: Row): DimVal = {
+ val dim = row.getAs[String]("dim")
+ val value = row.getAs[String]("value")
+
+ new DimVal(dim, value)
+ }
+}
+
+case class DimVal(dim: String, value: String)
+
+case class WalLogAgg(from: String,
+ vertices: Seq[WalLog],
+ edges: Seq[WalLog])
+
+case class WalLog(timestamp: Long,
+ operation: String,
+ elem: String,
+ from: String,
+ to: String,
+ service: String,
+ label: String,
+ props: String) {
+ val isVertex = elem == "v" || elem == "vertex"
+ val id = from
+ val columnName = label
+ val serviceName = to
+ lazy val propsJson = Json.parse(props).as[JsObject]
+ lazy val propsKeyValues = propsJson.fields.map { case (key, jsValue) =>
+ key -> JSONParser.jsValueToString(jsValue)
+ }
+}
+
+object WalLog {
+ val orderByTsAsc = Ordering.by[WalLog, Long](walLog => walLog.timestamp)
+
+ val WalLogSchema = StructType(Seq(
+ StructField("timestamp", LongType, false),
+ StructField("operation", StringType, false),
+ StructField("elem", StringType, false),
+ StructField("from", StringType, false),
+ StructField("to", StringType, false),
+ StructField("service", StringType, true),
+ StructField("label", StringType, false),
+ StructField("props", StringType, false)
+ // StructField("direction", StringType, true)
+ ))
+
+ def fromRow(row: Row): WalLog = {
+ val timestamp = row.getAs[Long]("timestamp")
+ val operation = Try(row.getAs[String]("operation")).toOption.getOrElse("insert")
+ val elem = Try(row.getAs[String]("elem")).toOption.getOrElse("edge")
+ val from = row.getAs[String]("from")
+ val to = row.getAs[String]("to")
+ val service = row.getAs[String]("service")
+ val label = row.getAs[String]("label")
+ val props = Try(row.getAs[String]("props")).toOption.getOrElse("{}")
+
+ WalLog(timestamp, operation, elem, from, to, service, label, props)
+ }
+
+
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/BuildTopFeaturesProcess.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/BuildTopFeaturesProcess.scala
new file mode 100644
index 0000000..048b470
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/BuildTopFeaturesProcess.scala
@@ -0,0 +1,126 @@
+package org.apache.s2graph.s2jobs.wal.process
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.WalLogAgg.toFeatureHash
+import org.apache.s2graph.s2jobs.wal.process.params.BuildTopFeaturesParam
+import org.apache.s2graph.s2jobs.wal.transformer._
+import org.apache.s2graph.s2jobs.wal.udfs.WalLogUDF
+import org.apache.s2graph.s2jobs.wal.{DimVal, DimValCount, WalLog, WalLogAgg}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import play.api.libs.json.{JsObject, Json}
+
+import scala.collection.mutable
+
+object BuildTopFeaturesProcess {
+ def extractDimValuesWithCount(transformers: Seq[Transformer]) = {
+ udf((rows: Seq[Row]) => {
+ val logs = rows.map(WalLog.fromRow)
+ val dimValCounts = mutable.Map.empty[DimVal, Int]
+
+ logs.foreach { walLog =>
+ walLog.propsKeyValues.foreach { case (propsKey, propsValue) =>
+ transformers.foreach { transformer =>
+ transformer.toDimValLs(walLog, propsKey, propsValue).foreach { dimVal =>
+ val newCount = dimValCounts.getOrElse(dimVal, 0) + 1
+ dimValCounts += (dimVal -> newCount)
+ }
+ }
+ }
+ }
+
+ dimValCounts.toSeq.sortBy(-_._2)map { case (dimVal, count) =>
+ DimValCount(dimVal, count)
+ }
+ })
+ }
+
+ def extractDimValues(transformers: Seq[Transformer]) = {
+ udf((rows: Seq[Row]) => {
+ val logs = rows.map(WalLog.fromRow)
+ // TODO: this can be changed into Map to count how many times each dimVal exist in sequence of walLog
+ // then change this to mutable.Map.empty[DimVal, Int], then aggregate.
+ val distinctDimValues = mutable.Set.empty[DimVal]
+
+ logs.foreach { walLog =>
+ walLog.propsKeyValues.foreach { case (propsKey, propsValue) =>
+ transformers.foreach { transformer =>
+ transformer.toDimValLs(walLog, propsKey, propsValue).foreach { dimVal =>
+ distinctDimValues += dimVal
+ }
+ }
+ }
+ }
+
+ distinctDimValues.toSeq
+ })
+ }
+
+ def buildDictionary(ss: SparkSession,
+ allDimVals: DataFrame,
+ param: BuildTopFeaturesParam,
+ dimValColumnName: String = "dimVal"): DataFrame = {
+ import ss.implicits._
+
+ val rawFeatures = allDimVals
+ .select(col(param._countColumnName), col(s"$dimValColumnName.dim").as("dim"), col(s"$dimValColumnName.value").as("value"))
+ .groupBy("dim", "value")
+ .agg(countDistinct(param._countColumnName).as("count"))
+ .filter(s"count > ${param._minUserCount}")
+
+ val ds: Dataset[((String, Long), String)] =
+ rawFeatures.select("dim", "value", "count").as[(String, String, Long)]
+ .map { case (dim, value, uv) =>
+ (dim, uv) -> value
+ }
+
+
+ implicit val ord = Ordering.Tuple2(Ordering.String, Ordering.Long.reverse)
+
+ val rdd: RDD[(Long, (String, Long), String)] = WalLogUDF.appendRank(ds, param.numOfPartitions, param.samplePointsPerPartitionHint)
+
+ rdd.toDF("rank", "dim_count", "value")
+ .withColumn("dim", col("dim_count._1"))
+ .withColumn("count", col("dim_count._2"))
+ .select("dim", "value", "count", "rank")
+ }
+}
+
+case class BuildTopFeaturesProcess(taskConf: TaskConf) extends org.apache.s2graph.s2jobs.task.Process(taskConf) {
+
+ import BuildTopFeaturesProcess._
+
+ override def execute(ss: SparkSession, inputMap: Map[String, DataFrame]): DataFrame = {
+ val countColumnName = taskConf.options.getOrElse("countColumnName", "from")
+ val numOfPartitions = taskConf.options.get("numOfPartitions").map(_.toInt)
+ val samplePointsPerPartitionHint = taskConf.options.get("samplePointsPerPartitionHint").map(_.toInt)
+ val minUserCount = taskConf.options.get("minUserCount").map(_.toLong)
+
+ numOfPartitions.map { d => ss.sqlContext.setConf("spark.sql.shuffle.partitions", d.toString) }
+
+ val param = BuildTopFeaturesParam(minUserCount = minUserCount, countColumnName = Option(countColumnName),
+ numOfPartitions = numOfPartitions, samplePointsPerPartitionHint = samplePointsPerPartitionHint
+ )
+
+ val edges = taskConf.inputs.tail.foldLeft(inputMap(taskConf.inputs.head)) { case (prev, cur) =>
+ prev.union(inputMap(cur))
+ }
+
+ //TODO: user expect to inject transformers that transform (WalLog, propertyKey, propertyValue) to Seq[DimVal].
+ val transformers = TaskConf.parseTransformers(taskConf)
+ val dimValExtractUDF = extractDimValues(transformers)
+ val dimValColumnName = "dimVal"
+
+ val rawFeatures = edges
+ .withColumn(dimValColumnName, explode(dimValExtractUDF(col("logs"))))
+
+ val dict = buildDictionary(ss, rawFeatures, param, dimValColumnName)
+
+ dict
+ }
+
+
+ override def mandatoryOptions: Set[String] = Set.empty
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/FilterTopFeaturesProcess.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/FilterTopFeaturesProcess.scala
new file mode 100644
index 0000000..ec7e645
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/FilterTopFeaturesProcess.scala
@@ -0,0 +1,90 @@
+package org.apache.s2graph.s2jobs.wal.process
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.WalLogAgg
+import org.apache.s2graph.s2jobs.wal.transformer.{DefaultTransformer, Transformer}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import play.api.libs.json.{JsObject, Json}
+
+object FilterTopFeaturesProcess {
+ private var validFeatureHashKeys: Set[Long] = null
+ def getValidFeatureHashKeys(validFeatureHashKeysBCast: Broadcast[Array[Long]]): Set[Long] = {
+ if (validFeatureHashKeys == null) {
+ validFeatureHashKeys = validFeatureHashKeysBCast.value.toSet
+ }
+
+ validFeatureHashKeys
+ }
+
+ def collectDistinctFeatureHashes(ss: SparkSession,
+ filteredDict: DataFrame): Array[Long] = {
+ import ss.implicits._
+
+ val featureHashUDF = udf((dim: String, value: String) => WalLogAgg.toFeatureHash(dim, value))
+
+ filteredDict.withColumn("featureHash", featureHashUDF(col("dim"), col("value")))
+ .select("featureHash")
+ .distinct().as[Long].collect()
+ }
+
+ def filterTopKsPerDim(dict: DataFrame,
+ maxRankPerDim: Broadcast[Map[String, Int]],
+ defaultMaxRank: Int): DataFrame = {
+ val filterUDF = udf((dim: String, rank: Long) => {
+ rank < maxRankPerDim.value.getOrElse(dim, defaultMaxRank)
+ })
+
+ dict.filter(filterUDF(col("dim"), col("rank")))
+ }
+
+ def filterWalLogAgg(ss: SparkSession,
+ walLogAgg: Dataset[WalLogAgg],
+ transformers: Seq[Transformer],
+ validFeatureHashKeysBCast: Broadcast[Array[Long]]) = {
+ import ss.implicits._
+ walLogAgg.mapPartitions { iter =>
+ val validFeatureHashKeys = getValidFeatureHashKeys(validFeatureHashKeysBCast)
+
+ iter.map { walLogAgg =>
+ WalLogAgg.filterProps(walLogAgg, transformers, validFeatureHashKeys)
+ }
+ }
+ }
+}
+
+class FilterTopFeaturesProcess(taskConf: TaskConf) extends org.apache.s2graph.s2jobs.task.Process(taskConf) {
+
+ import FilterTopFeaturesProcess._
+
+ /*
+ filter topKs per dim, then build valid dimValLs.
+ then broadcast valid dimValLs to original dataframe, and filter out not valid dimVal.
+ */
+ override def execute(ss: SparkSession, inputMap: Map[String, DataFrame]): DataFrame = {
+ import ss.implicits._
+
+ val maxRankPerDim = taskConf.options.get("maxRankPerDim").map { s =>
+ Json.parse(s).as[JsObject].fields.map { case (k, jsValue) =>
+ k -> jsValue.as[Int]
+ }.toMap
+ }
+ val maxRankPerDimBCast = ss.sparkContext.broadcast(maxRankPerDim.getOrElse(Map.empty))
+
+ val defaultMaxRank = taskConf.options.get("defaultMaxRank").map(_.toInt)
+
+ val featureDict = inputMap(taskConf.options("featureDict"))
+ val walLogAgg = inputMap(taskConf.options("walLogAgg")).as[WalLogAgg]
+
+ val transformers = TaskConf.parseTransformers(taskConf)
+
+ val filteredDict = filterTopKsPerDim(featureDict, maxRankPerDimBCast, defaultMaxRank.getOrElse(Int.MaxValue))
+ val validFeatureHashKeys = collectDistinctFeatureHashes(ss, filteredDict)
+ val validFeatureHashKeysBCast = ss.sparkContext.broadcast(validFeatureHashKeys)
+
+ filterWalLogAgg(ss, walLogAgg, transformers, validFeatureHashKeysBCast).toDF()
+ }
+
+ override def mandatoryOptions: Set[String] = Set("featureDict", "walLogAgg")
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/WalLogAggregateProcess.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/WalLogAggregateProcess.scala
new file mode 100644
index 0000000..e4aa4e1
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/WalLogAggregateProcess.scala
@@ -0,0 +1,60 @@
+package org.apache.s2graph.s2jobs.wal.process
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.process.params.AggregateParam
+import org.apache.s2graph.s2jobs.wal.{WalLog, WalLogAgg}
+import org.apache.spark.sql._
+
+object WalLogAggregateProcess {
+ def aggregate(ss: SparkSession,
+ dataset: Dataset[WalLogAgg],
+ aggregateParam: AggregateParam)(implicit ord: Ordering[WalLog]) = {
+ import ss.implicits._
+ dataset.groupByKey(_.from).flatMapGroups { case (_, iter) =>
+ WalLogAgg.merge(iter, aggregateParam)
+ }.toDF(WalLogAgg.outputColumns: _*)
+ }
+
+ def aggregateRaw(ss: SparkSession,
+ dataset: Dataset[WalLog],
+ aggregateParam: AggregateParam)(implicit ord: Ordering[WalLog]): DataFrame = {
+ import ss.implicits._
+
+ dataset.groupByKey(walLog => walLog.from).flatMapGroups { case (key, iter) =>
+ WalLogAgg.mergeWalLogs(iter, aggregateParam)
+ }.toDF(WalLogAgg.outputColumns: _*)
+ }
+}
+
+
+/**
+ * expect DataFrame of WalLog, then group WalLog by groupByKeys(default from).
+ * produce DataFrame of WalLogAgg which abstract the session consists of sequence of WalLog ordered by timestamp(desc).
+ *
+ * one way to visualize this is that transforming (row, column, value) matrix entries into (row, Sparse Vector(column:value).
+ * note that we only keep track of max topK latest walLog per each groupByKeys
+ */
+class WalLogAggregateProcess(taskConf: TaskConf) extends org.apache.s2graph.s2jobs.task.Process(taskConf) {
+
+ import WalLogAggregateProcess._
+
+ override def execute(ss: SparkSession, inputMap: Map[String, DataFrame]): DataFrame = {
+ import ss.implicits._
+
+ //TODO: Current implementation only expect taskConf.options as Map[String, String].
+ //Once we change taskConf.options as JsObject, then we can simply parse input paramter as following.
+ //implicit val paramReads = Json.reads[AggregateParam]
+ val param = AggregateParam.fromTaskConf(taskConf)
+ param.numOfPartitions.foreach(d => ss.sqlContext.setConf("spark.sql.shuffle.partitions", d.toString))
+
+ implicit val ord = WalLog.orderByTsAsc
+ val walLogs = taskConf.inputs.tail.foldLeft(inputMap(taskConf.inputs.head)) { case (prev, cur) =>
+ prev.union(inputMap(cur))
+ }
+
+ if (param.arrayType) aggregate(ss, walLogs.as[WalLogAgg], param)
+ else aggregateRaw(ss, walLogs.as[WalLog], param)
+ }
+
+ override def mandatoryOptions: Set[String] = Set.empty
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/AggregateParam.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/AggregateParam.scala
new file mode 100644
index 0000000..523ec8e
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/AggregateParam.scala
@@ -0,0 +1,46 @@
+package org.apache.s2graph.s2jobs.wal.process.params
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+
+object AggregateParam {
+ val defaultGroupByKeys = Seq("from")
+ val defaultTopK = 1000
+ val defaultIsArrayType = false
+ val defaultShouldSortTopItems = true
+
+ def fromTaskConf(taskConf: TaskConf): AggregateParam = {
+ val groupByKeys = taskConf.options.get("groupByKeys").map(_.split(",").filter(_.nonEmpty).toSeq)
+ val maxNumOfEdges = taskConf.options.get("maxNumOfEdges").map(_.toInt).getOrElse(defaultTopK)
+ val arrayType = taskConf.options.get("arrayType").map(_.toBoolean).getOrElse(defaultIsArrayType)
+ val sortTopItems = taskConf.options.get("sortTopItems").map(_.toBoolean).getOrElse(defaultShouldSortTopItems)
+ val numOfPartitions = taskConf.options.get("numOfPartitions").map(_.toInt)
+ val validTimestampDuration = taskConf.options.get("validTimestampDuration").map(_.toLong).getOrElse(Long.MaxValue)
+ val nowOpt = taskConf.options.get("now").map(_.toLong)
+
+ new AggregateParam(groupByKeys = groupByKeys,
+ topK = Option(maxNumOfEdges),
+ isArrayType = Option(arrayType),
+ shouldSortTopItems = Option(sortTopItems),
+ numOfPartitions = numOfPartitions,
+ validTimestampDuration = Option(validTimestampDuration),
+ nowOpt = nowOpt
+ )
+ }
+}
+
+case class AggregateParam(groupByKeys: Option[Seq[String]],
+ topK: Option[Int],
+ isArrayType: Option[Boolean],
+ shouldSortTopItems: Option[Boolean],
+ numOfPartitions: Option[Int],
+ validTimestampDuration: Option[Long],
+ nowOpt: Option[Long]) {
+
+ import AggregateParam._
+
+ val groupByColumns = groupByKeys.getOrElse(defaultGroupByKeys)
+ val heapSize = topK.getOrElse(defaultTopK)
+ val arrayType = isArrayType.getOrElse(defaultIsArrayType)
+ val sortTopItems = shouldSortTopItems.getOrElse(defaultShouldSortTopItems)
+ val now = nowOpt.getOrElse(System.currentTimeMillis())
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/BuildTopFeaturesParam.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/BuildTopFeaturesParam.scala
new file mode 100644
index 0000000..3e8bae5
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/BuildTopFeaturesParam.scala
@@ -0,0 +1,17 @@
+package org.apache.s2graph.s2jobs.wal.process.params
+
+object BuildTopFeaturesParam {
+ val defaultMinUserCount = 0L
+ val defaultCountColumnName = "from"
+}
+
+case class BuildTopFeaturesParam(minUserCount: Option[Long],
+ countColumnName: Option[String],
+ samplePointsPerPartitionHint: Option[Int],
+ numOfPartitions: Option[Int]) {
+
+ import BuildTopFeaturesParam._
+
+ val _countColumnName = countColumnName.getOrElse(defaultCountColumnName)
+ val _minUserCount = minUserCount.getOrElse(defaultMinUserCount)
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/FilterTopFeaturesParam.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/FilterTopFeaturesParam.scala
new file mode 100644
index 0000000..3b9d868
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/process/params/FilterTopFeaturesParam.scala
@@ -0,0 +1,7 @@
+package org.apache.s2graph.s2jobs.wal.process.params
+
+class FilterTopFeaturesParam(maxRankPerDim: Option[Map[String, Int]],
+ defaultMaxRank: Option[Int]) {
+ val _maxRankPerDim = maxRankPerDim.getOrElse(Map.empty)
+ val _defaultMaxRank = defaultMaxRank.getOrElse(Int.MaxValue)
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/DefaultTransformer.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/DefaultTransformer.scala
new file mode 100644
index 0000000..de328e5
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/DefaultTransformer.scala
@@ -0,0 +1,5 @@
+package org.apache.s2graph.s2jobs.wal.transformer
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+
+case class DefaultTransformer(taskConf: TaskConf) extends Transformer(taskConf)
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/ExtractDomain.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/ExtractDomain.scala
new file mode 100644
index 0000000..45bbe66
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/ExtractDomain.scala
@@ -0,0 +1,23 @@
+package org.apache.s2graph.s2jobs.wal.transformer
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.utils.UrlUtils
+import org.apache.s2graph.s2jobs.wal.{DimVal, WalLog}
+import play.api.libs.json.Json
+
+case class ExtractDomain(taskConf: TaskConf) extends Transformer(taskConf) {
+ val urlDimensions = Json.parse(taskConf.options.getOrElse("urlDimensions", "[]")).as[Set[String]]
+ val hostDimName = taskConf.options.getOrElse("hostDimName", "host")
+ val domainDimName= taskConf.options.getOrElse("domainDimName", "domain")
+ val keywordDimName = taskConf.options.getOrElse("keywordDimName", "uri_keywords")
+ override def toDimValLs(walLog: WalLog, propertyKey: String, propertyValue: String): Seq[DimVal] = {
+ if (!urlDimensions(propertyKey)) Nil
+ else {
+ val (_, domains, kwdOpt) = UrlUtils.extract(propertyValue)
+
+ domains.headOption.toSeq.map(DimVal(hostDimName, _)) ++
+ domains.map(DimVal(domainDimName, _)) ++
+ kwdOpt.toSeq.map(DimVal(keywordDimName, _))
+ }
+ }
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/ExtractServiceName.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/ExtractServiceName.scala
new file mode 100644
index 0000000..15efe9f
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/ExtractServiceName.scala
@@ -0,0 +1,24 @@
+package org.apache.s2graph.s2jobs.wal.transformer
+
+import org.apache.s2graph.core.JSONParser
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.{DimVal, WalLog}
+import play.api.libs.json.{JsObject, Json}
+
+class ExtractServiceName(taskConf: TaskConf) extends Transformer(taskConf) {
+ val serviceDims = Json.parse(taskConf.options.getOrElse("serviceDims", "[]")).as[Set[String]]
+ val domainServiceMap = Json.parse(taskConf.options.getOrElse("domainServiceMap", "{}")).as[JsObject].fields.map { case (k, v) =>
+ k -> JSONParser.jsValueToString(v)
+ }.toMap
+ val serviceDimName = taskConf.options.getOrElse("serviceDimName", "serviceDimName")
+
+ override def toDimValLs(walLog: WalLog, propertyKey: String, propertyValue: String): Seq[DimVal] = {
+ if (!serviceDims(propertyKey)) Nil
+ else {
+ val serviceName = domainServiceMap.getOrElse(propertyValue, propertyValue)
+
+ Seq(DimVal(serviceDimName, serviceName))
+ }
+ }
+}
+
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/Transformer.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/Transformer.scala
new file mode 100644
index 0000000..68be2ad
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/transformer/Transformer.scala
@@ -0,0 +1,16 @@
+package org.apache.s2graph.s2jobs.wal.transformer
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.{DimVal, WalLog}
+
+/**
+ * decide how to transform walLog's each property key value to Seq[DimVal]
+ */
+abstract class Transformer(taskConf: TaskConf) extends Serializable {
+ def toDimValLs(walLog: WalLog, propertyKey: String, propertyValue: String): Seq[DimVal] = {
+ val dim = s"${walLog.label}:${propertyKey}"
+ val value = propertyValue
+
+ Seq(DimVal(dim, value))
+ }
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udafs/WalLogUDAF.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udafs/WalLogUDAF.scala
new file mode 100644
index 0000000..81a1356
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udafs/WalLogUDAF.scala
@@ -0,0 +1,290 @@
+package org.apache.s2graph.s2jobs.wal.udafs
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.types._
+
+import scala.annotation.tailrec
+import scala.collection.mutable
+
+object WalLogUDAF {
+ type Element = (Long, String, String, String)
+
+ val emptyRow = new GenericRow(Array(-1L, "empty", "empty", "empty"))
+
+ val elementOrd = Ordering.by[Element, Long](_._1)
+
+ val rowOrdering = new Ordering[Row] {
+ override def compare(x: Row, y: Row): Int = {
+ x.getAs[Long](0).compareTo(y.getAs[Long](0))
+ }
+ }
+
+ val rowOrderingDesc = new Ordering[Row] {
+ override def compare(x: Row, y: Row): Int = {
+ -x.getAs[Long](0).compareTo(y.getAs[Long](0))
+ }
+ }
+
+ val fields = Seq(
+ StructField(name = "timestamp", LongType),
+ StructField(name = "to", StringType),
+ StructField(name = "label", StringType),
+ StructField(name = "props", StringType)
+ )
+
+ val arrayType = ArrayType(elementType = StructType(fields))
+
+ def apply(maxNumOfEdges: Int = 1000): GroupByAggOptimized = {
+ new GroupByAggOptimized(maxNumOfEdges)
+ }
+
+ def swap[T](array: mutable.Seq[T], i: Int, j: Int) = {
+ val tmp = array(i)
+ array(i) = array(j)
+ array(j) = tmp
+ }
+
+ @tailrec
+ def percolateDown[T](array: mutable.Seq[T], idx: Int)(implicit ordering: Ordering[T]): Unit = {
+ val left = 2 * idx + 1
+ val right = 2 * idx + 2
+ var smallest = idx
+
+ if (left < array.size && ordering.compare(array(left), array(smallest)) < 0) {
+ smallest = left
+ }
+
+ if (right < array.size && ordering.compare(array(right), array(smallest)) < 0) {
+ smallest = right
+ }
+
+ if (smallest != idx) {
+ swap(array, idx, smallest)
+ percolateDown(array, smallest)
+ }
+ }
+
+ def percolateUp[T](array: mutable.Seq[T],
+ idx: Int)(implicit ordering: Ordering[T]): Unit = {
+ var pos = idx
+ var parent = (pos - 1) / 2
+ while (parent >= 0 && ordering.compare(array(pos), array(parent)) < 0) {
+ // swap pos and parent, since a[parent] > array[pos]
+ swap(array, parent, pos)
+ pos = parent
+ parent = (pos - 1) / 2
+ }
+ }
+
+ def addToTopK[T](array: mutable.Seq[T],
+ size: Int,
+ newData: T)(implicit ordering: Ordering[T]): mutable.Seq[T] = {
+ // use array as minHeap to keep track of topK.
+ // parent = (i -1) / 2
+ // left child = 2 * i + 1
+ // right chiud = 2 * i + 2
+
+ // check if array is already full.
+ if (array.size >= size) {
+ // compare newData to min. newData < array(0)
+ val currentMin = array(0)
+ if (ordering.compare(newData, currentMin) < 0) {
+ // drop newData
+ } else {
+ // delete min
+ array(0) = newData
+ // percolate down
+ percolateDown(array, 0)
+ }
+ array
+ } else {
+ // append new element into seqeunce since there are room left.
+ val newArray = array :+ newData
+ val idx = newArray.size - 1
+ // percolate up last element
+ percolateUp(newArray, idx)
+ newArray
+ }
+ }
+
+ def mergeTwoSeq[T](prev: Seq[T], cur: Seq[T], size: Int)(implicit ordering: Ordering[T]): Seq[T] = {
+ import scala.collection.mutable
+ val (n, m) = (cur.size, prev.size)
+
+ var (i, j) = (0, 0)
+ var idx = 0
+ val arr = new mutable.ArrayBuffer[T](size)
+
+ while (idx < size && i < n && j < m) {
+ if (ordering.compare(cur(i), prev(j)) < 0) {
+ arr += cur(i)
+ i += 1
+ } else {
+ arr += prev(j)
+ j += 1
+ }
+ idx += 1
+ }
+ while (idx < size && i < n) {
+ arr += cur(i)
+ i += 1
+ }
+ while (idx < size && j < m) {
+ arr += prev(j)
+ j += 1
+ }
+
+ arr
+ }
+}
+
+class GroupByAggOptimized(maxNumOfEdges: Int = 1000) extends UserDefinedAggregateFunction {
+
+ import WalLogUDAF._
+
+ implicit val ord = rowOrdering
+
+ val arrayType = ArrayType(elementType = StructType(fields))
+
+ type ROWS = mutable.Seq[Row]
+
+ override def inputSchema: StructType = StructType(fields)
+
+ override def bufferSchema: StructType = StructType(Seq(
+ StructField(name = "edges", dataType = arrayType)
+ ))
+
+ override def dataType: DataType = arrayType
+
+ override def deterministic: Boolean = true
+
+ override def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer.update(0, mutable.ArrayBuffer.empty[Row])
+ }
+
+ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ val prev = buffer.getAs[ROWS](0)
+
+ val updated = addToTopK(prev, maxNumOfEdges, input)
+
+ buffer.update(0, updated)
+ }
+
+ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ var prev = buffer1.getAs[ROWS](0)
+ val cur = buffer2.getAs[ROWS](0)
+
+ cur.filter(_ != null).foreach { row =>
+ prev = addToTopK(prev, maxNumOfEdges, row)
+ }
+
+ buffer1.update(0, prev)
+ }
+
+ override def evaluate(buffer: Row): Any = {
+ val ls = buffer.getAs[ROWS](0)
+ takeTopK(ls, maxNumOfEdges)
+ }
+
+ private def takeTopK(ls: Seq[Row], k: Int) = {
+ val sorted = ls.sorted
+ if (sorted.size <= k) sorted else sorted.take(k)
+ }
+}
+
+class GroupByAgg(maxNumOfEdges: Int = 1000) extends UserDefinedAggregateFunction {
+ import WalLogUDAF._
+
+ implicit val ord = rowOrderingDesc
+
+ val arrayType = ArrayType(elementType = StructType(fields))
+
+ override def inputSchema: StructType = StructType(fields)
+
+ override def bufferSchema: StructType = StructType(Seq(
+ StructField(name = "edges", dataType = arrayType),
+ StructField(name = "buffered", dataType = BooleanType)
+ ))
+
+ override def dataType: DataType = arrayType
+
+ override def deterministic: Boolean = true
+
+ override def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer.update(0, scala.collection.mutable.ListBuffer.empty[Element])
+ }
+
+ /* not optimized */
+ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ val element = input
+
+ val prev = buffer.getAs[Seq[Row]](0)
+ val appended = prev :+ element
+
+ buffer.update(0, appended)
+ buffer.update(1, false)
+ }
+
+ private def takeTopK(ls: Seq[Row], k: Int) = {
+ val sorted = ls.sorted
+ if (sorted.size <= k) sorted else sorted.take(k)
+ }
+ /* not optimized */
+ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ val cur = buffer2.getAs[Seq[Row]](0)
+ val prev = buffer1.getAs[Seq[Row]](0)
+
+ buffer1.update(0, takeTopK(prev ++ cur, maxNumOfEdges))
+ buffer1.update(1, true)
+ }
+
+ override def evaluate(buffer: Row): Any = {
+ val ls = buffer.getAs[Seq[Row]](0)
+ val buffered = buffer.getAs[Boolean](1)
+ if (buffered) ls
+ else takeTopK(ls, maxNumOfEdges)
+ }
+}
+
+class GroupByArrayAgg(maxNumOfEdges: Int = 1000) extends UserDefinedAggregateFunction {
+ import WalLogUDAF._
+
+ implicit val ord = rowOrdering
+
+ import scala.collection.mutable
+
+ override def inputSchema: StructType = StructType(Seq(
+ StructField(name = "edges", dataType = arrayType)
+ ))
+
+ override def bufferSchema: StructType = StructType(Seq(
+ StructField(name = "edges", dataType = arrayType)
+ ))
+
+ override def dataType: DataType = arrayType
+
+ override def deterministic: Boolean = true
+
+ override def initialize(buffer: MutableAggregationBuffer): Unit =
+ buffer.update(0, mutable.ListBuffer.empty[Row])
+
+ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ val cur = input.getAs[Seq[Row]](0)
+ val prev = buffer.getAs[Seq[Row]](0)
+ val merged = mergeTwoSeq(cur, prev, maxNumOfEdges)
+
+ buffer.update(0, merged)
+ }
+
+ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ val cur = buffer2.getAs[Seq[Row]](0)
+ val prev = buffer1.getAs[Seq[Row]](0)
+
+ val merged = mergeTwoSeq(cur, prev, maxNumOfEdges)
+ buffer1.update(0, merged)
+ }
+
+ override def evaluate(buffer: Row): Any = buffer.getAs[Seq[Row]](0)
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udfs/WalLogUDF.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udfs/WalLogUDF.scala
new file mode 100644
index 0000000..d213d6c
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udfs/WalLogUDF.scala
@@ -0,0 +1,210 @@
+package org.apache.s2graph.s2jobs.wal.udfs
+
+import com.google.common.hash.Hashing
+import org.apache.s2graph.core.JSONParser
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.{Dataset, Row}
+import play.api.libs.json._
+
+import scala.reflect.ClassTag
+
+object WalLogUDF {
+
+ import scala.collection.mutable
+
+ type MergedProps = Map[String, Seq[String]]
+ type MutableMergedProps = mutable.Map[String, mutable.Map[String, Int]]
+ type MutableMergedPropsInner = mutable.Map[String, Int]
+
+ def initMutableMergedPropsInner = mutable.Map.empty[String, Int]
+
+ def initMutableMergedProps = mutable.Map.empty[String, mutable.Map[String, Int]]
+
+// //TODO:
+// def toDimension(rawActivity: RawActivity, propertyKey: String): String = {
+// // val (ts, dst, label, _) = rawActivity
+// // label + "." + propertyKey
+// propertyKey
+// }
+//
+// def updateMutableMergedProps(mutableMergedProps: MutableMergedProps)(dimension: String,
+// dimensionValue: String,
+// count: Int = 1): Unit = {
+// val buffer = mutableMergedProps.getOrElseUpdate(dimension, initMutableMergedPropsInner)
+// val newCount = buffer.getOrElse(dimensionValue, 0) + count
+// buffer += (dimensionValue -> newCount)
+// }
+//
+// def groupByDimensionValues(rawActivity: RawActivity,
+// propsJson: JsObject,
+// mergedProps: MutableMergedProps,
+// toDimensionFunc: (RawActivity, String) => String,
+// excludePropKeys: Set[String] = Set.empty): Unit = {
+// propsJson.fields.filter(t => !excludePropKeys(t._1)).foreach { case (propertyKey, jsValue) =>
+// val values = jsValue match {
+// case JsString(s) => Seq(s)
+// case JsArray(arr) => arr.map(JSONParser.jsValueToString)
+// case _ => Seq(jsValue.toString())
+// }
+// val dimension = toDimensionFunc(rawActivity, propertyKey)
+//
+// values.foreach { value =>
+// updateMutableMergedProps(mergedProps)(dimension, value)
+// }
+// }
+// }
+//
+// def buildMergedProps(rawActivities: Seq[RawActivity],
+// toDimensionFunc: (RawActivity, String) => String,
+// defaultTopKs: Int = 100,
+// dimTopKs: Map[String, Int] = Map.empty,
+// excludePropKeys: Set[String] = Set.empty,
+// dimValExtractors: Seq[Extractor] = Nil): MergedProps = {
+// val mergedProps = initMutableMergedProps
+//
+// rawActivities.foreach { case rawActivity@(_, _, _, rawProps) =>
+// val propsJson = Json.parse(rawProps).as[JsObject]
+// groupByDimensionValues(rawActivity, propsJson, mergedProps, toDimensionFunc, excludePropKeys)
+// }
+// // work on extra dimVals.
+// dimValExtractors.foreach { extractor =>
+// extractor.extract(rawActivities, mergedProps)
+// }
+//
+// mergedProps.map { case (key, values) =>
+// val topK = dimTopKs.getOrElse(key, defaultTopKs)
+//
+// key -> values.toSeq.sortBy(-_._2).take(topK).map(_._1)
+// }.toMap
+// }
+//
+// def rowToRawActivity(row: Row): RawActivity = {
+// (row.getAs[Long](0), row.getAs[String](1), row.getAs[String](2), row.getAs[String](3))
+// }
+//
+// def appendMergeProps(toDimensionFunc: (RawActivity, String) => String = toDimension,
+// defaultTopKs: Int = 100,
+// dimTopKs: Map[String, Int] = Map.empty,
+// excludePropKeys: Set[String] = Set.empty,
+// dimValExtractors: Seq[Extractor] = Nil,
+// minTs: Long = 0,
+// maxTs: Long = Long.MaxValue) = udf((acts: Seq[Row]) => {
+// val rows = acts.map(rowToRawActivity).filter(act => act._1 >= minTs && act._1 < maxTs)
+//
+// buildMergedProps(rows, toDimensionFunc, defaultTopKs, dimTopKs, excludePropKeys, dimValExtractors)
+// })
+
+ val extractDimensionValues = {
+ udf((dimensionValues: Map[String, Seq[String]]) => {
+ dimensionValues.toSeq.flatMap { case (dimension, values) =>
+ values.map { value => dimension -> value }
+ }
+ })
+ }
+
+ def toHash(dimension: String, dimensionValue: String): Long = {
+ val key = s"$dimension.$dimensionValue"
+ Hashing.murmur3_128().hashBytes(key.toString.getBytes("UTF-8")).asLong()
+ }
+
+ def filterDimensionValues(validDimValues: Broadcast[Set[Long]]) = {
+ udf((dimensionValues: Map[String, Seq[String]]) => {
+ dimensionValues.map { case (dimension, values) =>
+ val filtered = values.filter { value =>
+ val hash = toHash(dimension, value)
+
+ validDimValues.value(hash)
+ }
+
+ dimension -> filtered
+ }
+ })
+ }
+
+ def appendRank[K1: ClassTag, K2: ClassTag, V: ClassTag](ds: Dataset[((K1, K2), V)],
+ numOfPartitions: Option[Int] = None,
+ samplePointsPerPartitionHint: Option[Int] = None)(implicit ordering: Ordering[(K1, K2)]) = {
+ import org.apache.spark.RangePartitioner
+ val rdd = ds.rdd
+
+ val partitioner = new RangePartitioner(numOfPartitions.getOrElse(rdd.partitions.size),
+ rdd,
+ true,
+ samplePointsPerPartitionHint = samplePointsPerPartitionHint.getOrElse(20)
+ )
+
+ val sorted = rdd.repartitionAndSortWithinPartitions(partitioner)
+
+ def rank(idx: Int, iter: Iterator[((K1, K2), V)]) = {
+ var curOffset = 1L
+ var curK1 = null.asInstanceOf[K1]
+
+ iter.map{ case ((key1, key2), value) =>
+ // println(s">>>[$idx] curK1: $curK1, curOffset: $curOffset")
+ val newOffset = if (curK1 == key1) curOffset + 1L else 1L
+ curOffset = newOffset
+ curK1 = key1
+ (idx, newOffset, key1, key2, value)
+ }
+ }
+
+ def getOffset(idx: Int, iter: Iterator[((K1, K2), V)]) = {
+ val buffer = mutable.Map.empty[K1, (Int, Long)]
+ if (!iter.hasNext) buffer.toIterator
+ else {
+ val ((k1, k2), v) = iter.next()
+ var prevKey1: K1 = k1
+ var size = 1L
+ iter.foreach { case ((k1, k2), v) =>
+ if (prevKey1 != k1) {
+ buffer += prevKey1 -> (idx, size)
+ prevKey1 = k1
+ size = 0L
+ }
+ size += 1L
+ }
+ if (size > 0) buffer += prevKey1 -> (idx, size)
+ buffer.iterator
+ }
+ }
+
+ val partRanks = sorted.mapPartitionsWithIndex(rank)
+ val _offsets = sorted.mapPartitionsWithIndex(getOffset)
+ val offsets = _offsets.groupBy(_._1).flatMap { case (k1, partitionWithSize) =>
+ val ls = partitionWithSize.toSeq.map(_._2).sortBy(_._1)
+ var sum = ls.head._2
+ val lss = ls.tail.map { case (partition, size) =>
+ val x = (partition, sum)
+ sum += size
+ x
+ }
+ lss.map { case (partition, offset) =>
+ (k1, partition) -> offset
+ }
+ }.collect()
+
+ println(offsets)
+
+ val offsetsBCast = ds.sparkSession.sparkContext.broadcast(offsets)
+
+ def adjust(iter: Iterator[(Int, Long, K1, K2, V)], startOffsets: Map[(K1, Int), Long]) = {
+ iter.map { case (partition, rankInPartition, key1, key2, value) =>
+ val startOffset = startOffsets.getOrElse((key1, partition), 0L)
+ val rank = startOffset + rankInPartition
+
+ (partition, rankInPartition, rank, (key1, key2), value)
+ }
+ }
+
+ val withRanks = partRanks
+ .mapPartitions { iter =>
+ val startOffsets = offsetsBCast.value.toMap
+ adjust(iter, startOffsets)
+ }.map { case (_, _, rank, (key1, key2), value) =>
+ (rank, (key1, key2), value)
+ }
+
+ withRanks
+ }
+}
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/utils/BoundedPriorityQueue.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/utils/BoundedPriorityQueue.scala
new file mode 100644
index 0000000..c146452
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/utils/BoundedPriorityQueue.scala
@@ -0,0 +1,52 @@
+package org.apache.s2graph.s2jobs.wal.utils
+
+import java.util.{PriorityQueue => JPriorityQueue}
+
+import scala.collection.JavaConverters._
+import scala.collection.generic.Growable
+
+/**
+ * copied from org.apache.spark.util.BoundedPriorityQueue since it is package private.
+ * @param maxSize
+ * @param ord
+ * @tparam A
+ */
+class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
+ extends Iterable[A] with Growable[A] with Serializable {
+
+ private val underlying = new JPriorityQueue[A](maxSize, ord)
+
+ override def iterator: Iterator[A] = underlying.iterator.asScala
+
+ override def size: Int = underlying.size
+
+ override def ++=(xs: TraversableOnce[A]): this.type = {
+ xs.foreach { this += _ }
+ this
+ }
+
+ override def +=(elem: A): this.type = {
+ if (size < maxSize) {
+ underlying.offer(elem)
+ } else {
+ maybeReplaceLowest(elem)
+ }
+ this
+ }
+
+ override def +=(elem1: A, elem2: A, elems: A*): this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ override def clear() { underlying.clear() }
+
+ private def maybeReplaceLowest(a: A): Boolean = {
+ val head = underlying.peek()
+ if (head != null && ord.gt(a, head)) {
+ underlying.poll()
+ underlying.offer(a)
+ } else {
+ false
+ }
+ }
+}
\ No newline at end of file
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/utils/UrlUtils.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/utils/UrlUtils.scala
new file mode 100644
index 0000000..2941357
--- /dev/null
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/utils/UrlUtils.scala
@@ -0,0 +1,57 @@
+package org.apache.s2graph.s2jobs.wal.utils
+
+import java.net.{URI, URLDecoder}
+
+import scala.util.matching.Regex
+
+object UrlUtils {
+ val pattern = new Regex("""(\\x[0-9A-Fa-f]{2}){3}""")
+ val koreanPattern = new scala.util.matching.Regex("([가-힣]+[\\-_a-zA-Z 0-9]*)+|([\\-_a-zA-Z 0-9]+[가-힣]+)")
+
+
+ // url extraction functions
+ def urlDecode(url: String): (Boolean, String) = {
+ try {
+ val decoded = URLDecoder.decode(url, "UTF-8")
+ (url != decoded, decoded)
+ } catch {
+ case e: Exception => (false, url)
+ }
+ }
+
+ def hex2String(url: String): String = {
+ pattern replaceAllIn(url, m => {
+ new String(m.toString.replaceAll("[^0-9A-Fa-f]", "").sliding(2, 2).toArray.map(Integer.parseInt(_, 16).toByte), "utf-8")
+ })
+ }
+
+ def toDomains(url: String, maxDepth: Int = 3): Seq[String] = {
+ val uri = new URI(url)
+ val domain = uri.getHost
+ if (domain == null) Nil
+ else {
+ val paths = uri.getPath.split("/")
+ if (paths.isEmpty) Seq(domain)
+ else {
+ val depth = Math.min(maxDepth, paths.size)
+ (1 to depth).map { ith =>
+ domain + paths.take(ith).mkString("/")
+ }
+ }
+ }
+ }
+
+ def extract(_url: String): (String, Seq[String], Option[String]) = {
+ try {
+ val url = hex2String(_url)
+ val (encoded, decodedUrl) = urlDecode(url)
+
+ val kwdOpt = koreanPattern.findAllMatchIn(decodedUrl).toList.map(_.group(0)).headOption.map(_.replaceAll("\\s", ""))
+ val domains = toDomains(url.replaceAll(" ", ""))
+ (decodedUrl, domains, kwdOpt)
+ } catch {
+ case e: Exception => (_url, Nil, None)
+ }
+ }
+}
+
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/ProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/ProcessTest.scala
index d9bbb5b..5539e43 100644
--- a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/ProcessTest.scala
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/ProcessTest.scala
@@ -26,7 +26,7 @@
test("SqlProcess execute sql") {
import spark.implicits._
-
+
val inputDF = Seq(
("a", "b", "friend"),
("a", "c", "friend"),
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/TestData.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/TestData.scala
new file mode 100644
index 0000000..fae9265
--- /dev/null
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/TestData.scala
@@ -0,0 +1,42 @@
+package org.apache.s2graph.s2jobs.wal
+
+object TestData {
+ val testServiceName = "s2graph"
+ val walLogsLs = Seq(
+ WalLog(1L, "insert", "edge", "u1", "i1", s"$testServiceName", "click", """{"item_name":"awesome item"}"""),
+ WalLog(2L, "insert", "edge", "u1", "i1", s"$testServiceName", "purchase", """{"price":2}"""),
+ WalLog(3L, "insert", "edge", "u1", "q1", s"$testServiceName", "search", """{"referrer":"www.google.com"}"""),
+ WalLog(4L, "insert", "edge", "u2", "i1", s"$testServiceName", "click", """{"item_name":"awesome item"}"""),
+ WalLog(5L, "insert", "edge", "u2", "q2", s"$testServiceName", "search", """{"referrer":"www.bing.com"}"""),
+ WalLog(6L, "insert", "edge", "u3", "i2", s"$testServiceName", "click", """{"item_name":"bad item"}"""),
+ WalLog(7L, "insert", "edge", "u4", "q1", s"$testServiceName", "search", """{"referrer":"www.google.com"}""")
+ )
+
+ // order by from
+ val aggExpected = Array(
+ WalLogAgg("u1", Seq(
+ WalLog(3L, "insert", "edge", "u1", "q1", s"$testServiceName", "search", """{"referrer":"www.google.com"}"""),
+ WalLog(2L, "insert", "edge", "u1", "i1", s"$testServiceName", "purchase", """{"price":2}"""),
+ WalLog(1L, "insert", "edge", "u1", "i1", s"$testServiceName", "click", """{"item_name":"awesome item"}""")
+ ), 3L, 1L),
+ WalLogAgg("u2", Seq(
+ WalLog(5L, "insert", "edge", "u2", "q2", s"$testServiceName", "search", """{"referrer":"www.bing.com"}"""),
+ WalLog(4L, "insert", "edge", "u2", "i1", s"$testServiceName", "click", """{"item_name":"awesome item"}""")
+ ), 5L, 4L),
+ WalLogAgg("u3", Seq(
+ WalLog(6L, "insert", "edge", "u3", "i2", s"$testServiceName", "click", """{"item_name":"bad item"}""")
+ ), 6L, 6L),
+ WalLogAgg("u4", Seq(
+ WalLog(7L, "insert", "edge", "u4", "q1", s"$testServiceName", "search", """{"referrer":"www.google.com"}""")
+ ), 7L, 7L)
+ )
+
+ // order by dim, rank
+ val featureDictExpected = Array(
+ DimValCountRank(DimVal("click:item_name", "awesome item"), 2, 1),
+ DimValCountRank(DimVal("click:item_name", "bad item"), 1, 2),
+ DimValCountRank(DimVal("purchase:price", "2"), 1, 1),
+ DimValCountRank(DimVal("search:referrer", "www.google.com"), 2, 1),
+ DimValCountRank(DimVal("search:referrer", "www.bing.com"), 1, 2)
+ )
+}
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/TransformerTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/TransformerTest.scala
new file mode 100644
index 0000000..974002b
--- /dev/null
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/TransformerTest.scala
@@ -0,0 +1,32 @@
+package org.apache.s2graph.s2jobs.wal
+
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.transformer._
+import org.scalatest.{FunSuite, Matchers}
+import play.api.libs.json.Json
+
+class TransformerTest extends FunSuite with Matchers {
+ val walLog = WalLog(1L, "insert", "edge", "a", "b", "s2graph", "friends", """{"name": 1, "url": "www.google.com"}""")
+
+ test("test default transformer") {
+ val taskConf = TaskConf.Empty
+ val transformer = new DefaultTransformer(taskConf)
+ val dimVals = transformer.toDimValLs(walLog, "name", "1")
+
+ dimVals shouldBe Seq(DimVal("friends:name", "1"))
+ }
+
+ test("test ExtractDomain from URL") {
+ val taskConf = TaskConf.Empty.copy(options =
+ Map("urlDimensions" -> Json.toJson(Seq("url")).toString())
+ )
+ val transformer = new ExtractDomain(taskConf)
+ val dimVals = transformer.toDimValLs(walLog, "url", "http://www.google.com/abc")
+
+ dimVals shouldBe Seq(
+ DimVal("host", "www.google.com"),
+ DimVal("domain", "www.google.com"),
+ DimVal("domain", "www.google.com/abc")
+ )
+ }
+}
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/BuildTopFeaturesProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/BuildTopFeaturesProcessTest.scala
new file mode 100644
index 0000000..4d2f079
--- /dev/null
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/BuildTopFeaturesProcessTest.scala
@@ -0,0 +1,31 @@
+package org.apache.s2graph.s2jobs.wal.process
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.DimValCountRank
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+
+class BuildTopFeaturesProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase {
+
+ import org.apache.s2graph.s2jobs.wal.TestData._
+
+ test("test entire process.") {
+ import spark.implicits._
+ val df = spark.createDataset(aggExpected).toDF()
+
+ val taskConf = new TaskConf(name = "test", `type` = "test", inputs = Seq("input"),
+ options = Map("minUserCount" -> "0")
+ )
+ val job = new BuildTopFeaturesProcess(taskConf)
+
+
+ val inputMap = Map("input" -> df)
+ val featureDicts = job.execute(spark, inputMap)
+ .orderBy("dim", "rank")
+ .map(DimValCountRank.fromRow)
+ .collect()
+
+ featureDicts shouldBe featureDictExpected
+
+ }
+}
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/FilterTopFeaturesProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/FilterTopFeaturesProcessTest.scala
new file mode 100644
index 0000000..cd8295a
--- /dev/null
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/FilterTopFeaturesProcessTest.scala
@@ -0,0 +1,84 @@
+package org.apache.s2graph.s2jobs.wal.process
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal.transformer.DefaultTransformer
+import org.apache.s2graph.s2jobs.wal.{DimValCountRank, WalLogAgg}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+
+class FilterTopFeaturesProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase {
+ import org.apache.s2graph.s2jobs.wal.TestData._
+
+ test("test filterTopKsPerDim.") {
+ import spark.implicits._
+ val featureDf = spark.createDataset(featureDictExpected).map { x =>
+ (x.dimVal.dim, x.dimVal.value, x.count, x.rank)
+ }.toDF("dim", "value", "count", "rank")
+
+ val maxRankPerDim = spark.sparkContext.broadcast(Map.empty[String, Int])
+
+ // filter nothing because all feature has rank < 10
+ val filtered = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 10)
+
+ val real = filtered.orderBy("dim", "rank").map(DimValCountRank.fromRow).collect()
+ real.zip(featureDictExpected).foreach { case (real, expected) =>
+ real shouldBe expected
+ }
+ // filter rank >= 2
+ val filtered2 = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 2)
+ val real2 = filtered2.orderBy("dim", "rank").map(DimValCountRank.fromRow).collect()
+ real2 shouldBe featureDictExpected.filter(_.rank < 2)
+ }
+
+
+ test("test filterWalLogAgg.") {
+ import spark.implicits._
+ val walLogAgg = spark.createDataset(aggExpected)
+ val featureDf = spark.createDataset(featureDictExpected).map { x =>
+ (x.dimVal.dim, x.dimVal.value, x.count, x.rank)
+ }.toDF("dim", "value", "count", "rank")
+ val maxRankPerDim = spark.sparkContext.broadcast(Map.empty[String, Int])
+
+ val transformers = Seq(DefaultTransformer(TaskConf.Empty))
+ // filter nothing. so input, output should be same.
+ val featureFiltered = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 10)
+ val validFeatureHashKeys = FilterTopFeaturesProcess.collectDistinctFeatureHashes(spark, featureFiltered)
+ val validFeatureHashKeysBCast = spark.sparkContext.broadcast(validFeatureHashKeys)
+ val real = FilterTopFeaturesProcess.filterWalLogAgg(spark, walLogAgg, transformers, validFeatureHashKeysBCast)
+ .collect().sortBy(_.from)
+
+ real.zip(aggExpected).foreach { case (real, expected) =>
+ real shouldBe expected
+ }
+ }
+
+ test("test entire process. filter nothing.") {
+ import spark.implicits._
+ val df = spark.createDataset(aggExpected).toDF()
+ val featureDf = spark.createDataset(featureDictExpected).map { x =>
+ (x.dimVal.dim, x.dimVal.value, x.count, x.rank)
+ }.toDF("dim", "value", "count", "rank")
+
+ val inputKey = "input"
+ val featureDictKey = "feature"
+ // filter nothing since we did not specified maxRankPerDim and defaultMaxRank.
+ val taskConf = new TaskConf(name = "test", `type` = "test",
+ inputs = Seq(inputKey, featureDictKey),
+ options = Map(
+ "featureDict" -> featureDictKey,
+ "walLogAgg" -> inputKey
+ )
+ )
+ val inputMap = Map(inputKey -> df, featureDictKey -> featureDf)
+ val job = new FilterTopFeaturesProcess(taskConf)
+ val filtered = job.execute(spark, inputMap)
+ .orderBy("from")
+ .as[WalLogAgg]
+ .collect()
+
+ filtered.zip(aggExpected).foreach { case (real, expected) =>
+ real shouldBe expected
+ }
+
+ }
+}
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/WalLogAggregateProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/WalLogAggregateProcessTest.scala
new file mode 100644
index 0000000..1bb7426
--- /dev/null
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/process/WalLogAggregateProcessTest.scala
@@ -0,0 +1,31 @@
+package org.apache.s2graph.s2jobs.wal.process
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.wal._
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+
+class WalLogAggregateProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase {
+ import org.apache.s2graph.s2jobs.wal.TestData._
+
+ test("test entire process") {
+ import spark.sqlContext.implicits._
+
+ val edges = spark.createDataset(walLogsLs).toDF()
+ val processKey = "agg"
+ val inputMap = Map(processKey -> edges)
+
+ val taskConf = new TaskConf(name = "test", `type` = "agg", inputs = Seq(processKey),
+ options = Map("maxNumOfEdges" -> "10")
+ )
+
+ val job = new WalLogAggregateProcess(taskConf = taskConf)
+ val processed = job.execute(spark, inputMap)
+
+ processed.printSchema()
+ processed.orderBy("from").as[WalLogAgg].collect().zip(aggExpected).foreach { case (real, expected) =>
+ real shouldBe expected
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/udafs/WalLogUDAFTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/udafs/WalLogUDAFTest.scala
new file mode 100644
index 0000000..aded56d
--- /dev/null
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/wal/udafs/WalLogUDAFTest.scala
@@ -0,0 +1,42 @@
+package org.apache.s2graph.s2jobs.wal.udafs
+
+import org.apache.s2graph.s2jobs.wal.utils.BoundedPriorityQueue
+import org.scalatest._
+
+import scala.collection.mutable
+import scala.util.Random
+
+class WalLogUDAFTest extends FunSuite with Matchers {
+
+ test("mergeTwoSeq") {
+ val prev: Array[Int] = Array(3, 2, 1)
+ val cur: Array[Int] = Array(4, 2, 2)
+
+ val ls = WalLogUDAF.mergeTwoSeq(prev, cur, 10)
+ println(ls.size)
+
+ ls.foreach { x =>
+ println(x)
+ }
+ }
+
+ test("addToTopK test.") {
+ import WalLogUDAF._
+ val numOfTest = 100
+ val numOfNums = 100
+ val maxNum = 10
+
+ (0 until numOfTest).foreach { testNum =>
+ val maxSize = 1 + Random.nextInt(numOfNums)
+ val pq = new BoundedPriorityQueue[Int](maxSize)
+ val arr = (0 until numOfNums).map(x => Random.nextInt(maxNum))
+ var result: mutable.Seq[Int] = mutable.ArrayBuffer.empty[Int]
+
+ arr.foreach { i =>
+ pq += i
+ result = addToTopK(result, maxSize, i)
+ }
+ result.sorted shouldBe pq.toSeq.sorted
+ }
+ }
+}