| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| //! Defines scalars used to construct groups, ex. in GROUP BY clauses. |
| |
| use ordered_float::OrderedFloat; |
| use std::convert::{From, TryFrom}; |
| |
| use crate::error::{DataFusionError, Result}; |
| use crate::scalar::ScalarValue; |
| |
| /// Enumeration of types that can be used in a GROUP BY expression |
| #[derive(Debug, PartialEq, Eq, Hash, Clone)] |
| pub(crate) enum GroupByScalar { |
| Float32(OrderedFloat<f32>), |
| Float64(OrderedFloat<f64>), |
| UInt8(u8), |
| UInt16(u16), |
| UInt32(u32), |
| UInt64(u64), |
| Int8(i8), |
| Int16(i16), |
| Int32(i32), |
| Int64(i64), |
| Utf8(Box<String>), |
| Boolean(bool), |
| TimeMillisecond(i64), |
| TimeMicrosecond(i64), |
| TimeNanosecond(i64), |
| Date32(i32), |
| } |
| |
| impl TryFrom<&ScalarValue> for GroupByScalar { |
| type Error = DataFusionError; |
| |
| fn try_from(scalar_value: &ScalarValue) -> Result<Self> { |
| Ok(match scalar_value { |
| ScalarValue::Float32(Some(v)) => { |
| GroupByScalar::Float32(OrderedFloat::from(*v)) |
| } |
| ScalarValue::Float64(Some(v)) => { |
| GroupByScalar::Float64(OrderedFloat::from(*v)) |
| } |
| ScalarValue::Boolean(Some(v)) => GroupByScalar::Boolean(*v), |
| ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v), |
| ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v), |
| ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v), |
| ScalarValue::Int64(Some(v)) => GroupByScalar::Int64(*v), |
| ScalarValue::UInt8(Some(v)) => GroupByScalar::UInt8(*v), |
| ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), |
| ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), |
| ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), |
| ScalarValue::TimestampMillisecond(Some(v)) => { |
| GroupByScalar::TimeMillisecond(*v) |
| } |
| ScalarValue::TimestampMicrosecond(Some(v)) => { |
| GroupByScalar::TimeMicrosecond(*v) |
| } |
| ScalarValue::TimestampNanosecond(Some(v)) => { |
| GroupByScalar::TimeNanosecond(*v) |
| } |
| ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), |
| ScalarValue::Float32(None) |
| | ScalarValue::Float64(None) |
| | ScalarValue::Boolean(None) |
| | ScalarValue::Int8(None) |
| | ScalarValue::Int16(None) |
| | ScalarValue::Int32(None) |
| | ScalarValue::Int64(None) |
| | ScalarValue::UInt8(None) |
| | ScalarValue::UInt16(None) |
| | ScalarValue::UInt32(None) |
| | ScalarValue::UInt64(None) |
| | ScalarValue::Utf8(None) => { |
| return Err(DataFusionError::Internal(format!( |
| "Cannot convert a ScalarValue holding NULL ({:?})", |
| scalar_value |
| ))); |
| } |
| v => { |
| return Err(DataFusionError::Internal(format!( |
| "Cannot convert a ScalarValue with associated DataType {:?}", |
| v.get_datatype() |
| ))) |
| } |
| }) |
| } |
| } |
| |
| impl From<&GroupByScalar> for ScalarValue { |
| fn from(group_by_scalar: &GroupByScalar) -> Self { |
| match group_by_scalar { |
| GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())), |
| GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())), |
| GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)), |
| GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)), |
| GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)), |
| GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)), |
| GroupByScalar::Int64(v) => ScalarValue::Int64(Some(*v)), |
| GroupByScalar::UInt8(v) => ScalarValue::UInt8(Some(*v)), |
| GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)), |
| GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)), |
| GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)), |
| GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())), |
| GroupByScalar::TimeMillisecond(v) => { |
| ScalarValue::TimestampMillisecond(Some(*v)) |
| } |
| GroupByScalar::TimeMicrosecond(v) => { |
| ScalarValue::TimestampMicrosecond(Some(*v)) |
| } |
| GroupByScalar::TimeNanosecond(v) => { |
| ScalarValue::TimestampNanosecond(Some(*v)) |
| } |
| GroupByScalar::Date32(v) => ScalarValue::Date32(Some(*v)), |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| |
| use crate::error::DataFusionError; |
| |
| macro_rules! scalar_eq_test { |
| ($TYPE:expr, $VALUE:expr) => {{ |
| let scalar_value = $TYPE($VALUE); |
| let a = GroupByScalar::try_from(&scalar_value).unwrap(); |
| |
| let scalar_value = $TYPE($VALUE); |
| let b = GroupByScalar::try_from(&scalar_value).unwrap(); |
| |
| assert_eq!(a, b); |
| }}; |
| } |
| |
| #[test] |
| fn test_scalar_ne_non_std() { |
| // Test only Scalars with non native Eq, Hash |
| scalar_eq_test!(ScalarValue::Float32, Some(1.0)); |
| scalar_eq_test!(ScalarValue::Float64, Some(1.0)); |
| } |
| |
| macro_rules! scalar_ne_test { |
| ($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{ |
| let scalar_value = $TYPE($LVALUE); |
| let a = GroupByScalar::try_from(&scalar_value).unwrap(); |
| |
| let scalar_value = $TYPE($RVALUE); |
| let b = GroupByScalar::try_from(&scalar_value).unwrap(); |
| |
| assert_ne!(a, b); |
| }}; |
| } |
| |
| #[test] |
| fn test_scalar_eq_non_std() { |
| // Test only Scalars with non native Eq, Hash |
| scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0)); |
| scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0)); |
| } |
| |
| #[test] |
| fn from_scalar_holding_none() { |
| let scalar_value = ScalarValue::Int8(None); |
| let result = GroupByScalar::try_from(&scalar_value); |
| |
| match result { |
| Err(DataFusionError::Internal(error_message)) => assert_eq!( |
| error_message, |
| String::from("Cannot convert a ScalarValue holding NULL (Int8(NULL))") |
| ), |
| _ => panic!("Unexpected result"), |
| } |
| } |
| |
| #[test] |
| fn from_scalar_unsupported() { |
| // Use any ScalarValue type not supported by GroupByScalar. |
| let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string())); |
| let result = GroupByScalar::try_from(&scalar_value); |
| |
| match result { |
| Err(DataFusionError::Internal(error_message)) => assert_eq!( |
| error_message, |
| String::from( |
| "Cannot convert a ScalarValue with associated DataType LargeUtf8" |
| ) |
| ), |
| _ => panic!("Unexpected result"), |
| } |
| } |
| |
| #[test] |
| fn size_of_group_by_scalar() { |
| assert_eq!(std::mem::size_of::<GroupByScalar>(), 16); |
| } |
| } |