| package org.template.classification |
| |
| import io.prediction.controller.P2LAlgorithm |
| import io.prediction.controller.Params |
| |
| import org.apache.spark.mllib.classification.NaiveBayes |
| import org.apache.spark.mllib.classification.NaiveBayesModel |
| import org.apache.spark.mllib.linalg.Vectors |
| |
| import grizzled.slf4j.Logger |
| |
| case class AlgorithmParams( |
| lambda: Double |
| ) extends Params |
| |
| // extends P2LAlgorithm because the MLlib's NaiveBayesModel doesn't contain RDD. |
| class NaiveBayesAlgorithm(val ap: AlgorithmParams) |
| extends P2LAlgorithm[PreparedData, NaiveBayesModel, Query, PredictedResult] { |
| |
| @transient lazy val logger = Logger[this.type] |
| |
| def train(data: PreparedData): NaiveBayesModel = { |
| // MLLib NaiveBayes cannot handle empty training data. |
| require(!data.labeledPoints.take(1).isEmpty, |
| s"RDD[labeldPoints] in PreparedData cannot be empty." + |
| " Please check if DataSource generates TrainingData" + |
| " and Preprator generates PreparedData correctly.") |
| |
| NaiveBayes.train(data.labeledPoints, ap.lambda) |
| } |
| |
| def predict(model: NaiveBayesModel, query: Query): PredictedResult = { |
| val label = model.predict(Vectors.dense(query.features)) |
| new PredictedResult(label) |
| } |
| |
| } |