blob: de80b76acc79b0134aad9882fa4d5d84cc262c3c [file] [log] [blame]
package org.template.classification
import io.prediction.controller.OptionAverageMetric
import io.prediction.controller.EmptyEvaluationInfo
import io.prediction.controller.Evaluation
case class Precision(label: Double)
extends OptionAverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
override def header: String = s"Precision(label = $label)"
def calculate(query: Query, predicted: PredictedResult, actual: ActualResult)
: Option[Double] = {
if (predicted.label == label) {
if (predicted.label == actual.label) {
Some(1.0) // True positive
} else {
Some(0.0) // False positive
}
} else {
None // Unrelated case for calcuating precision
}
}
}
object PrecisionEvaluation extends Evaluation {
engineMetric = (ClassificationEngine(), new Precision(label = 1.0))
}