blob: f981e9eabb47c588106e7c69dfb4e3c460b0d9c0 [file] [log] [blame]
package org.example.classification
import org.apache.predictionio.controller.PDataSource
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.Params
import org.apache.predictionio.data.store.PEventStore
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import grizzled.slf4j.Logger
case class DataSourceParams(
appName: String,
evalK: Option[Int] // define the k-fold parameter.
) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, ActualResult] {
@transient lazy val logger = Logger[this.type]
override
def readTraining(sc: SparkContext): TrainingData = {
val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties(
appName = dsp.appName,
entityType = "user",
// only keep entities with these required properties defined
required = Some(List("plan", "attr0", "attr1", "attr2")))(sc)
// aggregateProperties() returns RDD pair of
// entity ID and its aggregated properties
.map { case (entityId, properties) =>
try {
LabeledPoint(properties.get[Double]("plan"),
Vectors.dense(Array(
properties.get[Double]("attr0"),
properties.get[Double]("attr1"),
properties.get[Double]("attr2")
))
)
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" ${entityId}. Exception: ${e}.")
throw e
}
}
}.cache()
new TrainingData(labeledPoints)
}
override
def readEval(sc: SparkContext)
: Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = {
require(dsp.evalK.nonEmpty, "DataSourceParams.evalK must not be None")
// The following code reads the data from data store. It is equivalent to
// the readTraining method. We copy-and-paste the exact code here for
// illustration purpose, a recommended approach is to factor out this logic
// into a helper function and have both readTraining and readEval call the
// helper.
val labeledPoints: RDD[LabeledPoint] = PEventStore.aggregateProperties(
appName = dsp.appName,
entityType = "user",
// only keep entities with these required properties defined
required = Some(List("plan", "attr0", "attr1", "attr2")))(sc)
// aggregateProperties() returns RDD pair of
// entity ID and its aggregated properties
.map { case (entityId, properties) =>
try {
LabeledPoint(properties.get[Double]("plan"),
Vectors.dense(Array(
properties.get[Double]("attr0"),
properties.get[Double]("attr1"),
properties.get[Double]("attr2")
))
)
} catch {
case e: Exception => {
logger.error(s"Failed to get properties ${properties} of" +
s" ${entityId}. Exception: ${e}.")
throw e
}
}
}.cache()
// End of reading from data store
// K-fold splitting
val evalK = dsp.evalK.get
val indexedPoints: RDD[(LabeledPoint, Long)] = labeledPoints.zipWithIndex()
(0 until evalK).map { idx =>
val trainingPoints = indexedPoints.filter(_._2 % evalK != idx).map(_._1)
val testingPoints = indexedPoints.filter(_._2 % evalK == idx).map(_._1)
(
new TrainingData(trainingPoints),
new EmptyEvaluationInfo(),
testingPoints.map {
p => (Query(p.features(0), p.features(1), p.features(2)), ActualResult(p.label))
}
)
}
}
}
class TrainingData(
val labeledPoints: RDD[LabeledPoint]
) extends Serializable