Added numFeatures option in PreparedData
diff --git a/engine.json b/engine.json
index da6bd9c..05c8440 100644
--- a/engine.json
+++ b/engine.json
@@ -9,7 +9,8 @@
},
"preparator": {
"params": {
- "nGram": 2
+ "nGram": 2,
+ "numFeatures": 15000
}
},
"algorithms": [
diff --git a/src/main/scala/org/template/textclassification/NBAlgorithm.scala b/src/main/scala/org/template/textclassification/NBAlgorithm.scala
index 512c2e9..1c1c31d 100644
--- a/src/main/scala/org/template/textclassification/NBAlgorithm.scala
+++ b/src/main/scala/org/template/textclassification/NBAlgorithm.scala
@@ -65,6 +65,8 @@
+ private val scoreArray = nb.pi.zip(nb.theta)
+
// 3. Given a document string, return a vector of corresponding
// class membership probabilities.
@@ -75,12 +77,7 @@
// Vectorize query,
val x: Vector = pd.transform(doc)
- normalize(
- nb.pi
- .zip(nb.theta)
- .map(
- e => exp(innerProduct(e._2, x.toArray) + e._1))
- )
+ normalize(scoreArray.map(e => exp(innerProduct(e._2, x.toArray) + e._1)))
}
// 4. Implement predict method for our model using
diff --git a/src/main/scala/org/template/textclassification/Preparator.scala b/src/main/scala/org/template/textclassification/Preparator.scala
index 871b7e4..badc49c 100644
--- a/src/main/scala/org/template/textclassification/Preparator.scala
+++ b/src/main/scala/org/template/textclassification/Preparator.scala
@@ -20,7 +20,8 @@
// components.
case class PreparatorParams(
- nGram : Int
+ nGram : Int,
+ numFeatures: Int = 15000
) extends Params
@@ -31,22 +32,23 @@
// Prepare your training data.
def prepare(sc : SparkContext, td: TrainingData): PreparedData = {
- new PreparedData(td, pp.nGram)
+ new PreparedData(td, pp.nGram, pp.numFeatures)
}
}
//------PreparedData------------------------
class PreparedData (
-val td : TrainingData,
-val nGram : Int
+ val td : TrainingData,
+ val nGram : Int,
+ val numFeatures: Int
) extends Serializable {
// 1. Hashing function: Text -> term frequency vector.
- private val hasher = new HashingTF()
+ private val hasher = new HashingTF(numFeatures = numFeatures)
private def hashTF (text : String) : Vector = {
val newList : Array[String] = text.split(" ")