blob: f5a149654272d929ab6692c821d7c6ac2a500797 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
use arrow_array::{Array, ArrowNativeTypeOp};
use arrow_schema::DataType;
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue};
use std::{cmp::min, sync::Arc};
macro_rules! integer_round {
($X:expr, $DIV:expr, $HALF:expr) => {{
let rem = $X % $DIV;
if rem <= -$HALF {
($X - rem).sub_wrapping($DIV)
} else if rem >= $HALF {
($X - rem).add_wrapping($DIV)
} else {
$X - rem
}
}};
}
macro_rules! round_integer_array {
($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
let ten: $NATIVE = 10;
let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
let half = div / 2;
arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half))
} else {
arrow::compute::kernels::arity::unary(array, |_| 0)
};
Ok(ColumnarValue::Array(Arc::new(result)))
}};
}
macro_rules! round_integer_scalar {
($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
let ten: $NATIVE = 10;
if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
let half = div / 2;
Ok(ColumnarValue::Scalar($TYPE(
$SCALAR.map(|x| integer_round!(x, div, half)),
)))
} else {
Ok(ColumnarValue::Scalar($TYPE(Some(0))))
}
}};
}
/// `round` function that simulates Spark `round` expression
pub fn spark_round(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
let value = &args[0];
let point = &args[1];
let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
return internal_err!("Invalid point argument for Round(): {:#?}", point);
};
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64),
DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
DataType::Decimal128(_, scale) if *scale >= 0 => {
let f = decimal_round_f(scale, point);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_array(array, precision, scale, &f)
}
DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
Arc::clone(array),
args[1].to_array(array.len())?,
])?)),
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Int64(a) if *point < 0 => {
round_integer_scalar!(a, point, ScalarValue::Int64, i64)
}
ScalarValue::Int32(a) if *point < 0 => {
round_integer_scalar!(a, point, ScalarValue::Int32, i32)
}
ScalarValue::Int16(a) if *point < 0 => {
round_integer_scalar!(a, point, ScalarValue::Int16, i16)
}
ScalarValue::Int8(a) if *point < 0 => {
round_integer_scalar!(a, point, ScalarValue::Int8, i8)
}
ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
let f = decimal_round_f(scale, point);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_scalar(a, precision, scale, &f)
}
ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
)),
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
},
}
}
// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by
// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication
#[inline]
fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
if *point < 0 {
if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) {
let half = div / 2;
let mul = 10_i128.pow_wrapping((-(*point)) as u32);
// i128 can hold 39 digits of a base 10 number, adding half will not cause overflow
Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
} else {
Box::new(move |_: i128| 0)
}
} else {
let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32));
let half = div / 2;
Box::new(move |x: i128| (x + x.signum() * half) / div)
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use crate::spark_round;
use arrow::array::{Float32Array, Float64Array};
use arrow_schema::DataType;
use datafusion_common::cast::{as_float32_array, as_float64_array};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::ColumnarValue;
#[test]
fn test_round_f32_array() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Float32Array::from(vec![
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
]))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
];
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
unreachable!()
};
let floats = as_float32_array(&result)?;
let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
assert_eq!(floats, &expected);
Ok(())
}
#[test]
fn test_round_f64_array() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
]))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
];
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
unreachable!()
};
let floats = as_float64_array(&result)?;
let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
assert_eq!(floats, &expected);
Ok(())
}
#[test]
fn test_round_f32_scalar() -> Result<()> {
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
];
let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
spark_round(&args, &DataType::Float32)?
else {
unreachable!()
};
assert_eq!(result, 125.23);
Ok(())
}
#[test]
fn test_round_f64_scalar() -> Result<()> {
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
];
let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
spark_round(&args, &DataType::Float64)?
else {
unreachable!()
};
assert_eq!(result, 125.23);
Ok(())
}
}