| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.predictionio.examples.classification |
| |
| import org.apache.predictionio.controller.P2LAlgorithm |
| import org.apache.predictionio.controller.Params |
| |
| import org.apache.spark.mllib.tree.RandomForest // CHANGED |
| import org.apache.spark.mllib.tree.model.RandomForestModel // CHANGED |
| import org.apache.spark.mllib.linalg.Vectors |
| import org.apache.spark.SparkContext |
| |
| // CHANGED |
| case class RandomForestAlgorithmParams( |
| numClasses: Int, |
| numTrees: Int, |
| featureSubsetStrategy: String, |
| impurity: String, |
| maxDepth: Int, |
| maxBins: Int |
| ) extends Params |
| |
| // extends P2LAlgorithm because the MLlib's RandomForestModel doesn't |
| // contain RDD. |
| class RandomForestAlgorithm(val ap: RandomForestAlgorithmParams) // CHANGED |
| extends P2LAlgorithm[PreparedData, RandomForestModel, // CHANGED |
| Query, PredictedResult] { |
| |
| // CHANGED |
| def train(sc: SparkContext, data: PreparedData): RandomForestModel = { |
| // CHANGED |
| // Empty categoricalFeaturesInfo indicates all features are continuous. |
| val categoricalFeaturesInfo = Map[Int, Int]() |
| RandomForest.trainClassifier( |
| data.labeledPoints, |
| ap.numClasses, |
| categoricalFeaturesInfo, |
| ap.numTrees, |
| ap.featureSubsetStrategy, |
| ap.impurity, |
| ap.maxDepth, |
| ap.maxBins) |
| } |
| |
| def predict( |
| model: RandomForestModel, // CHANGED |
| query: Query): PredictedResult = { |
| |
| val label = model.predict(Vectors.dense( |
| Array(query.attr0, query.attr1, query.attr2) |
| )) |
| PredictedResult(label) |
| } |
| |
| } |