refactor lexico sort (#423)
diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs
index 25611cb..feef00c 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -17,14 +17,12 @@
//! Defines sort kernel for `ArrayRef`
-use std::cmp::Ordering;
-
use crate::array::*;
use crate::buffer::MutableBuffer;
use crate::compute::take;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
-
+use std::cmp::Ordering;
use TimeUnit::*;
/// Sort the `ArrayRef` using `SortOptions`.
@@ -835,26 +833,55 @@
));
};
- // map to data and DynComparator
- let flat_columns = columns
- .iter()
- .map(
- |column| -> Result<(&ArrayData, DynComparator, SortOptions)> {
- // flatten and convert build comparators
- // use ArrayData for is_valid checks later to avoid dynamic call
- let values = column.values.as_ref();
- let data = values.data_ref();
- Ok((
- data,
- build_compare(values, values)?,
- column.options.unwrap_or_default(),
- ))
- },
- )
- .collect::<Result<Vec<(&ArrayData, DynComparator, SortOptions)>>>()?;
+ let mut value_indices = (0..row_count).collect::<Vec<usize>>();
+ let mut len = value_indices.len();
- let lex_comparator = |a_idx: &usize, b_idx: &usize| -> Ordering {
- for (data, comparator, sort_option) in flat_columns.iter() {
+ if let Some(limit) = limit {
+ len = limit.min(len);
+ }
+
+ let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
+ sort_by(&mut value_indices, len, |a, b| {
+ lexicographical_comparator.compare(a, b)
+ });
+
+ Ok(UInt32Array::from(
+ (&value_indices)[0..len]
+ .iter()
+ .map(|i| *i as u32)
+ .collect::<Vec<u32>>(),
+ ))
+}
+
+/// It's unstable_sort, may not preserve the order of equal elements
+pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut is_less: F)
+where
+ F: FnMut(&T, &T) -> Ordering,
+{
+ let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less);
+ before.sort_unstable_by(is_less);
+}
+
+type LexicographicalCompareItem<'a> = (
+ &'a ArrayData, // data
+ Box<dyn Fn(usize, usize) -> Ordering + 'a>, // comparator
+ SortOptions, // sort_option
+);
+
+/// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data
+/// at given two indices. The lifetime is the same at the data wrapped.
+pub(super) struct LexicographicalComparator<'a> {
+ compare_items: Vec<LexicographicalCompareItem<'a>>,
+}
+
+impl LexicographicalComparator<'_> {
+ /// lexicographically compare values at the wrapped columns with given indices.
+ pub(super) fn compare<'a, 'b>(
+ &'a self,
+ a_idx: &'b usize,
+ b_idx: &'b usize,
+ ) -> Ordering {
+ for (data, comparator, sort_option) in &self.compare_items {
match (data.is_valid(*a_idx), data.is_valid(*b_idx)) {
(true, true) => {
match (comparator)(*a_idx, *b_idx) {
@@ -889,31 +916,29 @@
}
Ordering::Equal
- };
-
- let mut value_indices = (0..row_count).collect::<Vec<usize>>();
- let mut len = value_indices.len();
-
- if let Some(limit) = limit {
- len = limit.min(len);
}
- sort_by(&mut value_indices, len, lex_comparator);
- Ok(UInt32Array::from(
- (&value_indices)[0..len]
+ /// Create a new lex comparator that will wrap the given sort columns and give comparison
+ /// results with two indices.
+ pub(super) fn try_new(
+ columns: &[SortColumn],
+ ) -> Result<LexicographicalComparator<'_>> {
+ let compare_items = columns
.iter()
- .map(|i| *i as u32)
- .collect::<Vec<u32>>(),
- ))
-}
-
-/// It's unstable_sort, may not preserve the order of equal elements
-pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut is_less: F)
-where
- F: FnMut(&T, &T) -> Ordering,
-{
- let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less);
- before.sort_unstable_by(is_less);
+ .map(|column| {
+ // flatten and convert build comparators
+ // use ArrayData for is_valid checks later to avoid dynamic call
+ let values = column.values.as_ref();
+ let data = values.data_ref();
+ Ok((
+ data,
+ build_compare(values, values)?,
+ column.options.unwrap_or_default(),
+ ))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(LexicographicalComparator { compare_items })
+ }
}
#[cfg(test)]