blob: 5c14caf1a6e36c008502d7f33aa967ade24217fa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file common_subexpr_elim.h
* \brief Interface of the Common Subexpressions Elimination (CSE) pass which rewrites statements
and expressions in order to eliminate redundant computations. In order to achieve that,
common (sub-)expressions are introduced into variables with let-in bindings, and the
places where the expression was used are replaced with the freshly introduced variable.
*/
#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_
#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h> // For the class StmtExprMutator
#include <tvm/tir/var.h>
#include <utility> // For std::pair
#include <vector>
#include "common_subexpr_elim_tools.h" // For the class MaybeValue
namespace tvm {
namespace tir {
/*!
* \brief A context is a vector of pairs that associates Var to MaybeValue
(which are either an expression or nothing)
*/
using Context = std::vector<std::pair<Var, MaybeValue>>;
/*!
* \brief Mutator that performs Common Subexpression Elimination (CSE) for the body of a
PrimFunc, mutating both its expressions and statements.
*/
class CommonSubexpressionEliminator : public StmtExprMutator {
public:
// Toplevel (static) function
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms);
PrimExpr VisitExpr(const PrimExpr& expr) override;
Stmt VisitStmt(const Stmt& stmt) override;
int GetNbVarGenerated();
protected:
// Constructor
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init,
bool identify_equiv_terms);
PrimExpr VisitExpr_(const LetNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
Stmt VisitStmt_(const ForNode* op) override;
private:
Stmt initial_body_; // Kept for checking if names of new variables already exist
Context context_; // Context associating variables to (maybe) definitions
int num_last_try_ = 0; // Number of the last variable tried
int nb_var_ = 0; // Number of variables introduced by the CSE pass
bool identify_equiv_terms_ = false;
static bool ForbiddenComputation(const PrimExpr& expr);
static bool IsEligibleComputation(const PrimExpr& expr);
static bool CanContainEligibleComputations(const PrimExpr& expr);
static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b);
Var GenerateNewVar(DataType type_annotation);
};
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_