blob: 1699925c92a3ac90745acb4d2155d8564e4e9417 [file] [log] [blame]
/** Copyright 2014 TappingStone, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prediction.controller
import io.prediction.core.BaseMetrics
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import scala.reflect._
import scala.reflect.runtime.universe._
/** Base class of metrics.
*
* Metrics compare predictions with actual known values and produce numerical
* comparisons.
*
* @tparam MP Metrics parameters class.
* @tparam PD Prepared data class.
* @tparam Q Input query class.
* @tparam P Output prediction class.
* @tparam A Actual value class.
* @tparam MU Metrics unit class.
* @tparam MR Metrics result class.
* @tparam MMR Multiple metrics results class.
* @group Metrics
*/
abstract class Metrics[
MP <: Params : ClassTag, DP, Q, P, A, MU, MR, MMR <: AnyRef]
extends BaseMetrics[MP, DP, Q, P, A, MU, MR, MMR] {
def computeUnitBase(input: (Q, P, A)): MU = {
computeUnit(input._1, input._2, input._3)
}
/** Implement this method to calculate a unit of metrics, comparing a pair of
* predicted and actual values.
*
* @param query Input query that produced the prediction.
* @param prediction The predicted value.
* @param actual The actual value.
*/
def computeUnit(query: Q, prediction: P, actual: A): MU
def computeSetBase(dataParams: DP, metricUnits: Seq[MU]): MR = {
computeSet(dataParams, metricUnits)
}
/** Implement this method to calculate metrics results of an evaluation.
*
* @param dataParams Data parameters that were used to generate data for this
* evaluation.
* @param metricUnits A list of metric units from [[computeUnit]].
*/
def computeSet(dataParams: DP, metricUnits: Seq[MU]): MR
def computeMultipleSetsBase(input: Seq[(DP, MR)]): MMR = {
computeMultipleSets(input)
}
/** Implement this method to aggregate all metrics results generated by
* each evaluation's [[computeSet]] to produce the final result.
*
* @param input A list of data parameters and metric unit pairs to aggregate.
*/
def computeMultipleSets(input: Seq[(DP, MR)]): MMR
}
/** Trait for nice metrics results
*
* Metrics result can be rendered nicely by implementing toHTML and toJSON
* methods. These results are rendered through dashboard.
* @group Metrics
*/
trait NiceRendering {
/** HTML portion of the rendered metrics results. */
def toHTML(): String
/** JSON portion of the rendered metrics results. */
def toJSON(): String
}
/** An implementation of mean square error metrics. `DP` is `AnyRef`. This
* support any kind of data parameters.
*
* @group Metrics
*/
class MeanSquareError extends Metrics[EmptyParams, AnyRef,
AnyRef, Double, Double, (Double, Double), String, String] {
def computeUnit(q: AnyRef, p: Double, a: Double): (Double, Double) = (p, a)
def computeSet(ep: AnyRef, data: Seq[(Double, Double)]): String = {
val units = data.map(e => math.pow(e._1 - e._2, 2))
val mse = units.sum / units.length
f"Set: $ep Size: ${data.length} MSE: ${mse}%8.6f"
}
def computeMultipleSets(input: Seq[(AnyRef, String)]): String = {
input.map(_._2).mkString("\n")
}
}