| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use crate::core::greatest_least_utils::GreatestLeastOperator; |
| use arrow::array::{make_comparator, Array, BooleanArray}; |
| use arrow::buffer::BooleanBuffer; |
| use arrow::compute::kernels::cmp; |
| use arrow::compute::SortOptions; |
| use arrow::datatypes::DataType; |
| use datafusion_common::{internal_err, Result, ScalarValue}; |
| use datafusion_doc::Documentation; |
| use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; |
| use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; |
| use datafusion_macros::user_doc; |
| use std::any::Any; |
| |
| const SORT_OPTIONS: SortOptions = SortOptions { |
| // We want greatest first |
| descending: false, |
| |
| // NULL will be less than any other value |
| nulls_first: true, |
| }; |
| |
| #[user_doc( |
| doc_section(label = "Conditional Functions"), |
| description = "Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.", |
| syntax_example = "greatest(expression1[, ..., expression_n])", |
| sql_example = r#"```sql |
| > select greatest(4, 7, 5); |
| +---------------------------+ |
| | greatest(4,7,5) | |
| +---------------------------+ |
| | 7 | |
| +---------------------------+ |
| ```"#, |
| argument( |
| name = "expression1, expression_n", |
| description = "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." |
| ) |
| )] |
| #[derive(Debug, PartialEq, Eq, Hash)] |
| pub struct GreatestFunc { |
| signature: Signature, |
| } |
| |
| impl Default for GreatestFunc { |
| fn default() -> Self { |
| GreatestFunc::new() |
| } |
| } |
| |
| impl GreatestFunc { |
| pub fn new() -> Self { |
| Self { |
| signature: Signature::user_defined(Volatility::Immutable), |
| } |
| } |
| } |
| |
| impl GreatestLeastOperator for GreatestFunc { |
| const NAME: &'static str = "greatest"; |
| |
| fn keep_scalar<'a>( |
| lhs: &'a ScalarValue, |
| rhs: &'a ScalarValue, |
| ) -> Result<&'a ScalarValue> { |
| if !lhs.data_type().is_nested() { |
| return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; |
| } |
| |
| // If complex type we can't compare directly as we want null values to be smaller |
| let cmp = make_comparator( |
| lhs.to_array()?.as_ref(), |
| rhs.to_array()?.as_ref(), |
| SORT_OPTIONS, |
| )?; |
| |
| if cmp(0, 0).is_ge() { |
| Ok(lhs) |
| } else { |
| Ok(rhs) |
| } |
| } |
| |
| /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array |
| /// Nulls are always considered smaller than any other value |
| fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> { |
| // Fast path: |
| // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorized kernel |
| // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. |
| // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case |
| if !lhs.data_type().is_nested() |
| && lhs.logical_null_count() == 0 |
| && rhs.logical_null_count() == 0 |
| { |
| return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); |
| } |
| |
| let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; |
| |
| if lhs.len() != rhs.len() { |
| return internal_err!( |
| "All arrays should have the same length for greatest comparison" |
| ); |
| } |
| |
| let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); |
| |
| // No nulls as we only want to keep the values that are larger, its either true or false |
| Ok(BooleanArray::new(values, None)) |
| } |
| } |
| |
| impl ScalarUDFImpl for GreatestFunc { |
| fn as_any(&self) -> &dyn Any { |
| self |
| } |
| |
| fn name(&self) -> &str { |
| "greatest" |
| } |
| |
| fn signature(&self) -> &Signature { |
| &self.signature |
| } |
| |
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
| Ok(arg_types[0].clone()) |
| } |
| |
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| super::greatest_least_utils::execute_conditional::<Self>(&args.args) |
| } |
| |
| fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
| let coerced_type = |
| super::greatest_least_utils::find_coerced_type::<Self>(arg_types)?; |
| |
| Ok(vec![coerced_type; arg_types.len()]) |
| } |
| |
| fn documentation(&self) -> Option<&Documentation> { |
| self.doc() |
| } |
| } |
| |
| #[cfg(test)] |
| mod test { |
| use crate::core; |
| use arrow::datatypes::DataType; |
| use datafusion_expr::ScalarUDFImpl; |
| |
| #[test] |
| fn test_greatest_return_types_without_common_supertype_in_arg_type() { |
| let greatest = core::greatest::GreatestFunc::new(); |
| let return_type = greatest |
| .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)]) |
| .unwrap(); |
| assert_eq!( |
| return_type, |
| vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)] |
| ); |
| } |
| } |