blob: 43fb98e54545c5db5efbd0142e457b8811ed19c1 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! SQL Utility Functions
use std::vec;
use arrow::datatypes::{
DECIMAL_DEFAULT_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType,
};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
};
use datafusion_common::{
Column, DFSchemaRef, Diagnostic, HashMap, Result, ScalarValue,
assert_or_internal_err, exec_datafusion_err, exec_err, internal_err, plan_err,
};
use datafusion_expr::builder::get_struct_unnested_columns;
use datafusion_expr::expr::{
Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
use datafusion_expr::{
ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, col, expr_vec_fmt,
};
use indexmap::IndexMap;
use sqlparser::ast::{Ident, Value};
/// Make a best-effort attempt at resolving all columns in the expression tree
pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
expr.clone()
.transform_up(|nested_expr| {
match nested_expr {
Expr::Column(col) => {
let (qualifier, field) =
plan.schema().qualified_field_from_column(&col)?;
Ok(Transformed::yes(Expr::Column(Column::from((
qualifier, field,
)))))
}
_ => {
// keep recursing
Ok(Transformed::no(nested_expr))
}
}
})
.data()
}
/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
///
/// For example, the expression `a + b < 1` would require, as input, the 2
/// individual columns, `a` and `b`. But, if the base expressions already
/// contain the `a + b` result, then that may be used in lieu of the `a` and
/// `b` columns.
///
/// This is useful in the context of a query like:
///
/// SELECT a + b < 1 ... GROUP BY a + b
///
/// where post-aggregation, `a + b` need not be a projection against the
/// individual columns `a` and `b`, but rather it is a projection against the
/// `a + b` found in the GROUP BY.
pub(crate) fn rebase_expr(
expr: &Expr,
base_exprs: &[Expr],
plan: &LogicalPlan,
) -> Result<Expr> {
expr.clone()
.transform_down(|nested_expr| {
if base_exprs.contains(&nested_expr) {
Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
} else {
Ok(Transformed::no(nested_expr))
}
})
.data()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
Projection,
Having,
Qualify,
OrderBy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CheckColumnsSatisfyExprsPurpose {
Aggregate(CheckColumnsMustReferenceAggregatePurpose),
}
impl CheckColumnsSatisfyExprsPurpose {
fn message_prefix(&self) -> &'static str {
match self {
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
"Column in SELECT must be in GROUP BY or an aggregate function"
}
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
"Column in HAVING must be in GROUP BY or an aggregate function"
}
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
"Column in QUALIFY must be in GROUP BY or an aggregate function"
}
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::OrderBy) => {
"Column in ORDER BY must be in GROUP BY or an aggregate function"
}
}
}
fn diagnostic_message(&self, expr: &Expr) -> String {
format!(
"'{expr}' must appear in GROUP BY clause because it's not an aggregate expression"
)
}
}
/// Determines if the set of `Expr`'s are a valid projection on the input
/// `Expr::Column`'s.
pub(crate) fn check_columns_satisfy_exprs(
columns: &[Expr],
exprs: &[Expr],
purpose: CheckColumnsSatisfyExprsPurpose,
) -> Result<()> {
columns.iter().try_for_each(|c| match c {
Expr::Column(_) => Ok(()),
_ => internal_err!("Expr::Column are required"),
})?;
let column_exprs = find_column_exprs(exprs);
for e in &column_exprs {
match e {
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
for e in exprs {
check_column_satisfies_expr(columns, e, purpose)?;
}
}
Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
for e in exprs {
check_column_satisfies_expr(columns, e, purpose)?;
}
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
for exprs in lists_of_exprs {
for e in exprs {
check_column_satisfies_expr(columns, e, purpose)?;
}
}
}
_ => check_column_satisfies_expr(columns, e, purpose)?,
}
}
Ok(())
}
fn check_column_satisfies_expr(
columns: &[Expr],
expr: &Expr,
purpose: CheckColumnsSatisfyExprsPurpose,
) -> Result<()> {
if !columns.contains(expr) {
let diagnostic = Diagnostic::new_error(
purpose.diagnostic_message(expr),
expr.spans().and_then(|spans| spans.first()),
)
.with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
return plan_err!(
"{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
purpose.message_prefix(),
expr,
expr_vec_fmt!(columns);
diagnostic=diagnostic
);
}
Ok(())
}
/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
/// aliasing.
pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
exprs
.iter()
.filter_map(|expr| match expr {
Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
_ => None,
})
.collect::<HashMap<String, Expr>>()
}
/// Given an expression that's literal int encoding position, lookup the corresponding expression
/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
/// otherwise, returns planning error.
/// If input expression is not an int literal, returns expression as-is.
pub(crate) fn resolve_positions_to_exprs(
expr: Expr,
select_exprs: &[Expr],
) -> Result<Expr> {
match expr {
// sql_expr_to_logical_expr maps number to i64
// https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
Expr::Literal(ScalarValue::Int64(Some(position)), _)
if position > 0_i64 && position <= select_exprs.len() as i64 =>
{
let index = (position - 1) as usize;
let select_expr = &select_exprs[index];
Ok(match select_expr {
Expr::Alias(Alias { expr, .. }) => *expr.clone(),
_ => select_expr.clone(),
})
}
Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
"Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
position,
select_exprs.len()
),
_ => Ok(expr),
}
}
/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
/// alias' underlying `Expr`.
pub(crate) fn resolve_aliases_to_exprs(
expr: Expr,
aliases: &HashMap<String, Expr>,
) -> Result<Expr> {
expr.transform_up(|nested_expr| match nested_expr {
Expr::Column(c) if c.relation.is_none() => {
if let Some(aliased_expr) = aliases.get(&c.name) {
Ok(Transformed::yes(aliased_expr.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
}
_ => Ok(Transformed::no(nested_expr)),
})
.data()
}
/// Given a slice of window expressions sharing the same sort key, find their common partition
/// keys.
pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
let all_partition_keys = window_exprs
.iter()
.map(|expr| match expr {
Expr::WindowFunction(window_fun) => {
let WindowFunction {
params: WindowFunctionParams { partition_by, .. },
..
} = window_fun.as_ref();
Ok(partition_by)
}
Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
Expr::WindowFunction(window_fun) => {
let WindowFunction {
params: WindowFunctionParams { partition_by, .. },
..
} = window_fun.as_ref();
Ok(partition_by)
}
expr => exec_err!("Impossibly got non-window expr {expr:?}"),
},
expr => exec_err!("Impossibly got non-window expr {expr:?}"),
})
.collect::<Result<Vec<_>>>()?;
let result = all_partition_keys
.iter()
.min_by_key(|s| s.len())
.ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
Ok(result)
}
/// Returns a validated `DataType` for the specified precision and
/// scale
pub(crate) fn make_decimal_type(
precision: Option<u64>,
scale: Option<u64>,
) -> Result<DataType> {
// postgres like behavior
let (precision, scale) = match (precision, scale) {
(Some(p), Some(s)) => (p as u8, s as i8),
(Some(p), None) => (p as u8, 0),
(None, Some(_)) => {
return plan_err!("Cannot specify only scale for decimal data type");
}
(None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
};
if precision == 0
|| precision > DECIMAL256_MAX_PRECISION
|| scale.unsigned_abs() > precision
{
plan_err!(
"Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
)
} else if precision > DECIMAL128_MAX_PRECISION
&& precision <= DECIMAL256_MAX_PRECISION
{
Ok(DataType::Decimal256(precision, scale))
} else {
Ok(DataType::Decimal128(precision, scale))
}
}
/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
pub(crate) fn normalize_ident(id: Ident) -> String {
match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
}
}
pub(crate) fn value_to_string(value: &Value) -> Option<String> {
match value {
Value::SingleQuotedString(s) => Some(s.to_string()),
Value::DollarQuotedString(s) => Some(s.to_string()),
Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
Value::UnicodeStringLiteral(s) => Some(s.to_string()),
Value::EscapedStringLiteral(s) => Some(s.to_string()),
Value::DoubleQuotedString(_)
| Value::NationalStringLiteral(_)
| Value::SingleQuotedByteStringLiteral(_)
| Value::DoubleQuotedByteStringLiteral(_)
| Value::TripleSingleQuotedString(_)
| Value::TripleDoubleQuotedString(_)
| Value::TripleSingleQuotedByteStringLiteral(_)
| Value::TripleDoubleQuotedByteStringLiteral(_)
| Value::SingleQuotedRawStringLiteral(_)
| Value::DoubleQuotedRawStringLiteral(_)
| Value::TripleSingleQuotedRawStringLiteral(_)
| Value::TripleDoubleQuotedRawStringLiteral(_)
| Value::HexStringLiteral(_)
| Value::Null
| Value::Placeholder(_) => None,
}
}
pub(crate) fn rewrite_recursive_unnests_bottom_up(
input: &LogicalPlan,
unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
inner_projection_exprs: &mut Vec<Expr>,
original_exprs: &[Expr],
) -> Result<Vec<Expr>> {
Ok(original_exprs
.iter()
.map(|expr| {
rewrite_recursive_unnest_bottom_up(
input,
unnest_placeholder_columns,
inner_projection_exprs,
expr,
)
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>())
}
pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
/*
This is only usedful when used with transform down up
A full example of how the transformation works:
*/
struct RecursiveUnnestRewriter<'a> {
input_schema: &'a DFSchemaRef,
root_expr: &'a Expr,
// Useful to detect which child expr is a part of/ not a part of unnest operation
top_most_unnest: Option<Unnest>,
consecutive_unnest: Vec<Option<Unnest>>,
inner_projection_exprs: &'a mut Vec<Expr>,
columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
transformed_root_exprs: Option<Vec<Expr>>,
}
impl RecursiveUnnestRewriter<'_> {
/// This struct stores the history of expr
/// during its tree-traversal with a notation of
/// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
/// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
///
/// The first item will be the inner most expr
fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
self.consecutive_unnest
.iter()
.rev()
.skip_while(|item| item.is_none())
.take_while(|item| item.is_some())
.to_owned()
.cloned()
.map(|item| item.unwrap())
.collect()
}
/// Check if the current expression is at the root level for struct unnest purposes.
/// This is true if:
/// 1. The expression IS the root expression, OR
/// 2. The root expression is an Alias wrapping this expression
///
/// This allows `unnest(struct_col) AS alias` to work, where the alias is simply
/// ignored for struct unnest (matching DuckDB behavior).
fn is_at_struct_allowed_root(&self, expr: &Expr) -> bool {
if expr == self.root_expr {
return true;
}
// Allow struct unnest when root is an alias wrapping the unnest
if let Expr::Alias(Alias { expr: inner, .. }) = self.root_expr {
return inner.as_ref() == expr;
}
false
}
fn transform(
&mut self,
level: usize,
alias_name: String,
expr_in_unnest: &Expr,
struct_allowed: bool,
) -> Result<Vec<Expr>> {
let inner_expr_name = expr_in_unnest.schema_name().to_string();
// Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
// inside unnest execution, each column inside the inner projection
// will be transformed into new columns. Thus we need to keep track of these placeholding column names
let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
let post_unnest_name =
format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
// This is due to the fact that unnest transformation should keep the original
// column name as is, to comply with group by and order by
let placeholder_column = Column::from_name(placeholder_name.clone());
let field = expr_in_unnest.to_field(self.input_schema)?.1;
let data_type = field.data_type();
match data_type {
DataType::Struct(inner_fields) => {
assert_or_internal_err!(
struct_allowed,
"unnest on struct can only be applied at the root level of select expression"
);
push_projection_dedupl(
self.inner_projection_exprs,
expr_in_unnest.clone().alias(placeholder_name.clone()),
);
self.columns_unnestings
.insert(Column::from_name(placeholder_name.clone()), None);
Ok(get_struct_unnested_columns(&placeholder_name, inner_fields)
.into_iter()
.map(Expr::Column)
.collect())
}
DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_) => {
push_projection_dedupl(
self.inner_projection_exprs,
expr_in_unnest.clone().alias(placeholder_name.clone()),
);
let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
let list_unnesting = self
.columns_unnestings
.entry(placeholder_column)
.or_insert(Some(vec![]));
let unnesting = ColumnUnnestList {
output_column: Column::from_name(post_unnest_name),
depth: level,
};
let list_unnestings = list_unnesting.as_mut().unwrap();
if !list_unnestings.contains(&unnesting) {
list_unnestings.push(unnesting);
}
Ok(vec![post_unnest_expr])
}
_ => {
internal_err!("unnest on non-list or struct type is not supported")
}
}
}
}
impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
type Node = Expr;
/// This downward traversal needs to keep track of:
/// - Whether or not some unnest expr has been visited from the top util the current node
/// - If some unnest expr has been visited, maintain a stack of such information, this
/// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
if let Expr::Unnest(ref unnest_expr) = expr {
let field = unnest_expr.expr.to_field(self.input_schema)?.1;
let data_type = field.data_type();
self.consecutive_unnest.push(Some(unnest_expr.clone()));
// if expr inside unnest is a struct, do not consider
// the next unnest as consecutive unnest (if any)
// meaning unnest(unnest(struct_arr_col)) can't
// be interpreted as unnest(struct_arr_col, depth:=2)
// but has to be split into multiple unnest logical plan instead
// a.k.a:
// - unnest(struct_col)
// unnest(struct_arr_col) as struct_col
if let DataType::Struct(_) = data_type {
self.consecutive_unnest.push(None);
}
if self.top_most_unnest.is_none() {
self.top_most_unnest = Some(unnest_expr.clone());
}
Ok(Transformed::no(expr))
} else {
self.consecutive_unnest.push(None);
Ok(Transformed::no(expr))
}
}
/// The rewriting only happens when the traversal has reached the top-most unnest expr
/// within a sequence of consecutive unnest exprs node
///
/// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
/// ```text
/// ┌──────────────────┐
/// │ binaryexpr │
/// │ │
/// └──────────────────┘
/// f_down / / │ │
/// / / f_up │ │
/// / / f_down│ │f_up
/// unnest │ │
/// │ │
/// f_down / / f_up(rewriting) │ │
/// / /
/// / / unnest
/// unnest
/// f_down / / f_up(rewriting)
/// f_down / /f_up / /
/// / / / /
/// / / unnest
/// column1
/// f_down / /f_up
/// / /
/// / /
/// column2
/// ```
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
if let Expr::Unnest(ref traversing_unnest) = expr {
if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
self.top_most_unnest = None;
}
// Find inside consecutive_unnest, the sequence of continuous unnest exprs
// Get the latest consecutive unnest exprs
// and check if current upward traversal is the returning to the root expr
// for example given a expr `unnest(unnest(col))` then the traversal happens like:
// down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
// the result of such traversal is unnest(col, depth:=2)
let unnest_stack = self.get_latest_consecutive_unnest();
// This traversal has reached the top most unnest again
// e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
// -> Unnest(2nd) -> Unnest(top) a.k.a here
// Thus
// Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
if traversing_unnest == unnest_stack.last().unwrap() {
let most_inner = unnest_stack.first().unwrap();
let inner_expr = most_inner.expr.as_ref();
// unnest(unnest(struct_arr_col)) is not allow to be done recursively
// it needs to be split into multiple unnest logical plan
// unnest(struct_arr)
// unnest(struct_arr_col) as struct_arr
// instead of unnest(struct_arr_col, depth = 2)
let unnest_recursion = unnest_stack.len();
let struct_allowed =
self.is_at_struct_allowed_root(&expr) && unnest_recursion == 1;
let mut transformed_exprs = self.transform(
unnest_recursion,
expr.schema_name().to_string(),
inner_expr,
struct_allowed,
)?;
// Only set transformed_root_exprs for struct unnest (which returns multiple expressions).
// For list unnest (single expression), we let the normal rewrite handle the alias.
if struct_allowed && transformed_exprs.len() > 1 {
self.transformed_root_exprs = Some(transformed_exprs.clone());
}
return Ok(Transformed::new(
transformed_exprs.swap_remove(0),
true,
TreeNodeRecursion::Continue,
));
}
} else {
self.consecutive_unnest.push(None);
}
// For column exprs that are not descendants of any unnest node
// retain their projection
// e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
// this condition can be checked by maintaining an Option<top most unnest>
if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
push_projection_dedupl(self.inner_projection_exprs, expr.clone());
}
Ok(Transformed::no(expr))
}
}
fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
let schema_name = expr.schema_name().to_string();
if !projection
.iter()
.any(|e| e.schema_name().to_string() == schema_name)
{
projection.push(expr);
}
}
/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
/// Given an expression which contains unnest expr as one of its children,
/// Try transform depends on unnest type
/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
///
/// The transformed exprs will be used in the outer projection
/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
/// is done only for the bottom expression
pub(crate) fn rewrite_recursive_unnest_bottom_up(
input: &LogicalPlan,
unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
inner_projection_exprs: &mut Vec<Expr>,
original_expr: &Expr,
) -> Result<Vec<Expr>> {
let mut rewriter = RecursiveUnnestRewriter {
input_schema: input.schema(),
root_expr: original_expr,
top_most_unnest: None,
consecutive_unnest: vec![],
inner_projection_exprs,
columns_unnestings: unnest_placeholder_columns,
transformed_root_exprs: None,
};
// This transformation is only done for list unnest
// struct unnest is done at the root level, and at the later stage
// because the syntax of TreeNode only support transform into 1 Expr, while
// Unnest struct will be transformed into multiple Exprs
// TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
//
// The transformation looks like:
// - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
// - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
let Transformed {
data: transformed_expr,
transformed,
tnr: _,
} = original_expr.clone().rewrite(&mut rewriter)?;
if !transformed {
// TODO: remove the next line after `Expr::Wildcard` is removed
#[expect(deprecated)]
if matches!(&transformed_expr, Expr::Column(_))
|| matches!(&transformed_expr, Expr::Wildcard { .. })
{
push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
Ok(vec![transformed_expr])
} else {
// We need to evaluate the expr in the inner projection,
// outer projection just select its name
let column_name = transformed_expr.schema_name().to_string();
push_projection_dedupl(inner_projection_exprs, transformed_expr);
Ok(vec![Expr::Column(Column::from_name(column_name))])
}
} else {
if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
return Ok(transformed_root_exprs);
}
Ok(vec![transformed_expr])
}
}
#[cfg(test)]
mod tests {
use std::{ops::Add, sync::Arc};
use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
use datafusion_common::{Column, DFSchema, Result};
use datafusion_expr::{
ColumnUnnestList, EmptyRelation, LogicalPlan, col, lit, unnest,
};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::expr_fn::count;
use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
use indexmap::IndexMap;
fn column_unnests_eq(
l: Vec<&str>,
r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
) {
let r_formatted: Vec<String> = r
.iter()
.map(|i| match i.1 {
None => format!("{}", i.0),
Some(vec) => format!(
"{}=>[{}]",
i.0,
vec.iter()
.map(|i| format!("{i}"))
.collect::<Vec<String>>()
.join(", ")
),
})
.collect();
let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
assert_eq!(l_formatted, r_formatted);
}
#[test]
fn test_transform_bottom_unnest_recursive() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"3d_col",
ArrowDataType::List(Arc::new(Field::new(
"2d_col",
ArrowDataType::List(Arc::new(Field::new(
"elements",
ArrowDataType::Int64,
true,
))),
true,
))),
true,
),
Field::new("i64_col", ArrowDataType::Int64, true),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = IndexMap::new();
let mut inner_projection_exprs = vec![];
// unnest(unnest(3d_col)) + unnest(unnest(3d_col))
let original_expr = unnest(unnest(col("3d_col")))
.add(unnest(unnest(col("3d_col"))))
.add(col("i64_col"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
// Only the bottom most unnest exprs are transformed
assert_eq!(
transformed_exprs,
vec![
col("__unnest_placeholder(3d_col,depth=2)")
.alias("UNNEST(UNNEST(3d_col))")
.add(
col("__unnest_placeholder(3d_col,depth=2)")
.alias("UNNEST(UNNEST(3d_col))")
)
.add(col("i64_col"))
]
);
column_unnests_eq(
vec![
"__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
],
&unnest_placeholder_columns,
);
// Still reference struct_col in original schema but with alias,
// to avoid colliding with the projection on the column itself if any
assert_eq!(
inner_projection_exprs,
vec![
col("3d_col").alias("__unnest_placeholder(3d_col)"),
col("i64_col")
]
);
// unnest(3d_col) as 2d_col
let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr_2,
)?;
assert_eq!(
transformed_exprs,
vec![
(col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
.alias("2d_col")
]
);
column_unnests_eq(
vec![
"__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
// Still reference struct_col in original schema but with alias,
// to avoid colliding with the projection on the column itself if any
assert_eq!(
inner_projection_exprs,
vec![
col("3d_col").alias("__unnest_placeholder(3d_col)"),
col("i64_col")
]
);
Ok(())
}
#[test]
fn test_transform_bottom_unnest() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"struct_col",
ArrowDataType::Struct(Fields::from(vec![
Field::new("field1", ArrowDataType::Int32, false),
Field::new("field2", ArrowDataType::Int32, false),
])),
false,
),
Field::new(
"array_col",
ArrowDataType::List(Arc::new(Field::new_list_field(
ArrowDataType::Int64,
true,
))),
true,
),
Field::new("int_col", ArrowDataType::Int32, false),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = IndexMap::new();
let mut inner_projection_exprs = vec![];
// unnest(struct_col)
let original_expr = unnest(col("struct_col"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
assert_eq!(
transformed_exprs,
vec![
col("__unnest_placeholder(struct_col).field1"),
col("__unnest_placeholder(struct_col).field2"),
]
);
column_unnests_eq(
vec!["__unnest_placeholder(struct_col)"],
&unnest_placeholder_columns,
);
// Still reference struct_col in original schema but with alias,
// to avoid colliding with the projection on the column itself if any
assert_eq!(
inner_projection_exprs,
vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
);
// unnest(array_col) + 1
let original_expr = unnest(col("array_col")).add(lit(1i64));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
column_unnests_eq(
vec![
"__unnest_placeholder(struct_col)",
"__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
// Only transform the unnest children
assert_eq!(
transformed_exprs,
vec![
col("__unnest_placeholder(array_col,depth=1)")
.alias("UNNEST(array_col)")
.add(lit(1i64))
]
);
// Keep appending to the current vector
// Still reference array_col in original schema but with alias,
// to avoid colliding with the projection on the column itself if any
assert_eq!(
inner_projection_exprs,
vec![
col("struct_col").alias("__unnest_placeholder(struct_col)"),
col("array_col").alias("__unnest_placeholder(array_col)")
]
);
Ok(())
}
// Unnest -> field access -> unnest
#[test]
fn test_transform_non_consecutive_unnests() -> Result<()> {
// List of struct
// [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
let schema = Schema::new(vec![
Field::new(
"struct_list",
ArrowDataType::List(Arc::new(Field::new(
"element",
ArrowDataType::Struct(Fields::from(vec![
Field::new(
// list of i64
"subfield1",
ArrowDataType::List(Arc::new(Field::new(
"i64_element",
ArrowDataType::Int64,
true,
))),
true,
),
Field::new(
// list of utf8
"subfield2",
ArrowDataType::List(Arc::new(Field::new(
"utf8_element",
ArrowDataType::Utf8,
true,
))),
true,
),
])),
true,
))),
true,
),
Field::new("int_col", ArrowDataType::Int32, false),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = IndexMap::new();
let mut inner_projection_exprs = vec![];
// An expr with multiple unnest
let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&select_expr1,
)?;
// Only the inner most/ bottom most unnest is transformed
assert_eq!(
transformed_exprs,
vec![unnest(
col("__unnest_placeholder(struct_list,depth=1)")
.alias("UNNEST(struct_list)")
.field("subfield1")
)]
);
column_unnests_eq(
vec![
"__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
);
// continue rewrite another expr in select
let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&select_expr2,
)?;
// Only the inner most/ bottom most unnest is transformed
assert_eq!(
transformed_exprs,
vec![unnest(
col("__unnest_placeholder(struct_list,depth=1)")
.alias("UNNEST(struct_list)")
.field("subfield2")
)]
);
// unnest place holder columns remain the same
// because expr1 and expr2 derive from the same unnest result
column_unnests_eq(
vec![
"__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
);
Ok(())
}
#[test]
fn test_resolve_positions_to_exprs() -> Result<()> {
let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
// Assert 1 resolved as first column in select list
let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
assert_eq!(resolved, col("c1"));
// Assert error if index out of select clause bounds
let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
assert!(resolved.is_err_and(|e| e.message().contains(
"Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
)));
let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
assert!(resolved.is_err_and(|e| e.message().contains(
"Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
)));
// Assert expression returned as-is
let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
assert_eq!(resolved, lit("text"));
let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
assert_eq!(resolved, col("fake"));
Ok(())
}
}