blob: aebb1cc2940fa3a31038b53f063a8354dc2539cf [file] [log] [blame]
package org.apache.s2graph.s2jobs.wal.process
import org.apache.s2graph.s2jobs.task.TaskConf
import org.apache.s2graph.s2jobs.wal.process.params.AggregateParam
import org.apache.s2graph.s2jobs.wal.{WalLog, WalLogAgg}
import org.apache.spark.sql._
object WalLogAggregateProcess {
def aggregate(ss: SparkSession,
dataset: Dataset[WalLogAgg],
aggregateParam: AggregateParam)(implicit ord: Ordering[WalLog]) = {
import ss.implicits._
dataset.groupByKey(_.from).flatMapGroups { case (_, iter) =>
WalLogAgg.merge(iter, aggregateParam)
}.toDF(WalLogAgg.outputColumns: _*)
}
def aggregateRaw(ss: SparkSession,
dataset: Dataset[WalLog],
aggregateParam: AggregateParam)(implicit ord: Ordering[WalLog]): DataFrame = {
import ss.implicits._
dataset.groupByKey(walLog => walLog.from).flatMapGroups { case (key, iter) =>
WalLogAgg.merge(iter.map(WalLogAgg(_)), aggregateParam)
}.toDF(WalLogAgg.outputColumns: _*)
}
}
/**
* expect DataFrame of WalLog, then group WalLog by groupByKeys(default from).
* produce DataFrame of WalLogAgg which abstract the session consists of sequence of WalLog ordered by timestamp(desc).
*
* one way to visualize this is that transforming (row, column, value) matrix entries into (row, Sparse Vector(column:value).
* note that we only keep track of max topK latest walLog per each groupByKeys
*/
class WalLogAggregateProcess(taskConf: TaskConf) extends org.apache.s2graph.s2jobs.task.Process(taskConf) {
import WalLogAggregateProcess._
override def execute(ss: SparkSession, inputMap: Map[String, DataFrame]): DataFrame = {
import ss.implicits._
//TODO: Current implementation only expect taskConf.options as Map[String, String].
//Once we change taskConf.options as JsObject, then we can simply parse input paramter as following.
//implicit val paramReads = Json.reads[AggregateParam]
val param = AggregateParam.fromTaskConf(taskConf)
param.numOfPartitions.foreach(d => ss.sqlContext.setConf("spark.sql.shuffle.partitions", d.toString))
implicit val ord = WalLog.orderByTsAsc
val walLogs = taskConf.inputs.tail.foldLeft(inputMap(taskConf.inputs.head)) { case (prev, cur) =>
prev.union(inputMap(cur))
}
if (param.arrayType) aggregate(ss, walLogs.as[WalLogAgg], param)
else aggregateRaw(ss, walLogs.as[WalLog], param)
}
override def mandatoryOptions: Set[String] = Set.empty
}