| package com.test1 |
| |
| import io.prediction.controller.PDataSource |
| import io.prediction.controller.EmptyEvaluationInfo |
| import io.prediction.controller.EmptyActualResult |
| import io.prediction.controller.Params |
| import io.prediction.data.storage.Event |
| import io.prediction.data.storage.Storage |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.SparkContext._ |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.mllib.regression.LabeledPoint |
| import org.apache.spark.mllib.linalg.Vectors |
| |
| import grizzled.slf4j.Logger |
| |
| case class DataSourceParams(appId: Int) extends Params |
| |
| class DataSource(val dsp: DataSourceParams) |
| extends PDataSource[TrainingData, |
| EmptyEvaluationInfo, Query, EmptyActualResult] { |
| |
| @transient lazy val logger = Logger[this.type] |
| |
| override |
| def readTraining(sc: SparkContext): TrainingData = { |
| val eventsDb = Storage.getPEvents() |
| val gendersMap = Map("Male" -> 0.0, "Female" -> 1.0) |
| val educationMap = Map("No School" -> 0.0,"High School" -> 1.0,"College" -> 2.0) |
| val labeledPoints: RDD[LabeledPoint] = eventsDb.aggregateProperties( |
| appId = dsp.appId, |
| entityType = "user", |
| // only keep entities with these required properties defined |
| required = Some(List("plan", "gender", "age", "education")))(sc) |
| // aggregateProperties() returns RDD pair of |
| // entity ID and its aggregated properties |
| .map { case (entityId, properties) => |
| try { |
| LabeledPoint(properties.get[Double]("plan"), |
| Vectors.dense(Array( |
| gendersMap(properties.get[String]("gender")), |
| properties.get[Double]("age"), |
| educationMap(properties.get[String]("education")) |
| )) |
| ) |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| }.cache() |
| |
| new TrainingData(labeledPoints, |
| gendersMap, |
| educationMap) |
| } |
| } |
| |
| class TrainingData( |
| val labeledPoints: RDD[LabeledPoint], |
| val gendersMap: Map[String,Double], |
| val educationMap: Map[String,Double] |
| ) extends Serializable |