blob: 52a86358f66f9131af06551b3b3ea3bb2d5e8504 [file] [log] [blame]
package com.test1
import io.prediction.controller.PDataSource
import io.prediction.controller.EmptyEvaluationInfo
import io.prediction.controller.EmptyActualResult
import io.prediction.controller.Params
import io.prediction.data.storage.Event
import io.prediction.data.storage.Storage
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import grizzled.slf4j.Logger
case class DataSourceParams(appId: Int) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, EmptyActualResult] {
@transient lazy val logger = Logger[this.type]
override
def readTraining(sc: SparkContext): TrainingData = {
val eventsDb = Storage.getPEvents()
val gendersMap = Map("Male" -> 0.0, "Female" -> 1.0)
val educationMap = Map("No School" -> 0.0,"High School" -> 1.0,"College" -> 2.0)
val labeledPoints: RDD[LabeledPoint] = eventsDb.aggregateProperties(
appId = dsp.appId,
entityType = "user",
// only keep entities with these required properties defined
required = Some(List("plan", "gender", "age", "education")))(sc)
// aggregateProperties() returns RDD pair of
// entity ID and its aggregated properties
.map { case (entityId, properties) =>
try {
LabeledPoint(properties.get[Double]("plan"),
Vectors.dense(Array(
gendersMap(properties.get[String]("gender")),
properties.get[Double]("age"),
educationMap(properties.get[String]("education"))
))
)
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" ${entityId}. Exception: ${e}.")
throw e
}
}
}.cache()
new TrainingData(labeledPoints,
gendersMap,
educationMap)
}
}
class TrainingData(
val labeledPoints: RDD[LabeledPoint],
val gendersMap: Map[String,Double],
val educationMap: Map[String,Double]
) extends Serializable