use iterator for partition kernel implementation (#438)
diff --git a/arrow/benches/partition_kernels.rs b/arrow/benches/partition_kernels.rs
index 6a9ce70..ae55fbd 100644
--- a/arrow/benches/partition_kernels.rs
+++ b/arrow/benches/partition_kernels.rs
@@ -48,7 +48,11 @@
})
.collect::<Vec<_>>();
- criterion::black_box(lexicographical_partition_ranges(&columns).unwrap());
+ criterion::black_box(
+ lexicographical_partition_ranges(&columns)
+ .unwrap()
+ .collect::<Vec<_>>(),
+ );
}
fn create_sorted_low_cardinality_data(length: usize) -> Vec<ArrayRef> {
diff --git a/arrow/src/compute/kernels/partition.rs b/arrow/src/compute/kernels/partition.rs
index e91f80b..ad35e92 100644
--- a/arrow/src/compute/kernels/partition.rs
+++ b/arrow/src/compute/kernels/partition.rs
@@ -21,6 +21,7 @@
use crate::compute::SortColumn;
use crate::error::{ArrowError, Result};
use std::cmp::Ordering;
+use std::iter::Iterator;
use std::ops::Range;
/// Given a list of already sorted columns, find partition ranges that would partition
@@ -34,65 +35,71 @@
/// range.
pub fn lexicographical_partition_ranges(
columns: &[SortColumn],
-) -> Result<Vec<Range<usize>>> {
- let partition_points = lexicographical_partition_points(columns)?;
- Ok(partition_points
- .iter()
- .zip(partition_points[1..].iter())
- .map(|(&start, &end)| Range { start, end })
- .collect())
+) -> Result<impl Iterator<Item = Range<usize>> + '_> {
+ LexicographicalPartitionIterator::try_new(columns)
}
-/// Given a list of already sorted columns, find partition ranges that would partition
-/// lexicographically equal values across columns.
-///
-/// Here LexicographicalComparator is used in conjunction with binary
-/// search so the columns *MUST* be pre-sorted already.
-///
-/// The returned vec would be of size k+1 where k is cardinality of the sorted values; the first and
-/// last value would be 0 and n.
-fn lexicographical_partition_points(columns: &[SortColumn]) -> Result<Vec<usize>> {
- if columns.is_empty() {
- return Err(ArrowError::InvalidArgumentError(
- "Sort requires at least one column".to_string(),
- ));
+struct LexicographicalPartitionIterator<'a> {
+ comparator: LexicographicalComparator<'a>,
+ num_rows: usize,
+ previous_partition_point: usize,
+ partition_point: usize,
+ value_indices: Vec<usize>,
+}
+
+impl<'a> LexicographicalPartitionIterator<'a> {
+ fn try_new(columns: &'a [SortColumn]) -> Result<LexicographicalPartitionIterator> {
+ if columns.is_empty() {
+ return Err(ArrowError::InvalidArgumentError(
+ "Sort requires at least one column".to_string(),
+ ));
+ }
+ let num_rows = columns[0].values.len();
+ if columns.iter().any(|item| item.values.len() != num_rows) {
+ return Err(ArrowError::ComputeError(
+ "Lexical sort columns have different row counts".to_string(),
+ ));
+ };
+
+ let comparator = LexicographicalComparator::try_new(columns)?;
+ let value_indices = (0..num_rows).collect::<Vec<usize>>();
+ Ok(LexicographicalPartitionIterator {
+ comparator,
+ num_rows,
+ previous_partition_point: 0,
+ partition_point: 0,
+ value_indices,
+ })
}
- let row_count = columns[0].values.len();
- if columns.iter().any(|item| item.values.len() != row_count) {
- return Err(ArrowError::ComputeError(
- "Lexical sort columns have different row counts".to_string(),
- ));
- };
+}
- let mut result = vec![];
- if row_count == 0 {
- return Ok(result);
+impl<'a> Iterator for LexicographicalPartitionIterator<'a> {
+ type Item = Range<usize>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.partition_point < self.num_rows {
+ // invariant:
+ // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point]
+ // so in order to save time we can do binary search on the value_indices[previous_partition_point..]
+ // and find when any value is greater than value_indices[previous_partition_point]; because we are using
+ // new indices, the new offset is _added_ to the previous_partition_point.
+ //
+ // be careful that idx is of type &usize which points to the actual value within value_indices, which itself
+ // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the
+ // original columnar data.
+ self.partition_point += self.value_indices[self.partition_point..]
+ .partition_point(|idx| {
+ self.comparator.compare(idx, &self.partition_point)
+ != Ordering::Greater
+ });
+ let start = self.previous_partition_point;
+ let end = self.partition_point;
+ self.previous_partition_point = self.partition_point;
+ Some(Range { start, end })
+ } else {
+ None
+ }
}
-
- let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
- let value_indices = (0..row_count).collect::<Vec<usize>>();
-
- let mut previous_partition_point = 0;
- result.push(previous_partition_point);
- while previous_partition_point < row_count {
- // invariant:
- // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point]
- // so in order to save time we can do binary search on the value_indices[previous_partition_point..]
- // and find when any value is greater than value_indices[previous_partition_point]; because we are using
- // new indices, the new offset is _added_ to the previous_partition_point.
- //
- // be careful that idx is of type &usize which points to the actual value within value_indices, which itself
- // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the
- // original columnar data.
- previous_partition_point += value_indices[previous_partition_point..]
- .partition_point(|idx| {
- lexicographical_comparator.compare(idx, &previous_partition_point)
- != Ordering::Greater
- });
- result.push(previous_partition_point);
- }
-
- Ok(result)
}
#[cfg(test)]
@@ -104,16 +111,16 @@
use std::sync::Arc;
#[test]
- fn test_lexicographical_partition_points_empty() {
+ fn test_lexicographical_partition_ranges_empty() {
let input = vec![];
assert!(
- lexicographical_partition_points(&input).is_err(),
- "lexicographical_partition_points should reject columns with empty rows"
+ lexicographical_partition_ranges(&input).is_err(),
+ "lexicographical_partition_ranges should reject columns with empty rows"
);
}
#[test]
- fn test_lexicographical_partition_points_unaligned_rows() {
+ fn test_lexicographical_partition_ranges_unaligned_rows() {
let input = vec![
SortColumn {
values: Arc::new(Int64Array::from(vec![None, Some(-1)])) as ArrayRef,
@@ -125,8 +132,8 @@
},
];
assert!(
- lexicographical_partition_points(&input).is_err(),
- "lexicographical_partition_points should reject columns with different row counts"
+ lexicographical_partition_ranges(&input).is_err(),
+ "lexicographical_partition_ranges should reject columns with different row counts"
);
}
@@ -141,14 +148,10 @@
}),
}];
{
- let results = lexicographical_partition_points(&input)?;
- assert_eq!(vec![0, 1, 8, 9], results);
- }
- {
let results = lexicographical_partition_ranges(&input)?;
assert_eq!(
vec![(0_usize..1_usize), (1_usize..8_usize), (8_usize..9_usize)],
- results
+ results.collect::<Vec<_>>()
);
}
Ok(())
@@ -163,13 +166,10 @@
nulls_first: true,
}),
}];
- {
- let results = lexicographical_partition_points(&input)?;
- assert_eq!(vec![0, 1000], results);
- }
+
{
let results = lexicographical_partition_ranges(&input)?;
- assert_eq!(vec![(0_usize..1000_usize)], results);
+ assert_eq!(vec![(0_usize..1000_usize)], results.collect::<Vec<_>>());
}
Ok(())
}
@@ -193,12 +193,8 @@
},
];
{
- let results = lexicographical_partition_points(&input)?;
- assert_eq!(vec![0, 1000], results);
- }
- {
let results = lexicographical_partition_ranges(&input)?;
- assert_eq!(vec![(0_usize..1000_usize)], results);
+ assert_eq!(vec![(0_usize..1000_usize)], results.collect::<Vec<_>>());
}
Ok(())
}
@@ -223,12 +219,11 @@
},
];
{
- let results = lexicographical_partition_points(&input)?;
- assert_eq!(vec![0, 1, 2], results);
- }
- {
let results = lexicographical_partition_ranges(&input)?;
- assert_eq!(vec![(0_usize..1_usize), (1_usize..2_usize)], results);
+ assert_eq!(
+ vec![(0_usize..1_usize), (1_usize..2_usize)],
+ results.collect::<Vec<_>>()
+ );
}
Ok(())
}
@@ -257,14 +252,10 @@
},
];
{
- let results = lexicographical_partition_points(&input)?;
- assert_eq!(vec![0, 1, 2, 3], results);
- }
- {
let results = lexicographical_partition_ranges(&input)?;
assert_eq!(
vec![(0_usize..1_usize), (1_usize..2_usize), (2_usize..3_usize),],
- results
+ results.collect::<Vec<_>>()
);
}
Ok(())
@@ -299,14 +290,10 @@
},
];
{
- let results = lexicographical_partition_points(&input)?;
- assert_eq!(vec![0, 1, 3, 4], results);
- }
- {
let results = lexicographical_partition_ranges(&input)?;
assert_eq!(
vec![(0_usize..1_usize), (1_usize..3_usize), (3_usize..4_usize),],
- results
+ results.collect::<Vec<_>>()
);
}
Ok(())