| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| //! Contains functions and function factories to compare arrays. |
| |
| use std::cmp::Ordering; |
| |
| use crate::array::*; |
| use crate::datatypes::TimeUnit; |
| use crate::datatypes::*; |
| use crate::error::{ArrowError, Result}; |
| |
| use num::Float; |
| |
| /// Compare the values at two arbitrary indices in two arrays. |
| pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>; |
| |
| /// compares two floats, placing NaNs at last |
| fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering { |
| match (a.is_nan(), b.is_nan()) { |
| (true, true) => Ordering::Equal, |
| (true, false) => Ordering::Greater, |
| (false, true) => Ordering::Less, |
| _ => a.partial_cmp(b).unwrap(), |
| } |
| } |
| |
| fn compare_primitives<'a, T: ArrowPrimitiveType>( |
| left: &'a Array, |
| right: &'a Array, |
| ) -> DynComparator<'a> |
| where |
| T::Native: Ord, |
| { |
| let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
| let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
| Box::new(move |i, j| left.value(i).cmp(&right.value(j))) |
| } |
| |
| fn compare_boolean<'a>(left: &'a Array, right: &'a Array) -> DynComparator<'a> { |
| let left = left.as_any().downcast_ref::<BooleanArray>().unwrap(); |
| let right = right.as_any().downcast_ref::<BooleanArray>().unwrap(); |
| Box::new(move |i, j| left.value(i).cmp(&right.value(j))) |
| } |
| |
| fn compare_float<'a, T: ArrowPrimitiveType>( |
| left: &'a Array, |
| right: &'a Array, |
| ) -> DynComparator<'a> |
| where |
| T::Native: Float, |
| { |
| let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
| let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); |
| Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j))) |
| } |
| |
| fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a> |
| where |
| T: StringOffsetSizeTrait, |
| { |
| let left = left |
| .as_any() |
| .downcast_ref::<GenericStringArray<T>>() |
| .unwrap(); |
| let right = right |
| .as_any() |
| .downcast_ref::<GenericStringArray<T>>() |
| .unwrap(); |
| Box::new(move |i, j| left.value(i).cmp(&right.value(j))) |
| } |
| |
| fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a> |
| where |
| T: ArrowDictionaryKeyType, |
| { |
| let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap(); |
| let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap(); |
| let left_keys = left.keys(); |
| let right_keys = right.keys(); |
| |
| let left_values = StringArray::from(left.values().data().clone()); |
| let right_values = StringArray::from(right.values().data().clone()); |
| |
| Box::new(move |i: usize, j: usize| { |
| let key_left = left_keys.value(i).to_usize().unwrap(); |
| let key_right = right_keys.value(j).to_usize().unwrap(); |
| let left = left_values.value(key_left); |
| let right = right_values.value(key_right); |
| left.cmp(&right) |
| }) |
| } |
| |
| /// returns a comparison function that compares two values at two different positions |
| /// between the two arrays. |
| /// The arrays' types must be equal. |
| /// # Example |
| /// ``` |
| /// use arrow::array::{build_compare, Int32Array}; |
| /// |
| /// # fn main() -> arrow::error::Result<()> { |
| /// let array1 = Int32Array::from(vec![1, 2]); |
| /// let array2 = Int32Array::from(vec![3, 4]); |
| /// |
| /// let cmp = build_compare(&array1, &array2)?; |
| /// |
| /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) |
| /// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); |
| /// # Ok(()) |
| /// # } |
| /// ``` |
| // This is a factory of comparisons. |
| // The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime. |
| pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> { |
| use DataType::*; |
| use IntervalUnit::*; |
| use TimeUnit::*; |
| Ok(match (left.data_type(), right.data_type()) { |
| (a, b) if a != b => { |
| return Err(ArrowError::InvalidArgumentError( |
| "Can't compare arrays of different types".to_string(), |
| )); |
| } |
| (Boolean, Boolean) => compare_boolean(left, right), |
| (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right), |
| (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right), |
| (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right), |
| (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right), |
| (Int8, Int8) => compare_primitives::<Int8Type>(left, right), |
| (Int16, Int16) => compare_primitives::<Int16Type>(left, right), |
| (Int32, Int32) => compare_primitives::<Int32Type>(left, right), |
| (Int64, Int64) => compare_primitives::<Int64Type>(left, right), |
| (Float32, Float32) => compare_float::<Float32Type>(left, right), |
| (Float64, Float64) => compare_float::<Float64Type>(left, right), |
| (Date32, Date32) => compare_primitives::<Date32Type>(left, right), |
| (Date64, Date64) => compare_primitives::<Date64Type>(left, right), |
| (Time32(Second), Time32(Second)) => { |
| compare_primitives::<Time32SecondType>(left, right) |
| } |
| (Time32(Millisecond), Time32(Millisecond)) => { |
| compare_primitives::<Time32MillisecondType>(left, right) |
| } |
| (Time64(Microsecond), Time64(Microsecond)) => { |
| compare_primitives::<Time64MicrosecondType>(left, right) |
| } |
| (Time64(Nanosecond), Time64(Nanosecond)) => { |
| compare_primitives::<Time64NanosecondType>(left, right) |
| } |
| (Timestamp(Second, _), Timestamp(Second, _)) => { |
| compare_primitives::<TimestampSecondType>(left, right) |
| } |
| (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => { |
| compare_primitives::<TimestampMillisecondType>(left, right) |
| } |
| (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => { |
| compare_primitives::<TimestampMicrosecondType>(left, right) |
| } |
| (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => { |
| compare_primitives::<TimestampNanosecondType>(left, right) |
| } |
| (Interval(YearMonth), Interval(YearMonth)) => { |
| compare_primitives::<IntervalYearMonthType>(left, right) |
| } |
| (Interval(DayTime), Interval(DayTime)) => { |
| compare_primitives::<IntervalDayTimeType>(left, right) |
| } |
| (Duration(Second), Duration(Second)) => { |
| compare_primitives::<DurationSecondType>(left, right) |
| } |
| (Duration(Millisecond), Duration(Millisecond)) => { |
| compare_primitives::<DurationMillisecondType>(left, right) |
| } |
| (Duration(Microsecond), Duration(Microsecond)) => { |
| compare_primitives::<DurationMicrosecondType>(left, right) |
| } |
| (Duration(Nanosecond), Duration(Nanosecond)) => { |
| compare_primitives::<DurationNanosecondType>(left, right) |
| } |
| (Utf8, Utf8) => compare_string::<i32>(left, right), |
| (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right), |
| ( |
| Dictionary(key_type_lhs, value_type_lhs), |
| Dictionary(key_type_rhs, value_type_rhs), |
| ) => { |
| if value_type_lhs.as_ref() != &DataType::Utf8 |
| || value_type_rhs.as_ref() != &DataType::Utf8 |
| { |
| return Err(ArrowError::InvalidArgumentError( |
| "Arrow still does not support comparisons of non-string dictionary arrays" |
| .to_string(), |
| )); |
| } |
| match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) { |
| (a, b) if a != b => { |
| return Err(ArrowError::InvalidArgumentError( |
| "Can't compare arrays of different types".to_string(), |
| )); |
| } |
| (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, right), |
| (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, right), |
| (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, right), |
| (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, right), |
| (Int8, Int8) => compare_dict_string::<Int8Type>(left, right), |
| (Int16, Int16) => compare_dict_string::<Int16Type>(left, right), |
| (Int32, Int32) => compare_dict_string::<Int32Type>(left, right), |
| (Int64, Int64) => compare_dict_string::<Int64Type>(left, right), |
| (lhs, _) => { |
| return Err(ArrowError::InvalidArgumentError(format!( |
| "Dictionaries do not support keys of type {:?}", |
| lhs |
| ))); |
| } |
| } |
| } |
| (lhs, _) => { |
| return Err(ArrowError::InvalidArgumentError(format!( |
| "The data type type {:?} has no natural order", |
| lhs |
| ))); |
| } |
| }) |
| } |
| |
| #[cfg(test)] |
| pub mod tests { |
| use super::*; |
| use crate::array::{Float64Array, Int32Array}; |
| use crate::error::Result; |
| use std::cmp::Ordering; |
| use std::iter::FromIterator; |
| |
| #[test] |
| fn test_i32() -> Result<()> { |
| let array = Int32Array::from(vec![1, 2]); |
| |
| let cmp = build_compare(&array, &array)?; |
| |
| assert_eq!(Ordering::Less, (cmp)(0, 1)); |
| Ok(()) |
| } |
| |
| #[test] |
| fn test_i32_i32() -> Result<()> { |
| let array1 = Int32Array::from(vec![1]); |
| let array2 = Int32Array::from(vec![2]); |
| |
| let cmp = build_compare(&array1, &array2)?; |
| |
| assert_eq!(Ordering::Less, (cmp)(0, 0)); |
| Ok(()) |
| } |
| |
| #[test] |
| fn test_f64() -> Result<()> { |
| let array = Float64Array::from(vec![1.0, 2.0]); |
| |
| let cmp = build_compare(&array, &array)?; |
| |
| assert_eq!(Ordering::Less, (cmp)(0, 1)); |
| Ok(()) |
| } |
| |
| #[test] |
| fn test_f64_nan() -> Result<()> { |
| let array = Float64Array::from(vec![1.0, f64::NAN]); |
| |
| let cmp = build_compare(&array, &array)?; |
| |
| assert_eq!(Ordering::Less, (cmp)(0, 1)); |
| Ok(()) |
| } |
| |
| #[test] |
| fn test_f64_zeros() -> Result<()> { |
| let array = Float64Array::from(vec![-0.0, 0.0]); |
| |
| let cmp = build_compare(&array, &array)?; |
| |
| assert_eq!(Ordering::Equal, (cmp)(0, 1)); |
| assert_eq!(Ordering::Equal, (cmp)(1, 0)); |
| Ok(()) |
| } |
| |
| #[test] |
| fn test_dict() -> Result<()> { |
| let data = vec!["a", "b", "c", "a", "a", "c", "c"]; |
| let array = DictionaryArray::<Int16Type>::from_iter(data.into_iter()); |
| |
| let cmp = build_compare(&array, &array)?; |
| |
| assert_eq!(Ordering::Less, (cmp)(0, 1)); |
| assert_eq!(Ordering::Equal, (cmp)(3, 4)); |
| assert_eq!(Ordering::Greater, (cmp)(2, 3)); |
| Ok(()) |
| } |
| |
| #[test] |
| fn test_multiple_dict() -> Result<()> { |
| let d1 = vec!["a", "b", "c", "d"]; |
| let a1 = DictionaryArray::<Int16Type>::from_iter(d1.into_iter()); |
| let d2 = vec!["e", "f", "g", "a"]; |
| let a2 = DictionaryArray::<Int16Type>::from_iter(d2.into_iter()); |
| |
| let cmp = build_compare(&a1, &a2)?; |
| |
| assert_eq!(Ordering::Less, (cmp)(0, 0)); |
| assert_eq!(Ordering::Equal, (cmp)(0, 3)); |
| assert_eq!(Ordering::Greater, (cmp)(1, 3)); |
| Ok(()) |
| } |
| } |