blob: 6625551268938d0717073578ba70ce76a22a28a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.examples.classification
import org.apache.predictionio.controller.P2LAlgorithm
import org.apache.predictionio.controller.Params
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext
import grizzled.slf4j.Logger
case class AlgorithmParams(
lambda: Double
) extends Params
// extends P2LAlgorithm because the MLlib's NaiveBayesModel doesn't contain RDD.
class NaiveBayesAlgorithm(val ap: AlgorithmParams)
extends P2LAlgorithm[PreparedData, NaiveBayesModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(sc: SparkContext, data: PreparedData): NaiveBayesModel = {
// MLLib NaiveBayes cannot handle empty training data.
require(data.labeledPoints.take(1).nonEmpty,
s"RDD[labeledPoints] in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preparator generates PreparedData correctly.")
NaiveBayes.train(data.labeledPoints, ap.lambda)
}
def predict(model: NaiveBayesModel, query: Query): PredictedResult = {
val label = model.predict(Vectors.dense(
// MODIFIED
Array(query.featureA, query.featureB, query.featureC, query.featureD)
))
PredictedResult(label)
}
}