blob: 871b7e47f2cdad7e970b35a17ddc18ee852bc834 [file] [log] [blame]
package org.template.textclassification
import io.prediction.controller.PPreparator
import io.prediction.controller.Params
import org.apache.spark.SparkContext
import org.apache.spark.mllib.feature.{IDF, IDFModel, HashingTF}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import scala.collection.immutable.HashMap
import scala.collection.JavaConversions._
import scala.math._
// 1. Initialize Preparator parameters. Recall that for our data
// representation we are only required to input the n-gram window
// components.
case class PreparatorParams(
nGram : Int
) extends Params
// 2. Initialize your Preparator class.
class Preparator(pp: PreparatorParams) extends PPreparator[TrainingData, PreparedData] {
// Prepare your training data.
def prepare(sc : SparkContext, td: TrainingData): PreparedData = {
new PreparedData(td, pp.nGram)
}
}
//------PreparedData------------------------
class PreparedData (
val td : TrainingData,
val nGram : Int
) extends Serializable {
// 1. Hashing function: Text -> term frequency vector.
private val hasher = new HashingTF()
private def hashTF (text : String) : Vector = {
val newList : Array[String] = text.split(" ")
.sliding(nGram)
.map(_.mkString)
.toArray
hasher.transform(newList)
}
// 2. Term frequency vector -> t.f.-i.d.f. vector.
val idf : IDFModel = new IDF().fit(td.data.map(e => hashTF(e.text)))
// 3. Document Transformer: text => tf-idf vector.
def transform(text : String): Vector = {
// Map(n-gram -> document tf)
idf.transform(hashTF(text))
}
// 4. Data Transformer: RDD[documents] => RDD[LabeledPoints]
val transformedData: RDD[(LabeledPoint)] = {
td.data.map(e => LabeledPoint(e.label, transform(e.text)))
}
// 5. Finally extract category map, associating label to category.
val categoryMap = td.data.map(e => (e.label, e.category)).collectAsMap
}