| package org.example.classification |
| |
| import org.apache.predictionio.controller.PDataSource |
| import org.apache.predictionio.controller.EmptyEvaluationInfo |
| import org.apache.predictionio.controller.Params |
| import org.apache.predictionio.data.store.PEventStore |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.mllib.regression.LabeledPoint |
| import org.apache.spark.mllib.linalg.Vectors |
| |
| import grizzled.slf4j.Logger |
| |
| case class DataSourceParams( |
| appName: String, |
| evalK: Option[Int] // define the k-fold parameter. |
| ) extends Params |
| |
| class DataSource(val dsp: DataSourceParams) |
| extends PDataSource[TrainingData, |
| EmptyEvaluationInfo, Query, ActualResult] { |
| |
| @transient lazy val logger = Logger[this.type] |
| |
| override |
| def readTraining(sc: SparkContext): TrainingData = { |
| |
| val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties( |
| appName = dsp.appName, |
| entityType = "user", |
| // only keep entities with these required properties defined |
| required = Some(List("plan", "attr0", "attr1", "attr2")))(sc) |
| // aggregateProperties() returns RDD pair of |
| // entity ID and its aggregated properties |
| .map { case (entityId, properties) => |
| try { |
| LabeledPoint(properties.get[Double]("plan"), |
| Vectors.dense(Array( |
| properties.get[Double]("attr0"), |
| properties.get[Double]("attr1"), |
| properties.get[Double]("attr2") |
| )) |
| ) |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| }.cache() |
| |
| new TrainingData(labeledPoints) |
| } |
| |
| override |
| def readEval(sc: SparkContext) |
| : Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = { |
| require(dsp.evalK.nonEmpty, "DataSourceParams.evalK must not be None") |
| |
| // The following code reads the data from data store. It is equivalent to |
| // the readTraining method. We copy-and-paste the exact code here for |
| // illustration purpose, a recommended approach is to factor out this logic |
| // into a helper function and have both readTraining and readEval call the |
| // helper. |
| val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties( |
| appName = dsp.appName, |
| entityType = "user", |
| // only keep entities with these required properties defined |
| required = Some(List("plan", "attr0", "attr1", "attr2")))(sc) |
| // aggregateProperties() returns RDD pair of |
| // entity ID and its aggregated properties |
| .map { case (entityId, properties) => |
| try { |
| LabeledPoint(properties.get[Double]("plan"), |
| Vectors.dense(Array( |
| properties.get[Double]("attr0"), |
| properties.get[Double]("attr1"), |
| properties.get[Double]("attr2") |
| )) |
| ) |
| } catch { |
| case e: Exception => { |
| logger.error(s"Failed to get properties ${properties} of" + |
| s" ${entityId}. Exception: ${e}.") |
| throw e |
| } |
| } |
| }.cache() |
| // End of reading from data store |
| |
| // K-fold splitting |
| val evalK = dsp.evalK.get |
| val indexedPoints: RDD[(LabeledPoint, Long)] = labeledPoints.zipWithIndex() |
| |
| (0 until evalK).map { idx => |
| val trainingPoints = indexedPoints.filter(_._2 % evalK != idx).map(_._1) |
| val testingPoints = indexedPoints.filter(_._2 % evalK == idx).map(_._1) |
| |
| ( |
| new TrainingData(trainingPoints), |
| new EmptyEvaluationInfo(), |
| testingPoints.map { |
| p => (Query(p.features(0), p.features(1), p.features(2)), ActualResult(p.label)) |
| } |
| ) |
| } |
| } |
| } |
| |
| class TrainingData( |
| val labeledPoints: RDD[LabeledPoint] |
| ) extends Serializable |