blob: 30dc50c20db06670efe76edda8979abaa24d9839 [file] [log] [blame]
//! Binary Tree implementations
use std::vec::*;
use std::borrow::Borrow;
use std::collections::VecDeque;
use std::boxed::*;
use linalg::{Matrix, BaseMatrix, Vector};
use learning::error::Error;
use super::{KNearest, KNearestSearch, get_distances, dist};
/// Binary tree
#[derive(Debug)]
pub struct BinaryTree<B: BinarySplit> {
// Binary tree leaf size
leafsize: usize,
// Search data
data: Option<Matrix<f64>>,
// Binary tree
root: Option<Node<B>>
}
impl<B: BinarySplit> Default for BinaryTree<B> {
/// Constructs default binary-tree (kd-tree or ball-tree) seach.
/// Each leaf contains 30 elements at maximum.
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::knn::{KDTree, BallTree};
/// let _ = KDTree::default();
/// let _ = BallTree::default();
/// ```
fn default() -> Self {
BinaryTree {
leafsize: 30,
data: None,
root: None
}
}
}
/// Binary splittable
pub trait BinarySplit: Sized {
/// Build branch from passed args
fn build(data: &Matrix<f64>, remains: Vec<usize>,
dim: usize, split: f64, min: Vector<f64>, max: Vector<f64>,
left: Node<Self>, right: Node<Self>)
-> Node<Self>;
/// Return a tuple of left and right node. First node is likely to be
/// closer to the point
unsafe fn maybe_close<'s, 'p>(&'s self, point: &'p [f64])
-> (&'s Node<Self>, &'s Node<Self>);
/// Return distance between the point and myself
fn dist(&self, point: &[f64]) -> f64;
/// Return left node
fn left(&self) -> &Node<Self>;
/// Return right node
fn right(&self) -> &Node<Self>;
}
/// Kd-tree branch
#[derive(Debug)]
pub struct KDTreeBranch {
/// dimension (column) to split
dim: usize,
/// split value
split: f64,
/// min and max of bounding box
/// i.e. hyper-rectangle contained in the branch
min: Vector<f64>,
max: Vector<f64>,
// link to left / right node
// - left node contains rows which the column specified with
// ``dim`` is less than ``split`` value.
// - right node contains greater than or equal to ``split`` value
left: Box<Node<KDTreeBranch>>,
right: Box<Node<KDTreeBranch>>,
}
/// Ball-tree branch
#[derive(Debug)]
pub struct BallTreeBranch {
/// dimension (column) to split
dim: usize,
/// split value
split: f64,
/// ball centroid and its radius
center: Vector<f64>,
radius: f64,
// link to left / right node, see KDTreeBranch comment
left: Box<Node<BallTreeBranch>>,
right: Box<Node<BallTreeBranch>>,
}
/// Kd-tree implementation
pub type KDTree = BinaryTree<KDTreeBranch>;
/// Ball-tree implementation
pub type BallTree = BinaryTree<BallTreeBranch>;
impl BinarySplit for KDTreeBranch {
fn build(_: &Matrix<f64>, _: Vec<usize>,
dim: usize, split: f64, min: Vector<f64>, max: Vector<f64>,
left: Node<Self>, right: Node<Self>) -> Node<Self> {
let b = KDTreeBranch {
dim: dim,
split: split,
min: min,
max: max,
left: Box::new(left),
right: Box::new(right)
};
Node::Branch(b)
}
unsafe fn maybe_close<'s, 'p>(&'s self, point: &'p [f64])
-> (&'s Node<Self>, &'s Node<Self>) {
if *point.get_unchecked(self.dim) < self.split {
(&self.left, &self.right)
} else {
(&self.right, &self.left)
}
}
fn dist(&self, point: &[f64]) -> f64 {
let mut d = 0.;
for ((&p, &mi), &ma) in point.iter()
.zip(self.min.iter())
.zip(self.max.iter()) {
if p < mi {
d += (mi - p) * (mi - p);
} else if ma < p {
d += (ma - p) * (ma - p);
}
// otherwise included in the hyper-rectangle
}
d.sqrt()
}
fn left(&self) -> &Node<Self> {
self.left.borrow()
}
fn right(&self) -> &Node<Self> {
self.right.borrow()
}
}
impl BinarySplit for BallTreeBranch {
fn build(data: &Matrix<f64>, remains: Vec<usize>,
dim: usize, split: f64, _: Vector<f64>, _: Vector<f64>,
left: Node<Self>, right: Node<Self>) -> Node<Self> {
// calculate centroid (mean)
// TODO: cleanup using .row()
let mut center: Vec<f64> = vec![0.; data.cols()];
for &i in &remains {
let row: Vec<f64> = data.select_rows(&[i]).into_vec();
for (c, r) in center.iter_mut().zip(row.iter()) {
*c += *r;
}
}
let len = remains.len() as f64;
for c in &mut center {
*c /= len;
}
let mut radius = 0.;
for &i in &remains {
let row: Vec<f64> = data.select_rows(&[i]).into_vec();
let d = dist(&center, &row);
if d > radius {
radius = d;
}
}
let b = BallTreeBranch {
dim: dim,
split: split,
center: Vector::new(center),
radius: radius,
left: Box::new(left),
right: Box::new(right)
};
Node::Branch(b)
}
unsafe fn maybe_close<'s, 'p>(&'s self, point: &'p [f64])
-> (&'s Node<Self>, &'s Node<Self>) {
if *point.get_unchecked(self.dim) < self.split {
(&self.left, &self.right)
} else {
(&self.right, &self.left)
}
}
fn dist(&self, point: &[f64]) -> f64 {
let d = dist(self.center.data(), point);
if d < self.radius {
0.
} else {
d - self.radius
}
}
fn left(&self) -> &Node<Self> {
self.left.borrow()
}
fn right(&self) -> &Node<Self> {
self.right.borrow()
}
}
/// Binary tree node (either branch or leaf)
#[derive(Debug)]
pub enum Node<B: BinarySplit> {
/// Binary tree branch
Branch(B),
/// Binary tree leaf
Leaf(Leaf)
}
/// Binary tree leaf
#[derive(Debug)]
pub struct Leaf {
children: Vec<usize>
}
impl Leaf {
fn new(children: Vec<usize>) -> Self {
Leaf {
children: children
}
}
}
impl<B: BinarySplit> BinaryTree<B> {
/// Constructs binary-tree (kd-tree or ball-tree) seach.
/// Specify leafsize which is maximum number to be contained in each leaf.
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::knn::{KDTree, BallTree};
/// let _ = KDTree::new(10);
/// let _ = BallTree::new(50);
/// ```
pub fn new(leafsize: usize) -> Self {
BinaryTree {
leafsize: leafsize,
data: None,
root: None
}
}
/// Select next split dimension and value. Returns tuple with 6 elements
/// - split dim
/// - split value
/// - remains for left node
/// - remains for right node
/// - updated max for left node
/// - updated min for right node
fn select_split(&self, data: &Matrix<f64>, mut remains: Vec<usize>,
mut dmin: Vector<f64>, mut dmax: Vector<f64>)
-> (usize, f64, Vec<usize>, Vec<usize>, Vector<f64>, Vector<f64>){
// avoid recursive call
loop {
// split columns which has the widest range
let (dim, d) = (&dmax - &dmin).argmax();
// Use midpoint rule, see "On the Efficiency of Nearest Neighbor Searching
// with Data Clustered in Lower Dimensions (Maneewongvatan and Mount, 1999)"
// ToDo: use unsafe get (v0.4.0?)
// https://github.com/AtheMathmo/rulinalg/pull/104
let split = unsafe {
dmin.data().get_unchecked(dim) + d / 2.0
};
// split remains
let mut l_remains: Vec<usize> = Vec::with_capacity(remains.len());
let mut r_remains: Vec<usize> = Vec::with_capacity(remains.len());
unsafe {
for r in remains {
if *data.get_unchecked([r, dim]) < split {
l_remains.push(r);
} else {
r_remains.push(r);
}
}
}
r_remains.shrink_to_fit();
l_remains.shrink_to_fit();
if l_remains.is_empty() {
// all rows are in r_remains. re-split r_remains
remains = r_remains;
dmin[dim] = split;
} else if r_remains.is_empty() {
// all rows are in l_remains. re-split l_remains
remains = l_remains;
dmax[dim] = split;
} else {
// new hyper-rectangle's min / max
let mut l_max = dmax.clone();
// ToDo: use unsafe mut (v0.4.0?)
// https://github.com/AtheMathmo/rulinalg/pull/104
l_max[dim] = split;
let mut r_min = dmin.clone();
r_min[dim] = split;
return (dim, split, l_remains, r_remains, l_max, r_min);
}
};
}
/// find next binary split
fn split(&self, data: &Matrix<f64>, remains: Vec<usize>,
dmin: Vector<f64>, dmax: Vector<f64>) -> Node<B> {
if remains.len() < self.leafsize {
Node::Leaf(Leaf::new(remains))
} else {
// ToDo: avoid this clone
let (dim, split, l_remains, r_remains, l_max, r_min) =
self.select_split(data, remains.clone(), dmin.clone(), dmax.clone());
let l_node = self.split(data, l_remains, dmin.clone(), l_max);
let g_node = self.split(data, r_remains, r_min, dmax.clone());
B::build(data, remains, dim, split, dmin, dmax, l_node, g_node)
}
}
/// find leaf contains search point
fn search_leaf<'s, 'p>(&'s self, point: &'p [f64], k: usize)
-> Result<(KNearest, VecDeque<&'s Node<B>>), Error> {
if let (&Some(ref root), &Some(ref data)) = (&self.root, &self.data) {
let mut queue: VecDeque<&Node<B>> = VecDeque::new();
queue.push_front(root);
loop {
// pop first element
let current: &Node<B> = queue.pop_front().unwrap();
match *current {
Node::Leaf(ref l) => {
let distances = get_distances(data, point, &l.children);
let kn = KNearest::new(k, l.children.clone(), distances);
return Ok((kn, queue));
},
Node::Branch(ref b) => {
// the current branch must contains target point.
// store the child branch which contains target point to
// the front, put other side on the back.
let (close, far) = unsafe {
b.maybe_close(point)
};
queue.push_front(close);
queue.push_back(far);
}
}
}
} else {
Err(Error::new_untrained())
}
}
}
/// Can search k-nearest items
impl<B: BinarySplit> KNearestSearch for BinaryTree<B> {
/// build data structure for search optimization
fn build(&mut self, data: Matrix<f64>) {
let remains: Vec<usize> = (0..data.rows()).collect();
let dmin = min(&data);
let dmax = max(&data);
self.root = Some(self.split(&data, remains, dmin, dmax));
self.data = Some(data);
}
/// Serch k-nearest items close to the point
fn search(&self, point: &[f64], k: usize) -> Result<(Vec<usize>, Vec<f64>), Error> {
if let Some(ref data) = self.data {
let (mut query, mut queue) = try!(self.search_leaf(point, k));
while !queue.is_empty() {
let current = queue.pop_front().unwrap();
match *current {
Node::Leaf(ref l) => {
let distances = get_distances(data, point, &l.children);
let mut current_dist = query.dist();
for (&i, d) in l.children.iter().zip(distances.into_iter()) {
if d < current_dist {
current_dist = query.add(i, d);
}
}
},
Node::Branch(ref b) => {
let d = b.dist(point);
if d < query.dist() {
queue.push_back(b.left());
queue.push_back(b.right());
}
}
}
}
Ok(query.get_results())
} else {
Err(Error::new_untrained())
}
}
}
/// min
fn min(data: &Matrix<f64>) -> Vector<f64> {
// ToDo: use rulinalg .min (v0.4.1?)
// https://github.com/AtheMathmo/rulinalg/pull/115
let mut results = Vec::with_capacity(data.cols());
for i in 0..data.cols() {
results.push(data[[0, i]]);
}
for row in data.row_iter() {
for (r, v) in results.iter_mut().zip(row.iter()) {
if *r > *v {
*r = *v;
}
}
}
Vector::new(results)
}
/// max
fn max(data: &Matrix<f64>) -> Vector<f64> {
// ToDo: use rulinalg .max (v0.4.1?)
// https://github.com/AtheMathmo/rulinalg/pull/115
let mut results = Vec::with_capacity(data.cols());
for i in 0..data.cols() {
results.push(data[[0, i]]);
}
for row in data.row_iter() {
for (r, v) in results.iter_mut().zip(row.iter()) {
if *r < *v {
*r = *v;
}
}
}
Vector::new(results)
}
#[cfg(test)]
mod tests {
use linalg::{Vector, Matrix, BaseMatrix};
use super::super::KNearestSearch;
use super::{KDTree, BallTree, min, max};
use super::{Node, BinarySplit, Leaf};
// return node's leaf reference, for testing purpose
fn as_leaf<B: BinarySplit>(n: &Node<B>) -> &Leaf {
match n {
&Node::Leaf(ref leaf) => leaf,
_ => panic!("Node is not leaf")
}
}
// return node's branch reference, for testing purpose
fn as_branch<B: BinarySplit>(n: &Node<B>) -> &B {
match n {
&Node::Branch(ref branch) => branch,
_ => panic!("Node is not branch")
}
}
#[test]
fn test_kdtree_construct() {
let m = Matrix::new(5, 2, vec![1., 2.,
8., 0.,
6., 10.,
3., 6.,
0., 3.]);
let mut tree = KDTree::new(3);
tree.build(m);
// split to [0, 1, 4] and [2, 3] with columns #1
let root = tree.root.unwrap();
let b = as_branch(&root);
assert_eq!(b.dim, 1);
assert_eq!(b.split, 5.);
assert_eq!(b.min, Vector::new(vec![0., 0.]));
assert_eq!(b.max, Vector::new(vec![8., 10.]));
// split to [0, 4] and [1] with columns #0
let bl = as_branch(b.left());
let br = as_leaf(b.right());
assert_eq!(bl.dim, 0);
assert_eq!(bl.split, 4.);
assert_eq!(bl.min, Vector::new(vec![0., 0.]));
assert_eq!(bl.max, Vector::new(vec![8., 5.]));
assert_eq!(br.children, vec![2, 3]);
let bll = as_leaf(bl.left());
let blr = as_leaf(bl.right());
assert_eq!(bll.children, vec![0, 4]);
assert_eq!(blr.children, vec![1]);
}
#[test]
fn test_kdtree_search() {
let m = Matrix::new(5, 2, vec![1., 2.,
8., 0.,
6., 10.,
3., 6.,
0., 3.]);
let mut tree = KDTree::new(3);
tree.build(m);
// search first leaf
let (kn, _) = tree.search_leaf(&vec![3., 4.9], 1).unwrap();
assert_eq!(kn.pairs, vec![(0, (2.0f64 * 2.0f64 + 2.9f64 * 2.9f64).sqrt())]);
// search tree
let (ind, dist) = tree.search(&vec![3., 4.9], 1).unwrap();
assert_eq!(ind, vec![3]);
assert_eq!(dist, vec![1.0999999999999996]);
let (ind, dist) = tree.search(&vec![3., 4.9], 3).unwrap();
assert_eq!(ind, vec![3, 0, 4]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076, 3.551056180912941]);
// search first leaf
let (kn, _) = tree.search_leaf(&vec![3., 4.9], 2).unwrap();
assert_eq!(kn.pairs, vec![(0, (2.0f64 * 2.0f64 + 2.9f64 * 2.9f64).sqrt()),
(4, (3.0f64 * 3.0f64 + (4.9f64 - 3.0f64) * (4.9f64 - 3.0f64)).sqrt())]);
// search tree
let (ind, dist) = tree.search(&vec![3., 4.9], 2).unwrap();
assert_eq!(ind, vec![3, 0]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076]);
}
#[cfg(feature = "datasets")]
#[test]
fn test_kdtree_search_iris_2cols() {
use super::super::super::super::datasets::iris;
let dataset = iris::load();
let data = dataset.data().select_cols(&[0, 1]);
let mut tree = KDTree::new(10);
tree.build(data);
// search tree
let (ind, dist) = tree.search(&vec![5.8, 3.6], 4).unwrap();
assert_eq!(ind, vec![18, 85, 36, 14]);
assert_eq!(dist, vec![0.22360679774997858, 0.2828427124746193, 0.31622776601683783, 0.3999999999999999]);
let (ind, dist) = tree.search(&vec![7.0, 2.6], 4).unwrap();
assert_eq!(ind, vec![76, 108, 102, 107]);
assert_eq!(dist, vec![0.28284271247461895, 0.31622776601683783, 0.41231056256176585, 0.4242640687119283]);
}
#[cfg(feature = "datasets")]
#[test]
fn test_kdtree_search_iris() {
use super::super::super::super::datasets::iris;
let dataset = iris::load();
let data = dataset.data();
let mut tree = KDTree::new(10);
tree.build(data.clone());
// search tree
let (ind, dist) = tree.search(&vec![5.8, 3.1, 3.8, 1.2], 8).unwrap();
assert_eq!(ind, vec![64, 88, 82, 95, 99, 96, 71, 61]);
assert_eq!(dist, vec![0.360555127546399, 0.3872983346207417, 0.41231056256176596,
0.4242640687119288, 0.4472135954999579, 0.4690415759823433,
0.4795831523312721, 0.5196152422706636]);
let (ind, dist) = tree.search(&vec![6.5, 3.5, 3.2, 1.3], 10).unwrap();
assert_eq!(ind, vec![71, 64, 74, 82, 79, 61, 65, 97, 75, 51]);
assert_eq!(dist, vec![1.1357816691600549, 1.1532562594670799, 1.2569805089976533,
1.2767145334803702, 1.2767145334803702, 1.284523257866513,
1.2845232578665131, 1.2884098726725122, 1.3076696830622023,
1.352774925846868]);
}
#[test]
fn test_kdtree_dim_selection() {
let m = Matrix::new(5, 2, vec![1., 2.,
3., 0.,
2., 10.,
3., 6.,
1., 3.]);
let mut tree = KDTree::new(3);
tree.build(m);
// split to [0, 1, 4] and [2, 3] with columns #1
let root = tree.root.unwrap();
let b = as_branch(&root);
assert_eq!(b.dim, 1);
assert_eq!(b.split, 5.);
assert_eq!(b.min, Vector::new(vec![1., 0.]));
assert_eq!(b.max, Vector::new(vec![3., 10.]));
// split to [0, 1] and [4] with columns #1
let bl = as_branch(b.left());
assert_eq!(bl.dim, 1);
assert_eq!(bl.split, 2.5);
assert_eq!(bl.min, Vector::new(vec![1., 0.]));
assert_eq!(bl.max, Vector::new(vec![3., 5.]));
let br = as_leaf(b.right());
assert_eq!(br.children, vec![2, 3]);
let bll = as_leaf(bl.left());
let blr = as_leaf(bl.right());
assert_eq!(bll.children, vec![0, 1]);
assert_eq!(blr.children, vec![4]);
}
#[test]
fn test_kdtree_dim_selection_biased() {
let m = Matrix::new(5, 2, vec![1., 0.,
3., 0.,
2., 20.,
3., 0.,
1., 0.]);
let mut tree = KDTree::new(3);
tree.build(m);
// split to [0, 1, 3, 4] and [2] with columns #1
let root = tree.root.unwrap();
let b = as_branch(&root);
assert_eq!(b.dim, 1);
assert_eq!(b.split, 10.);
assert_eq!(b.min, Vector::new(vec![1., 0.]));
assert_eq!(b.max, Vector::new(vec![3., 20.]));
// split to [0, 4] and [1, 3] with columns #0
let bl = as_branch(b.left());
assert_eq!(bl.dim, 0);
assert_eq!(bl.split, 2.);
assert_eq!(bl.min, Vector::new(vec![1., 0.]));
assert_eq!(bl.max, Vector::new(vec![3., 10.]));
let br = as_leaf(b.right());
assert_eq!(br.children, vec![2]);
let bll = as_leaf(bl.left());
let blr = as_leaf(bl.right());
assert_eq!(bll.children, vec![0, 4]);
assert_eq!(blr.children, vec![1, 3]);
}
#[test]
fn test_kdtree_untrained() {
let tree = KDTree::default();
let e = tree.search_leaf(&vec![3., 4.9], 1);
assert!(e.is_err());
let e = tree.search(&vec![3., 4.9], 1);
assert!(e.is_err());
}
#[test]
fn test_balltree_construct() {
let m = Matrix::new(5, 2, vec![1., 2.,
8., 0.,
6., 10.,
3., 6.,
0., 3.]);
let mut tree = BallTree::new(3);
tree.build(m);
// split to [0, 1, 4] and [2, 3] with columns #1
let root = tree.root.unwrap();
let b = as_branch(&root);
assert_eq!(b.dim, 1);
assert_eq!(b.split, 5.);
assert_eq!(b.center, Vector::new(vec![18. / 5., 21. / 5.]));
// distance between the center and [2]
let exp_d: f64 = (6. - 3.6) * (6. - 3.6) + (10. - 4.2) * (10. - 4.2);
assert_eq!(b.radius, exp_d.sqrt());
// split to [0, 4] and [1] with columns #0
let bl = as_branch(b.left());
let br = as_leaf(b.right());
assert_eq!(bl.dim, 0);
assert_eq!(bl.split, 4.);
assert_eq!(bl.center, Vector::new(vec![3., 5. / 3.]));
// distance between the center and [1]
let exp_d: f64 = (3. - 8.) * (3. - 8.) + 5. / 3. * 5. / 3.;
assert_eq!(bl.radius, exp_d.sqrt());
assert_eq!(br.children, vec![2, 3]);
let bll = as_leaf(bl.left());
let blr = as_leaf(bl.right());
assert_eq!(bll.children, vec![0, 4]);
assert_eq!(blr.children, vec![1]);
}
#[test]
fn test_balltree_search() {
let m = Matrix::new(5, 2, vec![1., 2.,
8., 0.,
6., 10.,
3., 6.,
0., 3.]);
let mut tree = BallTree::new(3);
tree.build(m);
// search first leaf
let (kn, _) = tree.search_leaf(&vec![3., 4.9], 1).unwrap();
assert_eq!(kn.pairs, vec![(0, (2.0f64 * 2.0f64 + 2.9f64 * 2.9f64).sqrt())]);
// search tree
let (ind, dist) = tree.search(&vec![3., 4.9], 1).unwrap();
assert_eq!(ind, vec![3]);
assert_eq!(dist, vec![1.0999999999999996]);
let (ind, dist) = tree.search(&vec![3., 4.9], 3).unwrap();
assert_eq!(ind, vec![3, 0, 4]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076, 3.551056180912941]);
// search first leaf
let (kn, _) = tree.search_leaf(&vec![3., 4.9], 2).unwrap();
assert_eq!(kn.pairs, vec![(0, (2.0f64 * 2.0f64 + 2.9f64 * 2.9f64).sqrt()),
(4, (3.0f64 * 3.0f64 + (4.9f64 - 3.0f64) * (4.9f64 - 3.0f64)).sqrt())]);
// search tree
let (ind, dist) = tree.search(&vec![3., 4.9], 2).unwrap();
assert_eq!(ind, vec![3, 0]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076]);
}
#[cfg(feature = "datasets")]
#[test]
fn test_balltree_search_iris() {
use super::super::super::super::datasets::iris;
let dataset = iris::load();
let data = dataset.data();
let mut tree = BallTree::new(10);
tree.build(data.clone());
// search tree
let (ind, dist) = tree.search(&vec![5.8, 3.1, 3.8, 1.2], 8).unwrap();
assert_eq!(ind, vec![64, 88, 82, 95, 99, 96, 71, 61]);
assert_eq!(dist, vec![0.360555127546399, 0.3872983346207417, 0.41231056256176596,
0.4242640687119288, 0.4472135954999579, 0.4690415759823433,
0.4795831523312721, 0.5196152422706636]);
let (ind, dist) = tree.search(&vec![6.5, 3.5, 3.2, 1.3], 10).unwrap();
assert_eq!(ind, vec![71, 64, 74, 82, 79, 61, 65, 97, 75, 51]);
assert_eq!(dist, vec![1.1357816691600549, 1.1532562594670799, 1.2569805089976533,
1.2767145334803702, 1.2767145334803702, 1.284523257866513,
1.2845232578665131, 1.2884098726725122, 1.3076696830622023,
1.352774925846868]);
}
#[test]
fn test_balltree_dim_selection_biased() {
let m = Matrix::new(5, 2, vec![1., 0.,
3., 0.,
2., 20.,
3., 0.,
1., 0.]);
let mut tree = BallTree::new(3);
tree.build(m);
// split to [0, 1, 3, 4] and [2] with columns #1
let root = tree.root.unwrap();
let b = as_branch(&root);
assert_eq!(b.dim, 1);
assert_eq!(b.split, 10.);
assert_eq!(b.center, Vector::new(vec![10. / 5., 20. / 5.]));
// distance between the center and [2]
let exp_d: f64 = (2. - 2.) * (2. - 2.) + (4. - 20.) * (4. - 20.);
assert_eq!(b.radius, exp_d.sqrt());
// split to [0, 4] and [1, 3] with columns #0
let bl = as_branch(b.left());
assert_eq!(bl.dim, 0);
assert_eq!(bl.split, 2.);
assert_eq!(bl.center, Vector::new(vec![8. / 4., 0.]));
// distance between the center and [0]
let exp_d: f64 = (2. - 1.) * (2. - 1.);
assert_eq!(bl.radius, exp_d.sqrt());
let br = as_leaf(b.right());
assert_eq!(br.children, vec![2]);
let bll = as_leaf(bl.left());
let blr = as_leaf(bl.right());
assert_eq!(bll.children, vec![0, 4]);
assert_eq!(blr.children, vec![1, 3]);
}
#[test]
fn test_balltree_untrained() {
let tree = BallTree::default();
let e = tree.search_leaf(&vec![3., 4.9], 1);
assert!(e.is_err());
let e = tree.search(&vec![3., 4.9], 1);
assert!(e.is_err());
}
#[test]
fn test_min_max() {
let data = Matrix::new(3, 2, vec![1., 2.,
2., 4.,
3., 1.]);
assert_eq!(min(&data), Vector::new(vec![1., 1.]));
assert_eq!(max(&data), Vector::new(vec![3., 4.]));
}
}