fix(9870): common expression elimination optimization, should always re-find the correct expression during re-write. (#9871)

* test(9870): reproducer of error with jumping traversal patterns in common-expr-elimination traversals

* refactor: remove the IdArray ordered idx, since the idx ordering does not always stay in sync with the updated TreeNode traversal

* refactor: use the only reproducible key (expr_identifer) for expr_set, while keeping the (stack-popped) symbol used for alias.

* refactor: encapsulate most of the logic within ExprSet, and delineate the expr_identifier from the alias symbol

* test(9870): demonstrate that the sqllogictests are now passing
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 05d7ac5..a1dc90d 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -2368,7 +2368,7 @@
 
 /// Aggregates its input based on a set of grouping and aggregate
 /// expressions (e.g. SUM).
-#[derive(Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 // mark non_exhaustive to encourage use of try_new/new()
 #[non_exhaustive]
 pub struct Aggregate {
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 0c9064d..25c25c6 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -17,6 +17,7 @@
 
 //! Eliminate common sub-expression.
 
+use std::collections::hash_map::Entry;
 use std::collections::{BTreeSet, HashMap};
 use std::sync::Arc;
 
@@ -35,37 +36,75 @@
 use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
 use datafusion_expr::{col, Expr, ExprSchemable};
 
-/// A map from expression's identifier to tuple including
-/// - the expression itself (cloned)
-/// - counter
-/// - DataType of this expression.
-type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;
+/// Set of expressions generated by the [`ExprIdentifierVisitor`]
+/// and consumed by the [`CommonSubexprRewriter`].
+#[derive(Default)]
+struct ExprSet {
+    /// A map from expression's identifier (stringified expr) to tuple including:
+    /// - the expression itself (cloned)
+    /// - counter
+    /// - DataType of this expression.
+    /// - symbol used as the identifier in the alias.
+    map: HashMap<Identifier, (Expr, usize, DataType, Identifier)>,
+}
 
-/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an
-/// initial expression  walk.
-///
-/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove
-/// common subexpressions.
-///
-/// Elements in this array are created on the walk down the expression tree
-/// during `f_down`. Thus element 0 is the root of the expression tree. The
-/// tuple contains:
-/// - series_number.
-///     - Incremented during `f_up`, start from 1.
-///     - Thus, items with higher idx have the lower series_number.
-/// - [`Identifier`]
-///     - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination.
-///
-/// # Example
-/// An expression like `(a + b)` would have the following `IdArray`:
-/// ```text
-/// [
-///   (3, "a + b"),
-///   (2, "a"),
-///   (1, "b")
-/// ]
-/// ```
-type IdArray = Vec<(usize, Identifier)>;
+impl ExprSet {
+    fn expr_identifier(expr: &Expr) -> Identifier {
+        format!("{expr}")
+    }
+
+    fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> {
+        self.map.get(key)
+    }
+
+    fn entry(
+        &mut self,
+        key: Identifier,
+    ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> {
+        self.map.entry(key)
+    }
+
+    fn populate_expr_set(
+        &mut self,
+        expr: &[Expr],
+        input_schema: DFSchemaRef,
+        expr_mask: ExprMask,
+    ) -> Result<()> {
+        expr.iter().try_for_each(|e| {
+            self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?;
+
+            Ok(())
+        })
+    }
+
+    /// Go through an expression tree and generate identifier for every node in this tree.
+    fn expr_to_identifier(
+        &mut self,
+        expr: &Expr,
+        input_schema: DFSchemaRef,
+        expr_mask: ExprMask,
+    ) -> Result<()> {
+        expr.visit(&mut ExprIdentifierVisitor {
+            expr_set: self,
+            input_schema,
+            visit_stack: vec![],
+            node_count: 0,
+            expr_mask,
+        })?;
+
+        Ok(())
+    }
+}
+
+impl From<Vec<(Identifier, (Expr, usize, DataType, Identifier))>> for ExprSet {
+    fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self {
+        let mut expr_set = Self::default();
+        entries.into_iter().for_each(|(k, v)| {
+            expr_set.map.insert(k, v);
+        });
+        expr_set
+    }
+}
 
 /// Identifier for each subexpression.
 ///
@@ -112,21 +151,16 @@
     fn rewrite_exprs_list(
         &self,
         exprs_list: &[&[Expr]],
-        arrays_list: &[&[Vec<(usize, String)>]],
         expr_set: &ExprSet,
         affected_id: &mut BTreeSet<Identifier>,
     ) -> Result<Vec<Vec<Expr>>> {
         exprs_list
             .iter()
-            .zip(arrays_list.iter())
-            .map(|(exprs, arrays)| {
+            .map(|exprs| {
                 exprs
                     .iter()
                     .cloned()
-                    .zip(arrays.iter())
-                    .map(|(expr, id_array)| {
-                        replace_common_expr(expr, id_array, expr_set, affected_id)
-                    })
+                    .map(|expr| replace_common_expr(expr, expr_set, affected_id))
                     .collect::<Result<Vec<_>>>()
             })
             .collect::<Result<Vec<_>>>()
@@ -135,7 +169,6 @@
     fn rewrite_expr(
         &self,
         exprs_list: &[&[Expr]],
-        arrays_list: &[&[Vec<(usize, String)>]],
         input: &LogicalPlan,
         expr_set: &ExprSet,
         config: &dyn OptimizerConfig,
@@ -143,7 +176,7 @@
         let mut affected_id = BTreeSet::<Identifier>::new();
 
         let rewrite_exprs =
-            self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?;
+            self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?;
 
         let mut new_input = self
             .try_optimize(input, config)?
@@ -161,8 +194,7 @@
         config: &dyn OptimizerConfig,
     ) -> Result<LogicalPlan> {
         let mut window_exprs = vec![];
-        let mut arrays_per_window = vec![];
-        let mut expr_set = ExprSet::new();
+        let mut expr_set = ExprSet::default();
 
         // Get all window expressions inside the consecutive window operators.
         // Consecutive window expressions may refer to same complex expression.
@@ -181,30 +213,18 @@
             plan = input.as_ref().clone();
 
             let input_schema = Arc::clone(input.schema());
-            let arrays =
-                to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+            expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?;
 
             window_exprs.push(window_expr);
-            arrays_per_window.push(arrays);
         }
 
         let mut window_exprs = window_exprs
             .iter()
             .map(|expr| expr.as_slice())
             .collect::<Vec<_>>();
-        let arrays_per_window = arrays_per_window
-            .iter()
-            .map(|arrays| arrays.as_slice())
-            .collect::<Vec<_>>();
 
-        assert_eq!(window_exprs.len(), arrays_per_window.len());
-        let (mut new_expr, new_input) = self.rewrite_expr(
-            &window_exprs,
-            &arrays_per_window,
-            &plan,
-            &expr_set,
-            config,
-        )?;
+        let (mut new_expr, new_input) =
+            self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?;
         assert_eq!(window_exprs.len(), new_expr.len());
 
         // Construct consecutive window operator, with their corresponding new window expressions.
@@ -241,46 +261,36 @@
             input,
             ..
         } = aggregate;
-        let mut expr_set = ExprSet::new();
+        let mut expr_set = ExprSet::default();
 
-        // rewrite inputs
+        // build expr_set, with groupby and aggr
         let input_schema = Arc::clone(input.schema());
-        let group_arrays = to_arrays(
+        expr_set.populate_expr_set(
             group_expr,
             Arc::clone(&input_schema),
-            &mut expr_set,
             ExprMask::Normal,
         )?;
-        let aggr_arrays =
-            to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+        expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?;
 
-        let (mut new_expr, new_input) = self.rewrite_expr(
-            &[group_expr, aggr_expr],
-            &[&group_arrays, &aggr_arrays],
-            input,
-            &expr_set,
-            config,
-        )?;
+        // rewrite inputs
+        let (mut new_expr, new_input) =
+            self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?;
         // note the reversed pop order.
         let new_aggr_expr = pop_expr(&mut new_expr)?;
         let new_group_expr = pop_expr(&mut new_expr)?;
 
         // create potential projection on top
-        let mut expr_set = ExprSet::new();
+        let mut expr_set = ExprSet::default();
         let new_input_schema = Arc::clone(new_input.schema());
-        let aggr_arrays = to_arrays(
+        expr_set.populate_expr_set(
             &new_aggr_expr,
             new_input_schema.clone(),
-            &mut expr_set,
             ExprMask::NormalAndAggregates,
         )?;
+
         let mut affected_id = BTreeSet::<Identifier>::new();
-        let mut rewritten = self.rewrite_exprs_list(
-            &[&new_aggr_expr],
-            &[&aggr_arrays],
-            &expr_set,
-            &mut affected_id,
-        )?;
+        let mut rewritten =
+            self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?;
         let rewritten = pop_expr(&mut rewritten)?;
 
         if affected_id.is_empty() {
@@ -300,9 +310,9 @@
 
             for id in affected_id {
                 match expr_set.get(&id) {
-                    Some((expr, _, _)) => {
+                    Some((expr, _, _, symbol)) => {
                         // todo: check `nullable`
-                        agg_exprs.push(expr.clone().alias(&id));
+                        agg_exprs.push(expr.clone().alias(symbol.as_str()));
                     }
                     _ => {
                         return internal_err!("expr_set invalid state");
@@ -320,9 +330,7 @@
                         agg_exprs.push(expr.alias(&name));
                         proj_exprs.push(Expr::Column(Column::from_name(name)));
                     } else {
-                        let id = ExprIdentifierVisitor::<'static>::expr_identifier(
-                            &expr_rewritten,
-                        );
+                        let id = ExprSet::expr_identifier(&expr_rewritten);
                         let out_name =
                             expr_rewritten.to_field(&new_input_schema)?.qualified_name();
                         agg_exprs.push(expr_rewritten.alias(&id));
@@ -356,13 +364,13 @@
         let inputs = plan.inputs();
         let input = inputs[0];
         let input_schema = Arc::clone(input.schema());
-        let mut expr_set = ExprSet::new();
+        let mut expr_set = ExprSet::default();
 
         // Visit expr list and build expr identifier to occuring count map (`expr_set`).
-        let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+        expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?;
 
         let (mut new_expr, new_input) =
-            self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?;
+            self.rewrite_expr(&[&expr], input, &expr_set, config)?;
 
         plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
     }
@@ -448,28 +456,6 @@
         .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string()))
 }
 
-fn to_arrays(
-    expr: &[Expr],
-    input_schema: DFSchemaRef,
-    expr_set: &mut ExprSet,
-    expr_mask: ExprMask,
-) -> Result<Vec<Vec<(usize, String)>>> {
-    expr.iter()
-        .map(|e| {
-            let mut id_array = vec![];
-            expr_to_identifier(
-                e,
-                expr_set,
-                &mut id_array,
-                Arc::clone(&input_schema),
-                expr_mask,
-            )?;
-
-            Ok(id_array)
-        })
-        .collect::<Result<Vec<_>>>()
-}
-
 /// Build the "intermediate" projection plan that evaluates the extracted common expressions.
 fn build_common_expr_project_plan(
     input: LogicalPlan,
@@ -481,11 +467,11 @@
 
     for id in affected_id {
         match expr_set.get(&id) {
-            Some((expr, _, data_type)) => {
+            Some((expr, _, data_type, symbol)) => {
                 // todo: check `nullable`
                 let field = DFField::new_unqualified(&id, data_type.clone(), true);
                 fields_set.insert(field.name().to_owned());
-                project_exprs.push(expr.clone().alias(&id));
+                project_exprs.push(expr.clone().alias(symbol.as_str()));
             }
             _ => {
                 return internal_err!("expr_set invalid state");
@@ -601,8 +587,6 @@
 struct ExprIdentifierVisitor<'a> {
     // param
     expr_set: &'a mut ExprSet,
-    /// series number (usize) and identifier.
-    id_array: &'a mut IdArray,
     /// input schema for the node that we're optimizing, so we can determine the correct datatype
     /// for each subexpression
     input_schema: DFSchemaRef,
@@ -610,8 +594,6 @@
     visit_stack: Vec<VisitRecord>,
     /// increased in fn_down, start from 0.
     node_count: usize,
-    /// increased in fn_up, start from 1.
-    series_number: usize,
     /// which expression should be skipped?
     expr_mask: ExprMask,
 }
@@ -628,10 +610,6 @@
 }
 
 impl ExprIdentifierVisitor<'_> {
-    fn expr_identifier(expr: &Expr) -> Identifier {
-        format!("{expr}")
-    }
-
     /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
     /// before it.
     fn pop_enter_mark(&mut self) -> (usize, Identifier) {
@@ -655,9 +633,6 @@
     type Node = Expr;
 
     fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
-        // put placeholder, sets the proper array length
-        self.id_array.push((0, "".to_string()));
-
         // related to https://github.com/apache/arrow-datafusion/issues/8814
         // If the expr contain volatile expression or is a short-circuit expression, skip it.
         if expr.short_circuits() || is_volatile_expression(expr)? {
@@ -674,70 +649,38 @@
     }
 
     fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
-        self.series_number += 1;
-
-        let (idx, sub_expr_identifier) = self.pop_enter_mark();
+        let (_idx, sub_expr_identifier) = self.pop_enter_mark();
 
         // skip exprs should not be recognize.
         if self.expr_mask.ignores(expr) {
-            let curr_expr_identifier = Self::expr_identifier(expr);
+            let curr_expr_identifier = ExprSet::expr_identifier(expr);
             self.visit_stack
                 .push(VisitRecord::ExprItem(curr_expr_identifier));
-            self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr
             return Ok(TreeNodeRecursion::Continue);
         }
-        let mut desc = Self::expr_identifier(expr);
-        desc.push_str(&sub_expr_identifier);
+        let curr_expr_identifier = ExprSet::expr_identifier(expr);
+        let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}");
 
-        self.id_array[idx] = (self.series_number, desc.clone());
-        self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
+        self.visit_stack
+            .push(VisitRecord::ExprItem(alias_symbol.clone()));
 
         let data_type = expr.get_type(&self.input_schema)?;
 
         self.expr_set
-            .entry(desc)
-            .or_insert_with(|| (expr.clone(), 0, data_type))
+            .entry(curr_expr_identifier)
+            .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol))
             .1 += 1;
         Ok(TreeNodeRecursion::Continue)
     }
 }
 
-/// Go through an expression tree and generate identifier for every node in this tree.
-fn expr_to_identifier(
-    expr: &Expr,
-    expr_set: &mut ExprSet,
-    id_array: &mut Vec<(usize, Identifier)>,
-    input_schema: DFSchemaRef,
-    expr_mask: ExprMask,
-) -> Result<()> {
-    expr.visit(&mut ExprIdentifierVisitor {
-        expr_set,
-        id_array,
-        input_schema,
-        visit_stack: vec![],
-        node_count: 0,
-        series_number: 0,
-        expr_mask,
-    })?;
-
-    Ok(())
-}
-
 /// Rewrite expression by replacing detected common sub-expression with
 /// the corresponding temporary column name. That column contains the
 /// evaluate result of replaced expression.
 struct CommonSubexprRewriter<'a> {
     expr_set: &'a ExprSet,
-    id_array: &'a IdArray,
     /// Which identifier is replaced.
     affected_id: &'a mut BTreeSet<Identifier>,
-
-    /// the max series number we have rewritten. Other expression nodes
-    /// with smaller series number is already replaced and shouldn't
-    /// do anything with them.
-    max_series_number: usize,
-    /// current node's information's index in `id_array`.
-    curr_index: usize,
 }
 
 impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
@@ -751,80 +694,41 @@
             return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
         }
 
-        let (series_number, curr_id) = &self.id_array[self.curr_index];
-
-        // halting conditions
-        if self.curr_index >= self.id_array.len()
-            || self.max_series_number > *series_number
-        {
-            return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
-        }
-
-        // skip `Expr`s without identifier (empty identifier).
-        if curr_id.is_empty() {
-            self.curr_index += 1; // incr idx for id_array, when not jumping
-            return Ok(Transformed::no(expr));
-        }
+        let curr_id = &ExprSet::expr_identifier(&expr);
 
         // lookup previously visited expression
         match self.expr_set.get(curr_id) {
-            Some((_, counter, _)) => {
+            Some((_, counter, _, symbol)) => {
                 // if has a commonly used (a.k.a. 1+ use) expr
                 if *counter > 1 {
                     self.affected_id.insert(curr_id.clone());
 
-                    // This expr tree is finished.
-                    if self.curr_index >= self.id_array.len() {
-                        return Ok(Transformed::new(
-                            expr,
-                            false,
-                            TreeNodeRecursion::Jump,
-                        ));
-                    }
-
-                    // incr idx for id_array, when not jumping
-                    self.curr_index += 1;
-
-                    // series_number was the inverse number ordering (when doing f_up)
-                    self.max_series_number = *series_number;
-                    // step index to skip all sub-node (which has smaller series number).
-                    while self.curr_index < self.id_array.len()
-                        && *series_number > self.id_array[self.curr_index].0
-                    {
-                        self.curr_index += 1;
-                    }
-
                     let expr_name = expr.display_name()?;
                     // Alias this `Column` expr to it original "expr name",
                     // `projection_push_down` optimizer use "expr name" to eliminate useless
                     // projections.
                     Ok(Transformed::new(
-                        col(curr_id).alias(expr_name),
+                        col(symbol).alias(expr_name),
                         true,
                         TreeNodeRecursion::Jump,
                     ))
                 } else {
-                    self.curr_index += 1;
                     Ok(Transformed::no(expr))
                 }
             }
-            _ => internal_err!("expr_set invalid state"),
+            None => Ok(Transformed::no(expr)),
         }
     }
 }
 
 fn replace_common_expr(
     expr: Expr,
-    id_array: &IdArray,
     expr_set: &ExprSet,
     affected_id: &mut BTreeSet<Identifier>,
 ) -> Result<Expr> {
     expr.rewrite(&mut CommonSubexprRewriter {
         expr_set,
-        id_array,
         affected_id,
-        max_series_number: 0,
-        curr_index: 0,
     })
     .data()
 }
@@ -861,73 +765,6 @@
     }
 
     #[test]
-    fn id_array_visitor() -> Result<()> {
-        let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2);
-
-        let schema = Arc::new(DFSchema::new_with_metadata(
-            vec![
-                DFField::new_unqualified("a", DataType::Int64, false),
-                DFField::new_unqualified("c", DataType::Int64, false),
-            ],
-            Default::default(),
-        )?);
-
-        // skip aggregates
-        let mut id_array = vec![];
-        expr_to_identifier(
-            &expr,
-            &mut HashMap::new(),
-            &mut id_array,
-            Arc::clone(&schema),
-            ExprMask::Normal,
-        )?;
-
-        let expected = vec![
-            (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
-            (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
-            (4, ""),
-            (3, "a + Int32(1)Int32(1)a"),
-            (1, ""),
-            (2, ""),
-            (6, ""),
-            (5, ""),
-            (8, "")
-        ]
-        .into_iter()
-        .map(|(number, id)| (number, id.into()))
-        .collect::<Vec<_>>();
-        assert_eq!(expected, id_array);
-
-        // include aggregates
-        let mut id_array = vec![];
-        expr_to_identifier(
-            &expr,
-            &mut HashMap::new(),
-            &mut id_array,
-            Arc::clone(&schema),
-            ExprMask::NormalAndAggregates,
-        )?;
-
-        let expected = vec![
-            (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
-            (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
-            (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
-            (3, "a + Int32(1)Int32(1)a"),
-            (1, ""),
-            (2, ""),
-            (6, "AVG(c)c"),
-            (5, ""),
-            (8, "")
-        ]
-        .into_iter()
-        .map(|(number, id)| (number, id.into()))
-        .collect::<Vec<_>>();
-        assert_eq!(expected, id_array);
-
-        Ok(())
-    }
-
-    #[test]
     fn tpch_q1_simplified() -> Result<()> {
         // SQL:
         //  select
@@ -1171,24 +1008,28 @@
         let table_scan = test_table_scan().unwrap();
         let affected_id: BTreeSet<Identifier> =
             ["c+a".to_string(), "b+a".to_string()].into_iter().collect();
-        let expr_set_1 = [
+        let expr_set_1 = vec![
             (
                 "c+a".to_string(),
-                (col("c") + col("a"), 1, DataType::UInt32),
+                (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()),
             ),
             (
                 "b+a".to_string(),
-                (col("b") + col("a"), 1, DataType::UInt32),
+                (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()),
             ),
         ]
-        .into_iter()
-        .collect();
-        let expr_set_2 = [
-            ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
-            ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)),
+        .into();
+        let expr_set_2 = vec![
+            (
+                "c+a".to_string(),
+                (col("c+a"), 1, DataType::UInt32, "c+a".to_string()),
+            ),
+            (
+                "b+a".to_string(),
+                (col("b+a"), 1, DataType::UInt32, "b+a".to_string()),
+            ),
         ]
-        .into_iter()
-        .collect();
+        .into();
         let project =
             build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1)
                 .unwrap();
@@ -1214,30 +1055,48 @@
             ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()]
                 .into_iter()
                 .collect();
-        let expr_set_1 = [
+        let expr_set_1 = vec![
             (
                 "test1.c+test1.a".to_string(),
-                (col("test1.c") + col("test1.a"), 1, DataType::UInt32),
+                (
+                    col("test1.c") + col("test1.a"),
+                    1,
+                    DataType::UInt32,
+                    "test1.c+test1.a".to_string(),
+                ),
             ),
             (
                 "test1.b+test1.a".to_string(),
-                (col("test1.b") + col("test1.a"), 1, DataType::UInt32),
+                (
+                    col("test1.b") + col("test1.a"),
+                    1,
+                    DataType::UInt32,
+                    "test1.b+test1.a".to_string(),
+                ),
             ),
         ]
-        .into_iter()
-        .collect();
-        let expr_set_2 = [
+        .into();
+        let expr_set_2 = vec![
             (
                 "test1.c+test1.a".to_string(),
-                (col("test1.c+test1.a"), 1, DataType::UInt32),
+                (
+                    col("test1.c+test1.a"),
+                    1,
+                    DataType::UInt32,
+                    "test1.c+test1.a".to_string(),
+                ),
             ),
             (
                 "test1.b+test1.a".to_string(),
-                (col("test1.b+test1.a"), 1, DataType::UInt32),
+                (
+                    col("test1.b+test1.a"),
+                    1,
+                    DataType::UInt32,
+                    "test1.b+test1.a".to_string(),
+                ),
             ),
         ]
-        .into_iter()
-        .collect();
+        .into();
         let project =
             build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1)
                 .unwrap();
diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt
index 75bcbc0..2e0cbf5 100644
--- a/datafusion/sqllogictest/test_files/expr.slt
+++ b/datafusion/sqllogictest/test_files/expr.slt
@@ -2262,3 +2262,66 @@
 select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles;
 ----
 10.1 0.09900990099 1.471623942989
+
+
+statement ok
+CREATE TABLE t1(
+    time TIMESTAMP,
+    load1 DOUBLE,
+    load2 DOUBLE,
+    host VARCHAR
+) AS VALUES
+  (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'),
+  (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'),
+  (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'),
+  (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL)
+;
+
+# struct scalar function with columns
+query ?
+select struct(time,load1,load2,host) from t1;
+----
+{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1}
+{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2}
+{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3}
+{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: }
+
+# can have an aggregate function with an inner coalesce
+query TR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1
+host2 2.2
+host3 3.3
+
+# can have an aggregate function with an inner CASE WHEN
+query TR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 101
+host2 202
+host3 303
+
+# can have 2 projections with aggr(short_circuited), with different short-circuited expr
+query TRR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1 101
+host2 2.2 202
+host3 3.3 303
+
+# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
+query TRR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1 101
+host2 2.2 202
+host3 3.3 303
+
+# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce)
+query TRR
+select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+----
+host1 1.1 101
+host2 2.2 202
+host3 3.3 303