Merge branch 'patch-1' of https://github.com/kanwarpartapsingh/incubator-predictionio-template-text-classifier
diff --git a/src/main/scala/Preparator.scala b/src/main/scala/Preparator.scala
index c8b35d0..8a5cb5c 100644
--- a/src/main/scala/Preparator.scala
+++ b/src/main/scala/Preparator.scala
@@ -11,6 +11,14 @@
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
+import org.apache.lucene.analysis.standard.StandardAnalyzer
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
+import org.apache.lucene.util.Version
+
+import java.io.StringReader
+
+import scala.collection.mutable
+
/** Define Preparator parameters. Recall that for our data
* representation we are only required to input the n-gram window
* components.
@@ -26,7 +34,7 @@
def prepare(sc: SparkContext, td: TrainingData): PreparedData = {
- val tfHasher = new TFHasher(pp.numFeatures, pp.nGram)
+ val tfHasher = new TFHasher(pp.numFeatures, pp.nGram, td.stopWords)
// Convert trainingdata's observation text into TF vector
// and then fit a IDF model
@@ -57,14 +65,35 @@
class TFHasher(
val numFeatures: Int,
- val nGram: Int
+ val nGram: Int,
+ val stopWords:Set[String]
) extends Serializable {
private val hasher = new HashingTF(numFeatures = numFeatures)
+/** Use Lucene StandardAnalyzer to tokenize text **/
+ def tokenize(content: String): Seq[String] = {
+ val tReader = new StringReader(content)
+ val analyzer = new StandardAnalyzer(Version.LATEST)
+ val tStream = analyzer.tokenStream("contents", tReader)
+ val term = tStream.addAttribute(classOf[CharTermAttribute])
+ tStream.reset()
+
+ val result = mutable.ArrayBuffer.empty[String]
+ while (tStream.incrementToken()) {
+ val termValue = term.toString
+
+ result += term.toString
+
+ }
+ result
+}
+
+
/** Hashing function: Text -> term frequency vector. */
def hashTF(text: String): Vector = {
- val newList : Array[String] = text.split(" ")
+ val newList : Array[String] = tokenize(text)
+ .filterNot(stopWords.contains(_))
.sliding(nGram)
.map(_.mkString)
.toArray
@@ -77,7 +106,7 @@
val hasher: TFHasher,
val idf: IDFModel
) extends Serializable {
-
+
/** trasform text to tf-idf vector. */
def transform(text: String): Vector = {
// Map(n-gram -> document tf)