blob: d420dad23069e2ffd7982301c6162516f0baadd3 [file] [log] [blame]
package org.example.vanilla
import org.apache.predictionio.controller.PDataSource
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.EmptyActualResult
import org.apache.predictionio.controller.Params
import org.apache.predictionio.data.storage.Event
import org.apache.predictionio.data.store.PEventStore
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import grizzled.slf4j.Logger
case class DataSourceParams(appName: String) extends Params
class DataSource(val dsp: DataSourceParams)
extends PDataSource[TrainingData,
EmptyEvaluationInfo, Query, EmptyActualResult] {
@transient lazy val logger = Logger[this.type]
override
def readTraining(sc: SparkContext): TrainingData = {
// read all events of EVENT involving ENTITY_TYPE and TARGET_ENTITY_TYPE
val eventsRDD: RDD[Event] = PEventStore.find(
appName = dsp.appName,
entityType = Some("ENTITY_TYPE"),
eventNames = Some(List("EVENT")),
targetEntityType = Some(Some("TARGET_ENTITY_TYPE")))(sc)
new TrainingData(eventsRDD)
}
}
class TrainingData(
val events: RDD[Event]
) extends Serializable {
override def toString = {
s"events: [${events.count()}] (${events.take(2).toList}...)"
}
}