blob: eea3ae6c6776b3f4a776dc24bb5364bd578c3435 [file] [log] [blame]
package org.template.recommendation
import org.apache.predictionio.controller.PDataSource
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.EmptyActualResult
import org.apache.predictionio.controller.Params
import org.apache.predictionio.data.storage.Event
import org.apache.predictionio.data.store.PEventStore
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import grizzled.slf4j.Logger
case class DataSourceEvalParams(kFold: Int, queryNum: Int)
case class DataSourceParams(
appName: String,
evalParams: Option[DataSourceEvalParams]) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, ActualResult] {
@transient lazy val logger = Logger[this.type]
def getRatings(sc: SparkContext): RDD[Rating] = {
val eventsRDD: RDD[Event] = PEventStore.find(
appName = dsp.appName,
entityType = Some("user"),
eventNames = Some(List("rate", "buy")), // read "rate" and "buy" event
// targetEntityType is optional field of an event.
targetEntityType = Some(Some("item")))(sc)
val ratingsRDD: RDD[Rating] = eventsRDD.map { event =>
val rating = try {
val ratingValue: Double = event.event match {
case "rate" => event.properties.get[Double]("rating")
case "buy" => 4.0 // map buy event to rating value of 4
case _ => throw new Exception(s"Unexpected event ${event} is read.")
}
// entityId and targetEntityId is String
Rating(event.entityId,
event.targetEntityId.get,
ratingValue)
} catch {
case e: Exception => {
logger.error(s"Cannot convert ${event} to Rating. Exception: ${e}.")
throw e
}
}
rating
}.cache()
ratingsRDD
}
override
def readTraining(sc: SparkContext): TrainingData = {
new TrainingData(getRatings(sc))
}
override
def readEval(sc: SparkContext)
: Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = {
require(!dsp.evalParams.isEmpty, "Must specify evalParams")
val evalParams = dsp.evalParams.get
val kFold = evalParams.kFold
val ratings: RDD[(Rating, Long)] = getRatings(sc).zipWithUniqueId
ratings.cache
(0 until kFold).map { idx => {
val trainingRatings = ratings.filter(_._2 % kFold != idx).map(_._1)
val testingRatings = ratings.filter(_._2 % kFold == idx).map(_._1)
val testingUsers: RDD[(String, Iterable[Rating])] = testingRatings.groupBy(_.user)
(new TrainingData(trainingRatings),
new EmptyEvaluationInfo(),
testingUsers.map {
case (user, ratings) => (Query(user, evalParams.queryNum), ActualResult(ratings.toArray))
}
)
}}
}
}
case class Rating(
user: String,
item: String,
rating: Double
)
class TrainingData(
val ratings: RDD[Rating]
) extends Serializable {
override def toString = {
s"ratings: [${ratings.count()}] (${ratings.take(2).toList}...)"
}
}