| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use crate::downcast_compute_op; |
| use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar}; |
| use arrow::array::{Float32Array, Float64Array, Int64Array}; |
| use arrow_array::{Array, ArrowNativeTypeOp}; |
| use arrow_schema::DataType; |
| use datafusion::physical_plan::ColumnarValue; |
| use datafusion_common::{DataFusionError, ScalarValue}; |
| use num::integer::div_ceil; |
| use std::sync::Arc; |
| |
| /// `ceil` function that simulates Spark `ceil` expression |
| pub fn spark_ceil( |
| args: &[ColumnarValue], |
| data_type: &DataType, |
| ) -> Result<ColumnarValue, DataFusionError> { |
| let value = &args[0]; |
| match value { |
| ColumnarValue::Array(array) => match array.data_type() { |
| DataType::Float32 => { |
| let result = downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array); |
| Ok(ColumnarValue::Array(result?)) |
| } |
| DataType::Float64 => { |
| let result = downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array); |
| Ok(ColumnarValue::Array(result?)) |
| } |
| DataType::Int64 => { |
| let result = array.as_any().downcast_ref::<Int64Array>().unwrap(); |
| Ok(ColumnarValue::Array(Arc::new(result.clone()))) |
| } |
| DataType::Decimal128(_, scale) if *scale > 0 => { |
| let f = decimal_ceil_f(scale); |
| let (precision, scale) = get_precision_scale(data_type); |
| make_decimal_array(array, precision, scale, &f) |
| } |
| other => Err(DataFusionError::Internal(format!( |
| "Unsupported data type {:?} for function ceil", |
| other, |
| ))), |
| }, |
| ColumnarValue::Scalar(a) => match a { |
| ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( |
| a.map(|x| x.ceil() as i64), |
| ))), |
| ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( |
| a.map(|x| x.ceil() as i64), |
| ))), |
| ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), |
| ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { |
| let f = decimal_ceil_f(scale); |
| let (precision, scale) = get_precision_scale(data_type); |
| make_decimal_scalar(a, precision, scale, &f) |
| } |
| _ => Err(DataFusionError::Internal(format!( |
| "Unsupported data type {:?} for function ceil", |
| value.data_type(), |
| ))), |
| }, |
| } |
| } |
| |
| #[inline] |
| fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 { |
| let div = 10_i128.pow_wrapping(*scale as u32); |
| move |x: i128| div_ceil(x, div) |
| } |