blob: 96a3ec83c729678b2c02d0a6d59b15d3f52ad0df [file] [log] [blame]
package org.example.recommendation
import org.apache.predictionio.controller.Evaluation
import org.apache.predictionio.controller.OptionAverageMetric
import org.apache.predictionio.controller.AverageMetric
import org.apache.predictionio.controller.EmptyEvaluationInfo
import org.apache.predictionio.controller.EngineParamsGenerator
import org.apache.predictionio.controller.EngineParams
import org.apache.predictionio.controller.MetricEvaluator
// Usage:
// $ pio eval org.example.recommendation.RecommendationEvaluation \
// org.example.recommendation.EngineParamsList
case class PrecisionAtK(k: Int, ratingThreshold: Double = 2.0)
extends OptionAverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
require(k > 0, "k must be greater than 0")
override def header = s"Precision@K (k=$k, threshold=$ratingThreshold)"
override
def calculate(q: Query, p: PredictedResult, a: ActualResult): Option[Double] = {
val positives: Set[String] = a.ratings.filter(_.rating >= ratingThreshold).map(_.item).toSet
// If there is no positive results, Precision is undefined. We don't consider this case in the
// metrics, hence we return None.
if (positives.size == 0) {
None
} else {
val tpCount: Int = p.itemScores.take(k).filter(is => positives(is.item)).size
Some(tpCount.toDouble / math.min(k, positives.size))
}
}
}
case class PositiveCount(ratingThreshold: Double = 2.0)
extends AverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
override def header = s"PositiveCount (threshold=$ratingThreshold)"
override
def calculate(q: Query, p: PredictedResult, a: ActualResult): Double = {
a.ratings.filter(_.rating >= ratingThreshold).size
}
}
object RecommendationEvaluation extends Evaluation {
engineEvaluator = (
RecommendationEngine(),
MetricEvaluator(
metric = PrecisionAtK(k = 10, ratingThreshold = 4.0),
otherMetrics = Seq(
PositiveCount(ratingThreshold = 4.0),
PrecisionAtK(k = 10, ratingThreshold = 2.0),
PositiveCount(ratingThreshold = 2.0),
PrecisionAtK(k = 10, ratingThreshold = 1.0),
PositiveCount(ratingThreshold = 1.0)
)))
}
object ComprehensiveRecommendationEvaluation extends Evaluation {
val ratingThresholds = Seq(0.0, 2.0, 4.0)
val ks = Seq(1, 3, 10)
engineEvaluator = (
RecommendationEngine(),
MetricEvaluator(
metric = PrecisionAtK(k = 3, ratingThreshold = 2.0),
otherMetrics = (
(for (r <- ratingThresholds) yield PositiveCount(ratingThreshold = r)) ++
(for (r <- ratingThresholds; k <- ks) yield PrecisionAtK(k = k, ratingThreshold = r))
)))
}
trait BaseEngineParamsList extends EngineParamsGenerator {
protected val baseEP = EngineParams(
dataSourceParams = DataSourceParams(
appName = "INVALID_APP_NAME",
evalParams = Some(DataSourceEvalParams(kFold = 5, queryNum = 10))))
}
object EngineParamsList extends BaseEngineParamsList {
engineParamsList = for(
rank <- Seq(5, 10, 20);
numIterations <- Seq(1, 5, 10))
yield baseEP.copy(
algorithmParamsList = Seq(
("als", ALSAlgorithmParams(rank, numIterations, 0.01, Some(3)))))
}