blob: 1666ae8e2d9e175b6978e3914892534b41a5ee05 [file] [log] [blame]
package org.template.classification
import io.prediction.controller.P2LAlgorithm
import io.prediction.controller.Params
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.linalg.Vectors
case class NaiveBayesAlgorithmParams(
lambda: Double
) extends Params
// extends P2LAlgorithm because the MLlib's NaiveBayesModel doesn't contain RDD.
class NaiveBayesAlgorithm(val ap: NaiveBayesAlgorithmParams)
extends P2LAlgorithm[PreparedData, NaiveBayesModel, Query, PredictedResult] {
def train(sc: SparkContext, data: PreparedData): NaiveBayesModel = {
NaiveBayes.train(data.labeledPoints, ap.lambda)
}
def predict(model: NaiveBayesModel, query: Query): PredictedResult = {
val label = model.predict(Vectors.dense(query.features))
new PredictedResult(label)
}
}