blob: 952e339f8bca9566405e65cbdc9581fb8c39f2f9 [file] [log] [blame]
//! Bruteforce search implementations
use std::vec::*;
use linalg::{Matrix, BaseMatrix};
use learning::error::Error;
use super::{KNearest, KNearestSearch, get_distances, dist};
/// Perform brute-force search
#[derive(Debug)]
pub struct BruteForce {
data: Option<Matrix<f64>>,
}
impl Default for BruteForce {
/// Constructs new brute-force search
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::knn::BruteForce;
/// let _ = BruteForce::default();
/// ```
fn default() -> Self {
BruteForce {
data: None
}
}
}
impl BruteForce {
/// Constructs new brute-force search.
/// BruteForce accepts no parapeters.
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::knn::BruteForce;
/// let _ = BruteForce::new();
/// ```
pub fn new() -> Self {
BruteForce::default()
}
}
/// Can search K-nearest items
impl KNearestSearch for BruteForce {
/// initialize BruteForce Searcher
fn build(&mut self, data: Matrix<f64>) {
self.data = Some(data);
}
/// Serch k-nearest items close to the point
fn search(&self, point: &[f64], k: usize) -> Result<(Vec<usize>, Vec<f64>), Error> {
if let Some(ref data) = self.data {
let indices: Vec<usize> = (0..k).collect();
let distances = get_distances(data, point, &indices);
let mut query = KNearest::new(k, indices, distances);
let mut current_dist = query.dist();
let mut i = k;
for row in data.row_iter().skip(k) {
let d = dist(point, row.raw_slice());
if d < current_dist {
current_dist = query.add(i, d);
}
i += 1;
}
Ok(query.get_results())
} else {
Err(Error::new_untrained())
}
}
}
#[cfg(test)]
mod tests {
use linalg::Matrix;
use super::super::KNearestSearch;
use super::BruteForce;
#[test]
fn test_bruteforce_search() {
let m = Matrix::new(5, 2, vec![1., 2.,
8., 0.,
6., 10.,
3., 6.,
0., 3.]);
let mut b = BruteForce::new();
b.build(m);
let (ind, dist) = b.search(&vec![3., 4.9], 1).unwrap();
assert_eq!(ind, vec![3]);
assert_eq!(dist, vec![1.0999999999999996]);
let (ind, dist) = b.search(&vec![3., 4.9], 2).unwrap();
assert_eq!(ind, vec![3, 0]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076]);
let (ind, dist) = b.search(&vec![3., 4.9], 3).unwrap();
assert_eq!(ind, vec![3, 0, 4]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076, 3.551056180912941]);
}
#[test]
fn test_bruteforce_untrained() {
let b = BruteForce::new();
let e = b.search(&vec![3., 4.9], 1);
assert!(e.is_err());
}
}