blob: 23faa3d1b4c98541e5874c298fdd2e62a698948e [file] [log] [blame]
package org.template.recommendation
import io.prediction.controller.PDataSource
import io.prediction.controller.EmptyEvaluationInfo
import io.prediction.controller.EmptyActualResult
import io.prediction.controller.Params
import io.prediction.data.storage.Event
import io.prediction.data.storage.Storage
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import grizzled.slf4j.Logger
case class DataSourceParams(appId: Int) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, EmptyActualResult] {
@transient lazy val logger = Logger[this.type]
override
def readTraining(sc: SparkContext): TrainingData = {
val eventsDb = Storage.getPEvents()
val eventsRDD: RDD[Event] = eventsDb.find(
appId = dsp.appId,
entityType = Some("user"),
eventNames = Some(List("rate", "buy")), // read "rate" and "buy" event
// targetEntityType is optional field of an event.
targetEntityType = Some(Some("item")))(sc)
val ratingsRDD: RDD[Rating] = eventsRDD.map { event =>
val rating = try {
val ratingValue: Double = event.event match {
case "rate" => event.properties.get[Double]("rating")
case "buy" => 4.0 // map buy event to rating value of 4
case _ => throw new Exception(s"Unexpected event ${event} is read.")
}
// entityId and targetEntityId is String
Rating(event.entityId,
event.targetEntityId.get,
ratingValue)
} catch {
case e: Exception => {
logger.error(s"Cannot convert ${event} to Rating. Exception: ${e}.")
throw e
}
}
rating
}.cache()
new TrainingData(ratingsRDD)
}
}
case class Rating(
user: String,
item: String,
rating: Double
)
class TrainingData(
val ratings: RDD[Rating]
) extends Serializable {
override def toString = {
s"ratings: [${ratings.count()}] (${ratings.take(2).toList}...)"
}
}