blob: 48c4010ee1433e16d13926accca7f6999d1d3efe [file] [log] [blame]
//! Functions for scoring a set of predictions, i.e. evaluating
//! how close predictions and truth are. All functions in this
//! module obey the convention that higher is better.
use libnum::{Zero, One};
use linalg::{BaseMatrix, Matrix};
use learning::toolkit::cost_fn::{CostFunc, MeanSqError};
// ************************************
// Classification Scores
// ************************************
/// Returns the fraction of outputs which match their target.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels.
/// * `targets` - Iterator of expected (actual) labels.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::accuracy;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(accuracy(outputs.iter(), targets.iter()), 0.5);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
pub fn accuracy<I1, I2, T>(outputs: I1, targets: I2) -> f64
where T: PartialEq,
I1: ExactSizeIterator + Iterator<Item=T>,
I2: ExactSizeIterator + Iterator<Item=T>
{
assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
let len = outputs.len() as f64;
let correct = outputs
.zip(targets)
.filter(|&(ref x, ref y)| x == y)
.count();
correct as f64 / len
}
/// Returns the fraction of outputs rows which match their target.
pub fn row_accuracy(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
accuracy(outputs.row_iter().map(|r| r.raw_slice()),
targets.row_iter().map(|r| r.raw_slice()))
}
/// Returns the precision score for 2 class classification.
///
/// Precision is calculated with true-positive / (true-positive + false-positive),
/// see [Precision and Recall](https://en.wikipedia.org/wiki/Precision_and_recall) for details.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::precision;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
/// - outputs or targets contains a value which is not 0 or 1
pub fn precision<'a, I, T>(outputs: I, targets: I) -> f64
where I: ExactSizeIterator<Item=&'a T>,
T: 'a + PartialEq + Zero + One
{
assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
let mut tpfp = 0.0f64;
let mut tp = 0.0f64;
for (ref o, ref t) in outputs.zip(targets) {
if *o == &T::one() {
tpfp += 1.0f64;
if *t == &T::one() {
tp += 1.0f64;
}
}
if ((*t != &T::zero()) & (*t != &T::one())) |
((*o != &T::zero()) & (*o != &T::one())) {
panic!("precision must be used for 2 class classification")
}
}
tp / tpfp
}
/// Returns the recall score for 2 class classification.
///
/// Recall is calculated with true-positive / (true-positive + false-negative),
/// see [Precision and Recall](https://en.wikipedia.org/wiki/Precision_and_recall) for details.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::recall;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
/// - outputs or targets contains a value which is not 0 or 1
pub fn recall<'a, I, T>(outputs: I, targets: I) -> f64
where I: ExactSizeIterator<Item=&'a T>,
T: 'a + PartialEq + Zero + One
{
assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
let mut tpfn = 0.0f64;
let mut tp = 0.0f64;
for (ref o, ref t) in outputs.zip(targets) {
if *t == &T::one() {
tpfn += 1.0f64;
if *o == &T::one() {
tp += 1.0f64;
}
}
if ((*t != &T::zero()) & (*t != &T::one())) |
((*o != &T::zero()) & (*o != &T::one())) {
panic!("recall must be used for 2 class classification")
}
}
tp / tpfn
}
/// Returns the f1 score for 2 class classification.
///
/// F1-score is calculated with 2 * precision * recall / (precision + recall),
/// see [F1 score](https://en.wikipedia.org/wiki/F1_score) for details.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::f1;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
/// - outputs or targets contains a value which is not 0 or 1
pub fn f1<'a, I, T>(outputs: I, targets: I) -> f64
where I: ExactSizeIterator<Item=&'a T>,
T: 'a + PartialEq + Zero + One
{
assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
let mut tpos = 0.0f64;
let mut fpos = 0.0f64;
let mut fneg = 0.0f64;
for (ref o, ref t) in outputs.zip(targets) {
if (*o == &T::one()) & (*t == &T::one()) {
tpos += 1.0f64;
} else if *t == &T::one() {
fpos += 1.0f64;
} else if *o == &T::one() {
fneg += 1.0f64;
}
if ((*t != &T::zero()) & (*t != &T::one())) |
((*o != &T::zero()) & (*o != &T::one())) {
panic!("f1-score must be used for 2 class classification")
}
}
2.0f64 * tpos / (2.0f64 * tpos + fneg + fpos)
}
// ************************************
// Regression Scores
// ************************************
// TODO: generalise to accept arbitrary iterators of diff-able things
/// Returns the additive inverse of the mean-squared-error of the
/// outputs. So higher is better, and the returned value is always
/// negative.
pub fn neg_mean_squared_error(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64
{
// MeanSqError divides the actual mean squared error by two.
-2f64 * MeanSqError::cost(outputs, targets)
}
#[cfg(test)]
mod tests {
use linalg::Matrix;
use super::{accuracy, precision, recall, f1, neg_mean_squared_error};
#[test]
fn test_accuracy() {
let outputs = [1, 2, 3, 4, 5, 6];
let targets = [1, 2, 3, 3, 5, 1];
assert_eq!(accuracy(outputs.iter(), targets.iter()), 2f64/3f64);
let outputs = [1, 1, 1, 0, 0, 0];
let targets = [1, 1, 1, 0, 0, 1];
assert_eq!(accuracy(outputs.iter(), targets.iter()), 5.0f64 / 6.0f64);
}
#[test]
fn test_precision() {
let outputs = [1, 1, 1, 0, 0, 0];
let targets = [1, 1, 0, 0, 1, 1];
assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
let outputs = [1, 1, 1, 0, 1, 1];
let targets = [1, 1, 0, 0, 1, 1];
assert_eq!(precision(outputs.iter(), targets.iter()), 0.8);
let outputs = [0, 0, 0, 1, 1, 1];
let targets = [1, 1, 1, 1, 1, 0];
assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
let outputs = [1, 1, 1, 1, 1, 0];
let targets = [0, 0, 0, 1, 1, 1];
assert_eq!(precision(outputs.iter(), targets.iter()), 0.4);
}
#[test]
#[should_panic]
fn test_precision_outputs_not_2class() {
let outputs = [1, 2, 1, 0, 0, 0];
let targets = [1, 1, 0, 0, 1, 1];
precision(outputs.iter(), targets.iter());
}
#[test]
#[should_panic]
fn test_precision_targets_not_2class() {
let outputs = [1, 0, 1, 0, 0, 0];
let targets = [1, 2, 0, 0, 1, 1];
precision(outputs.iter(), targets.iter());
}
#[test]
fn test_recall() {
let outputs = [1, 1, 1, 0, 0, 0];
let targets = [1, 1, 0, 0, 1, 1];
assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);
let outputs = [1, 1, 1, 0, 1, 1];
let targets = [1, 1, 0, 0, 1, 1];
assert_eq!(recall(outputs.iter(), targets.iter()), 1.0);
let outputs = [0, 0, 0, 1, 1, 1];
let targets = [1, 1, 1, 1, 1, 0];
assert_eq!(recall(outputs.iter(), targets.iter()), 0.4);
let outputs = [1, 1, 1, 1, 1, 0];
let targets = [0, 0, 0, 1, 1, 1];
assert_eq!(recall(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
}
#[test]
#[should_panic]
fn test_recall_outputs_not_2class() {
let outputs = [1, 2, 1, 0, 0, 0];
let targets = [1, 1, 0, 0, 1, 1];
recall(outputs.iter(), targets.iter());
}
#[test]
#[should_panic]
fn test_recall_targets_not_2class() {
let outputs = [1, 0, 1, 0, 0, 0];
let targets = [1, 2, 0, 0, 1, 1];
recall(outputs.iter(), targets.iter());
}
#[test]
fn test_f1() {
let outputs = [1, 1, 1, 0, 0, 0];
let targets = [1, 1, 0, 0, 1, 1];
assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);
let outputs = [1, 1, 1, 0, 1, 1];
let targets = [1, 1, 0, 0, 1, 1];
assert_eq!(f1(outputs.iter(), targets.iter()), 0.8888888888888888);
let outputs = [0, 0, 0, 1, 1, 1];
let targets = [1, 1, 1, 1, 1, 0];
assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
let outputs = [1, 1, 1, 1, 1, 0];
let targets = [0, 0, 0, 1, 1, 1];
assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
}
#[test]
#[should_panic]
fn test_f1_outputs_not_2class() {
let outputs = [1, 2, 1, 0, 0, 0];
let targets = [1, 1, 0, 0, 1, 1];
f1(outputs.iter(), targets.iter());
}
#[test]
#[should_panic]
fn test_f1_targets_not_2class() {
let outputs = [1, 0, 1, 0, 0, 0];
let targets = [1, 2, 0, 0, 1, 1];
f1(outputs.iter(), targets.iter());
}
#[test]
fn test_neg_mean_squared_error_1d() {
let outputs = Matrix::new(3, 1, vec![1f64, 2f64, 3f64]);
let targets = Matrix::new(3, 1, vec![2f64, 4f64, 3f64]);
assert_eq!(neg_mean_squared_error(&outputs, &targets), -5f64/3f64);
}
#[test]
fn test_neg_mean_squared_error_2d() {
let outputs = Matrix::new(3, 2, vec![
1f64, 2f64,
3f64, 4f64,
5f64, 6f64
]);
let targets = Matrix::new(3, 2, vec![
1.5f64, 2.5f64,
5f64, 6f64,
5.5f64, 6.5f64
]);
assert_eq!(neg_mean_squared_error(&outputs, &targets), -3f64);
}
}