blob: 6a29909c78b7813607b63b6fb4e0b7e066e0eb65 [file] [log] [blame]
package org.template.textclassification
import grizzled.slf4j.Logger
import io.prediction.controller.EmptyEvaluationInfo
import io.prediction.controller.Params
import io.prediction.controller.PDataSource
import io.prediction.controller.SanityCheck
import io.prediction.data.store.PEventStore
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
// 1. Initialize your Data Source parameters. This is
// where you specify your application name, MyTextApp,
// and the number of folds that are to be used for
// cross validation.
case class DataSourceParams(
appName: String,
evalK: Option[Int]
) extends Params
// 2. Define your DataSource component. Remember, you must
// implement a readTraining method, and, optionally, a
// readEval method.
class DataSource (
val dsp : DataSourceParams
) extends PDataSource[TrainingData, EmptyEvaluationInfo, Query, ActualResult] {
@transient lazy val logger = Logger[this.type]
// Helper function used to store data given
// a SparkContext.
private def readEventData(sc: SparkContext) : RDD[Observation] = {
//Get RDD of Events.
PEventStore.find(
appName = dsp.appName,
entityType = Some("content"), // specify data entity type
eventNames = Some(List("e-mail")) // specify data event name
// Convert collected RDD of events to and RDD of Observation
// objects.
)(sc).map(e => {
val label : String = e.properties.get[String]("label")
Observation(
if (label == "spam") 1.0 else 0.0,
e.properties.get[String]("text"),
label
)
}).cache
}
// Helper function used to store stop words from
// event server.
private def readStopWords(sc : SparkContext) : Set[String] = {
PEventStore.find(
appName = dsp.appName,
entityType = Some("resource"),
eventNames = Some(List("stopwords"))
//Convert collected RDD of strings to a string set.
)(sc)
.map(e => e.properties.get[String]("word"))
.collect
.toSet
}
// Read in data and stop words from event server
// and store them in a TrainingData instance.
override
def readTraining(sc: SparkContext): TrainingData = {
new TrainingData(readEventData(sc), readStopWords(sc))
}
// Used for evaluation: reads in event data and creates
// cross-validation folds.
override
def readEval(sc: SparkContext):
Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = {
// Zip your RDD of events read from the server with indices
// for the purposes of creating our folds.
val data = readEventData(sc).zipWithIndex()
// Create cross validation folds by partitioning indices
// based on their index value modulo the number of folds.
(0 until dsp.evalK.get).map { k =>
// Prepare training data for fold.
val train = new TrainingData(
data.filter(_._2 % dsp.evalK.get != k).map(_._1),
readStopWords
((sc)))
// Prepare test data for fold.
val test = data.filter(_._2 % dsp.evalK.get == k)
.map(_._1)
.map(e => (new Query(e.text), new ActualResult(e.category)))
(train, new EmptyEvaluationInfo, test)
}
}
}
// 3. Observation class serving as a wrapper for both our
// data's class label and document string.
case class Observation(
label : Double,
text : String,
category :String
) extends Serializable
// 4. TrainingData class serving as a wrapper for all
// read in from the Event Server.
class TrainingData(
val data : RDD[Observation],
val stopWords : Set[String]
) extends Serializable with SanityCheck {
// Sanity check to make sure your data is being fed in correctly.
def sanityCheck {
try {
val obs : Array[Double] = data.takeSample(false, 5).map(_.label)
println()
(0 until 5).foreach(
k => println("Observation " + (k + 1) +" label: " + obs(k))
)
println()
} catch {
case (e : ArrayIndexOutOfBoundsException) => {
println()
println("Data set is empty, make sure event fields match imported data.")
println()
}
}
}
}