| package org.template.classification |
| |
| import io.prediction.controller.P2LAlgorithm |
| import io.prediction.controller.Params |
| |
| import org.apache.spark.SparkContext |
| import org.apache.spark.mllib.classification.NaiveBayes |
| import org.apache.spark.mllib.classification.NaiveBayesModel |
| import org.apache.spark.mllib.linalg.Vectors |
| |
| case class NaiveBayesAlgorithmParams( |
| lambda: Double |
| ) extends Params |
| |
| // extends P2LAlgorithm because the MLlib's NaiveBayesModel doesn't contain RDD. |
| class NaiveBayesAlgorithm(val ap: NaiveBayesAlgorithmParams) |
| extends P2LAlgorithm[PreparedData, NaiveBayesModel, Query, PredictedResult] { |
| |
| def train(sc: SparkContext, data: PreparedData): NaiveBayesModel = { |
| NaiveBayes.train(data.labeledPoints, ap.lambda) |
| } |
| |
| def predict(model: NaiveBayesModel, query: Query): PredictedResult = { |
| val label = model.predict(Vectors.dense(query.features)) |
| new PredictedResult(label) |
| } |
| |
| } |