| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| //! This module provides ScalarValue, an enum that can be used for storage of single elements |
| |
| use std::{convert::TryFrom, fmt, sync::Arc}; |
| |
| use arrow::array::{ |
| Array, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, |
| Int8Array, LargeStringArray, ListArray, StringArray, UInt16Array, UInt32Array, |
| UInt64Array, UInt8Array, |
| }; |
| use arrow::array::{ |
| Int16Builder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, UInt16Builder, |
| UInt32Builder, UInt64Builder, UInt8Builder, |
| }; |
| use arrow::{ |
| array::{ArrayRef, PrimitiveArrayOps}, |
| datatypes::DataType, |
| }; |
| |
| use crate::error::{ExecutionError, Result}; |
| |
| /// Represents a dynamically typed, nullable single value. |
| /// This is the single-valued counter-part of arrow’s `Array`. |
| #[derive(Clone, PartialEq)] |
| pub enum ScalarValue { |
| /// true or false value |
| Boolean(Option<bool>), |
| /// 32bit float |
| Float32(Option<f32>), |
| /// 64bit float |
| Float64(Option<f64>), |
| /// signed 8bit int |
| Int8(Option<i8>), |
| /// signed 16bit int |
| Int16(Option<i16>), |
| /// signed 32bit int |
| Int32(Option<i32>), |
| /// signed 64bit int |
| Int64(Option<i64>), |
| /// unsigned 8bit int |
| UInt8(Option<u8>), |
| /// unsigned 16bit int |
| UInt16(Option<u16>), |
| /// unsigned 32bit int |
| UInt32(Option<u32>), |
| /// unsigned 64bit int |
| UInt64(Option<u64>), |
| /// utf-8 encoded string. |
| Utf8(Option<String>), |
| /// utf-8 encoded string representing a LargeString's arrow type. |
| LargeUtf8(Option<String>), |
| /// list of nested ScalarValue |
| List(Option<Vec<ScalarValue>>, DataType), |
| } |
| |
| macro_rules! typed_cast { |
| ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ |
| let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); |
| ScalarValue::$SCALAR(match array.is_null($index) { |
| true => None, |
| false => Some(array.value($index).into()), |
| }) |
| }}; |
| } |
| |
| macro_rules! build_list { |
| ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr) => {{ |
| match $VALUES { |
| None => { |
| let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new(0)); |
| builder.append(false).unwrap(); |
| builder.finish() |
| } |
| Some(values) => { |
| let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new(values.len())); |
| |
| for scalar_value in values { |
| match scalar_value { |
| ScalarValue::$SCALAR_TY(Some(v)) => { |
| builder.values().append_value(*v).unwrap() |
| } |
| ScalarValue::$SCALAR_TY(None) => { |
| builder.values().append_null().unwrap(); |
| } |
| _ => panic!("Incompatible ScalarValue for list"), |
| }; |
| } |
| |
| builder.append(true).unwrap(); |
| builder.finish() |
| } |
| } |
| }}; |
| } |
| |
| impl ScalarValue { |
| /// Getter for the `DataType` of the value |
| pub fn get_datatype(&self) -> DataType { |
| match self { |
| ScalarValue::Boolean(_) => DataType::Boolean, |
| ScalarValue::UInt8(_) => DataType::UInt8, |
| ScalarValue::UInt16(_) => DataType::UInt16, |
| ScalarValue::UInt32(_) => DataType::UInt32, |
| ScalarValue::UInt64(_) => DataType::UInt64, |
| ScalarValue::Int8(_) => DataType::Int8, |
| ScalarValue::Int16(_) => DataType::Int16, |
| ScalarValue::Int32(_) => DataType::Int32, |
| ScalarValue::Int64(_) => DataType::Int64, |
| ScalarValue::Float32(_) => DataType::Float32, |
| ScalarValue::Float64(_) => DataType::Float64, |
| ScalarValue::Utf8(_) => DataType::Utf8, |
| ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, |
| ScalarValue::List(_, data_type) => { |
| DataType::List(Box::new(data_type.clone())) |
| } |
| } |
| } |
| |
| /// whether this value is null or not. |
| pub fn is_null(&self) -> bool { |
| match *self { |
| ScalarValue::Boolean(None) |
| | ScalarValue::UInt8(None) |
| | ScalarValue::UInt16(None) |
| | ScalarValue::UInt32(None) |
| | ScalarValue::UInt64(None) |
| | ScalarValue::Int8(None) |
| | ScalarValue::Int16(None) |
| | ScalarValue::Int32(None) |
| | ScalarValue::Int64(None) |
| | ScalarValue::Float32(None) |
| | ScalarValue::Float64(None) |
| | ScalarValue::Utf8(None) |
| | ScalarValue::LargeUtf8(None) |
| | ScalarValue::List(None, _) => true, |
| _ => false, |
| } |
| } |
| |
| /// Converts a scalar value into an 1-row array. |
| pub fn to_array(&self) -> ArrayRef { |
| match self { |
| ScalarValue::Boolean(e) => Arc::new(BooleanArray::from(vec![*e])) as ArrayRef, |
| ScalarValue::Float64(e) => Arc::new(Float64Array::from(vec![*e])) as ArrayRef, |
| ScalarValue::Float32(e) => Arc::new(Float32Array::from(vec![*e])), |
| ScalarValue::Int8(e) => Arc::new(Int8Array::from(vec![*e])), |
| ScalarValue::Int16(e) => Arc::new(Int16Array::from(vec![*e])), |
| ScalarValue::Int32(e) => Arc::new(Int32Array::from(vec![*e])), |
| ScalarValue::Int64(e) => Arc::new(Int64Array::from(vec![*e])), |
| ScalarValue::UInt8(e) => Arc::new(UInt8Array::from(vec![*e])), |
| ScalarValue::UInt16(e) => Arc::new(UInt16Array::from(vec![*e])), |
| ScalarValue::UInt32(e) => Arc::new(UInt32Array::from(vec![*e])), |
| ScalarValue::UInt64(e) => Arc::new(UInt64Array::from(vec![*e])), |
| ScalarValue::Utf8(e) => Arc::new(StringArray::from(vec![e.as_deref()])), |
| ScalarValue::LargeUtf8(e) => { |
| Arc::new(LargeStringArray::from(vec![e.as_deref()])) |
| } |
| ScalarValue::List(values, data_type) => Arc::new(match data_type { |
| DataType::Int8 => build_list!(Int8Builder, Int8, values), |
| DataType::Int16 => build_list!(Int16Builder, Int16, values), |
| DataType::Int32 => build_list!(Int32Builder, Int32, values), |
| DataType::Int64 => build_list!(Int64Builder, Int64, values), |
| DataType::UInt8 => build_list!(UInt8Builder, UInt8, values), |
| DataType::UInt16 => build_list!(UInt16Builder, UInt16, values), |
| DataType::UInt32 => build_list!(UInt32Builder, UInt32, values), |
| DataType::UInt64 => build_list!(UInt64Builder, UInt64, values), |
| _ => panic!("Unexpected DataType for list"), |
| }), |
| } |
| } |
| |
| /// Converts a value in `array` at `index` into a ScalarValue |
| pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> { |
| Ok(match array.data_type() { |
| DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), |
| DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), |
| DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), |
| DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), |
| DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), |
| DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), |
| DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), |
| DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), |
| DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), |
| DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), |
| DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), |
| DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), |
| DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), |
| DataType::List(nested_type) => { |
| let list_array = array.as_any().downcast_ref::<ListArray>().ok_or( |
| ExecutionError::InternalError( |
| "Failed to downcast ListArray".to_string(), |
| ), |
| )?; |
| let value = match list_array.is_null(index) { |
| true => None, |
| false => { |
| let nested_array = list_array.value(index); |
| let scalar_vec = (0..nested_array.len()) |
| .map(|i| ScalarValue::try_from_array(&nested_array, i)) |
| .collect::<Result<Vec<_>>>()?; |
| Some(scalar_vec) |
| } |
| }; |
| ScalarValue::List(value, *nested_type.clone()) |
| } |
| other => { |
| return Err(ExecutionError::NotImplemented(format!( |
| "Can't create a scalar of array of type \"{:?}\"", |
| other |
| ))) |
| } |
| }) |
| } |
| } |
| |
| impl From<f64> for ScalarValue { |
| fn from(value: f64) -> Self { |
| ScalarValue::Float64(Some(value)) |
| } |
| } |
| |
| impl From<f32> for ScalarValue { |
| fn from(value: f32) -> Self { |
| ScalarValue::Float32(Some(value)) |
| } |
| } |
| |
| impl From<i8> for ScalarValue { |
| fn from(value: i8) -> Self { |
| ScalarValue::Int8(Some(value)) |
| } |
| } |
| |
| impl From<i16> for ScalarValue { |
| fn from(value: i16) -> Self { |
| ScalarValue::Int16(Some(value)) |
| } |
| } |
| |
| impl From<i32> for ScalarValue { |
| fn from(value: i32) -> Self { |
| ScalarValue::Int32(Some(value)) |
| } |
| } |
| |
| impl From<i64> for ScalarValue { |
| fn from(value: i64) -> Self { |
| ScalarValue::Int64(Some(value)) |
| } |
| } |
| |
| impl From<bool> for ScalarValue { |
| fn from(value: bool) -> Self { |
| ScalarValue::Boolean(Some(value)) |
| } |
| } |
| |
| impl From<u8> for ScalarValue { |
| fn from(value: u8) -> Self { |
| ScalarValue::UInt8(Some(value)) |
| } |
| } |
| |
| impl From<u16> for ScalarValue { |
| fn from(value: u16) -> Self { |
| ScalarValue::UInt16(Some(value)) |
| } |
| } |
| |
| impl From<u32> for ScalarValue { |
| fn from(value: u32) -> Self { |
| ScalarValue::UInt32(Some(value)) |
| } |
| } |
| |
| impl From<u64> for ScalarValue { |
| fn from(value: u64) -> Self { |
| ScalarValue::UInt64(Some(value)) |
| } |
| } |
| |
| impl TryFrom<&DataType> for ScalarValue { |
| type Error = ExecutionError; |
| |
| fn try_from(datatype: &DataType) -> Result<Self> { |
| Ok(match datatype { |
| &DataType::Boolean => ScalarValue::Boolean(None), |
| &DataType::Float64 => ScalarValue::Float64(None), |
| &DataType::Float32 => ScalarValue::Float32(None), |
| &DataType::Int8 => ScalarValue::Int8(None), |
| &DataType::Int16 => ScalarValue::Int16(None), |
| &DataType::Int32 => ScalarValue::Int32(None), |
| &DataType::Int64 => ScalarValue::Int64(None), |
| &DataType::UInt8 => ScalarValue::UInt8(None), |
| &DataType::UInt16 => ScalarValue::UInt16(None), |
| &DataType::UInt32 => ScalarValue::UInt32(None), |
| &DataType::UInt64 => ScalarValue::UInt64(None), |
| &DataType::Utf8 => ScalarValue::Utf8(None), |
| &DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), |
| &DataType::List(ref nested_type) => { |
| ScalarValue::List(None, *nested_type.clone()) |
| } |
| _ => { |
| return Err(ExecutionError::NotImplemented(format!( |
| "Can't create a scalar of type \"{:?}\"", |
| datatype |
| ))) |
| } |
| }) |
| } |
| } |
| |
| macro_rules! format_option { |
| ($F:expr, $EXPR:expr) => {{ |
| match $EXPR { |
| Some(e) => write!($F, "{}", e), |
| None => write!($F, "NULL"), |
| } |
| }}; |
| } |
| |
| impl fmt::Display for ScalarValue { |
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| match self { |
| ScalarValue::Boolean(e) => format_option!(f, e)?, |
| ScalarValue::Float32(e) => format_option!(f, e)?, |
| ScalarValue::Float64(e) => format_option!(f, e)?, |
| ScalarValue::Int8(e) => format_option!(f, e)?, |
| ScalarValue::Int16(e) => format_option!(f, e)?, |
| ScalarValue::Int32(e) => format_option!(f, e)?, |
| ScalarValue::Int64(e) => format_option!(f, e)?, |
| ScalarValue::UInt8(e) => format_option!(f, e)?, |
| ScalarValue::UInt16(e) => format_option!(f, e)?, |
| ScalarValue::UInt32(e) => format_option!(f, e)?, |
| ScalarValue::UInt64(e) => format_option!(f, e)?, |
| ScalarValue::Utf8(e) => format_option!(f, e)?, |
| ScalarValue::LargeUtf8(e) => format_option!(f, e)?, |
| ScalarValue::List(e, _) => match e { |
| Some(l) => write!( |
| f, |
| "{}", |
| l.iter() |
| .map(|v| format!("{}", v)) |
| .collect::<Vec<_>>() |
| .join(",") |
| )?, |
| None => write!(f, "NULL")?, |
| }, |
| }; |
| Ok(()) |
| } |
| } |
| |
| impl fmt::Debug for ScalarValue { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| match self { |
| ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), |
| ScalarValue::Float32(_) => write!(f, "Float32({})", self), |
| ScalarValue::Float64(_) => write!(f, "Float64({})", self), |
| ScalarValue::Int8(_) => write!(f, "Int8({})", self), |
| ScalarValue::Int16(_) => write!(f, "Int16({})", self), |
| ScalarValue::Int32(_) => write!(f, "Int32({})", self), |
| ScalarValue::Int64(_) => write!(f, "Int64({})", self), |
| ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), |
| ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), |
| ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), |
| ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), |
| ScalarValue::Utf8(_) => write!(f, "Utf8(\"{}\")", self), |
| ScalarValue::LargeUtf8(_) => write!(f, "LargeUtf8(\"{}\")", self), |
| ScalarValue::List(_, _) => write!(f, "List([{}])", self), |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| |
| #[test] |
| fn scalar_list_null_to_array() -> Result<()> { |
| let list_array_ref = ScalarValue::List(None, DataType::UInt64).to_array(); |
| let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap(); |
| |
| assert!(list_array.is_null(0)); |
| assert_eq!(list_array.len(), 1); |
| assert_eq!(list_array.values().len(), 0); |
| |
| Ok(()) |
| } |
| |
| #[test] |
| fn scalar_list_to_array() -> Result<()> { |
| let list_array_ref = ScalarValue::List( |
| Some(vec![ |
| ScalarValue::UInt64(Some(100)), |
| ScalarValue::UInt64(None), |
| ScalarValue::UInt64(Some(101)), |
| ]), |
| DataType::UInt64, |
| ) |
| .to_array(); |
| |
| let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap(); |
| assert_eq!(list_array.len(), 1); |
| assert_eq!(list_array.values().len(), 3); |
| |
| let prim_array_ref = list_array.value(0); |
| let prim_array = prim_array_ref |
| .as_any() |
| .downcast_ref::<UInt64Array>() |
| .unwrap(); |
| assert_eq!(prim_array.len(), 3); |
| assert_eq!(prim_array.value(0), 100); |
| assert!(prim_array.is_null(1)); |
| assert_eq!(prim_array.value(2), 101); |
| |
| Ok(()) |
| } |
| } |