blob: 76dd7ca68dbc3d9d78a8a9011ec91b1617ac40c1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.examples.classification
import org.apache.predictionio.controller.P2LAlgorithm
import org.apache.predictionio.controller.Params
import org.apache.spark.mllib.tree.RandomForest // CHANGED
import org.apache.spark.mllib.tree.model.RandomForestModel // CHANGED
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext
// CHANGED
case class RandomForestAlgorithmParams(
numClasses: Int,
numTrees: Int,
featureSubsetStrategy: String,
impurity: String,
maxDepth: Int,
maxBins: Int
) extends Params
// extends P2LAlgorithm because the MLlib's RandomForestModel doesn't
// contain RDD.
class RandomForestAlgorithm(val ap: RandomForestAlgorithmParams) // CHANGED
extends P2LAlgorithm[PreparedData, RandomForestModel, // CHANGED
Query, PredictedResult] {
// CHANGED
def train(sc: SparkContext, data: PreparedData): RandomForestModel = {
// CHANGED
// Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
RandomForest.trainClassifier(
data.labeledPoints,
ap.numClasses,
categoricalFeaturesInfo,
ap.numTrees,
ap.featureSubsetStrategy,
ap.impurity,
ap.maxDepth,
ap.maxBins)
}
def predict(
model: RandomForestModel, // CHANGED
query: Query): PredictedResult = {
val label = model.predict(Vectors.dense(
Array(query.attr0, query.attr1, query.attr2)
))
PredictedResult(label)
}
}