Fix subtraction underflow when sorting string arrays with many nulls (#285)
diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs
index 9287425..7cd463d 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -410,24 +410,27 @@
len = limit.min(len);
}
if !descending {
- sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1));
+ sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp(a.1, b.1)
+ });
} else {
- sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse());
+ sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp(a.1, b.1).reverse()
+ });
// reverse to keep a stable ordering
nulls.reverse();
}
// collect results directly into a buffer instead of a vec to avoid another aligned allocation
- let mut result = MutableBuffer::new(values.len() * std::mem::size_of::<u32>());
+ let result_capacity = len * std::mem::size_of::<u32>();
+ let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
- result.resize(values.len() * std::mem::size_of::<u32>(), 0);
+ result.resize(result_capacity, 0);
let result_slice: &mut [u32] = result.typed_data_mut();
- debug_assert_eq!(result_slice.len(), nulls_len + valids_len);
-
if options.nulls_first {
let size = nulls_len.min(len);
- result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls);
+ result_slice[0..size].copy_from_slice(&nulls[0..size]);
if nulls_len < len {
insert_valid_values(result_slice, nulls_len, &valids[0..len - size]);
}
@@ -626,9 +629,13 @@
len = limit.min(len);
}
if !descending {
- sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1));
+ sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp(a.1, b.1)
+ });
} else {
- sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse());
+ sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp(a.1, b.1).reverse()
+ });
// reverse to keep a stable ordering
nulls.reverse();
}
@@ -689,11 +696,11 @@
len = limit.min(len);
}
if !descending {
- sort_by(&mut valids, len - nulls_len, |a, b| {
+ sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
cmp_array(a.1.as_ref(), b.1.as_ref())
});
} else {
- sort_by(&mut valids, len - nulls_len, |a, b| {
+ sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
});
// reverse to keep a stable ordering
@@ -1285,6 +1292,48 @@
None,
vec![5, 0, 2, 1, 4, 3],
);
+
+ // valid values less than limit with extra nulls
+ test_sort_to_indices_primitive_arrays::<Float64Type>(
+ vec![Some(2.0), None, None, Some(1.0)],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![3, 0, 1],
+ );
+
+ test_sort_to_indices_primitive_arrays::<Float64Type>(
+ vec![Some(2.0), None, None, Some(1.0)],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![1, 2, 3],
+ );
+
+ // more nulls than limit
+ test_sort_to_indices_primitive_arrays::<Float64Type>(
+ vec![Some(1.0), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(2),
+ vec![1, 2],
+ );
+
+ test_sort_to_indices_primitive_arrays::<Float64Type>(
+ vec![Some(1.0), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(2),
+ vec![0, 1],
+ );
}
#[test]
@@ -1329,6 +1378,48 @@
Some(3),
vec![5, 0, 2],
);
+
+ // valid values less than limit with extra nulls
+ test_sort_to_indices_boolean_arrays(
+ vec![Some(true), None, None, Some(false)],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![3, 0, 1],
+ );
+
+ test_sort_to_indices_boolean_arrays(
+ vec![Some(true), None, None, Some(false)],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![1, 2, 3],
+ );
+
+ // more nulls than limit
+ test_sort_to_indices_boolean_arrays(
+ vec![Some(true), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(2),
+ vec![1, 2],
+ );
+
+ test_sort_to_indices_boolean_arrays(
+ vec![Some(true), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(2),
+ vec![0, 1],
+ );
}
#[test]
@@ -1686,6 +1777,48 @@
Some(3),
vec![3, 0, 2],
);
+
+ // valid values less than limit with extra nulls
+ test_sort_to_indices_string_arrays(
+ vec![Some("def"), None, None, Some("abc")],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![3, 0, 1],
+ );
+
+ test_sort_to_indices_string_arrays(
+ vec![Some("def"), None, None, Some("abc")],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![1, 2, 3],
+ );
+
+ // more nulls than limit
+ test_sort_to_indices_string_arrays(
+ vec![Some("def"), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(2),
+ vec![1, 2],
+ );
+
+ test_sort_to_indices_string_arrays(
+ vec![Some("def"), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(2),
+ vec![0, 1],
+ );
}
#[test]
@@ -1799,6 +1932,48 @@
Some(3),
vec![None, None, Some("sad")],
);
+
+ // valid values less than limit with extra nulls
+ test_sort_string_arrays(
+ vec![Some("def"), None, None, Some("abc")],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![Some("abc"), Some("def"), None],
+ );
+
+ test_sort_string_arrays(
+ vec![Some("def"), None, None, Some("abc")],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some("abc")],
+ );
+
+ // more nulls than limit
+ test_sort_string_arrays(
+ vec![Some("def"), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(2),
+ vec![None, None],
+ );
+
+ test_sort_string_arrays(
+ vec![Some("def"), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(2),
+ vec![Some("def"), None],
+ );
}
#[test]
@@ -1912,6 +2087,48 @@
Some(3),
vec![None, None, Some("sad")],
);
+
+ // valid values less than limit with extra nulls
+ test_sort_string_dict_arrays::<Int16Type>(
+ vec![Some("def"), None, None, Some("abc")],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![Some("abc"), Some("def"), None],
+ );
+
+ test_sort_string_dict_arrays::<Int16Type>(
+ vec![Some("def"), None, None, Some("abc")],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some("abc")],
+ );
+
+ // more nulls than limit
+ test_sort_string_dict_arrays::<Int16Type>(
+ vec![Some("def"), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(2),
+ vec![None, None],
+ );
+
+ test_sort_string_dict_arrays::<Int16Type>(
+ vec![Some("def"), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(2),
+ vec![Some("def"), None],
+ );
}
#[test]
@@ -1999,6 +2216,52 @@
vec![Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)])],
None,
);
+
+ // valid values less than limit with extra nulls
+ test_sort_list_arrays::<Int32Type>(
+ vec![Some(vec![Some(1)]), None, None, Some(vec![Some(2)])],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![Some(vec![Some(1)]), Some(vec![Some(2)]), None],
+ None,
+ );
+
+ test_sort_list_arrays::<Int32Type>(
+ vec![Some(vec![Some(1)]), None, None, Some(vec![Some(2)])],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some(vec![Some(2)])],
+ None,
+ );
+
+ // more nulls than limit
+ test_sort_list_arrays::<Int32Type>(
+ vec![Some(vec![Some(1)]), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(2),
+ vec![None, None],
+ None,
+ );
+
+ test_sort_list_arrays::<Int32Type>(
+ vec![Some(vec![Some(1)]), None, None, None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ Some(2),
+ vec![Some(vec![Some(1)]), None],
+ None,
+ );
}
#[test]