fix(9870): common expression elimination optimization, should always re-find the correct expression during re-write. (#9871)
* test(9870): reproducer of error with jumping traversal patterns in common-expr-elimination traversals
* refactor: remove the IdArray ordered idx, since the idx ordering does not always stay in sync with the updated TreeNode traversal
* refactor: use the only reproducible key (expr_identifer) for expr_set, while keeping the (stack-popped) symbol used for alias.
* refactor: encapsulate most of the logic within ExprSet, and delineate the expr_identifier from the alias symbol
* test(9870): demonstrate that the sqllogictests are now passing
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 05d7ac5..a1dc90d 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -2368,7 +2368,7 @@
/// Aggregates its input based on a set of grouping and aggregate
/// expressions (e.g. SUM).
-#[derive(Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
// mark non_exhaustive to encourage use of try_new/new()
#[non_exhaustive]
pub struct Aggregate {
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 0c9064d..25c25c6 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -17,6 +17,7 @@
//! Eliminate common sub-expression.
+use std::collections::hash_map::Entry;
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
@@ -35,37 +36,75 @@
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
use datafusion_expr::{col, Expr, ExprSchemable};
-/// A map from expression's identifier to tuple including
-/// - the expression itself (cloned)
-/// - counter
-/// - DataType of this expression.
-type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;
+/// Set of expressions generated by the [`ExprIdentifierVisitor`]
+/// and consumed by the [`CommonSubexprRewriter`].
+#[derive(Default)]
+struct ExprSet {
+ /// A map from expression's identifier (stringified expr) to tuple including:
+ /// - the expression itself (cloned)
+ /// - counter
+ /// - DataType of this expression.
+ /// - symbol used as the identifier in the alias.
+ map: HashMap<Identifier, (Expr, usize, DataType, Identifier)>,
+}
-/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an
-/// initial expression walk.
-///
-/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove
-/// common subexpressions.
-///
-/// Elements in this array are created on the walk down the expression tree
-/// during `f_down`. Thus element 0 is the root of the expression tree. The
-/// tuple contains:
-/// - series_number.
-/// - Incremented during `f_up`, start from 1.
-/// - Thus, items with higher idx have the lower series_number.
-/// - [`Identifier`]
-/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination.
-///
-/// # Example
-/// An expression like `(a + b)` would have the following `IdArray`:
-/// ```text
-/// [
-/// (3, "a + b"),
-/// (2, "a"),
-/// (1, "b")
-/// ]
-/// ```
-type IdArray = Vec<(usize, Identifier)>;
+impl ExprSet {
+ fn expr_identifier(expr: &Expr) -> Identifier {
+ format!("{expr}")
+ }
+
+ fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> {
+ self.map.get(key)
+ }
+
+ fn entry(
+ &mut self,
+ key: Identifier,
+ ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> {
+ self.map.entry(key)
+ }
+
+ fn populate_expr_set(
+ &mut self,
+ expr: &[Expr],
+ input_schema: DFSchemaRef,
+ expr_mask: ExprMask,
+ ) -> Result<()> {
+ expr.iter().try_for_each(|e| {
+ self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?;
+
+ Ok(())
+ })
+ }
+
+ /// Go through an expression tree and generate identifier for every node in this tree.
+ fn expr_to_identifier(
+ &mut self,
+ expr: &Expr,
+ input_schema: DFSchemaRef,
+ expr_mask: ExprMask,
+ ) -> Result<()> {
+ expr.visit(&mut ExprIdentifierVisitor {
+ expr_set: self,
+ input_schema,
+ visit_stack: vec![],
+ node_count: 0,
+ expr_mask,
+ })?;
+
+ Ok(())
+ }
+}
+
+impl From<Vec<(Identifier, (Expr, usize, DataType, Identifier))>> for ExprSet {
+ fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self {
+ let mut expr_set = Self::default();
+ entries.into_iter().for_each(|(k, v)| {
+ expr_set.map.insert(k, v);
+ });
+ expr_set
+ }
+}
/// Identifier for each subexpression.
///
@@ -112,21 +151,16 @@
fn rewrite_exprs_list(
&self,
exprs_list: &[&[Expr]],
- arrays_list: &[&[Vec<(usize, String)>]],
expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Vec<Vec<Expr>>> {
exprs_list
.iter()
- .zip(arrays_list.iter())
- .map(|(exprs, arrays)| {
+ .map(|exprs| {
exprs
.iter()
.cloned()
- .zip(arrays.iter())
- .map(|(expr, id_array)| {
- replace_common_expr(expr, id_array, expr_set, affected_id)
- })
+ .map(|expr| replace_common_expr(expr, expr_set, affected_id))
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
@@ -135,7 +169,6 @@
fn rewrite_expr(
&self,
exprs_list: &[&[Expr]],
- arrays_list: &[&[Vec<(usize, String)>]],
input: &LogicalPlan,
expr_set: &ExprSet,
config: &dyn OptimizerConfig,
@@ -143,7 +176,7 @@
let mut affected_id = BTreeSet::<Identifier>::new();
let rewrite_exprs =
- self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?;
+ self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?;
let mut new_input = self
.try_optimize(input, config)?
@@ -161,8 +194,7 @@
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
let mut window_exprs = vec![];
- let mut arrays_per_window = vec![];
- let mut expr_set = ExprSet::new();
+ let mut expr_set = ExprSet::default();
// Get all window expressions inside the consecutive window operators.
// Consecutive window expressions may refer to same complex expression.
@@ -181,30 +213,18 @@
plan = input.as_ref().clone();
let input_schema = Arc::clone(input.schema());
- let arrays =
- to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+ expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?;
window_exprs.push(window_expr);
- arrays_per_window.push(arrays);
}
let mut window_exprs = window_exprs
.iter()
.map(|expr| expr.as_slice())
.collect::<Vec<_>>();
- let arrays_per_window = arrays_per_window
- .iter()
- .map(|arrays| arrays.as_slice())
- .collect::<Vec<_>>();
- assert_eq!(window_exprs.len(), arrays_per_window.len());
- let (mut new_expr, new_input) = self.rewrite_expr(
- &window_exprs,
- &arrays_per_window,
- &plan,
- &expr_set,
- config,
- )?;
+ let (mut new_expr, new_input) =
+ self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?;
assert_eq!(window_exprs.len(), new_expr.len());
// Construct consecutive window operator, with their corresponding new window expressions.
@@ -241,46 +261,36 @@
input,
..
} = aggregate;
- let mut expr_set = ExprSet::new();
+ let mut expr_set = ExprSet::default();
- // rewrite inputs
+ // build expr_set, with groupby and aggr
let input_schema = Arc::clone(input.schema());
- let group_arrays = to_arrays(
+ expr_set.populate_expr_set(
group_expr,
Arc::clone(&input_schema),
- &mut expr_set,
ExprMask::Normal,
)?;
- let aggr_arrays =
- to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+ expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?;
- let (mut new_expr, new_input) = self.rewrite_expr(
- &[group_expr, aggr_expr],
- &[&group_arrays, &aggr_arrays],
- input,
- &expr_set,
- config,
- )?;
+ // rewrite inputs
+ let (mut new_expr, new_input) =
+ self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?;
// note the reversed pop order.
let new_aggr_expr = pop_expr(&mut new_expr)?;
let new_group_expr = pop_expr(&mut new_expr)?;
// create potential projection on top
- let mut expr_set = ExprSet::new();
+ let mut expr_set = ExprSet::default();
let new_input_schema = Arc::clone(new_input.schema());
- let aggr_arrays = to_arrays(
+ expr_set.populate_expr_set(
&new_aggr_expr,
new_input_schema.clone(),
- &mut expr_set,
ExprMask::NormalAndAggregates,
)?;
+
let mut affected_id = BTreeSet::<Identifier>::new();
- let mut rewritten = self.rewrite_exprs_list(
- &[&new_aggr_expr],
- &[&aggr_arrays],
- &expr_set,
- &mut affected_id,
- )?;
+ let mut rewritten =
+ self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?;
let rewritten = pop_expr(&mut rewritten)?;
if affected_id.is_empty() {
@@ -300,9 +310,9 @@
for id in affected_id {
match expr_set.get(&id) {
- Some((expr, _, _)) => {
+ Some((expr, _, _, symbol)) => {
// todo: check `nullable`
- agg_exprs.push(expr.clone().alias(&id));
+ agg_exprs.push(expr.clone().alias(symbol.as_str()));
}
_ => {
return internal_err!("expr_set invalid state");
@@ -320,9 +330,7 @@
agg_exprs.push(expr.alias(&name));
proj_exprs.push(Expr::Column(Column::from_name(name)));
} else {
- let id = ExprIdentifierVisitor::<'static>::expr_identifier(
- &expr_rewritten,
- );
+ let id = ExprSet::expr_identifier(&expr_rewritten);
let out_name =
expr_rewritten.to_field(&new_input_schema)?.qualified_name();
agg_exprs.push(expr_rewritten.alias(&id));
@@ -356,13 +364,13 @@
let inputs = plan.inputs();
let input = inputs[0];
let input_schema = Arc::clone(input.schema());
- let mut expr_set = ExprSet::new();
+ let mut expr_set = ExprSet::default();
// Visit expr list and build expr identifier to occuring count map (`expr_set`).
- let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+ expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?;
let (mut new_expr, new_input) =
- self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?;
+ self.rewrite_expr(&[&expr], input, &expr_set, config)?;
plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
}
@@ -448,28 +456,6 @@
.ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string()))
}
-fn to_arrays(
- expr: &[Expr],
- input_schema: DFSchemaRef,
- expr_set: &mut ExprSet,
- expr_mask: ExprMask,
-) -> Result<Vec<Vec<(usize, String)>>> {
- expr.iter()
- .map(|e| {
- let mut id_array = vec![];
- expr_to_identifier(
- e,
- expr_set,
- &mut id_array,
- Arc::clone(&input_schema),
- expr_mask,
- )?;
-
- Ok(id_array)
- })
- .collect::<Result<Vec<_>>>()
-}
-
/// Build the "intermediate" projection plan that evaluates the extracted common expressions.
fn build_common_expr_project_plan(
input: LogicalPlan,
@@ -481,11 +467,11 @@
for id in affected_id {
match expr_set.get(&id) {
- Some((expr, _, data_type)) => {
+ Some((expr, _, data_type, symbol)) => {
// todo: check `nullable`
let field = DFField::new_unqualified(&id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
- project_exprs.push(expr.clone().alias(&id));
+ project_exprs.push(expr.clone().alias(symbol.as_str()));
}
_ => {
return internal_err!("expr_set invalid state");
@@ -601,8 +587,6 @@
struct ExprIdentifierVisitor<'a> {
// param
expr_set: &'a mut ExprSet,
- /// series number (usize) and identifier.
- id_array: &'a mut IdArray,
/// input schema for the node that we're optimizing, so we can determine the correct datatype
/// for each subexpression
input_schema: DFSchemaRef,
@@ -610,8 +594,6 @@
visit_stack: Vec<VisitRecord>,
/// increased in fn_down, start from 0.
node_count: usize,
- /// increased in fn_up, start from 1.
- series_number: usize,
/// which expression should be skipped?
expr_mask: ExprMask,
}
@@ -628,10 +610,6 @@
}
impl ExprIdentifierVisitor<'_> {
- fn expr_identifier(expr: &Expr) -> Identifier {
- format!("{expr}")
- }
-
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
@@ -655,9 +633,6 @@
type Node = Expr;
fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
- // put placeholder, sets the proper array length
- self.id_array.push((0, "".to_string()));
-
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || is_volatile_expression(expr)? {
@@ -674,70 +649,38 @@
}
fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
- self.series_number += 1;
-
- let (idx, sub_expr_identifier) = self.pop_enter_mark();
+ let (_idx, sub_expr_identifier) = self.pop_enter_mark();
// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
- let curr_expr_identifier = Self::expr_identifier(expr);
+ let curr_expr_identifier = ExprSet::expr_identifier(expr);
self.visit_stack
.push(VisitRecord::ExprItem(curr_expr_identifier));
- self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr
return Ok(TreeNodeRecursion::Continue);
}
- let mut desc = Self::expr_identifier(expr);
- desc.push_str(&sub_expr_identifier);
+ let curr_expr_identifier = ExprSet::expr_identifier(expr);
+ let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}");
- self.id_array[idx] = (self.series_number, desc.clone());
- self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
+ self.visit_stack
+ .push(VisitRecord::ExprItem(alias_symbol.clone()));
let data_type = expr.get_type(&self.input_schema)?;
self.expr_set
- .entry(desc)
- .or_insert_with(|| (expr.clone(), 0, data_type))
+ .entry(curr_expr_identifier)
+ .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol))
.1 += 1;
Ok(TreeNodeRecursion::Continue)
}
}
-/// Go through an expression tree and generate identifier for every node in this tree.
-fn expr_to_identifier(
- expr: &Expr,
- expr_set: &mut ExprSet,
- id_array: &mut Vec<(usize, Identifier)>,
- input_schema: DFSchemaRef,
- expr_mask: ExprMask,
-) -> Result<()> {
- expr.visit(&mut ExprIdentifierVisitor {
- expr_set,
- id_array,
- input_schema,
- visit_stack: vec![],
- node_count: 0,
- series_number: 0,
- expr_mask,
- })?;
-
- Ok(())
-}
-
/// Rewrite expression by replacing detected common sub-expression with
/// the corresponding temporary column name. That column contains the
/// evaluate result of replaced expression.
struct CommonSubexprRewriter<'a> {
expr_set: &'a ExprSet,
- id_array: &'a IdArray,
/// Which identifier is replaced.
affected_id: &'a mut BTreeSet<Identifier>,
-
- /// the max series number we have rewritten. Other expression nodes
- /// with smaller series number is already replaced and shouldn't
- /// do anything with them.
- max_series_number: usize,
- /// current node's information's index in `id_array`.
- curr_index: usize,
}
impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
@@ -751,80 +694,41 @@
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}
- let (series_number, curr_id) = &self.id_array[self.curr_index];
-
- // halting conditions
- if self.curr_index >= self.id_array.len()
- || self.max_series_number > *series_number
- {
- return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
- }
-
- // skip `Expr`s without identifier (empty identifier).
- if curr_id.is_empty() {
- self.curr_index += 1; // incr idx for id_array, when not jumping
- return Ok(Transformed::no(expr));
- }
+ let curr_id = &ExprSet::expr_identifier(&expr);
// lookup previously visited expression
match self.expr_set.get(curr_id) {
- Some((_, counter, _)) => {
+ Some((_, counter, _, symbol)) => {
// if has a commonly used (a.k.a. 1+ use) expr
if *counter > 1 {
self.affected_id.insert(curr_id.clone());
- // This expr tree is finished.
- if self.curr_index >= self.id_array.len() {
- return Ok(Transformed::new(
- expr,
- false,
- TreeNodeRecursion::Jump,
- ));
- }
-
- // incr idx for id_array, when not jumping
- self.curr_index += 1;
-
- // series_number was the inverse number ordering (when doing f_up)
- self.max_series_number = *series_number;
- // step index to skip all sub-node (which has smaller series number).
- while self.curr_index < self.id_array.len()
- && *series_number > self.id_array[self.curr_index].0
- {
- self.curr_index += 1;
- }
-
let expr_name = expr.display_name()?;
// Alias this `Column` expr to it original "expr name",
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(Transformed::new(
- col(curr_id).alias(expr_name),
+ col(symbol).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))
} else {
- self.curr_index += 1;
Ok(Transformed::no(expr))
}
}
- _ => internal_err!("expr_set invalid state"),
+ None => Ok(Transformed::no(expr)),
}
}
}
fn replace_common_expr(
expr: Expr,
- id_array: &IdArray,
expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_set,
- id_array,
affected_id,
- max_series_number: 0,
- curr_index: 0,
})
.data()
}
@@ -861,73 +765,6 @@
}
#[test]
- fn id_array_visitor() -> Result<()> {
- let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2);
-
- let schema = Arc::new(DFSchema::new_with_metadata(
- vec![
- DFField::new_unqualified("a", DataType::Int64, false),
- DFField::new_unqualified("c", DataType::Int64, false),
- ],
- Default::default(),
- )?);
-
- // skip aggregates
- let mut id_array = vec![];
- expr_to_identifier(
- &expr,
- &mut HashMap::new(),
- &mut id_array,
- Arc::clone(&schema),
- ExprMask::Normal,
- )?;
-
- let expected = vec![
- (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
- (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
- (4, ""),
- (3, "a + Int32(1)Int32(1)a"),
- (1, ""),
- (2, ""),
- (6, ""),
- (5, ""),
- (8, "")
- ]
- .into_iter()
- .map(|(number, id)| (number, id.into()))
- .collect::<Vec<_>>();
- assert_eq!(expected, id_array);
-
- // include aggregates
- let mut id_array = vec![];
- expr_to_identifier(
- &expr,
- &mut HashMap::new(),
- &mut id_array,
- Arc::clone(&schema),
- ExprMask::NormalAndAggregates,
- )?;
-
- let expected = vec![
- (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
- (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
- (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
- (3, "a + Int32(1)Int32(1)a"),
- (1, ""),
- (2, ""),
- (6, "AVG(c)c"),
- (5, ""),
- (8, "")
- ]
- .into_iter()
- .map(|(number, id)| (number, id.into()))
- .collect::<Vec<_>>();
- assert_eq!(expected, id_array);
-
- Ok(())
- }
-
- #[test]
fn tpch_q1_simplified() -> Result<()> {
// SQL:
// select
@@ -1171,24 +1008,28 @@
let table_scan = test_table_scan().unwrap();
let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "b+a".to_string()].into_iter().collect();
- let expr_set_1 = [
+ let expr_set_1 = vec![
(
"c+a".to_string(),
- (col("c") + col("a"), 1, DataType::UInt32),
+ (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()),
),
(
"b+a".to_string(),
- (col("b") + col("a"), 1, DataType::UInt32),
+ (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()),
),
]
- .into_iter()
- .collect();
- let expr_set_2 = [
- ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
- ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)),
+ .into();
+ let expr_set_2 = vec![
+ (
+ "c+a".to_string(),
+ (col("c+a"), 1, DataType::UInt32, "c+a".to_string()),
+ ),
+ (
+ "b+a".to_string(),
+ (col("b+a"), 1, DataType::UInt32, "b+a".to_string()),
+ ),
]
- .into_iter()
- .collect();
+ .into();
let project =
build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1)
.unwrap();
@@ -1214,30 +1055,48 @@
["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()]
.into_iter()
.collect();
- let expr_set_1 = [
+ let expr_set_1 = vec![
(
"test1.c+test1.a".to_string(),
- (col("test1.c") + col("test1.a"), 1, DataType::UInt32),
+ (
+ col("test1.c") + col("test1.a"),
+ 1,
+ DataType::UInt32,
+ "test1.c+test1.a".to_string(),
+ ),
),
(
"test1.b+test1.a".to_string(),
- (col("test1.b") + col("test1.a"), 1, DataType::UInt32),
+ (
+ col("test1.b") + col("test1.a"),
+ 1,
+ DataType::UInt32,
+ "test1.b+test1.a".to_string(),
+ ),
),
]
- .into_iter()
- .collect();
- let expr_set_2 = [
+ .into();
+ let expr_set_2 = vec![
(
"test1.c+test1.a".to_string(),
- (col("test1.c+test1.a"), 1, DataType::UInt32),
+ (
+ col("test1.c+test1.a"),
+ 1,
+ DataType::UInt32,
+ "test1.c+test1.a".to_string(),
+ ),
),
(
"test1.b+test1.a".to_string(),
- (col("test1.b+test1.a"), 1, DataType::UInt32),
+ (
+ col("test1.b+test1.a"),
+ 1,
+ DataType::UInt32,
+ "test1.b+test1.a".to_string(),
+ ),
),
]
- .into_iter()
- .collect();
+ .into();
let project =
build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1)
.unwrap();
diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt
index 75bcbc0..2e0cbf5 100644
--- a/datafusion/sqllogictest/test_files/expr.slt
+++ b/datafusion/sqllogictest/test_files/expr.slt
@@ -2262,3 +2262,66 @@
select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles;
----
10.1 0.09900990099 1.471623942989
+
+
+statement ok
+CREATE TABLE t1(
+ time TIMESTAMP,
+ load1 DOUBLE,
+ load2 DOUBLE,
+ host VARCHAR
+) AS VALUES
+ (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'),
+ (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'),
+ (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'),
+ (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL)
+;
+
+# struct scalar function with columns
+query ?
+select struct(time,load1,load2,host) from t1;
+----
+{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1}
+{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2}
+{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3}
+{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: }
+
+# can have an aggregate function with an inner coalesce
+query TR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1
+host2 2.2
+host3 3.3
+
+# can have an aggregate function with an inner CASE WHEN
+query TR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 101
+host2 202
+host3 303
+
+# can have 2 projections with aggr(short_circuited), with different short-circuited expr
+query TRR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1 101
+host2 2.2 202
+host3 3.3 303
+
+# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
+query TRR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1 101
+host2 2.2 202
+host3 3.3 303
+
+# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce)
+query TRR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1 101
+host2 2.2 202
+host3 3.3 303