blob: 447d85370ca86bbd8cafc2880f03334a5d6a436a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file common_subexpr_elim.cc
* \brief Implementation of the Common Subexpressions Elimination (CSE) pass
which rewrites statements and expressions in order to eliminate
redundant computations. In order to achieve that, common (sub-)
expressions are introduced into variables with let-in bindings,
and the places where the expression was used are replaced with
the freshly introduced variable.
*/
#include "common_subexpr_elim.h"
#include <tvm/ir/transform.h> // For the class Pass and the class PassContext
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/tir/analysis.h> // For the analysis which gives the size of an expr
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h> // For the class PrimFunc
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> // For the decl of the function returning the pass
#include <algorithm> // For the algorithm std::find
#include <iostream>
#include <string>
#include <unordered_map> // For the hashtable datatype
#include <utility> // For std::pair and std::move
#include <vector>
#include "../analysis/check_contains.h" // For the visitor CheckContains
#include "common_subexpr_elim_tools.h" // For the auxiliary analysis (visitors) and tools
#include "replace_selected_expr.h" // For the mutator ReplaceSelectedExpr
namespace tvm {
namespace tir {
/*!
* \brief Check whether a computation is forbidden for being treated by the CSE pass.
The important thing about forbidden computations is that not only we won't want
to collect them for the CSE pass, but we also won't even want to collect computations
that contain them.
The reason is that reusing such computations would change the semantics of the program,
and therefore before doing any introduction of var or any reuse of already introduced
variables, we will make sure that the computation being considered is not forbidden, and
that it does not even contain a forbidden computation.
* \param expr The expression to check
* \return Whether `expr` is a forbidden computation or not
*/
bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
// Function calls, loads and buffer loads are absolutely forbidden as introducing them into
// variables would change the semantics of the program.
return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
expr.as<BufferLoadNode>() != nullptr);
}
/*!
* \brief Predicate used for verifying that a computation is eligible for being treated by
the CSE pass, i.e. for being introduced into a variable / for being replaced by a
variable.
Being eligible is a conjunction of a few conditions, like not being an atom (constant
or variable), not being a forbidden node, not containing a forbidden node, etc.
* \param expr The expression to check
* \return Whether `expr` is an eligible computation or not
*/
bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
return (
// In order to be eligible, the given expression should not be a constant
(expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
(expr.as<StringImmNode>() == nullptr)
// and it should not be a variable
&& (expr.as<VarNode>() == nullptr)
// and it should not be a forbidden computation (function calls and loads)
&& (!ForbiddenComputation(expr))
// and it should not even contain a forbidden computation (function calls and loads)
// the reason is that we don't want to register expressions like (x + f(y)) or
// (x + Mem[i]) as introducing them into variables could change the semantics
&& (!CheckContains::ExprContains(expr, ForbiddenComputation))
// and it should not be a ramp node or a broadcast node due to some internals TVM
// constraints (which check for these node explicitely without performing any
// evaluation first, so if they have been put into variables it fails)
&& (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
}
/*!
* \brief Predicate used (when considering eligible computations) for only diving into
expressions that are allowed to contain eligible computations. Customize this predicate
if you want to make it forbidden to rewrite inside a specific node, like inside
a Load node for instance.
* \param expr The expression to check
* \return Whether `expr` can contain some eligible computations or not, and therefore
if recursing inside `expr` is necessary.
*/
bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
// Uncomment the next line to prevent the collection and the replacement of eligible computations
// inside the index of Load nodes. We initially thought that this would be needed in order to
// not harm the indexing mode of the CPU, but as we are still far from ASM code, we
// finally want to perform such simplifications, which tend to happen fairly frequently.
// return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
return true;
}
/*!
* \brief Implements an order on pairs (expression,frequency). First attempts to compare them
using the size of the expression. If it is the same, decides something else still
deterministic.
* \param a The first pair
* \param b The second pair
* \return A boolean telling if the first pair `a` comes before the second pair `b`
* \note We need this order to be deterministic in order to have a fully deterministic pass,
* as we will deal with elements that are coming from a hashtable, but the order in which
* they appeared in the hashtable was based on some runtime addresses, so it can potentially
* change with every execution.
*/
bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a,
std::pair<PrimExpr, size_t> b) {
size_t a_size = CalculateExprComplexity(a.first);
size_t b_size = CalculateExprComplexity(b.first);
// Criteria 1 - Size of the expression comes first
// `a` comes before `b` if the size of `a` is bigger
if (a_size > b_size) {
return true;
}
// `a` does NOT come before `b` if the size of `b` is bigger
if (b_size > a_size) {
return false;
}
// Criteria 2 - If they had the same size, use the lexicographic order as a last resort
// as we need a deterministic order
std::stringstream a_stream;
std::stringstream b_stream;
a_stream << a.first;
b_stream << b.first;
return (a_stream.str().compare(b_stream.str()) < 0);
}
/*!
* \brief Generates a new fresh variable, whose name will be cse_var_i.
* \param type_annotation The type of the new variable to generate
* \return A new variable of type `type_annotation` called cse_var_i where i is the first available
integer.
*/
Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
// Increase `num_last_try_` for this new attempt
num_last_try_++;
// Builds the variable name, which is sce_var_i where i will go up from 1
std::string prefix = "cse_var_";
std::string name = prefix.append(std::to_string(num_last_try_));
// Builds a String using the std::string
String string_name(name);
// Check that the name that we want to use for the new variable isn't already being used
// (names don't really have to be unique as they are just hints, and having the same name
// doesn't means that it's the same variable, but it's clearer for dumps)
if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
// If the name is already used, call ourselves recursively for trying with the next one
return GenerateNewVar(type_annotation);
}
// Increase `nb_var_` for this new generation of variable that we have just done
nb_var_++;
// Return a new Variable using the name built and the given type_annotation
return (Var(string_name, type_annotation));
}
/*!
* \brief Gives the number of variables generated by the CSE on the current function
(i.e., getter for `nb_var_`).
* \return A copy of `nb_var_`
*/
int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
/*!
* \brief Toplevel (static) method that performs Common Subexpression Elimination on
a given statement (which should be the body of a PrimFunc). This method should be
called for each PrimFunc definition.
* \param stmt The statement of the function being analyzed, on which we want to perform CSE
* \param context_init The initial context, which should contain the formal parameters
of the function being analyzed
* \return A new statement where CSE has been performed
*/
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init,
bool identify_equiv_terms) {
// As this function is being called for each PrimFunc definition, we create a new instance
// for the one we are having now.
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init,
identify_equiv_terms);
return common_subexpression_eliminator.VisitStmt(stmt);
}
/*!
* \brief Protected constructor of CommonSubexpressionEliminator.
* \param context_init The context at the beginning of the CSE pass. It should contain the
formal parameters of the function that will be analyzed
*/
CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
const Context& context_init,
bool identify_equiv_terms)
: initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {}
/*!
* \brief The method which overrides the generic dispatcher of StmtExprMutator.
Entry point to the common subexpression elimination mutator for expressions.
* \param expr The expression to mutate
*/
PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
bool variables_created = false; // Will be needed for knowing if the CSE has created new vars
PrimExpr result = expr;
// Obtain the (syntactic) eligible computations done by the input expression, and keep it as
// a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the
// number of time this exact syntactic computation is being computed.
ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
expr, IsEligibleComputation, CanContainEligibleComputations);
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_);
// Sort the vector of semantic entities by decreasing size
std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
OrderOnExprAndFrequency);
// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"
// The predicate later used (when doing replacements) to select expressions that are
// equivalent to the current computation (`computation_and_nb.first`)
std::function<bool(const PrimExpr&)> predicate_selector =
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
// that `current_expr` is an eligible computation even if we know that
// `computation_and_nb.first` is eligible by construction, in case that one day the
// equivalence relation would not preserve the eligibility any more (even though that
// would probably be a very weird equivalence).
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
IsEligibleComputation(current_expr));
};
// See if there is a pair (`var`, `value`) in the context where `value` is semantically
// equivalent to `computation_and_nb.first`
auto it_on_var = std::find_if(
context_.begin(), context_.end(),
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
// Note : safe to call value() as we check has_value() just before
return (var_and_value.second.has_value() &&
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
ident_equiv_terms));
});
// Case where we have a perfectly equivalent computation already available in a variable
// introduced (i.e, present in context_).
// Note that this case is needed when the user has written something like
// [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by
// an already existing variable holding A, when such a variable happens to exist.
if (it_on_var != context_.end()) {
// Replace in the current `result` everything that is selected by the selector with
// the existing variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
} else {
// The current computation is not equivalent to a computation already done. We will
// need to see if we want to introduce it.
// --- Chunk needed for reusing the UndefinedVars() analysis ---
// 1 - Wraps the computation into a statement
Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
// 2.1 - Transform the context into a vector of variables instead of pairs
std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
[](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
// 2.2 - Transform the std::vector into an Array
Array<Var> array_vars_known = Array<Var>(vector_vars_known);
// --- End of chunk needed for reusing the UndefinedVars() analysis ---
// We use the UndefinedVars() analysis to get the undefined vars of the computation
Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
// Check if we can introduce it : if it contains no undefined variables and if we want
// to introduce it according to the predicate
if (vars_undefined.empty() &&
PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
// Create a new variable for this computation
Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
// Replace in the current `result` everything that is selected by the selector with
// the new variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var,
CanContainEligibleComputations);
// Build a let-in that introduces the new variable in the current `result`
result = Let(new_var, computation_and_nb.first, result);
// We don't add the variable to the context because the invariant is that the
// context is the context in which 'result' makes sense, and we've just updated it.
} else {
// Here it's not doable to introduce (via a let-in) the computation at this level
// as it contains variables that are not yet declared, and/or because the predicate
// did not select it.
// Either way, we will simply add to the vector of computations the direct subexprs
// of the current computation, as these ones might be good candidates
// for being introduced into variables.
// Note that we don't need to add all of its subexpressions, but only its *direct*
// subexpressions as we consider them from biggest to smallest, and if they were
// all added at once, then there could be dependencies between them, as commoning
// one of them could remove some other possibilities.
// Computing the direct subexpressions will return a small number of direct
// subexpressions (typically 0 to 3)
std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
// The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
// decreasing size/complexity), and it will only insert at locations > i as the
// direct subexprs are necessarily smaller than the current computation.
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs,
identify_equiv_terms_);
}
}
// Note : we do not remove the current element, as we never look back in the local vector
} // End of for loop
// If the CSE pass has created some variables, then we run it again as more commoning could
// potentially happen using the new variables introduced
if (variables_created) {
result = VisitExpr(result);
} else {
// But if no changes were performed, we recurse inside the children by calling the dispatcher.
// Calling the dispatcher to the specific treatments, which will update the context
// appropriately before doing the recursive calls on the children nodes
result = StmtExprMutator::VisitExpr(result);
}
return result;
}
/*!
* \brief The method which overrides the specific treatment for a LetNode
*/
PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
// At this point, we have already done the generic treatment of introducing (via let-in) what
// was doable at the toplevel of the given let-in.
// Save the context at the entry of the function
Context context_at_entry = context_;
// Recurse on the `value` field for potentially rewriting it
PrimExpr value_new = VisitExpr(op->value);
// Augment the context with the association (`var`, `value`) for preparing the next recursion
// on the `body`
context_.push_back({op->var, MaybeValue(op->value)});
// Recurse on the `body` (with this extended context)
// The recursive call will have potentially done new simplifications, because in this recursive
// call `var` will be a part of the context.
// (see in VisitExpr() that no introduction were performed when a computation was using an
// undefined variable, as that would lead to ill-formed code)
PrimExpr body_new = VisitExpr(op->body);
// Restaure the context to its content at the entrance to not carry out of scope declarations
// as the variable introduced by the let-in is not in scope outside of its body
context_ = context_at_entry;
// Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
// have been done.
// If the `value` and the `body` of the let-in have been rewritten to the same thing
if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
// then return a reference to the same node
return GetRef<PrimExpr>(op);
} else {
// Otherwise return a let-in built with the new `value_new` and the new `body_new` that
// have just been obtained
return Let(op->var, value_new, body_new, op->span);
}
}
/*!
* \brief The method which overrides the generic dispatcher of StmtExprMutator.
Entry point to the common subexpression elimination mutator for statements.
* \param stmt The statement to mutate.
*/
Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
bool variables_created = false; // Will be needed for knowing if the CSE has created new vars
Stmt result = stmt;
// Obtain the (syntactic) eligible computations done by the input statement, and keep it as
// a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the
// number of time this exact syntactic computation is being computed.
ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy(
stmt, IsEligibleComputation, CanContainEligibleComputations);
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_);
// Sort the vector of semantic entities by decreasing size
std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(),
OrderOnExprAndFrequency);
// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];
bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"
// The predicate later used (when doing replacements) to select expressions that are
// equivalent to the current computation (`computation_and_nb.first`)
std::function<bool(const PrimExpr&)> predicate_selector =
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
// that `current_expr` is an eligible computation even if we know that
// `computation_and_nb.first` is eligible by construction, in case that one day the
// equivalence relation would not preserve the eligibility any more (even though that
// would probably be a very weird equivalence).
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
IsEligibleComputation(current_expr));
};
// See if there is a pair (`var`, `value`) in the context where `value` is semantically
// equivalent to `computation_and_nb.first`
auto it_on_var = std::find_if(
context_.begin(), context_.end(),
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
// Note : safe to call value() as we check has_value() just before
return (var_and_value.second.has_value() &&
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
ident_equiv_terms));
});
// Case where we have a perfectly equivalent computation already available in a variable
// introduced (i.e, present in context_).
// Note that this case is needed when the user has written something like
// [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by
// an already existing variable holding A, when such a variable happens to exist.
if (it_on_var != context_.end()) {
// Replace in the current `result` everything that is selected by the selector with
// the existing variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(
result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
} else {
// The current computation is not equivalent to a computation already done. We will
// need to see if we want to introduce it.
// --- Chunk needed for reusing the UndefinedVars() analysis ---
// 1 - Wraps the computation into a statement
Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
// 2.1 - Transform the context into a vector of variables instead of pairs
std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
[](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
// 2.2 - Transform the std::vector into an Array
Array<Var> array_vars_known = Array<Var>(vector_vars_known);
// --- End of chunk needed for reusing the UndefinedVars() analysis ---
// We use the UndefinedVars() analysis to get the undefined vars of the computation
Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
// Check if we can introduce it : if it contains no undefined variables and if we want
// to introduce it according to the predicate
if (vars_undefined.empty() &&
PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
// Create a new variable for this computation
Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
variables_created = true;
// Replace in the current `result` everything that is selected by the selector with
// the new variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(result, predicate_selector, new_var,
CanContainEligibleComputations);
// Build a let-in that introduces the new variable in the current `result`
result = LetStmt(new_var, computation_and_nb.first, result);
// We don't add the variable to the context because the invariant is that the
// context is the context in which 'result' makes sense, and we've just updated it.
} else {
// Here it's not doable to introduce (via a let-in) the computation at this level
// as it contains variables that are not yet declared, and/or because the predicate
// did not select it.
// Either way, we will simply add to the vector of computations the direct subexprs
// of the current computation, as these ones might be good candidates
// for being introduced into variables.
// Note that we don't need to add all of its subexpressions, but only its *direct*
// subexpressions as we consider them from biggest to smallest, and if they were
// all added at once, then there could be dependencies between them, as commoning
// one of them could remove some other possibilities.
// Computing the direct subexpressions will return a small number of direct
// subexpressions (typically 0 to 3)
std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
// The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by
// decreasing size/complexity), and it will only insert at locations > i as the
// direct subexprs are necessarily smaller than the current computation.
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
identify_equiv_terms_);
}
}
// Note : we do not remove the current element, as we never look back in the local vector
} // End of for loop
// If the CSE pass has created some variables, then we run it again as more commoning could
// potentially happen using the new variables introduced
if (variables_created) {
result = VisitStmt(result);
} else {
// But if no changes were performed, we recurse inside the children by calling the dispatcher.
// Calling the dispatcher to the specific treatments, which will update the context
// appropriately before doing the recursive calls on the children nodes
result = StmtExprMutator::VisitStmt(result);
}
return result;
}
/*!
* \brief The method which overrides the specific treatment for a LetStmtNode
*/
Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) {
// At this point, we have already done the generic treatment of introducing (via let-in) what
// was doable at the toplevel of the given let-in.
// Save the context at the entry of the function
Context context_at_entry = context_;
// Recurse on the `value` field for potentially rewriting it
PrimExpr value_new = VisitExpr(op->value);
// Augment the context with the association (`var`, `value`) for preparing the next recursion
// on the `body`
context_.push_back({op->var, MaybeValue(op->value)});
// Recurse on the `body` (with this extended context)
// The recursive call will have potentially done new simplifications, because in this recursive
// call `var` will be a part of the context.
// (see in VisitStmt() that no introduction were performed when a computation was using an
// undefined variable, as that would lead to ill-formed code)
Stmt body_new = VisitStmt(op->body);
// Restaure the context to its content at the entrance to not carry out of scope declarations
// as the variable introduced by the let-in is not in scope outside of its body
context_ = context_at_entry;
// Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
// have been done.
// If the `value` and the `body` of the let-in have been rewritten to the same thing
if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
// Return a reference to the same node
return GetRef<Stmt>(op);
} else {
// Otherwise return a let-in built with the new `value_new` and the new `body_new` that
// have just been obtained
return LetStmt(op->var, value_new, body_new, op->span);
}
}
/*!
* \brief The method which overrides the specific treatment for a ForNode
*/
Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) {
// At this point, we have already done the generic treatment of introducing (via let-in) what
// was doable at the toplevel of the given for loop.
// Save the context at the entry of the function
Context context_at_entry = context_;
// Recurse on the `min` field for potentially rewriting it
PrimExpr min_new = VisitExpr(op->min);
// Recurse on the `extent` field for potentially rewriting it
PrimExpr extent_new = VisitExpr(op->extent);
// Augment the context with the association {loop_var, no value} (no value as its value will
// change during the execution of the loop) for preparing the next recursion on the `body`
context_.push_back({op->loop_var, MaybeValue()});
// Recurse on the `body` (with this extended context)
Stmt body_new = VisitStmt(op->body);
// Restaure the context to its content at the entrance to not carry out of scope declarations
// as the variable introduced by the for loop is not in scope outside of its body
context_ = context_at_entry;
// Rebuild the for loop with (potentially) a new `min_new`, `extent_new` and `body_new`, where
// new simplifications might have been done.
// If the `min`, `extent` and `body` of the for loop have been rewritten to the same thing
if (min_new.same_as(op->min) && extent_new.same_as(op->extent) && body_new.same_as(op->body)) {
// Return a reference to the same node
return GetRef<Stmt>(op);
} else {
// Otherwise return a for node built with the new `min_new`, `extent_new` and `body_new`
// that have just been obtained
return For(op->loop_var, min_new, extent_new, op->kind, body_new, op->thread_binding,
op->annotations, op->span);
}
}
namespace transform {
/*!
* \brief The function which returns the pass for the Common Subexpression Elimination.
* \return The pass for performing CSE.
*/
Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) {
auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) {
if (enable_cse_tir) {
auto* n = f.CopyOnWrite();
Context context_init;
// Add to the initial context all the parameters of the function, as that is needed for
// doing commoning on terms that use these parameters (it is only possible to introduce
// a term into a new variable at a specific point in the program if all the variables that
// it uses have already been declared at this point)
for (auto current_param : f->params) {
// The parameters of the functions are variables associated with no value
context_init.push_back({current_param, MaybeValue()});
}
// Do the Common Subexpression Elimination on the body of the function, with the initial
// context that we have prepared
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init,
identify_equiv_terms);
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.CommonSubexprElimTIR", {});
}
// The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it
TVM_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR);
} // namespace transform
} // namespace tir
} // namespace tvm