blob: 674d3386171f841c8f699a654f9f82a2d28495df [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with
//! a [`CSEController`], that defines how to eliminate common subtrees from a particular
//! [`TreeNode`] tree.
use crate::hash_utils::combine_hashes;
use crate::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
TreeNodeVisitor,
};
use crate::Result;
use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
use std::sync::Arc;
/// Hashes the direct content of an [`TreeNode`] without recursing into its children.
///
/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds
/// a deep hash of a node and its descendants during the bottom-up phase of the first
/// traversal and so avoid computing the hash of the node and then the hash of its
/// descendants separately.
///
/// If a node doesn't have any children then the value returned by `hash_node()` is
/// similar to '.hash()`, but not necessarily returns the same value.
pub trait HashNode {
fn hash_node<H: Hasher>(&self, state: &mut H);
}
impl<T: HashNode + ?Sized> HashNode for Arc<T> {
fn hash_node<H: Hasher>(&self, state: &mut H) {
(**self).hash_node(state);
}
}
/// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
///
/// Normalization is the process of converting a node into a canonical form that can be used
/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
pub trait Normalizeable {
fn can_normalize(&self) -> bool;
}
/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
/// normalized nodes in optimizations like Common Subexpression Elimination (CSE).
///
/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
/// are considered equal in CSE optimization, even if their original forms differ.
///
/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
/// internal representations.
pub trait NormalizeEq: Eq + Normalizeable {
fn normalize_eq(&self, other: &Self) -> bool;
}
/// Identifier that represents a [`TreeNode`] tree.
///
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
/// "have no collision (as low as possible)"
#[derive(Debug, Eq)]
struct Identifier<'n, N: NormalizeEq> {
// Hash of `node` built up incrementally during the first, visiting traversal.
// Its value is not necessarily equal to default hash of the node. E.g. it is not
// equal to `expr.hash()` if the node is `Expr`.
hash: u64,
node: &'n N,
}
impl<N: NormalizeEq> Clone for Identifier<'_, N> {
fn clone(&self) -> Self {
*self
}
}
impl<N: NormalizeEq> Copy for Identifier<'_, N> {}
impl<N: NormalizeEq> Hash for Identifier<'_, N> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}
impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash && self.node.normalize_eq(other.node)
}
}
impl<'n, N> Identifier<'n, N>
where
N: HashNode + NormalizeEq,
{
fn new(node: &'n N, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
node.hash_node(&mut hasher);
let hash = hasher.finish();
Self { hash, node }
}
fn combine(mut self, other: Option<Self>) -> Self {
other.map_or(self, |other_id| {
self.hash = combine_hashes(self.hash, other_id.hash);
self
})
}
}
/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the
/// preorder index of the nodes.
///
/// This cache is filled by [`CSEVisitor`] during the first traversal and is
/// used by [`CSERewriter`] during the second traversal.
///
/// The purpose of this cache is to quickly find the identifier of a node during the
/// second traversal.
///
/// Elements in this array are added during `f_down` so the indexes represent the preorder
/// index of nodes and thus element 0 belongs to the root of the tree.
///
/// The elements of the array are tuples that contain:
/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start
/// from 0.
/// - The optional [`Identifier`] of the node. If none the node should not be considered
/// for CSE.
///
/// # Example
/// An expression tree like `(a + b)` would have the following `IdArray`:
/// ```text
/// [
/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))),
/// (1, Some(Identifier(hash_of("a"), &"a"))),
/// (0, Some(Identifier(hash_of("b"), &"b")))
/// ]
/// ```
type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
#[derive(PartialEq, Eq)]
/// How many times a node is evaluated. A node can be considered common if evaluated
/// surely at least 2 times or surely only once but also conditionally.
enum NodeEvaluation {
SurelyOnce,
ConditionallyAtLeastOnce,
Common,
}
/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
/// extracted during the second, rewriting traversal.
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
type ChildrenList<N> = (Vec<N>, Vec<N>);
/// The [`TreeNode`] specific definition of elimination.
pub trait CSEController {
/// The type of the tree nodes.
type Node;
/// Splits the children to normal and conditionally evaluated ones or returns `None`
/// if all are always evaluated.
fn conditional_children(node: &Self::Node) -> Option<ChildrenList<&Self::Node>>;
// Returns true if a node is valid. If a node is invalid then it can't be eliminated.
// Validity is propagated up which means no subtree can be eliminated that contains
// an invalid node.
// (E.g. volatile expressions are not valid and subtrees containing such a node can't
// be extracted.)
fn is_valid(node: &Self::Node) -> bool;
// Returns true if a node should be ignored during CSE. Contrary to validity of a node,
// it is not propagated up.
fn is_ignored(&self, node: &Self::Node) -> bool;
// Generates a new name for the extracted subtree.
fn generate_alias(&self) -> String;
// Replaces a node to the generated alias.
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
// A helper method called on each node during top-down traversal during the second,
// rewriting traversal of CSE.
fn rewrite_f_down(&mut self, _node: &Self::Node) {}
// A helper method called on each node during bottom-up traversal during the second,
// rewriting traversal of CSE.
fn rewrite_f_up(&mut self, _node: &Self::Node) {}
}
/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common
/// subtrees.
#[derive(Debug)]
pub enum FoundCommonNodes<N> {
/// No common [`TreeNode`]s were found
No { original_nodes_list: Vec<Vec<N>> },
/// Common [`TreeNode`]s were found
Yes {
/// extracted common [`TreeNode`]
common_nodes: Vec<(N, String)>,
/// new [`TreeNode`]s with common subtrees replaced
new_nodes_list: Vec<Vec<N>>,
/// original [`TreeNode`]s
original_nodes_list: Vec<Vec<N>>,
},
}
/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees.
///
/// An identifier contains information of the [`TreeNode`] itself and its subtrees.
/// This visitor implementation use a stack `visit_stack` to track traversal, which
/// lets us know when a subtree's visiting is finished. When `pre_visit` is called
/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack.
/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem`
/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
///
/// This visitor also records identifier in `id_array`. Makes the following traverse
/// pass can get the identifier of a node without recalculate it. We assign each node
/// in the tree a series number, start from 1, maintained by `series_number`.
/// Series number represents the order we left (`f_up()`) a node. Has the property
/// that child node's series number always smaller than parent's. While `id_array` is
/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to
/// get the index of `id_array` for each node.
///
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
/// because they should not be recognized as common subtree.
struct CSEVisitor<'a, 'n, N, C>
where
N: NormalizeEq,
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
node_stats: &'a mut NodeStats<'n, N>,
/// cache to speed up second traversal
id_array: &'a mut IdArray<'n, N>,
/// inner states
visit_stack: Vec<VisitRecord<'n, N>>,
/// preorder index, start from 0.
down_index: usize,
/// postorder index, start from 0.
up_index: usize,
/// a [`RandomState`] to generate hashes during the first traversal
random_state: &'a RandomState,
/// a flag to indicate that common [`TreeNode`]s found
found_common: bool,
/// if we are in a conditional branch. A conditional branch means that the [`TreeNode`]
/// might not be executed depending on the runtime values of other [`TreeNode`]s, and
/// thus can not be extracted as a common [`TreeNode`].
conditional: bool,
controller: &'a C,
}
/// Record item that used when traversing a [`TreeNode`] tree.
enum VisitRecord<'n, N>
where
N: NormalizeEq,
{
/// Marks the beginning of [`TreeNode`]. It contains:
/// - The post-order index assigned during the first, visiting traversal.
EnterMark(usize),
/// Marks an accumulated subtree. It contains:
/// - The accumulated identifier of a subtree.
/// - A accumulated boolean flag if the subtree is valid for CSE.
/// The flag is propagated up from children to parent. (E.g. volatile expressions
/// are not valid and can't be extracted, but non-volatile children of volatile
/// expressions can be extracted.)
NodeItem(Identifier<'n, N>, bool),
}
impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
/// it. Returns a tuple that contains:
/// - The pre-order index of the [`TreeNode`] we marked.
/// - The accumulated identifier of the children of the marked [`TreeNode`].
/// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all
/// children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a
/// common [`TreeNode`] from its children POV).
/// (E.g. if any of the children of the marked expression is not valid (e.g. is
/// volatile) then the expression is also not valid, so we can propagate this
/// information up from children to parents via `visit_stack` during the first,
/// visiting traversal and no need to test the expression's validity beforehand with
/// an extra traversal).
fn pop_enter_mark(
&mut self,
can_normalize: bool,
) -> (usize, Option<Identifier<'n, N>>, bool) {
let mut node_ids: Vec<Identifier<'n, N>> = vec![];
let mut is_valid = true;
while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index) => {
if can_normalize {
node_ids.sort_by_key(|i| i.hash);
}
let node_id = node_ids
.into_iter()
.fold(None, |accum, item| Some(item.combine(accum)));
return (down_index, node_id, is_valid);
}
VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
node_ids.push(sub_node_id);
is_valid &= sub_node_is_valid;
}
}
}
unreachable!("EnterMark should paired with NodeItem");
}
}
impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.id_array.push((0, None));
self.visit_stack
.push(VisitRecord::EnterMark(self.down_index));
self.down_index += 1;
// If a node can short-circuit then some of its children might not be executed so
// count the occurrence either normal or conditional.
Ok(if self.conditional {
// If we are already in a conditionally evaluated subtree then continue
// traversal.
TreeNodeRecursion::Continue
} else {
// If we are already in a node that can short-circuit then start new
// traversals on its normal conditional children.
match C::conditional_children(node) {
Some((normal, conditional)) => {
normal
.into_iter()
.try_for_each(|n| n.visit(self).map(|_| ()))?;
self.conditional = true;
conditional
.into_iter()
.try_for_each(|n| n.visit(self).map(|_| ()))?;
self.conditional = false;
TreeNodeRecursion::Jump
}
// In case of non-short-circuit node continue the traversal.
_ => TreeNodeRecursion::Continue,
}
})
}
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_index, sub_node_id, sub_node_is_valid) =
self.pop_enter_mark(node.can_normalize());
let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
let is_valid = C::is_valid(node) && sub_node_is_valid;
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.controller.is_ignored(node) {
self.id_array[down_index].1 = Some(node_id);
self.node_stats
.entry(node_id)
.and_modify(|evaluation| {
if *evaluation == NodeEvaluation::SurelyOnce
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
&& !self.conditional
{
*evaluation = NodeEvaluation::Common;
self.found_common = true;
}
})
.or_insert_with(|| {
if self.conditional {
NodeEvaluation::ConditionallyAtLeastOnce
} else {
NodeEvaluation::SurelyOnce
}
});
}
self.visit_stack
.push(VisitRecord::NodeItem(node_id, is_valid));
self.up_index += 1;
Ok(TreeNodeRecursion::Continue)
}
}
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
/// replaced [`TreeNode`] tree.
struct CSERewriter<'a, 'n, N, C>
where
N: NormalizeEq,
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,
/// cache to speed up second traversal
id_array: &'a IdArray<'n, N>,
/// common [`TreeNode`]s, that are replaced during the second traversal, are collected
/// to this map
common_nodes: &'a mut CommonNodes<'n, N>,
// preorder index, starts from 0.
down_index: usize,
controller: &'a mut C,
}
impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
where
N: TreeNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;
fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
self.controller.rewrite_f_down(&node);
let (up_index, node_id) = self.id_array[self.down_index];
self.down_index += 1;
// Handle nodes with identifiers only
if let Some(node_id) = node_id {
let evaluation = self.node_stats.get(&node_id).unwrap();
if *evaluation == NodeEvaluation::Common {
// step index to skip all sub-node (which has smaller series number).
while self.down_index < self.id_array.len()
&& self.id_array[self.down_index].0 < up_index
{
self.down_index += 1;
}
// We *must* replace all original nodes with same `node_id`, not just the first
// node which is inserted into the common_nodes. This is because nodes with the same
// `node_id` are semantically equivalent, but not exactly the same.
//
// For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
// In this case, we should replace the common expression `1 + a` with a new variable
// (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
// `__common_cse_1`.
//
// The final result would be:
// - `__common_cse_1 as a + 1`
// - `__common_cse_1 as 1 + a`
//
// This way, we can efficiently handle semantically equivalent expressions without
// incorrectly treating them as identical.
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
{
self.controller.rewrite(&node, alias)
} else {
let node_alias = self.controller.generate_alias();
let rewritten = self.controller.rewrite(&node, &node_alias);
self.common_nodes.insert(node_id, (node, node_alias));
rewritten
};
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
}
}
Ok(Transformed::no(node))
}
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
self.controller.rewrite_f_up(&node);
Ok(Transformed::no(node))
}
}
/// The main entry point of Common Subexpression Elimination.
///
/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular
/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the
/// [`CSE::extract_common_nodes()`] method.
pub struct CSE<N, C: CSEController<Node = N>> {
random_state: RandomState,
phantom_data: PhantomData<N>,
controller: C,
}
impl<N, C> CSE<N, C>
where
N: TreeNode + HashNode + Clone + NormalizeEq,
C: CSEController<Node = N>,
{
pub fn new(controller: C) -> Self {
Self {
random_state: RandomState::new(),
phantom_data: PhantomData,
controller,
}
}
/// Add an identifier to `id_array` for every [`TreeNode`] in this tree.
fn node_to_id_array<'n>(
&self,
node: &'n N,
node_stats: &mut NodeStats<'n, N>,
id_array: &mut IdArray<'n, N>,
) -> Result<bool> {
let mut visitor = CSEVisitor {
node_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
random_state: &self.random_state,
found_common: false,
conditional: false,
controller: &self.controller,
};
node.visit(&mut visitor)?;
Ok(visitor.found_common)
}
/// Returns the identifier list for each element in `nodes` and a flag to indicate if
/// rewrite phase of CSE make sense.
///
/// Returns and array with 1 element for each input node in `nodes`
///
/// Each element is itself the result of [`CSE::node_to_id_array`] for that node
/// (e.g. the identifiers for each node in the tree)
fn to_arrays<'n>(
&self,
nodes: &'n [N],
node_stats: &mut NodeStats<'n, N>,
) -> Result<(bool, Vec<IdArray<'n, N>>)> {
let mut found_common = false;
nodes
.iter()
.map(|n| {
let mut id_array = vec![];
self.node_to_id_array(n, node_stats, &mut id_array)
.map(|fc| {
found_common |= fc;
id_array
})
})
.collect::<Result<Vec<_>>>()
.map(|id_arrays| (found_common, id_arrays))
}
/// Replace common subtrees in `node` with the corresponding temporary
/// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]
fn replace_common_node<'n>(
&mut self,
node: N,
id_array: &IdArray<'n, N>,
node_stats: &NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<N> {
if id_array.is_empty() {
Ok(Transformed::no(node))
} else {
node.rewrite(&mut CSERewriter {
node_stats,
id_array,
common_nodes,
down_index: 0,
controller: &mut self.controller,
})
}
.data()
}
/// Replace common subtrees in `nodes_list` with the corresponding temporary
/// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`].
fn rewrite_nodes_list<'n>(
&mut self,
nodes_list: Vec<Vec<N>>,
arrays_list: &[Vec<IdArray<'n, N>>],
node_stats: &NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<Vec<Vec<N>>> {
nodes_list
.into_iter()
.zip(arrays_list.iter())
.map(|(nodes, arrays)| {
nodes
.into_iter()
.zip(arrays.iter())
.map(|(node, id_array)| {
self.replace_common_node(node, id_array, node_stats, common_nodes)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
}
/// Extracts common [`TreeNode`]s and rewrites `nodes_list`.
///
/// Returns [`FoundCommonNodes`] recording the result of the extraction.
pub fn extract_common_nodes(
&mut self,
nodes_list: Vec<Vec<N>>,
) -> Result<FoundCommonNodes<N>> {
let mut found_common = false;
let mut node_stats = NodeStats::new();
let id_arrays_list = nodes_list
.iter()
.map(|nodes| {
self.to_arrays(nodes, &mut node_stats)
.map(|(fc, id_arrays)| {
found_common |= fc;
id_arrays
})
})
.collect::<Result<Vec<_>>>()?;
if found_common {
let mut common_nodes = CommonNodes::new();
let new_nodes_list = self.rewrite_nodes_list(
// Must clone the list of nodes as Identifiers use references to original
// nodes so we have to keep them intact.
nodes_list.clone(),
&id_arrays_list,
&node_stats,
&mut common_nodes,
)?;
assert!(!common_nodes.is_empty());
Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
new_nodes_list,
original_nodes_list: nodes_list,
})
} else {
Ok(FoundCommonNodes::No {
original_nodes_list: nodes_list,
})
}
}
}
#[cfg(test)]
mod test {
use crate::alias::AliasGenerator;
use crate::cse::{
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
Normalizeable, CSE,
};
use crate::tree_node::tests::TestTreeNode;
use crate::Result;
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
const CSE_PREFIX: &str = "__common_node";
#[derive(Clone, Copy)]
pub enum TestTreeNodeMask {
Normal,
NormalAndAggregates,
}
pub struct TestTreeNodeCSEController<'a> {
alias_generator: &'a AliasGenerator,
mask: TestTreeNodeMask,
}
impl<'a> TestTreeNodeCSEController<'a> {
fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self {
Self {
alias_generator,
mask,
}
}
}
impl CSEController for TestTreeNodeCSEController<'_> {
type Node = TestTreeNode<String>;
fn conditional_children(
_: &Self::Node,
) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> {
None
}
fn is_valid(_node: &Self::Node) -> bool {
true
}
fn is_ignored(&self, node: &Self::Node) -> bool {
let is_leaf = node.is_leaf();
let is_aggr = node.data == "avg" || node.data == "sum";
match self.mask {
TestTreeNodeMask::Normal => is_leaf || is_aggr,
TestTreeNodeMask::NormalAndAggregates => is_leaf,
}
}
fn generate_alias(&self) -> String {
self.alias_generator.next(CSE_PREFIX)
}
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}
}
impl HashNode for TestTreeNode<String> {
fn hash_node<H: Hasher>(&self, state: &mut H) {
self.data.hash(state);
}
}
impl Normalizeable for TestTreeNode<String> {
fn can_normalize(&self) -> bool {
false
}
}
impl NormalizeEq for TestTreeNode<String> {
fn normalize_eq(&self, other: &Self) -> bool {
self == other
}
}
#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));
let a_plus_1 = TestTreeNode::new(
vec![
TestTreeNode::new_leaf("a".to_string()),
TestTreeNode::new_leaf("1".to_string()),
],
"+".to_string(),
);
let avg_c = TestTreeNode::new(
vec![TestTreeNode::new_leaf("c".to_string())],
"avg".to_string(),
);
let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string());
let sum_a_plus_1_minus_avg_c =
TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string());
let root = TestTreeNode::new(
vec![
sum_a_plus_1_minus_avg_c,
TestTreeNode::new_leaf("2".to_string()),
],
"*".to_string(),
);
let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else {
panic!("Cannot extract subtree references")
};
let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else {
panic!("Cannot extract subtree references")
};
let [a_plus_1] = sum_a_plus_1.children.as_slice() else {
panic!("Cannot extract subtree references")
};
// skip aggregates
let mut id_array = vec![];
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
// Collect distinct hashes and set them to 0 in `id_array`
fn collect_hashes(
id_array: &mut IdArray<'_, TestTreeNode<String>>,
) -> HashSet<u64> {
id_array
.iter_mut()
.flat_map(|(_, id_option)| {
id_option.as_mut().map(|node_id| {
let hash = node_id.hash;
node_id.hash = 0;
hash
})
})
.collect::<HashSet<_>>()
}
let hashes = collect_hashes(&mut id_array);
assert_eq!(hashes.len(), 3);
let expected = vec![
(
8,
Some(Identifier {
hash: 0,
node: &root,
}),
),
(
6,
Some(Identifier {
hash: 0,
node: sum_a_plus_1_minus_avg_c,
}),
),
(3, None),
(
2,
Some(Identifier {
hash: 0,
node: a_plus_1,
}),
),
(0, None),
(1, None),
(5, None),
(4, None),
(7, None),
];
assert_eq!(expected, id_array);
// include aggregates
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));
let mut id_array = vec![];
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
let hashes = collect_hashes(&mut id_array);
assert_eq!(hashes.len(), 5);
let expected = vec![
(
8,
Some(Identifier {
hash: 0,
node: &root,
}),
),
(
6,
Some(Identifier {
hash: 0,
node: sum_a_plus_1_minus_avg_c,
}),
),
(
3,
Some(Identifier {
hash: 0,
node: sum_a_plus_1,
}),
),
(
2,
Some(Identifier {
hash: 0,
node: a_plus_1,
}),
),
(0, None),
(1, None),
(
5,
Some(Identifier {
hash: 0,
node: avg_c,
}),
),
(4, None),
(7, None),
];
assert_eq!(expected, id_array);
Ok(())
}
}