blob: ae920cb870f78f872262c7cfd213b29d5fbfcd7a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Type coercion rules for functions with multiple valid signatures
//!
//! Coercion is performed automatically by DataFusion when the types
//! of arguments passed to a function do not exacty match the types
//! required by that function. In this case, DataFusion will attempt to
//! *coerce* the arguments to types accepted by the function by
//! inserting CAST operations.
//!
//! CAST operations added by coercion are lossless and never discard
//! information. For example coercion from i32 -> i64 might be
//! performed because all valid i32 values can be represented using an
//! i64. However, i64 -> i32 is never performed as there are i64
//! values which can not be represented by i32 values.
use std::{sync::Arc, vec};
use arrow::datatypes::{DataType, Schema, TimeUnit};
use super::{functions::Signature, PhysicalExpr};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::cast;
/// Returns `expressions` coerced to types compatible with
/// `signature`, if possible.
///
/// See the module level documentation for more detail on coercion.
pub fn coerce(
expressions: &[Arc<dyn PhysicalExpr>],
schema: &Schema,
signature: &Signature,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
let current_types = expressions
.iter()
.map(|e| e.data_type(schema))
.collect::<Result<Vec<_>>>()?;
let new_types = data_types(&current_types, signature)?;
expressions
.iter()
.enumerate()
.map(|(i, expr)| cast(expr.clone(), &schema, new_types[i].clone()))
.collect::<Result<Vec<_>>>()
}
/// Returns the data types that each argument must be coerced to match
/// `signature`.
///
/// See the module level documentation for more detail on coercion.
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
let valid_types = get_valid_types(signature, current_types)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}
for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, &current_types) {
return Ok(types);
}
}
// none possible -> Error
Err(DataFusionError::Plan(format!(
"Coercion from {:?} to the signature {:?} failed.",
current_types, signature
)))
}
fn get_valid_types(
signature: &Signature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
Signature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
Signature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
Signature::VariadicEqual => {
// one entry with the same len as current_types, whose type is `current_types[0]`.
vec![current_types
.iter()
.map(|_| current_types[0].clone())
.collect()]
}
Signature::Exact(valid_types) => vec![valid_types.clone()],
Signature::Any(number) => {
if current_types.len() != *number {
return Err(DataFusionError::Plan(format!(
"The function expected {} arguments but received {}",
number,
current_types.len()
)));
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
Signature::OneOf(types) => {
let mut r = vec![];
for s in types {
r.extend(get_valid_types(s, current_types)?);
}
r
}
};
Ok(valid_types)
}
/// Try to coerce current_types into valid_types.
fn maybe_data_types(
valid_types: &[DataType],
current_types: &[DataType],
) -> Option<Vec<DataType>> {
if valid_types.len() != current_types.len() {
return None;
}
let mut new_type = Vec::with_capacity(valid_types.len());
for (i, valid_type) in valid_types.iter().enumerate() {
let current_type = &current_types[i];
if current_type == valid_type {
new_type.push(current_type.clone())
} else {
// attempt to coerce
if can_coerce_from(valid_type, &current_type) {
new_type.push(valid_type.clone())
} else {
// not possible
return None;
}
}
}
Some(new_type)
}
/// Return true if a value of type `type_from` can be coerced
/// (losslessly converted) into a value of `type_to`
///
/// See the module level documentation for more detail on coercion.
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;
match type_into {
Int8 => matches!(type_from, Int8),
Int16 => matches!(type_from, Int8 | Int16 | UInt8),
Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16),
Int64 => matches!(
type_from,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
),
UInt8 => matches!(type_from, UInt8),
UInt16 => matches!(type_from, UInt8 | UInt16),
UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32),
UInt64 => matches!(type_from, UInt8 | UInt16 | UInt32 | UInt64),
Float32 => matches!(
type_from,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32
),
Float64 => matches!(
type_from,
Int8 | Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
),
Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)),
Utf8 => true,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::expressions::col;
use arrow::datatypes::{DataType, Field, Schema};
#[test]
fn test_maybe_data_types() {
// this vec contains: arg1, arg2, expected result
let cases = vec![
// 2 entries, same values
(
vec![DataType::UInt8, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt8, DataType::UInt16]),
),
// 2 entries, can coerse values
(
vec![DataType::UInt16, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt16, DataType::UInt16]),
),
// 0 entries, all good
(vec![], vec![], Some(vec![])),
// 2 entries, can't coerce
(
vec![DataType::Boolean, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
None,
),
// u32 -> u16 is possible
(
vec![DataType::Boolean, DataType::UInt32],
vec![DataType::Boolean, DataType::UInt16],
Some(vec![DataType::Boolean, DataType::UInt32]),
),
];
for case in cases {
assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
}
}
#[test]
fn test_coerce() -> Result<()> {
// create a schema
let schema = |t: Vec<DataType>| {
Schema::new(
t.iter()
.enumerate()
.map(|(i, t)| Field::new(&*format!("c{}", i), t.clone(), true))
.collect(),
)
};
// create a vector of expressions
let expressions = |t: Vec<DataType>, schema| -> Result<Vec<_>> {
t.iter()
.enumerate()
.map(|(i, t)| cast(col(&format!("c{}", i)), &schema, t.clone()))
.collect::<Result<Vec<_>>>()
};
// create a case: input + expected result
let case =
|observed: Vec<DataType>, valid, expected: Vec<DataType>| -> Result<_> {
let schema = schema(observed.clone());
let expr = expressions(observed, schema.clone())?;
let expected = expressions(expected, schema.clone())?;
Ok((expr.clone(), schema, valid, expected))
};
let cases = vec![
// u16 -> u32
case(
vec![DataType::UInt16],
Signature::Uniform(1, vec![DataType::UInt32]),
vec![DataType::UInt32],
)?,
// same type
case(
vec![DataType::UInt32, DataType::UInt32],
Signature::Uniform(2, vec![DataType::UInt32]),
vec![DataType::UInt32, DataType::UInt32],
)?,
case(
vec![DataType::UInt32],
Signature::Uniform(1, vec![DataType::Float32, DataType::Float64]),
vec![DataType::Float32],
)?,
// u32 -> f32
case(
vec![DataType::UInt32, DataType::UInt32],
Signature::Variadic(vec![DataType::Float32]),
vec![DataType::Float32, DataType::Float32],
)?,
// u32 -> f32
case(
vec![DataType::Float32, DataType::UInt32],
Signature::VariadicEqual,
vec![DataType::Float32, DataType::Float32],
)?,
// common type is u64
case(
vec![DataType::UInt32, DataType::UInt64],
Signature::Variadic(vec![DataType::UInt32, DataType::UInt64]),
vec![DataType::UInt64, DataType::UInt64],
)?,
// f32 -> f32
case(
vec![DataType::Float32],
Signature::Any(1),
vec![DataType::Float32],
)?,
];
for case in cases {
let observed = format!("{:?}", coerce(&case.0, &case.1, &case.2)?);
let expected = format!("{:?}", case.3);
assert_eq!(observed, expected);
}
// now cases that are expected to fail
let cases = vec![
// we do not know how to cast bool to UInt16 => fail
case(
vec![DataType::Boolean],
Signature::Uniform(1, vec![DataType::UInt16]),
vec![],
)?,
// u32 and bool are not uniform
case(
vec![DataType::UInt32, DataType::Boolean],
Signature::VariadicEqual,
vec![],
)?,
// bool is not castable to u32
case(
vec![DataType::Boolean, DataType::Boolean],
Signature::Variadic(vec![DataType::UInt32]),
vec![],
)?,
// expected two arguments
case(vec![DataType::UInt32], Signature::Any(2), vec![])?,
];
for case in cases {
if coerce(&case.0, &case.1, &case.2).is_ok() {
return Err(DataFusionError::Plan(format!(
"Error was expected in {:?}",
case
)));
}
}
Ok(())
}
}