blob: aebb5a41826e662debedc850a1b11e4fac2ac18e [file] [log] [blame]
package org.example.classification
import org.apache.predictionio.controller.OptionAverageMetric
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.Evaluation
case class Precision(label: Double)
extends OptionAverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
override def header: String = s"Precision(label = $label)"
override
def calculate(query: Query, predicted: PredictedResult, actual: ActualResult)
: Option[Double] = {
if (predicted.label == label) {
if (predicted.label == actual.label) {
Some(1.0) // True positive
} else {
Some(0.0) // False positive
}
} else {
None // Unrelated case for calculating precision
}
}
}
object PrecisionEvaluation extends Evaluation {
engineMetric = (ClassificationEngine(), Precision(label = 1.0))
}