blob: c40b187a0aa7570d3d312c7b4edd7d2dd8d416e5 [file] [log] [blame]
package org.template.textclassification
import io.prediction.controller.P2LAlgorithm
import io.prediction.controller.Params
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import grizzled.slf4j.Logger
import java.nio.file.{Files, Paths}
import vw.VW
case class AlgorithmParams(
maxIter: Int,
regParam: Double,
stepSize: Double,
bitPrecision: Int,
modelName: String,
namespace: String,
ngram: Int
) extends Params
// extends P2LAlgorithm because VW doesn't contain RDD.
class VowpalLogisticRegressionWithSGDAlgorithm(val ap: AlgorithmParams)
extends P2LAlgorithm[PreparedData, Array[Byte], Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(sc: SparkContext, data: PreparedData): Array[Byte] = {
require(!data.td.data.take(1).isEmpty,
s"RDD[labeldPoints] in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
val reg = "--l2 " + ap.regParam
//val iters = "-c -k --passes " + ap.maxIter
val lrate = "-l " + ap.stepSize
val ngram = "--ngram " + ap.ngram
val vw = new VW("--loss_function logistic --invert_hash readable.model -b " + ap.bitPrecision + " " + "-f " + ap.modelName + " " + reg + " " + lrate + " " + ngram)
val inputs = for (point <- data.transformedData.collect) yield (if (point.point.label.toDouble == 0.0) "-1.0" else "1.0") + " |" + ap.namespace + " " + rawTextToVWFormattedString(point.text) + " " + vectorToVWFormattedString(point.point.features)
//val inputs = for (point <- data.transformedData) yield (if (point.label.toDouble == 0.0) "-1.0" else "1.0") + " |" + ap.namespace + " " + rawTextToVWFormattedString(point.)
//Regressing
//val inputs = for (point <- data.td.data) yield point.category.toDouble.toString + " |" + ap.namespace + " " + rawTextToVWFormattedString(point.text)
//for (item <- inputsCollected) logger.info(item)
val results = for (item <- inputs) yield vw.learn(item)
val matchOnTrainSet = for (item <- inputs) yield item.startsWith(if(vw.predict(item).toDouble > 0.5) "1" else "-1")
val acc = (for (x <- matchOnTrainSet) yield if(x) 1 else 0).sum.toDouble / matchOnTrainSet.size
println("Accuracy on Training set: " + acc)
vw.close()
Files.readAllBytes(Paths.get(ap.modelName))
}
def predict(byteArray: Array[Byte], query: Query): PredictedResult = {
Files.write(Paths.get(ap.modelName), byteArray)
val vw = new VW("--link logistic -i " + ap.modelName)
val pred = vw.predict("|" + ap.namespace + " " + rawTextToVWFormattedString(query.text)).toDouble
vw.close()
val category = (if(pred > 0.5) 1 else 0).toString
val prob = (if(pred > 0.5) pred else 1.0 - pred)
val result = new PredictedResult(category, prob)
result
}
def rawTextToVWFormattedString(str: String) : String = {
//VW input cannot contain these characters
str.replaceAll("[|:]", " ")
}
def vectorToVWFormattedString(vec: Vector): String = {
vec.toArray.zipWithIndex.map{ case (dbl, int) => s"$int:$dbl"} mkString " "
}
}