use iterator for partition kernel implementation (#438)

diff --git a/arrow/benches/partition_kernels.rs b/arrow/benches/partition_kernels.rs
index 6a9ce70..ae55fbd 100644
--- a/arrow/benches/partition_kernels.rs
+++ b/arrow/benches/partition_kernels.rs
@@ -48,7 +48,11 @@
         })
         .collect::<Vec<_>>();
 
-    criterion::black_box(lexicographical_partition_ranges(&columns).unwrap());
+    criterion::black_box(
+        lexicographical_partition_ranges(&columns)
+            .unwrap()
+            .collect::<Vec<_>>(),
+    );
 }
 
 fn create_sorted_low_cardinality_data(length: usize) -> Vec<ArrayRef> {
diff --git a/arrow/src/compute/kernels/partition.rs b/arrow/src/compute/kernels/partition.rs
index e91f80b..ad35e92 100644
--- a/arrow/src/compute/kernels/partition.rs
+++ b/arrow/src/compute/kernels/partition.rs
@@ -21,6 +21,7 @@
 use crate::compute::SortColumn;
 use crate::error::{ArrowError, Result};
 use std::cmp::Ordering;
+use std::iter::Iterator;
 use std::ops::Range;
 
 /// Given a list of already sorted columns, find partition ranges that would partition
@@ -34,65 +35,71 @@
 /// range.
 pub fn lexicographical_partition_ranges(
     columns: &[SortColumn],
-) -> Result<Vec<Range<usize>>> {
-    let partition_points = lexicographical_partition_points(columns)?;
-    Ok(partition_points
-        .iter()
-        .zip(partition_points[1..].iter())
-        .map(|(&start, &end)| Range { start, end })
-        .collect())
+) -> Result<impl Iterator<Item = Range<usize>> + '_> {
+    LexicographicalPartitionIterator::try_new(columns)
 }
 
-/// Given a list of already sorted columns, find partition ranges that would partition
-/// lexicographically equal values across columns.
-///
-/// Here LexicographicalComparator is used in conjunction with binary
-/// search so the columns *MUST* be pre-sorted already.
-///
-/// The returned vec would be of size k+1 where k is cardinality of the sorted values; the first and
-/// last value would be 0 and n.
-fn lexicographical_partition_points(columns: &[SortColumn]) -> Result<Vec<usize>> {
-    if columns.is_empty() {
-        return Err(ArrowError::InvalidArgumentError(
-            "Sort requires at least one column".to_string(),
-        ));
+struct LexicographicalPartitionIterator<'a> {
+    comparator: LexicographicalComparator<'a>,
+    num_rows: usize,
+    previous_partition_point: usize,
+    partition_point: usize,
+    value_indices: Vec<usize>,
+}
+
+impl<'a> LexicographicalPartitionIterator<'a> {
+    fn try_new(columns: &'a [SortColumn]) -> Result<LexicographicalPartitionIterator> {
+        if columns.is_empty() {
+            return Err(ArrowError::InvalidArgumentError(
+                "Sort requires at least one column".to_string(),
+            ));
+        }
+        let num_rows = columns[0].values.len();
+        if columns.iter().any(|item| item.values.len() != num_rows) {
+            return Err(ArrowError::ComputeError(
+                "Lexical sort columns have different row counts".to_string(),
+            ));
+        };
+
+        let comparator = LexicographicalComparator::try_new(columns)?;
+        let value_indices = (0..num_rows).collect::<Vec<usize>>();
+        Ok(LexicographicalPartitionIterator {
+            comparator,
+            num_rows,
+            previous_partition_point: 0,
+            partition_point: 0,
+            value_indices,
+        })
     }
-    let row_count = columns[0].values.len();
-    if columns.iter().any(|item| item.values.len() != row_count) {
-        return Err(ArrowError::ComputeError(
-            "Lexical sort columns have different row counts".to_string(),
-        ));
-    };
+}
 
-    let mut result = vec![];
-    if row_count == 0 {
-        return Ok(result);
+impl<'a> Iterator for LexicographicalPartitionIterator<'a> {
+    type Item = Range<usize>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.partition_point < self.num_rows {
+            // invariant:
+            // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point]
+            // so in order to save time we can do binary search on the value_indices[previous_partition_point..]
+            // and find when any value is greater than value_indices[previous_partition_point]; because we are using
+            // new indices, the new offset is _added_ to the previous_partition_point.
+            //
+            // be careful that idx is of type &usize which points to the actual value within value_indices, which itself
+            // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the
+            // original columnar data.
+            self.partition_point += self.value_indices[self.partition_point..]
+                .partition_point(|idx| {
+                    self.comparator.compare(idx, &self.partition_point)
+                        != Ordering::Greater
+                });
+            let start = self.previous_partition_point;
+            let end = self.partition_point;
+            self.previous_partition_point = self.partition_point;
+            Some(Range { start, end })
+        } else {
+            None
+        }
     }
-
-    let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
-    let value_indices = (0..row_count).collect::<Vec<usize>>();
-
-    let mut previous_partition_point = 0;
-    result.push(previous_partition_point);
-    while previous_partition_point < row_count {
-        // invariant:
-        // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point]
-        // so in order to save time we can do binary search on the value_indices[previous_partition_point..]
-        // and find when any value is greater than value_indices[previous_partition_point]; because we are using
-        // new indices, the new offset is _added_ to the previous_partition_point.
-        //
-        // be careful that idx is of type &usize which points to the actual value within value_indices, which itself
-        // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the
-        // original columnar data.
-        previous_partition_point += value_indices[previous_partition_point..]
-            .partition_point(|idx| {
-                lexicographical_comparator.compare(idx, &previous_partition_point)
-                    != Ordering::Greater
-            });
-        result.push(previous_partition_point);
-    }
-
-    Ok(result)
 }
 
 #[cfg(test)]
@@ -104,16 +111,16 @@
     use std::sync::Arc;
 
     #[test]
-    fn test_lexicographical_partition_points_empty() {
+    fn test_lexicographical_partition_ranges_empty() {
         let input = vec![];
         assert!(
-            lexicographical_partition_points(&input).is_err(),
-            "lexicographical_partition_points should reject columns with empty rows"
+            lexicographical_partition_ranges(&input).is_err(),
+            "lexicographical_partition_ranges should reject columns with empty rows"
         );
     }
 
     #[test]
-    fn test_lexicographical_partition_points_unaligned_rows() {
+    fn test_lexicographical_partition_ranges_unaligned_rows() {
         let input = vec![
             SortColumn {
                 values: Arc::new(Int64Array::from(vec![None, Some(-1)])) as ArrayRef,
@@ -125,8 +132,8 @@
             },
         ];
         assert!(
-            lexicographical_partition_points(&input).is_err(),
-            "lexicographical_partition_points should reject columns with different row counts"
+            lexicographical_partition_ranges(&input).is_err(),
+            "lexicographical_partition_ranges should reject columns with different row counts"
         );
     }
 
@@ -141,14 +148,10 @@
             }),
         }];
         {
-            let results = lexicographical_partition_points(&input)?;
-            assert_eq!(vec![0, 1, 8, 9], results);
-        }
-        {
             let results = lexicographical_partition_ranges(&input)?;
             assert_eq!(
                 vec![(0_usize..1_usize), (1_usize..8_usize), (8_usize..9_usize)],
-                results
+                results.collect::<Vec<_>>()
             );
         }
         Ok(())
@@ -163,13 +166,10 @@
                 nulls_first: true,
             }),
         }];
-        {
-            let results = lexicographical_partition_points(&input)?;
-            assert_eq!(vec![0, 1000], results);
-        }
+
         {
             let results = lexicographical_partition_ranges(&input)?;
-            assert_eq!(vec![(0_usize..1000_usize)], results);
+            assert_eq!(vec![(0_usize..1000_usize)], results.collect::<Vec<_>>());
         }
         Ok(())
     }
@@ -193,12 +193,8 @@
             },
         ];
         {
-            let results = lexicographical_partition_points(&input)?;
-            assert_eq!(vec![0, 1000], results);
-        }
-        {
             let results = lexicographical_partition_ranges(&input)?;
-            assert_eq!(vec![(0_usize..1000_usize)], results);
+            assert_eq!(vec![(0_usize..1000_usize)], results.collect::<Vec<_>>());
         }
         Ok(())
     }
@@ -223,12 +219,11 @@
             },
         ];
         {
-            let results = lexicographical_partition_points(&input)?;
-            assert_eq!(vec![0, 1, 2], results);
-        }
-        {
             let results = lexicographical_partition_ranges(&input)?;
-            assert_eq!(vec![(0_usize..1_usize), (1_usize..2_usize)], results);
+            assert_eq!(
+                vec![(0_usize..1_usize), (1_usize..2_usize)],
+                results.collect::<Vec<_>>()
+            );
         }
         Ok(())
     }
@@ -257,14 +252,10 @@
             },
         ];
         {
-            let results = lexicographical_partition_points(&input)?;
-            assert_eq!(vec![0, 1, 2, 3], results);
-        }
-        {
             let results = lexicographical_partition_ranges(&input)?;
             assert_eq!(
                 vec![(0_usize..1_usize), (1_usize..2_usize), (2_usize..3_usize),],
-                results
+                results.collect::<Vec<_>>()
             );
         }
         Ok(())
@@ -299,14 +290,10 @@
             },
         ];
         {
-            let results = lexicographical_partition_points(&input)?;
-            assert_eq!(vec![0, 1, 3, 4], results);
-        }
-        {
             let results = lexicographical_partition_ranges(&input)?;
             assert_eq!(
                 vec![(0_usize..1_usize), (1_usize..3_usize), (3_usize..4_usize),],
-                results
+                results.collect::<Vec<_>>()
             );
         }
         Ok(())