| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; |
| use arrow::compute::try_binary; |
| use arrow::datatypes::DataType; |
| use arrow::error::ArrowError; |
| use datafusion_common::{DataFusionError, Result, ScalarValue}; |
| use datafusion_expr::function::Hint; |
| use datafusion_expr::ColumnarValue; |
| use std::sync::Arc; |
| |
| /// Creates a function to identify the optimal return type of a string function given |
| /// the type of its first argument. |
| /// |
| /// If the input type is `LargeUtf8` or `LargeBinary` the return type is |
| /// `$largeUtf8Type`, |
| /// |
| /// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, |
| /// |
| /// If the input type is `Utf8View` the return type is $utf8Type, |
| macro_rules! get_optimal_return_type { |
| ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { |
| pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> { |
| Ok(match arg_type { |
| // LargeBinary inputs are automatically coerced to Utf8 |
| DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, |
| // Binary inputs are automatically coerced to Utf8 |
| DataType::Utf8 | DataType::Binary => $utf8Type, |
| // Utf8View max offset size is u32::MAX, the same as UTF8 |
| DataType::Utf8View | DataType::BinaryView => $utf8Type, |
| DataType::Null => DataType::Null, |
| DataType::Dictionary(_, value_type) => match **value_type { |
| DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, |
| DataType::Utf8 | DataType::Binary => $utf8Type, |
| DataType::Null => DataType::Null, |
| _ => { |
| return datafusion_common::exec_err!( |
| "The {} function can only accept strings, but got {:?}.", |
| name.to_uppercase(), |
| **value_type |
| ); |
| } |
| }, |
| data_type => { |
| return datafusion_common::exec_err!( |
| "The {} function can only accept strings, but got {:?}.", |
| name.to_uppercase(), |
| data_type |
| ); |
| } |
| }) |
| } |
| }; |
| } |
| |
| // `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. |
| get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); |
| |
| // `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. |
| get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); |
| |
| /// Creates a scalar function implementation for the given function. |
| /// * `inner` - the function to be executed |
| /// * `hints` - hints to be used when expanding scalars to arrays |
| pub fn make_scalar_function<F>( |
| inner: F, |
| hints: Vec<Hint>, |
| ) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue> |
| where |
| F: Fn(&[ArrayRef]) -> Result<ArrayRef>, |
| { |
| move |args: &[ColumnarValue]| { |
| // first, identify if any of the arguments is an Array. If yes, store its `len`, |
| // as any scalar will need to be converted to an array of len `len`. |
| let len = args |
| .iter() |
| .fold(Option::<usize>::None, |acc, arg| match arg { |
| ColumnarValue::Scalar(_) => acc, |
| ColumnarValue::Array(a) => Some(a.len()), |
| }); |
| |
| let is_scalar = len.is_none(); |
| |
| let inferred_length = len.unwrap_or(1); |
| let args = args |
| .iter() |
| .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) |
| .map(|(arg, hint)| { |
| // Decide on the length to expand this scalar to depending |
| // on the given hints. |
| let expansion_len = match hint { |
| Hint::AcceptsSingular => 1, |
| Hint::Pad => inferred_length, |
| }; |
| arg.to_array(expansion_len) |
| }) |
| .collect::<Result<Vec<_>>>()?; |
| |
| let result = (inner)(&args); |
| if is_scalar { |
| // If all inputs are scalar, keeps output as scalar |
| let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); |
| result.map(ColumnarValue::Scalar) |
| } else { |
| result.map(ColumnarValue::Array) |
| } |
| } |
| } |
| |
| /// Computes a binary math function for input arrays using a specified function. |
| /// Generic types: |
| /// - `L`: Left array primitive type |
| /// - `R`: Right array primitive type |
| /// - `O`: Output array primitive type |
| /// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>` |
| pub fn calculate_binary_math<L, R, O, F>( |
| left: &dyn Array, |
| right: &ColumnarValue, |
| fun: F, |
| ) -> Result<Arc<PrimitiveArray<O>>> |
| where |
| R: ArrowPrimitiveType, |
| L: ArrowPrimitiveType, |
| O: ArrowPrimitiveType, |
| F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>, |
| R::Native: TryFrom<ScalarValue>, |
| { |
| let left = left.as_primitive::<L>(); |
| let right = right.cast_to(&R::DATA_TYPE, None)?; |
| let result = match right { |
| ColumnarValue::Scalar(scalar) => { |
| let right = R::Native::try_from(scalar.clone()).map_err(|_| { |
| DataFusionError::NotImplemented(format!( |
| "Cannot convert scalar value {} to {}", |
| &scalar, |
| R::DATA_TYPE |
| )) |
| })?; |
| left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))? |
| } |
| ColumnarValue::Array(right) => { |
| let right = right.as_primitive::<R>(); |
| try_binary::<_, _, _, O>(left, right, &fun)? |
| } |
| }; |
| Ok(Arc::new(result) as _) |
| } |
| |
| /// Converts Decimal128 components (value and scale) to an unscaled i128 |
| pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> { |
| if scale < 0 { |
| Err(ArrowError::ComputeError( |
| "Negative scale is not supported".into(), |
| )) |
| } else if scale == 0 { |
| Ok(value) |
| } else { |
| match i128::from(10).checked_pow(scale as u32) { |
| Some(divisor) => Ok(value / divisor), |
| None => Err(ArrowError::ComputeError(format!( |
| "Cannot get a power of {scale}" |
| ))), |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| pub mod test { |
| /// $FUNC ScalarUDFImpl to test |
| /// $ARGS arguments (vec) to pass to function |
| /// $EXPECTED a Result<ColumnarValue> |
| /// $EXPECTED_TYPE is the expected value type |
| /// $EXPECTED_DATA_TYPE is the expected result type |
| /// $ARRAY_TYPE is the column type after function applied |
| /// $CONFIG_OPTIONS config options to pass to function |
| macro_rules! test_function { |
| ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => { |
| let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED; |
| let func = $FUNC; |
| |
| let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>(); |
| let cardinality = $ARGS |
| .iter() |
| .fold(Option::<usize>::None, |acc, arg| match arg { |
| ColumnarValue::Scalar(_) => acc, |
| ColumnarValue::Array(a) => Some(a.len()), |
| }) |
| .unwrap_or(1); |
| |
| let scalar_arguments = $ARGS.iter().map(|arg| match arg { |
| ColumnarValue::Scalar(scalar) => Some(scalar.clone()), |
| ColumnarValue::Array(_) => None, |
| }).collect::<Vec<_>>(); |
| let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>(); |
| |
| let nullables = $ARGS.iter().map(|arg| match arg { |
| ColumnarValue::Scalar(scalar) => scalar.is_null(), |
| ColumnarValue::Array(a) => a.null_count() > 0, |
| }).collect::<Vec<_>>(); |
| |
| let field_array = data_array.into_iter().zip(nullables).enumerate() |
| .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) |
| .map(std::sync::Arc::new) |
| .collect::<Vec<_>>(); |
| |
| let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { |
| arg_fields: &field_array, |
| scalar_arguments: &scalar_arguments_refs, |
| }); |
| let arg_fields = $ARGS.iter() |
| .enumerate() |
| .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) |
| .collect::<Vec<_>>(); |
| |
| match expected { |
| Ok(expected) => { |
| assert_eq!(return_field.is_ok(), true); |
| let return_field = return_field.unwrap(); |
| let return_type = return_field.data_type(); |
| assert_eq!(return_type, &$EXPECTED_DATA_TYPE); |
| |
| let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ |
| args: $ARGS, |
| arg_fields, |
| number_rows: cardinality, |
| return_field, |
| config_options: $CONFIG_OPTIONS |
| }); |
| assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); |
| |
| let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); |
| let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); |
| assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); |
| |
| // value is correct |
| match expected { |
| Some(v) => assert_eq!(result.value(0), v), |
| None => assert!(result.is_null(0)), |
| }; |
| } |
| Err(expected_error) => { |
| if let Ok(return_field) = return_field { |
| // invoke is expected error - cannot use .expect_err() due to Debug not being implemented |
| match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs { |
| args: $ARGS, |
| arg_fields, |
| number_rows: cardinality, |
| return_field, |
| config_options: $CONFIG_OPTIONS, |
| }) { |
| Ok(_) => assert!(false, "expected error"), |
| Err(error) => { |
| assert!(expected_error |
| .strip_backtrace() |
| .starts_with(&error.strip_backtrace())); |
| } |
| } |
| } else if let Err(error) = return_field { |
| datafusion_common::assert_contains!( |
| expected_error.strip_backtrace(), |
| error.strip_backtrace() |
| ); |
| } |
| } |
| }; |
| }; |
| |
| ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { |
| test_function!( |
| $FUNC, |
| $ARGS, |
| $EXPECTED, |
| $EXPECTED_TYPE, |
| $EXPECTED_DATA_TYPE, |
| $ARRAY_TYPE, |
| std::sync::Arc::new(datafusion_common::config::ConfigOptions::default()) |
| ) |
| }; |
| } |
| |
| use arrow::datatypes::DataType; |
| #[allow(unused_imports)] |
| pub(crate) use test_function; |
| |
| use super::*; |
| |
| #[test] |
| fn string_to_int_type() { |
| let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap(); |
| assert_eq!(v, DataType::Int32); |
| |
| let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap(); |
| assert_eq!(v, DataType::Int32); |
| |
| let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap(); |
| assert_eq!(v, DataType::Int64); |
| } |
| |
| #[test] |
| fn test_decimal128_to_i128() { |
| let cases = [ |
| (123, 0, Some(123)), |
| (1230, 1, Some(123)), |
| (123000, 3, Some(123)), |
| (1, 0, Some(1)), |
| (123, -3, None), |
| (123, i8::MAX, None), |
| (i128::MAX, 0, Some(i128::MAX)), |
| (i128::MAX, 3, Some(i128::MAX / 1000)), |
| ]; |
| |
| for (value, scale, expected) in cases { |
| match decimal128_to_i128(value, scale) { |
| Ok(actual) => { |
| assert_eq!( |
| actual, |
| expected.expect("Got value but expected none"), |
| "{value} and {scale} vs {expected:?}" |
| ); |
| } |
| Err(_) => assert!(expected.is_none()), |
| } |
| } |
| } |
| } |