blob: 2a32d78ab44f51060a03ab24f42fe1112b907a54 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{str::FromStr, sync::Arc};
use arrow::{array::*, datatypes::*};
use bigdecimal::{FromPrimitive, ToPrimitive};
use datafusion::common::{
cast::{as_float32_array, as_float64_array},
Result,
};
use num::{cast::AsPrimitive, Bounded, Integer, Signed};
use paste::paste;
use crate::df_execution_err;
pub fn cast(array: &dyn Array, cast_type: &DataType) -> Result<ArrayRef> {
return cast_impl(array, cast_type, false);
}
pub fn cast_scan_input_array(array: &dyn Array, cast_type: &DataType) -> Result<ArrayRef> {
return cast_impl(array, cast_type, true);
}
pub fn cast_impl(
array: &dyn Array,
cast_type: &DataType,
match_struct_fields: bool,
) -> Result<ArrayRef> {
Ok(match (&array.data_type(), cast_type) {
(&t1, t2) if t1 == t2 => make_array(array.to_data()),
(_, &DataType::Null) => Arc::new(NullArray::new(array.len())),
// float to int
(&DataType::Float32, &DataType::Int8) => Arc::new(cast_float_to_integer::<_, Int8Type>(
as_float32_array(array)?,
)),
(&DataType::Float32, &DataType::Int16) => Arc::new(cast_float_to_integer::<_, Int16Type>(
as_float32_array(array)?,
)),
(&DataType::Float32, &DataType::Int32) => Arc::new(cast_float_to_integer::<_, Int32Type>(
as_float32_array(array)?,
)),
(&DataType::Float32, &DataType::Int64) => Arc::new(cast_float_to_integer::<_, Int64Type>(
as_float32_array(array)?,
)),
(&DataType::Float64, &DataType::Int8) => Arc::new(cast_float_to_integer::<_, Int8Type>(
as_float64_array(array)?,
)),
(&DataType::Float64, &DataType::Int16) => Arc::new(cast_float_to_integer::<_, Int16Type>(
as_float64_array(array)?,
)),
(&DataType::Float64, &DataType::Int32) => Arc::new(cast_float_to_integer::<_, Int32Type>(
as_float64_array(array)?,
)),
(&DataType::Float64, &DataType::Int64) => Arc::new(cast_float_to_integer::<_, Int64Type>(
as_float64_array(array)?,
)),
(&DataType::Utf8, &DataType::Int8)
| (&DataType::Utf8, &DataType::Int16)
| (&DataType::Utf8, &DataType::Int32)
| (&DataType::Utf8, &DataType::Int64) => {
// spark compatible string to integer cast
try_cast_string_array_to_integer(array, cast_type)?
}
(&DataType::Utf8, &DataType::Decimal128(..)) => {
// spark compatible string to decimal cast
try_cast_string_array_to_decimal(array, cast_type)?
}
(&DataType::Decimal128(..), DataType::Utf8) => {
// spark compatible decimal to string cast
try_cast_decimal_array_to_string(array, cast_type)?
}
(&DataType::Timestamp(..), DataType::Float64) => {
// timestamp to f64 = timestamp to i64 to f64, only used in agg.sum()
arrow::compute::cast(
&arrow::compute::cast(array, &DataType::Int64)?,
&DataType::Float64,
)?
}
(&DataType::Boolean, DataType::Utf8) => {
// spark compatible boolean to string cast
try_cast_boolean_array_to_string(array, cast_type)?
}
(&DataType::List(_), DataType::List(to_field)) => {
let list = as_list_array(array);
let items = cast_impl(list.values(), to_field.data_type(), match_struct_fields)?;
make_array(
list.to_data()
.into_builder()
.data_type(DataType::List(to_field.clone()))
.child_data(vec![items.into_data()])
.build()?,
)
}
(&DataType::Struct(_), DataType::Struct(to_fields)) => {
let struct_ = as_struct_array(array);
if !match_struct_fields {
if to_fields.len() != struct_.num_columns() {
df_execution_err!("cannot cast structs with different numbers of fields")?;
}
let casted_arrays = struct_
.columns()
.iter()
.zip(to_fields)
.map(|(column, to_field)| {
cast_impl(column, to_field.data_type(), match_struct_fields)
})
.collect::<Result<Vec<_>>>()?;
make_array(
struct_
.to_data()
.into_builder()
.data_type(DataType::Struct(to_fields.clone()))
.child_data(
casted_arrays
.into_iter()
.map(|array| array.into_data())
.collect(),
)
.build()?,
)
} else {
let mut null_column_name = vec![];
let casted_arrays = to_fields
.iter()
.map(|field| {
let col = struct_.column_by_name(field.name().as_str());
if col.is_some() {
cast_impl(col.unwrap(), field.data_type(), match_struct_fields)
} else {
null_column_name.push(field.name().clone());
Ok(new_null_array(field.data_type(), struct_.len()))
}
})
.collect::<Result<Vec<_>>>()?;
let casted_fields = to_fields
.iter()
.map(|field: &FieldRef| {
if null_column_name.contains(field.name()) {
Arc::new(Field::new(field.name(), field.data_type().clone(), true))
} else {
field.clone()
}
})
.collect::<Vec<_>>();
make_array(
struct_
.to_data()
.into_builder()
.data_type(DataType::Struct(Fields::from(casted_fields)))
.child_data(
casted_arrays
.into_iter()
.map(|array| array.into_data())
.collect(),
)
.build()?,
)
}
}
(&DataType::Map(..), &DataType::Map(ref to_entries_field, to_sorted)) => {
let map = as_map_array(array);
let entries = cast_impl(
map.entries(),
to_entries_field.data_type(),
match_struct_fields,
)?;
make_array(
map.to_data()
.into_builder()
.data_type(DataType::Map(to_entries_field.clone(), to_sorted))
.child_data(vec![entries.into_data()])
.build()?,
)
}
_ => {
// default cast
arrow::compute::kernels::cast::cast(array, cast_type)?
}
})
}
fn try_cast_string_array_to_integer(array: &dyn Array, cast_type: &DataType) -> Result<ArrayRef> {
macro_rules! cast {
($target_type:ident) => {{
type B = paste! {[<$target_type Builder>]};
let array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = B::new();
for v in array.iter() {
match v {
Some(s) => builder.append_option(to_integer(s)),
None => builder.append_null(),
}
}
Arc::new(builder.finish())
}};
}
Ok(match cast_type {
DataType::Int8 => cast!(Int8),
DataType::Int16 => cast!(Int16),
DataType::Int32 => cast!(Int32),
DataType::Int64 => cast!(Int64),
_ => arrow::compute::cast(array, cast_type)?,
})
}
fn try_cast_string_array_to_decimal(array: &dyn Array, cast_type: &DataType) -> Result<ArrayRef> {
if let &DataType::Decimal128(precision, scale) = cast_type {
let array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = Decimal128Builder::new();
for v in array.iter() {
match v {
Some(s) => match to_decimal(s, precision, scale) {
Some(v) => builder.append_value(v),
None => builder.append_null(),
},
None => builder.append_null(),
}
}
return Ok(Arc::new(
builder
.finish()
.with_precision_and_scale(precision, scale)?,
));
}
unreachable!("cast_type must be DataType::Decimal")
}
fn try_cast_decimal_array_to_string(array: &dyn Array, cast_type: &DataType) -> Result<ArrayRef> {
if let &DataType::Utf8 = cast_type {
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
let mut builder = StringBuilder::new();
for v in 0..array.len() {
if array.is_valid(v) {
builder.append_value(array.value_as_string(v))
} else {
builder.append_null()
}
}
return Ok(Arc::new(builder.finish()));
}
unreachable!("cast_type must be DataType::Utf8")
}
fn try_cast_boolean_array_to_string(array: &dyn Array, cast_type: &DataType) -> Result<ArrayRef> {
if let &DataType::Utf8 = cast_type {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
return Ok(Arc::new(
array
.iter()
.map(|value| value.map(|value| if value { "true" } else { "false" }))
.collect::<StringArray>(),
));
}
unreachable!("cast_type must be DataType::Utf8")
}
fn cast_float_to_integer<F: ArrowPrimitiveType, T: ArrowPrimitiveType>(
array: &PrimitiveArray<F>,
) -> PrimitiveArray<T>
where
F::Native: AsPrimitive<T::Native>,
{
arrow::compute::unary(array, |v| v.as_())
}
// this implementation is original copied from spark UTF8String.scala
fn to_integer<T: Bounded + FromPrimitive + Integer + Signed + Copy>(input: &str) -> Option<T> {
let bytes = input.as_bytes();
if bytes.is_empty() {
return None;
}
let b = bytes[0];
let negative = b == b'-';
let mut offset = 0;
if negative || b == b'+' {
offset += 1;
if bytes.len() == 1 {
return None;
}
}
let separator = b'.';
let radix = T::from_usize(10).unwrap();
let stop_value = T::min_value() / radix;
let mut result = T::zero();
while offset < bytes.len() {
let b = bytes[offset];
offset += 1;
if b == separator {
// We allow decimals and will return a truncated integral in that case.
// Therefore we won't throw an exception here (checking the fractional
// part happens below.)
break;
}
let digit = if b.is_ascii_digit() {
b - b'0'
} else {
return None;
};
// We are going to process the new digit and accumulate the result. However,
// before doing this, if the result is already smaller than the
// stopValue(Long.MIN_VALUE / radix), then result * 10 will definitely
// be smaller than minValue, and we can stop.
if result < stop_value {
return None;
}
result = result * radix - T::from_u8(digit).unwrap();
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE /
// radix), we can just use `result > 0` to check overflow. If result
// overflows, we should stop.
if result > T::zero() {
return None;
}
}
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while offset < bytes.len() {
let current_byte = bytes[offset];
if !current_byte.is_ascii_digit() {
return None;
}
offset += 1;
}
if !negative {
result = -result;
if result < T::zero() {
return None;
}
}
Some(result)
}
fn to_decimal(input: &str, precision: u8, scale: i8) -> Option<i128> {
let precision = precision as u64;
let scale = scale as i64;
bigdecimal::BigDecimal::from_str(input)
.ok()
.map(|decimal| decimal.with_prec(precision).with_scale(scale))
.and_then(|decimal| {
let (bigint, _exp) = decimal.as_bigint_and_exponent();
bigint.to_i128()
})
}
#[cfg(test)]
mod test {
use datafusion::common::cast::as_int32_array;
use crate::cast::*;
#[test]
fn test_float_to_int() {
let f64_array: ArrayRef = Arc::new(Float64Array::from_iter(vec![
None,
Some(123.456),
Some(987.654),
Some(i32::MAX as f64 + 10000.0),
Some(i32::MIN as f64 - 10000.0),
Some(f64::INFINITY),
Some(f64::NEG_INFINITY),
Some(f64::NAN),
]));
let casted = cast(&f64_array, &DataType::Int32).unwrap();
let i32_array = as_int32_array(&casted).unwrap();
assert_eq!(
i32_array,
&Int32Array::from_iter(vec![
None,
Some(123),
Some(987),
Some(i32::MAX),
Some(i32::MIN),
Some(i32::MAX),
Some(i32::MIN),
Some(0),
])
);
}
}