| package org.example.vanilla |
| |
| import org.apache.predictionio.controller.PDataSource |
| import org.apache.predictionio.controller.EmptyEvaluationInfo |
| import org.apache.predictionio.controller.EmptyActualResult |
| import org.apache.predictionio.controller.Params |
| import org.apache.predictionio.data.storage.Event |
| import org.apache.predictionio.data.store.PEventStore |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.rdd.RDD |
| |
| import grizzled.slf4j.Logger |
| |
| case class DataSourceParams(appName: String) extends Params |
| |
| class DataSource(val dsp: DataSourceParams) |
| extends PDataSource[TrainingData, |
| EmptyEvaluationInfo, Query, EmptyActualResult] { |
| |
| @transient lazy val logger = Logger[this.type] |
| |
| override |
| def readTraining(sc: SparkContext): TrainingData = { |
| |
| // read all events of EVENT involving ENTITY_TYPE and TARGET_ENTITY_TYPE |
| val eventsRDD: RDD[Event] = PEventStore.find( |
| appName = dsp.appName, |
| entityType = Some("ENTITY_TYPE"), |
| eventNames = Some(List("EVENT")), |
| targetEntityType = Some(Some("TARGET_ENTITY_TYPE")))(sc) |
| |
| new TrainingData(eventsRDD) |
| } |
| } |
| |
| class TrainingData( |
| val events: RDD[Event] |
| ) extends Serializable { |
| override def toString = { |
| s"events: [${events.count()}] (${events.take(2).toList}...)" |
| } |
| } |