blob: 53f92294467baebbd39f055f5b8dc86ffde8842e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.examples.classification
import org.apache.predictionio.controller.AverageMetric
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.EngineParams
import org.apache.predictionio.controller.EngineParamsGenerator
import org.apache.predictionio.controller.Evaluation
case class Accuracy()
extends AverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
def calculate(query: Query, predicted: PredictedResult, actual: ActualResult)
: Double = (if (predicted.label == actual.label) 1.0 else 0.0)
}
object AccuracyEvaluation extends Evaluation {
// Define Engine and Metric used in Evaluation
engineMetric = (ClassificationEngine(), Accuracy())
}
object EngineParamsList extends EngineParamsGenerator {
// Define list of EngineParams used in Evaluation
// First, we define the base engine params. It specifies the appId from which
// the data is read, and a evalK parameter is used to define the
// cross-validation.
private[this] val baseEP = EngineParams(
dataSourceParams = DataSourceParams(appName = "INVALID_APP_NAME", evalK = Some(5)))
// Second, we specify the engine params list by explicitly listing all
// algorithm parameters. In this case, we evaluate 3 engine params, each with
// a different algorithm params value.
engineParamsList = Seq(
baseEP.copy(algorithmParamsList = Seq(("naive", AlgorithmParams(10.0)))),
baseEP.copy(algorithmParamsList = Seq(("naive", AlgorithmParams(100.0)))),
baseEP.copy(algorithmParamsList = Seq(("naive", AlgorithmParams(1000.0)))))
}