blob: 04f865a0f59440fea56d64f5830b3cd79acdf7cb [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::{Deref, DerefMut};
use unchecked_index::UncheckedIndex;
pub trait ComparableForLoserTree {
fn lt(&self, other: &Self) -> bool;
}
/// An implementation of the tournament loser tree data structure.
pub struct LoserTree<T> {
losers: UncheckedIndex<Vec<usize>>,
values: UncheckedIndex<Vec<T>>,
}
#[allow(clippy::len_without_is_empty)]
impl<T: ComparableForLoserTree> LoserTree<T> {
pub fn new(values: Vec<T>) -> Self {
let mut tree = unsafe {
// safety:
// this component is performance critical, use unchecked index
// to avoid boundary checking.
Self {
losers: unchecked_index::unchecked_index(vec![]),
values: unchecked_index::unchecked_index(values),
}
};
tree.init_tree();
tree
}
pub fn values(&self) -> &[T] {
&self.values
}
pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn peek(&self) -> &T {
&self.values[0]
}
pub fn peek_mut(&mut self) -> LoserTreePeekMut<'_, T> {
LoserTreePeekMut {
tree: self,
dirty: false,
}
}
fn init_tree(&mut self) {
self.losers.resize(self.values.len(), usize::MAX);
for i in 0..self.values.len() {
let mut winner = i;
let mut cmp_node = (self.values.len() + i) / 2;
while cmp_node != 0 && self.losers[cmp_node] != usize::MAX {
let challenger = self.losers[cmp_node];
if self.values[challenger].lt(&self.values[winner]) {
self.losers[cmp_node] = winner;
winner = challenger;
} else {
self.losers[cmp_node] = challenger;
}
cmp_node /= 2;
}
self.losers[cmp_node] = winner;
}
}
fn adjust_tree(&mut self) {
let mut winner = self.losers[0];
let mut cmp_node = (self.values.len() + winner) / 2;
while cmp_node != 0 {
let challenger = self.losers[cmp_node];
if self.values[challenger].lt(&self.values[winner]) {
self.losers[cmp_node] = winner;
winner = challenger;
}
cmp_node /= 2;
}
self.losers[0] = winner;
}
}
/// A PeekMut structure to the loser tree, used to get smallest value and auto
/// adjusting after dropped.
pub struct LoserTreePeekMut<'a, T: ComparableForLoserTree> {
tree: &'a mut LoserTree<T>,
dirty: bool,
}
impl<T: ComparableForLoserTree> LoserTreePeekMut<'_, T> {
pub fn adjust(&mut self) {
if self.dirty {
self.tree.adjust_tree();
self.dirty = false;
}
}
}
impl<T: ComparableForLoserTree> Deref for LoserTreePeekMut<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.tree.values[self.tree.losers[0]]
}
}
impl<T: ComparableForLoserTree> DerefMut for LoserTreePeekMut<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.dirty = true;
&mut self.tree.values[self.tree.losers[0]]
}
}
impl<T: ComparableForLoserTree> Drop for LoserTreePeekMut<'_, T> {
fn drop(&mut self) {
if self.dirty {
self.tree.adjust_tree();
}
}
}
#[cfg(test)]
mod test {
use itertools::Itertools;
use rand::Rng;
use crate::algorithm::loser_tree::{ComparableForLoserTree, LoserTree};
#[test]
fn fuzztest() {
for _ in 0..10 {
let num_nodes = rand::rng().random_range(1..=999);
let mut nodes = vec![];
for _ in 0..num_nodes {
let node_len = rand::rng().random_range(1..=999);
let mut node = vec![];
for _ in 0..node_len {
node.push(rand::rng().random_range(1000..=9999));
}
nodes.push(node);
}
// expected
let expected = nodes
.clone()
.into_iter()
.flatten()
.sorted_unstable()
.collect_vec();
// actual
struct Cursor {
row_idx: usize,
values: Vec<u64>,
}
impl ComparableForLoserTree for Cursor {
fn lt(&self, other: &Self) -> bool {
match (
self.values.get(self.row_idx),
other.values.get(other.row_idx),
) {
(Some(v1), Some(v2)) => v1 < v2,
(None, _) => false,
(_, None) => true,
}
}
}
let mut loser_tree = LoserTree::new(
nodes
.into_iter()
.map(|node| Cursor {
row_idx: 0,
values: node.into_iter().sorted_unstable().collect_vec(),
})
.collect_vec(),
);
let mut actual = vec![];
loop {
let mut min = loser_tree.peek_mut();
if let Some(v) = min.values.get(min.row_idx) {
actual.push(*v);
min.row_idx += 1;
} else {
break;
}
}
for cursor in loser_tree.values() {
assert_eq!(cursor.row_idx, cursor.values.len());
}
assert_eq!(actual, expected);
}
}
}