blob: 743d5b99cde950325ad10f30cb2887083fcaae32 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Declaration of built-in (scalar) functions.
//! This module contains built-in functions' enumeration and metadata.
//!
//! Generally, a function has:
//! * a signature
//! * a return type, that is a function of the incoming argument's types
//! * the computation, that must accept each valid signature
//!
//! * Signature: see `Signature`
//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
//!
//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
//! to a function that supports f64, it is coerced to f64.
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::expressions::Literal;
use crate::PhysicalExpr;
use arrow::array::{Array, RecordBatch};
use arrow::datatypes::{DataType, FieldRef, Schema};
use datafusion_common::config::{ConfigEntry, ConfigOptions};
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::ExprProperties;
use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
use datafusion_expr::{
expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
Volatility,
};
/// Physical expression of a scalar function
pub struct ScalarFunctionExpr {
fun: Arc<ScalarUDF>,
name: String,
args: Vec<Arc<dyn PhysicalExpr>>,
return_field: FieldRef,
config_options: Arc<ConfigOptions>,
}
impl Debug for ScalarFunctionExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarFunctionExpr")
.field("fun", &"<FUNC>")
.field("name", &self.name)
.field("args", &self.args)
.field("return_field", &self.return_field)
.finish()
}
}
impl ScalarFunctionExpr {
/// Create a new Scalar function
pub fn new(
name: &str,
fun: Arc<ScalarUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
return_field: FieldRef,
config_options: Arc<ConfigOptions>,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_field,
config_options,
}
}
/// Create a new Scalar function
pub fn try_new(
fun: Arc<ScalarUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
schema: &Schema,
config_options: Arc<ConfigOptions>,
) -> Result<Self> {
let name = fun.name().to_string();
let arg_fields = args
.iter()
.map(|e| e.return_field(schema))
.collect::<Result<Vec<_>>>()?;
// verify that input data types is consistent with function's `TypeSignature`
let arg_types = arg_fields
.iter()
.map(|f| f.data_type().clone())
.collect::<Vec<_>>();
data_types_with_scalar_udf(&arg_types, &fun)?;
let arguments = args
.iter()
.map(|e| {
e.as_any()
.downcast_ref::<Literal>()
.map(|literal| literal.value())
})
.collect::<Vec<_>>();
let ret_args = ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &arguments,
};
let return_field = fun.return_field_from_args(ret_args)?;
Ok(Self {
fun,
name,
args,
return_field,
config_options,
})
}
/// Get the scalar function implementation
pub fn fun(&self) -> &ScalarUDF {
&self.fun
}
/// The name for this expression
pub fn name(&self) -> &str {
&self.name
}
/// Input arguments
pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
&self.args
}
/// Data type produced by this expression
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.return_field = self
.return_field
.as_ref()
.clone()
.with_nullable(nullable)
.into();
self
}
pub fn nullable(&self) -> bool {
self.return_field.is_nullable()
}
pub fn config_options(&self) -> &ConfigOptions {
&self.config_options
}
/// Given an arbitrary PhysicalExpr attempt to downcast it to a ScalarFunctionExpr
/// and verify that its inner function is of type T.
/// If the downcast fails, or the function is not of type T, returns `None`.
/// Otherwise returns `Some(ScalarFunctionExpr)`.
pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
where
T: 'static,
{
match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
Some(scalar_expr)
if scalar_expr
.fun()
.inner()
.as_any()
.downcast_ref::<T>()
.is_some() =>
{
Some(scalar_expr)
}
_ => None,
}
}
}
impl fmt::Display for ScalarFunctionExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
}
}
impl PartialEq for ScalarFunctionExpr {
fn eq(&self, o: &Self) -> bool {
if std::ptr::eq(self, o) {
// The equality implementation is somewhat expensive, so let's short-circuit when possible.
return true;
}
let Self {
fun,
name,
args,
return_field,
config_options,
} = self;
fun.eq(&o.fun)
&& name.eq(&o.name)
&& args.eq(&o.args)
&& return_field.eq(&o.return_field)
&& (Arc::ptr_eq(config_options, &o.config_options)
|| sorted_config_entries(config_options)
== sorted_config_entries(&o.config_options))
}
}
impl Eq for ScalarFunctionExpr {}
impl Hash for ScalarFunctionExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self {
fun,
name,
args,
return_field,
config_options: _, // expensive to hash, and often equal
} = self;
fun.hash(state);
name.hash(state);
args.hash(state);
return_field.hash(state);
}
}
fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
let mut entries = config_options.entries();
entries.sort_by(|l, r| l.key.cmp(&r.key));
entries
}
impl PhysicalExpr for ScalarFunctionExpr {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.return_field.data_type().clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(self.return_field.is_nullable())
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let args = self
.args
.iter()
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;
let arg_fields = self
.args
.iter()
.map(|e| e.return_field(batch.schema_ref()))
.collect::<Result<Vec<_>>>()?;
let input_empty = args.is_empty();
let input_all_scalar = args
.iter()
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
// evaluate the function
let output = self.fun.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields,
number_rows: batch.num_rows(),
return_field: Arc::clone(&self.return_field),
config_options: Arc::clone(&self.config_options),
})?;
if let ColumnarValue::Array(array) = &output {
if array.len() != batch.num_rows() {
// If the arguments are a non-empty slice of scalar values, we can assume that
// returning a one-element array is equivalent to returning a scalar.
let preserve_scalar =
array.len() == 1 && !input_empty && input_all_scalar;
return if preserve_scalar {
ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
} else {
internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
self.name, batch.num_rows(), array.len())
};
}
}
Ok(output)
}
fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
Ok(Arc::clone(&self.return_field))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.args.iter().collect()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(ScalarFunctionExpr::new(
&self.name,
Arc::clone(&self.fun),
children,
Arc::clone(&self.return_field),
Arc::clone(&self.config_options),
)))
}
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
self.fun.evaluate_bounds(children)
}
fn propagate_constraints(
&self,
interval: &Interval,
children: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.fun.propagate_constraints(interval, children)
}
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
let sort_properties = self.fun.output_ordering(children)?;
let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
let children_range = children
.iter()
.map(|props| &props.range)
.collect::<Vec<_>>();
let range = self.fun().evaluate_bounds(&children_range)?;
Ok(ExprProperties {
sort_properties,
range,
preserves_lex_ordering,
})
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}(", self.name)?;
for (i, expr) in self.args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
expr.fmt_sql(f)?;
}
write!(f, ")")
}
fn is_volatile_node(&self) -> bool {
self.fun.signature().volatility == Volatility::Volatile
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::Column;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
use datafusion_physical_expr_common::physical_expr::is_volatile;
use std::any::Any;
/// Test helper to create a mock UDF with a specific volatility
#[derive(Debug, PartialEq, Eq, Hash)]
struct MockScalarUDF {
signature: Signature,
}
impl ScalarUDFImpl for MockScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"mock_function"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
}
}
#[test]
fn test_scalar_function_volatile_node() {
// Create a volatile UDF
let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
signature: Signature::uniform(
1,
vec![DataType::Float32],
Volatility::Volatile,
),
}));
// Create a non-volatile UDF
let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
}));
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
let config_options = Arc::new(ConfigOptions::new());
// Test volatile function
let volatile_expr = ScalarFunctionExpr::try_new(
volatile_udf,
args.clone(),
&schema,
Arc::clone(&config_options),
)
.unwrap();
assert!(volatile_expr.is_volatile_node());
let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
assert!(is_volatile(&volatile_arc));
// Test non-volatile function
let stable_expr =
ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
.unwrap();
assert!(!stable_expr.is_volatile_node());
let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
assert!(!is_volatile(&stable_arc));
}
}