blob: 59aa730fdea635dbc7241fff4c8f7d6c1208b723 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Declaration of built-in (aggregate) functions.
//! This module contains built-in aggregates' enumeration and metadata.
//!
//! Generally, an aggregate has:
//! * a signature
//! * a return type, that is a function of the incoming argument's types
//! * the computation, that must accept each valid signature
//!
//! * Signature: see `Signature`
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.
use super::{
functions::Signature,
type_coercion::{coerce, data_types},
Accumulator, AggregateExpr, PhysicalExpr,
};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Schema};
use expressions::{avg_return_type, sum_return_type};
use std::{fmt, str::FromStr, sync::Arc};
/// the implementation of an aggregate function
pub type AccumulatorFunctionImplementation =
Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
/// This signature corresponds to which types an aggregator serializes
/// its state, given its return datatype.
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
/// Enum of all built-in scalar functions
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AggregateFunction {
/// count
Count,
/// sum
Sum,
/// min
Min,
/// max
Max,
/// avg
Avg,
}
impl fmt::Display for AggregateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// uppercase of the debug.
write!(f, "{}", format!("{:?}", self).to_uppercase())
}
}
impl FromStr for AggregateFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match &*name.to_uppercase() {
"MIN" => AggregateFunction::Min,
"MAX" => AggregateFunction::Max,
"COUNT" => AggregateFunction::Count,
"AVG" => AggregateFunction::Avg,
"SUM" => AggregateFunction::Sum,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
name
)))
}
})
}
}
/// Returns the datatype of the scalar function
pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
// verify that this is a valid set of data types for this function
data_types(arg_types, &signature(fun))?;
match fun {
AggregateFunction::Count => Ok(DataType::UInt64),
AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
}
}
/// Create a physical (function) expression.
/// This function errors when `args`' can't be coerced to a valid argument type of the function.
pub fn create_aggregate_expr(
fun: &AggregateFunction,
distinct: bool,
args: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: String,
) -> Result<Arc<dyn AggregateExpr>> {
// coerce
let arg = coerce(args, input_schema, &signature(fun))?[0].clone();
let arg_types = args
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
let return_type = return_type(&fun, &arg_types)?;
Ok(match (fun, distinct) {
(AggregateFunction::Count, false) => {
Arc::new(expressions::Count::new(arg, name, return_type))
}
(AggregateFunction::Count, true) => {
Arc::new(distinct_expressions::DistinctCount::new(
arg_types,
args.to_vec(),
name,
return_type,
))
}
(AggregateFunction::Sum, false) => {
Arc::new(expressions::Sum::new(arg, name, return_type))
}
(AggregateFunction::Sum, true) => {
return Err(DataFusionError::NotImplemented(
"SUM(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Min, _) => {
Arc::new(expressions::Min::new(arg, name, return_type))
}
(AggregateFunction::Max, _) => {
Arc::new(expressions::Max::new(arg, name, return_type))
}
(AggregateFunction::Avg, false) => {
Arc::new(expressions::Avg::new(arg, name, return_type))
}
(AggregateFunction::Avg, true) => {
return Err(DataFusionError::NotImplemented(
"AVG(DISTINCT) aggregations are not available".to_string(),
));
}
})
}
static NUMERICS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];
/// the signatures supported by the function `fun`.
fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match fun {
AggregateFunction::Count => Signature::Any(1),
AggregateFunction::Min | AggregateFunction::Max => {
let mut valid = vec![DataType::Utf8, DataType::LargeUtf8];
valid.extend_from_slice(NUMERICS);
Signature::Uniform(1, valid)
}
AggregateFunction::Avg | AggregateFunction::Sum => {
Signature::Uniform(1, NUMERICS.to_vec())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
#[test]
fn test_min_max() -> Result<()> {
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?;
assert_eq!(DataType::Int32, observed);
Ok(())
}
#[test]
fn test_sum_no_utf8() {
let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
assert!(observed.is_err());
}
#[test]
fn test_sum_upcasts() -> Result<()> {
let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_count_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?;
assert_eq!(DataType::UInt64, observed);
let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_avg_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);
let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_avg_no_utf8() {
let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]);
assert!(observed.is_err());
}
}