blob: b19299981cef39f1cebe5534619fe0c4e25c8d86 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Expression utilities
use std::cmp::Ordering;
use std::collections::{BTreeSet, HashSet};
use std::sync::Arc;
use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
use crate::expr_rewriter::strip_outer_reference;
use crate::{
BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, and,
};
use datafusion_expr_common::signature::{Signature, TypeSignature};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::utils::get_at_indices;
use datafusion_common::{
Column, DFSchema, DFSchemaRef, HashMap, Result, TableReference, internal_err,
plan_err,
};
#[cfg(not(feature = "sql"))]
use crate::expr::{ExceptSelectItem, ExcludeSelectItem};
use indexmap::IndexSet;
#[cfg(feature = "sql")]
use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
/// The value to which `COUNT(*)` is expanded to in
/// `COUNT(<constant>)` expressions
pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
/// Count the number of distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
// Groupings sets have an additional integral column for the grouping id
Ok(grouping_set.distinct_expr().len() + 1)
} else {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
}
}
/// Internal helper that generates indices for powerset subsets using bitset iteration.
/// Returns an iterator of index vectors, where each vector contains the indices
/// of elements to include in that subset.
fn powerset_indices(len: usize) -> impl Iterator<Item = Vec<usize>> {
(0..(1 << len)).map(move |mask| {
let mut indices = vec![];
let mut bitset = mask;
while bitset > 0 {
let rightmost: u64 = bitset & !(bitset - 1);
let idx = rightmost.trailing_zeros() as usize;
indices.push(idx);
bitset &= bitset - 1;
}
indices
})
}
/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
/// including the empty set and S itself.
///
/// Example:
///
/// If S is the set {x, y, z}, then all the subsets of S are \
/// {} \
/// {x} \
/// {y} \
/// {z} \
/// {x, y} \
/// {x, z} \
/// {y, z} \
/// {x, y, z} \
/// and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
///
/// [power set]: https://en.wikipedia.org/wiki/Power_set
pub fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>> {
if slice.len() >= 64 {
return plan_err!("The size of the set must be less than 64");
}
Ok(powerset_indices(slice.len())
.map(|indices| indices.iter().map(|&idx| &slice[idx]).collect())
.collect())
}
/// check the number of expressions contained in the grouping_set
fn check_grouping_set_size_limit(size: usize) -> Result<()> {
let max_grouping_set_size = 65535;
if size > max_grouping_set_size {
return plan_err!(
"The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"
);
}
Ok(())
}
/// check the number of grouping_set contained in the grouping sets
fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
let max_grouping_sets_size = 4096;
if size > max_grouping_sets_size {
return plan_err!(
"The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"
);
}
Ok(())
}
/// Merge two grouping_set
///
/// # Example
/// ```text
/// (A, B), (C, D) -> (A, B, C, D)
/// ```
///
/// # Error
/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
///
/// [`DataFusionError`]: datafusion_common::DataFusionError
fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
check_grouping_set_size_limit(left.len() + right.len())?;
Ok(left.iter().chain(right.iter()).cloned().collect())
}
/// Compute the cross product of two grouping_sets
///
/// # Example
/// ```text
/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
/// ```
///
/// # Error
/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
///
/// [`DataFusionError`]: datafusion_common::DataFusionError
fn cross_join_grouping_sets<T: Clone>(
left: &[Vec<T>],
right: &[Vec<T>],
) -> Result<Vec<Vec<T>>> {
let grouping_sets_size = left.len() * right.len();
check_grouping_sets_size_limit(grouping_sets_size)?;
let mut result = Vec::with_capacity(grouping_sets_size);
for le in left {
for re in right {
result.push(merge_grouping_set(le, re)?);
}
}
Ok(result)
}
/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
/// no conversion will be performed.
///
/// e.g.
///
/// person.id,\
/// GROUPING SETS ((person.age, person.salary),(person.age)),\
/// ROLLUP(person.state, person.birth_date)
///
/// =>
///
/// GROUPING SETS (\
/// (person.id, person.age, person.salary),\
/// (person.id, person.age, person.salary, person.state),\
/// (person.id, person.age, person.salary, person.state, person.birth_date),\
/// (person.id, person.age),\
/// (person.id, person.age, person.state),\
/// (person.id, person.age, person.state, person.birth_date)\
/// )
pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
let has_grouping_set = group_expr
.iter()
.any(|expr| matches!(expr, Expr::GroupingSet(_)));
if !has_grouping_set || group_expr.len() == 1 {
return Ok(group_expr);
}
// Only process mix grouping sets
let partial_sets = group_expr
.iter()
.map(|expr| {
let exprs = match expr {
Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets.iter().map(|e| e.iter().collect()).collect()
}
Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
let grouping_sets = powerset(group_exprs)?;
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets
}
Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
let size = group_exprs.len();
let slice = group_exprs.as_slice();
check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
(0..(size + 1))
.map(|i| slice[0..i].iter().collect())
.collect()
}
expr => vec![vec![expr]],
};
Ok(exprs)
})
.collect::<Result<Vec<_>>>()?;
// Cross Join
let grouping_sets = partial_sets
.into_iter()
.map(Ok)
.reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
.transpose()?
.map(|e| {
e.into_iter()
.map(|e| e.into_iter().cloned().collect())
.collect()
})
.unwrap_or_default();
Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
grouping_sets,
))])
}
/// Find all distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr())
} else {
Ok(group_expr
.iter()
.collect::<IndexSet<_>>()
.into_iter()
.collect())
}
}
/// Recursively walk an expression tree, collecting the unique set of columns
/// referenced in the expression
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
expr.apply(|expr| {
match expr {
Expr::Column(qc) => {
accum.insert(qc.clone());
}
// Use explicit pattern match instead of a default
// implementation, so that in the future if someone adds
// new Expr types, they will check here as well
// TODO: remove the next line after `Expr::Wildcard` is removed
#[expect(deprecated)]
Expr::Unnest(_)
| Expr::ScalarVariable(_, _)
| Expr::Alias(_)
| Expr::Literal(_, _)
| Expr::BinaryExpr { .. }
| Expr::Like { .. }
| Expr::SimilarTo { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Between { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::SetComparison(_)
| Expr::ScalarSubquery(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(_)
| Expr::OuterReferenceColumn { .. } => {}
}
Ok(TreeNodeRecursion::Continue)
})
.map(|_| ())
}
/// Find excluded columns in the schema, if any
/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
fn get_excluded_columns(
opt_exclude: Option<&ExcludeSelectItem>,
opt_except: Option<&ExceptSelectItem>,
schema: &DFSchema,
qualifier: Option<&TableReference>,
) -> Result<Vec<Column>> {
let mut idents = vec![];
if let Some(excepts) = opt_except {
idents.push(&excepts.first_element);
idents.extend(&excepts.additional_elements);
}
if let Some(exclude) = opt_exclude {
match exclude {
ExcludeSelectItem::Single(ident) => idents.push(ident),
ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
}
}
// Excluded columns should be unique
let n_elem = idents.len();
let unique_idents = idents.into_iter().collect::<HashSet<_>>();
// If HashSet size, and vector length are different, this means that some of the excluded columns
// are not unique. In this case return error.
if n_elem != unique_idents.len() {
return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
}
let mut result = vec![];
for ident in unique_idents.into_iter() {
let col_name = ident.value.as_str();
let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
result.push(Column::from((qualifier, field)));
}
Ok(result)
}
/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
fn get_exprs_except_skipped(
schema: &DFSchema,
columns_to_skip: &HashSet<Column>,
) -> Vec<Expr> {
if columns_to_skip.is_empty() {
schema.iter().map(Expr::from).collect::<Vec<Expr>>()
} else {
schema
.columns()
.iter()
.filter_map(|c| {
if !columns_to_skip.contains(c) {
Some(Expr::Column(c.clone()))
} else {
None
}
})
.collect::<Vec<Expr>>()
}
}
/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice
/// (once for each join side), but an unqualified wildcard should include it only once.
/// This function returns the columns that should be excluded.
fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
let using_columns = plan.using_columns()?;
let excluded = using_columns
.into_iter()
// For each USING JOIN condition, only expand to one of each join column in projection
.flat_map(|cols| {
let mut cols = cols.into_iter().collect::<Vec<_>>();
// sort join columns to make sure we consistently keep the same
// qualified column
cols.sort();
let mut out_column_names: HashSet<String> = HashSet::new();
cols.into_iter().filter_map(move |c| {
if out_column_names.contains(&c.name) {
Some(c)
} else {
out_column_names.insert(c.name);
None
}
})
})
.collect::<HashSet<_>>();
Ok(excluded)
}
/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
pub fn expand_wildcard(
schema: &DFSchema,
plan: &LogicalPlan,
wildcard_options: Option<&WildcardOptions>,
) -> Result<Vec<Expr>> {
let mut columns_to_skip = exclude_using_columns(plan)?;
let excluded_columns = if let Some(WildcardOptions {
exclude: opt_exclude,
except: opt_except,
..
}) = wildcard_options
{
get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
} else {
vec![]
};
// Add each excluded `Column` to columns_to_skip
columns_to_skip.extend(excluded_columns);
Ok(get_exprs_except_skipped(schema, &columns_to_skip))
}
/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
pub fn expand_qualified_wildcard(
qualifier: &TableReference,
schema: &DFSchema,
wildcard_options: Option<&WildcardOptions>,
) -> Result<Vec<Expr>> {
let qualified_indices = schema.fields_indices_with_qualified(qualifier);
let projected_func_dependencies = schema
.functional_dependencies()
.project_functional_dependencies(&qualified_indices, qualified_indices.len());
let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
if fields_with_qualified.is_empty() {
return plan_err!("Invalid qualifier {qualifier}");
}
let qualified_schema = Arc::new(Schema::new_with_metadata(
fields_with_qualified,
schema.metadata().clone(),
));
let qualified_dfschema =
DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
.with_functional_dependencies(projected_func_dependencies)?;
let excluded_columns = if let Some(WildcardOptions {
exclude: opt_exclude,
except: opt_except,
..
}) = wildcard_options
{
get_excluded_columns(
opt_exclude.as_ref(),
opt_except.as_ref(),
schema,
Some(qualifier),
)?
} else {
vec![]
};
// Add each excluded `Column` to columns_to_skip
let mut columns_to_skip = HashSet::new();
columns_to_skip.extend(excluded_columns);
Ok(get_exprs_except_skipped(
&qualified_dfschema,
&columns_to_skip,
))
}
/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
type WindowSortKey = Vec<(Sort, bool)>;
/// Generate a sort key for a given window expr's partition_by and order_by expr
pub fn generate_sort_key(
partition_by: &[Expr],
order_by: &[Sort],
) -> Result<WindowSortKey> {
let normalized_order_by_keys = order_by
.iter()
.map(|e| {
let Sort { expr, .. } = e;
Sort::new(expr.clone(), true, false)
})
.collect::<Vec<_>>();
let mut final_sort_keys = vec![];
let mut is_partition_flag = vec![];
partition_by.iter().for_each(|e| {
// By default, create sort key with ASC is true and NULLS LAST to be consistent with
// PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
let e = e.clone().sort(true, false);
if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
let order_by_key = &order_by[pos];
if !final_sort_keys.contains(order_by_key) {
final_sort_keys.push(order_by_key.clone());
is_partition_flag.push(true);
}
} else if !final_sort_keys.contains(&e) {
final_sort_keys.push(e);
is_partition_flag.push(true);
}
});
order_by.iter().for_each(|e| {
if !final_sort_keys.contains(e) {
final_sort_keys.push(e.clone());
is_partition_flag.push(false);
}
});
let res = final_sort_keys
.into_iter()
.zip(is_partition_flag)
.collect::<Vec<_>>();
Ok(res)
}
/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
pub fn compare_sort_expr(
sort_expr_a: &Sort,
sort_expr_b: &Sort,
schema: &DFSchemaRef,
) -> Ordering {
let Sort {
expr: expr_a,
asc: asc_a,
nulls_first: nulls_first_a,
} = sort_expr_a;
let Sort {
expr: expr_b,
asc: asc_b,
nulls_first: nulls_first_b,
} = sort_expr_b;
let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
match idx_a.cmp(idx_b) {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
Ordering::Less => return Ordering::Greater,
Ordering::Greater => {
return Ordering::Less;
}
Ordering::Equal => {}
}
match (asc_a, asc_b) {
(true, false) => {
return Ordering::Greater;
}
(false, true) => {
return Ordering::Less;
}
_ => {}
}
match (nulls_first_a, nulls_first_b) {
(true, false) => {
return Ordering::Less;
}
(false, true) => {
return Ordering::Greater;
}
_ => {}
}
Ordering::Equal
}
/// Group a slice of window expression expr by their order by expressions
pub fn group_window_expr_by_sort_keys(
window_expr: impl IntoIterator<Item = Expr>,
) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
let mut result = vec![];
window_expr.into_iter().try_for_each(|expr| match &expr {
Expr::WindowFunction(window_fun) => {
let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
let sort_key = generate_sort_key(partition_by, order_by)?;
if let Some((_, values)) = result.iter_mut().find(
|group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
) {
values.push(expr);
} else {
result.push((sort_key, vec![expr]))
}
Ok(())
}
other => internal_err!(
"Impossibly got non-window expr {other:?}"
),
})?;
Ok(result)
}
/// Collect all deeply nested `Expr::AggregateFunction`.
/// They are returned in order of occurrence (depth
/// first), with duplicates omitted.
pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::AggregateFunction { .. })
})
}
/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
/// (depth first), with duplicates omitted.
pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(nested_expr, Expr::WindowFunction { .. })
})
}
/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
/// (depth first), with duplicates omitted.
pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
find_exprs_in_expr(expr, &|nested_expr| {
matches!(nested_expr, Expr::OuterReferenceColumn { .. })
})
}
/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
/// pass the provided test. The returned `Expr`'s are deduplicated and returned
/// in order of appearance (depth first).
fn find_exprs_in_exprs<'a, F>(
exprs: impl IntoIterator<Item = &'a Expr>,
test_fn: &F,
) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
exprs
.into_iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
/// provided test. The returned `Expr`'s are deduplicated and returned in order
/// of appearance (depth first).
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
let mut exprs = vec![];
expr.apply(|expr| {
if test_fn(expr) {
if !(exprs.contains(expr)) {
exprs.push(expr.clone())
}
// Stop recursing down this expr once we find a match
return Ok(TreeNodeRecursion::Jump);
}
Ok(TreeNodeRecursion::Continue)
})
// pre_visit always returns OK, so this will always too
.expect("no way to return error during recursion");
exprs
}
/// Recursively inspect an [`Expr`] and all its children.
pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
let mut err = Ok(());
expr.apply(|expr| {
if let Err(e) = f(expr) {
// Save the error for later (it may not be a DataFusionError)
err = Err(e);
Ok(TreeNodeRecursion::Stop)
} else {
// keep going
Ok(TreeNodeRecursion::Continue)
}
})
// The closure always returns OK, so this will always too
.expect("no way to return error during recursion");
err
}
/// Create schema fields from an expression list, for use in result set schema construction
///
/// This function converts a list of expressions into a list of complete schema fields,
/// making comprehensive determinations about each field's properties including:
/// - **Data type**: Resolved based on expression type and input schema context
/// - **Nullability**: Determined by expression-specific nullability rules
/// - **Metadata**: Computed based on expression type (preserving, merging, or generating new metadata)
/// - **Table reference scoping**: Establishing proper qualified field references
///
/// Each expression is converted to a field by calling [`Expr::to_field`], which performs
/// the complete field resolution process for all field properties.
///
/// # Returns
///
/// A `Result` containing a vector of `(Option<TableReference>, Arc<Field>)` tuples,
/// where each Field contains complete schema information (type, nullability, metadata)
/// and proper table reference scoping for the corresponding expression.
pub fn exprlist_to_fields<'a>(
exprs: impl IntoIterator<Item = &'a Expr>,
plan: &LogicalPlan,
) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
// Look for exact match in plan's output schema
let input_schema = plan.schema();
exprs
.into_iter()
.map(|e| e.to_field(input_schema))
.collect()
}
/// Convert an expression into Column expression if it's already provided as input plan.
///
/// For example, it rewrites:
///
/// ```text
/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
/// .project(vec![col("c1"), sum(col("c2"))?
/// ```
///
/// Into:
///
/// ```text
/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
/// .project(vec![col("c1"), col("SUM(c2)")?
/// ```
pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
let output_exprs = match input.columnized_output_exprs() {
Ok(exprs) if !exprs.is_empty() => exprs,
_ => return Ok(e),
};
let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
e.transform_down(|node: Expr| match exprs_map.get(&node) {
Some(column) => Ok(Transformed::new(
Expr::Column(column.clone()),
true,
TreeNodeRecursion::Jump,
)),
None => Ok(Transformed::no(node)),
})
.data()
}
/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
/// appearance (depth first), and may contain duplicates.
pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
exprs
.iter()
.flat_map(find_columns_referenced_by_expr)
.map(Expr::Column)
.collect()
}
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
let mut exprs = vec![];
e.apply(|expr| {
if let Expr::Column(c) = expr {
exprs.push(c.clone())
}
Ok(TreeNodeRecursion::Continue)
})
// As the closure always returns Ok, this "can't" error
.expect("Unexpected error");
exprs
}
/// Convert any `Expr` to an `Expr::Column`.
pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(col) => {
let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
Ok(Expr::from(Column::from((qualifier, field))))
}
_ => Ok(Expr::Column(Column::from_name(
expr.schema_name().to_string(),
))),
}
}
/// Recursively walk an expression tree, collecting the column indexes
/// referenced in the expression
pub(crate) fn find_column_indexes_referenced_by_expr(
e: &Expr,
schema: &DFSchemaRef,
) -> Vec<usize> {
let mut indexes = vec![];
e.apply(|expr| {
match expr {
Expr::Column(qc) => {
if let Ok(idx) = schema.index_of_column(qc) {
indexes.push(idx);
}
}
Expr::Literal(_, _) => {
indexes.push(usize::MAX);
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
indexes
}
/// Can this data type be used in hash join equal conditions??
/// Data types here come from function 'equal_rows', if more data types are supported
/// in create_hashes, add those data types here to generate join logical plan.
pub fn can_hash(data_type: &DataType) -> bool {
match data_type {
DataType::Null => true,
DataType::Boolean => true,
DataType::Int8 => true,
DataType::Int16 => true,
DataType::Int32 => true,
DataType::Int64 => true,
DataType::UInt8 => true,
DataType::UInt16 => true,
DataType::UInt32 => true,
DataType::UInt64 => true,
DataType::Float16 => true,
DataType::Float32 => true,
DataType::Float64 => true,
DataType::Decimal32(_, _) => true,
DataType::Decimal64(_, _) => true,
DataType::Decimal128(_, _) => true,
DataType::Decimal256(_, _) => true,
DataType::Timestamp(_, _) => true,
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
DataType::Utf8View => true,
DataType::Binary => true,
DataType::LargeBinary => true,
DataType::BinaryView => true,
DataType::Date32 => true,
DataType::Date64 => true,
DataType::Time32(_) => true,
DataType::Time64(_) => true,
DataType::Duration(_) => true,
DataType::Interval(_) => true,
DataType::FixedSizeBinary(_) => true,
DataType::Dictionary(key_type, value_type) => {
DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
}
DataType::List(value_type) => can_hash(value_type.data_type()),
DataType::LargeList(value_type) => can_hash(value_type.data_type()),
DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
DataType::ListView(_)
| DataType::LargeListView(_)
| DataType::Union(_, _)
| DataType::RunEndEncoded(_, _) => false,
}
}
/// Check whether all columns are from the schema.
pub fn check_all_columns_from_schema(
columns: &HashSet<&Column>,
schema: &DFSchema,
) -> Result<bool> {
for col in columns.iter() {
let exist = schema.is_column_from_schema(col);
if !exist {
return Ok(false);
}
}
Ok(true)
}
/// Give two sides of the equijoin predicate, return a valid join key pair.
/// If there is no valid join key pair, return None.
///
/// A valid join means:
/// 1. All referenced column of the left side is from the left schema, and
/// all referenced column of the right side is from the right schema.
/// 2. Or opposite. All referenced column of the left side is from the right schema,
/// and the right side is from the left schema.
pub fn find_valid_equijoin_key_pair(
left_key: &Expr,
right_key: &Expr,
left_schema: &DFSchema,
right_schema: &DFSchema,
) -> Result<Option<(Expr, Expr)>> {
let left_using_columns = left_key.column_refs();
let right_using_columns = right_key.column_refs();
// Conditions like a = 10, will be added to non-equijoin.
if left_using_columns.is_empty() || right_using_columns.is_empty() {
return Ok(None);
}
if check_all_columns_from_schema(&left_using_columns, left_schema)?
&& check_all_columns_from_schema(&right_using_columns, right_schema)?
{
return Ok(Some((left_key.clone(), right_key.clone())));
} else if check_all_columns_from_schema(&right_using_columns, left_schema)?
&& check_all_columns_from_schema(&left_using_columns, right_schema)?
{
return Ok(Some((right_key.clone(), left_key.clone())));
}
Ok(None)
}
/// Creates a detailed error message for a function with wrong signature.
///
/// For example, a query like `select round(3.14, 1.1);` would yield:
/// ```text
/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
/// Candidate functions:
/// round(Float64, Int64)
/// round(Float32, Int64)
/// round(Float64)
/// round(Float32)
/// ```
#[expect(clippy::needless_pass_by_value)]
#[deprecated(since = "53.0.0", note = "Internal function")]
pub fn generate_signature_error_msg(
func_name: &str,
func_signature: Signature,
input_expr_types: &[DataType],
) -> String {
let candidate_signatures = func_signature
.type_signature
.to_string_repr_with_names(func_signature.parameter_names.as_deref())
.iter()
.map(|args_str| format!("\t{func_name}({args_str})"))
.collect::<Vec<String>>()
.join("\n");
format!(
"No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
func_name,
TypeSignature::join_types(input_expr_types, ", "),
candidate_signatures
)
}
/// Creates a detailed error message for a function with wrong signature.
///
/// For example, a query like `select round(3.14, 1.1);` would yield:
/// ```text
/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
/// Candidate functions:
/// round(Float64, Int64)
/// round(Float32, Int64)
/// round(Float64)
/// round(Float32)
/// ```
pub(crate) fn generate_signature_error_message(
func_name: &str,
func_signature: &Signature,
input_expr_types: &[DataType],
) -> String {
#[expect(deprecated)]
generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types)
}
/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
///
/// See [`split_conjunction_owned`] for more details and an example.
pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
split_conjunction_impl(expr, vec![])
}
fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
let exprs = split_conjunction_impl(left, exprs);
split_conjunction_impl(right, exprs)
}
Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
other => {
exprs.push(other);
exprs
}
}
}
/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
///
/// See [`split_conjunction_owned`] for more details and an example.
pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
let mut stack = vec![expr];
std::iter::from_fn(move || {
while let Some(expr) = stack.pop() {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
stack.push(right);
stack.push(left);
}
Expr::Alias(Alias { expr, .. }) => stack.push(expr),
other => return Some(other),
}
}
None
})
}
/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
///
/// See [`split_conjunction_owned`] for more details and an example.
pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
let mut stack = vec![expr];
std::iter::from_fn(move || {
while let Some(expr) = stack.pop() {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
stack.push(*right);
stack.push(*left);
}
Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
other => return Some(other),
}
}
None
})
}
/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
///
/// This is often used to "split" filter expressions such as `col1 = 5
/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
///
/// # Example
/// ```
/// # use datafusion_expr::{col, lit};
/// # use datafusion_expr::utils::split_conjunction_owned;
/// // a=1 AND b=2
/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
///
/// // [a=1, b=2]
/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
///
/// // use split_conjunction_owned to split them
/// assert_eq!(split_conjunction_owned(expr), split);
/// ```
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
split_binary_owned(expr, Operator::And)
}
/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
///
/// This is often used to "split" expressions such as `col1 = 5
/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
///
/// # Example
/// ```
/// # use datafusion_expr::{col, lit, Operator};
/// # use datafusion_expr::utils::split_binary_owned;
/// # use std::ops::Add;
/// // a=1 + b=2
/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
///
/// // [a=1, b=2]
/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
///
/// // use split_binary_owned to split them
/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
/// ```
pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
split_binary_owned_impl(expr, op, vec![])
}
fn split_binary_owned_impl(
expr: Expr,
operator: Operator,
mut exprs: Vec<Expr>,
) -> Vec<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
let exprs = split_binary_owned_impl(*left, operator, exprs);
split_binary_owned_impl(*right, operator, exprs)
}
Expr::Alias(Alias { expr, .. }) => {
split_binary_owned_impl(*expr, operator, exprs)
}
other => {
exprs.push(other);
exprs
}
}
}
/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
///
/// See [`split_binary_owned`] for more details and an example.
pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
split_binary_impl(expr, op, vec![])
}
fn split_binary_impl<'a>(
expr: &'a Expr,
operator: Operator,
mut exprs: Vec<&'a Expr>,
) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
let exprs = split_binary_impl(left, operator, exprs);
split_binary_impl(right, operator, exprs)
}
Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
other => {
exprs.push(other);
exprs
}
}
}
/// Combines an array of filter expressions into a single filter
/// expression consisting of the input filter expressions joined with
/// logical AND.
///
/// Returns None if the filters array is empty.
///
/// # Example
/// ```
/// # use datafusion_expr::{col, lit};
/// # use datafusion_expr::utils::conjunction;
/// // a=1 AND b=2
/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
///
/// // [a=1, b=2]
/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
///
/// // use conjunction to join them together with `AND`
/// assert_eq!(conjunction(split), Some(expr));
/// ```
pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(Expr::and)
}
/// Combines an array of filter expressions into a single filter
/// expression consisting of the input filter expressions joined with
/// logical OR.
///
/// Returns None if the filters array is empty.
///
/// # Example
/// ```
/// # use datafusion_expr::{col, lit};
/// # use datafusion_expr::utils::disjunction;
/// // a=1 OR b=2
/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
///
/// // [a=1, b=2]
/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
///
/// // use disjunction to join them together with `OR`
/// assert_eq!(disjunction(split), Some(expr));
/// ```
pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(Expr::or)
}
/// Returns a new [LogicalPlan] that filters the output of `plan` with a
/// [LogicalPlan::Filter] with all `predicates` ANDed.
///
/// # Example
/// Before:
/// ```text
/// plan
/// ```
///
/// After:
/// ```text
/// Filter(predicate)
/// plan
/// ```
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
// reduce filters to a single filter with an AND
let predicate = predicates
.iter()
.skip(1)
.fold(predicates[0].clone(), |acc, predicate| {
and(acc, (*predicate).to_owned())
});
Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(plan),
)?))
}
/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
/// one not in the subquery (closed upon from outer scope)
///
/// # Arguments
///
/// * `exprs` - List of expressions that may or may not be joins
///
/// # Return value
///
/// Tuple of (expressions containing joins, remaining non-join expressions)
pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
let mut joins = vec![];
let mut others = vec![];
for filter in exprs.into_iter() {
// If the expression contains correlated predicates, add it to join filters
if filter.contains_outer() {
if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
{
joins.push(strip_outer_reference((*filter).clone()));
}
} else {
others.push((*filter).clone());
}
}
Ok((joins, others))
}
/// Returns the first (and only) element in a slice, or an error
///
/// # Arguments
///
/// * `slice` - The slice to extract from
///
/// # Return value
///
/// The first element, or an error
pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
match slice {
[it] => Ok(it),
[] => plan_err!("No items found!"),
_ => plan_err!("More than one item found!"),
}
}
/// merge inputs schema into a single schema.
///
/// This function merges schemas from multiple logical plan inputs using [`DFSchema::merge`].
/// Refer to that documentation for details on precedence and metadata handling.
pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
if inputs.len() == 1 {
inputs[0].schema().as_ref().clone()
} else {
inputs.iter().map(|input| input.schema()).fold(
DFSchema::empty(),
|mut lhs, rhs| {
lhs.merge(rhs);
lhs
},
)
}
}
/// Build state name. State is the intermediate state of the aggregate function.
pub fn format_state_name(name: &str, state_name: &str) -> String {
format!("{name}[{state_name}]")
}
/// Determine the set of [`Column`]s produced by the subquery.
pub fn collect_subquery_cols(
exprs: &[Expr],
subquery_schema: &DFSchema,
) -> Result<BTreeSet<Column>> {
exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
let mut using_cols: Vec<Column> = vec![];
for col in expr.column_refs().into_iter() {
if subquery_schema.has_column(col) {
using_cols.push(col.clone());
}
}
cols.extend(using_cols);
Result::<_>::Ok(cols)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
Cast, ExprFunctionExt, WindowFunctionDefinition, col, cube,
expr::WindowFunction,
expr_vec_fmt, grouping_set, lit, rollup,
test::function_stub::{max_udaf, min_udaf, sum_udaf},
};
use arrow::datatypes::{UnionFields, UnionMode};
use datafusion_expr_common::signature::{TypeSignature, Volatility};
#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
let result = group_window_expr_by_sort_keys(vec![])?;
let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
let max1 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
));
let max2 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
));
let min3 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![col("name")],
));
let sum4 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![col("age")],
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
let key = vec![];
let expected: Vec<(WindowSortKey, Vec<Expr>)> =
vec![(key, vec![max1, max2, min3, sum4])];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_group_window_expr_by_sort_keys() -> Result<()> {
let age_asc = Sort::new(col("age"), true, true);
let name_desc = Sort::new(col("name"), false, true);
let created_at_desc = Sort::new(col("created_at"), false, true);
let max1 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
))
.order_by(vec![age_asc.clone(), name_desc.clone()])
.build()
.unwrap();
let max2 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
));
let min3 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![col("name")],
))
.order_by(vec![age_asc.clone(), name_desc.clone()])
.build()
.unwrap();
let sum4 = Expr::from(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![col("age")],
))
.order_by(vec![
name_desc.clone(),
age_asc.clone(),
created_at_desc.clone(),
])
.build()
.unwrap();
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
let key2 = vec![];
let key3 = vec![
(name_desc, false),
(age_asc, false),
(created_at_desc, false),
];
let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
(key1, vec![max1, min3]),
(key2, vec![max2]),
(key3, vec![sum4]),
];
assert_eq!(expected, result);
Ok(())
}
#[test]
fn avoid_generate_duplicate_sort_keys() -> Result<()> {
let asc_or_desc = [true, false];
let nulls_first_or_last = [true, false];
let partition_by = &[col("age"), col("name"), col("created_at")];
for asc_ in asc_or_desc {
for nulls_first_ in nulls_first_or_last {
let order_by = &[
Sort {
expr: col("age"),
asc: asc_,
nulls_first: nulls_first_,
},
Sort {
expr: col("name"),
asc: asc_,
nulls_first: nulls_first_,
},
];
let expected = vec![
(
Sort {
expr: col("age"),
asc: asc_,
nulls_first: nulls_first_,
},
true,
),
(
Sort {
expr: col("name"),
asc: asc_,
nulls_first: nulls_first_,
},
true,
),
(
Sort {
expr: col("created_at"),
asc: true,
nulls_first: false,
},
true,
),
];
let result = generate_sort_key(partition_by, order_by)?;
assert_eq!(expected, result);
}
}
Ok(())
}
#[test]
fn test_enumerate_grouping_sets() -> Result<()> {
let multi_cols = vec![col("col1"), col("col2"), col("col3")];
let simple_col = col("simple_col");
let cube = cube(multi_cols.clone());
let rollup = rollup(multi_cols.clone());
let grouping_set = grouping_set(vec![multi_cols]);
// 1. col
let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[simple_col]", &result);
// 2. cube
let sets = enumerate_grouping_sets(vec![cube.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[CUBE (col1, col2, col3)]", &result);
// 3. rollup
let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
// 4. col + cube
let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col2), \
(simple_col, col1, col2), \
(simple_col, col3), \
(simple_col, col1, col3), \
(simple_col, col2, col3), \
(simple_col, col1, col2, col3))]",
&result
);
// 5. col + rollup
let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col3))]",
&result
);
// 6. col + grouping_set
let sets =
enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col, col1, col2, col3))]",
&result
);
// 7. col + grouping_set + rollup
let sets = enumerate_grouping_sets(vec![
simple_col.clone(),
grouping_set,
rollup.clone(),
])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col, col1, col2, col3), \
(simple_col, col1, col2, col3, col1), \
(simple_col, col1, col2, col3, col1, col2), \
(simple_col, col1, col2, col3, col1, col2, col3))]",
&result
);
// 8. col + cube + rollup
let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
let result = format!("[{}]", expr_vec_fmt!(sets));
assert_eq!(
"[GROUPING SETS (\
(simple_col), \
(simple_col, col1), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col3), \
(simple_col, col1), \
(simple_col, col1, col1), \
(simple_col, col1, col1, col2), \
(simple_col, col1, col1, col2, col3), \
(simple_col, col2), \
(simple_col, col2, col1), \
(simple_col, col2, col1, col2), \
(simple_col, col2, col1, col2, col3), \
(simple_col, col1, col2), \
(simple_col, col1, col2, col1), \
(simple_col, col1, col2, col1, col2), \
(simple_col, col1, col2, col1, col2, col3), \
(simple_col, col3), \
(simple_col, col3, col1), \
(simple_col, col3, col1, col2), \
(simple_col, col3, col1, col2, col3), \
(simple_col, col1, col3), \
(simple_col, col1, col3, col1), \
(simple_col, col1, col3, col1, col2), \
(simple_col, col1, col3, col1, col2, col3), \
(simple_col, col2, col3), \
(simple_col, col2, col3, col1), \
(simple_col, col2, col3, col1, col2), \
(simple_col, col2, col3, col1, col2, col3), \
(simple_col, col1, col2, col3), \
(simple_col, col1, col2, col3, col1), \
(simple_col, col1, col2, col3, col1, col2), \
(simple_col, col1, col2, col3, col1, col2, col3))]",
&result
);
Ok(())
}
#[test]
fn test_split_conjunction() {
let expr = col("a");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_conjunction_two() {
let expr = col("a").eq(lit(5)).and(col("b"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_alias() {
let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b"); // has no alias
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_binary_owned() {
let expr = col("a");
assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
}
#[test]
fn test_split_binary_owned_two() {
assert_eq!(
split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_binary_owned_different_op() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(
// expr is connected by OR, but pass in AND
split_binary_owned(expr.clone(), Operator::And),
vec![expr]
);
}
#[test]
fn test_split_conjunction_owned() {
let expr = col("a");
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_split_conjunction_owned_two() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_conjunction_owned_alias() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
vec![
col("a").eq(lit(5)),
// no alias on b
col("b"),
]
);
}
#[test]
fn test_conjunction_empty() {
assert_eq!(conjunction(vec![]), None);
}
#[test]
fn test_conjunction() {
// `[A, B, C]`
let expr = conjunction(vec![col("a"), col("b"), col("c")]);
// --> `(A AND B) AND C`
assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
// which is different than `A AND (B AND C)`
assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
}
#[test]
fn test_disjunction_empty() {
assert_eq!(disjunction(vec![]), None);
}
#[test]
fn test_disjunction() {
// `[A, B, C]`
let expr = disjunction(vec![col("a"), col("b"), col("c")]);
// --> `(A OR B) OR C`
assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
// which is different than `A OR (B OR C)`
assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
}
#[test]
fn test_split_conjunction_owned_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
#[test]
fn test_can_hash() {
let union_fields: UnionFields = [
(0, Arc::new(Field::new("A", DataType::Int32, true))),
(1, Arc::new(Field::new("B", DataType::Float64, true))),
]
.into_iter()
.collect();
let union_type = DataType::Union(union_fields, UnionMode::Sparse);
assert!(!can_hash(&union_type));
let list_union_type =
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
assert!(!can_hash(&list_union_type));
}
#[test]
fn test_generate_signature_error_msg_with_parameter_names() {
let sig = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Int64,
DataType::Int64,
]),
],
Volatility::Immutable,
)
.with_parameter_names(vec![
"str".to_string(),
"start_pos".to_string(),
"length".to_string(),
])
.expect("valid parameter names");
// Generate error message with only 1 argument provided
let error_msg =
generate_signature_error_message("substr", &sig, &[DataType::Utf8]);
assert!(
error_msg.contains("str: Utf8, start_pos: Int64"),
"Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}"
);
assert!(
error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"),
"Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}"
);
}
#[test]
fn test_generate_signature_error_msg_without_parameter_names() {
let sig = Signature::one_of(
vec![TypeSignature::Any(2), TypeSignature::Any(3)],
Volatility::Immutable,
);
let error_msg =
generate_signature_error_message("my_func", &sig, &[DataType::Int32]);
assert!(
error_msg.contains("Any, Any"),
"Expected 'Any, Any' without parameter names, got: {error_msg}"
);
}
}