blob: e541970a2717895e53692f1e3dd291029f459e9c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rewrite_simplify.h
* \brief Rewrite-rule based simplification.
*/
#ifndef TVM_ARITH_REWRITE_SIMPLIFY_H_
#define TVM_ARITH_REWRITE_SIMPLIFY_H_
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/op.h>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "const_fold.h"
#include "ir_mutator_with_analyzer.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace tir;
/* \brief Usage counters for RewriteSimplifier
*
* These are intended for debug and testing purposes, to ensure that
* PrimExpr simplifications and TIR passes do not require an excessive
*/
struct RewriteSimplifierStatsNode : Object {
int64_t nodes_visited{0};
int64_t constraints_entered{0};
int64_t rewrites_attempted{0};
int64_t rewrites_performed{0};
int64_t max_recursive_depth{0};
int64_t num_recursive_rewrites{0};
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<RewriteSimplifierStatsNode>()
.def_ro("nodes_visited", &RewriteSimplifierStatsNode::nodes_visited)
.def_ro("constraints_entered", &RewriteSimplifierStatsNode::constraints_entered)
.def_ro("rewrites_attempted", &RewriteSimplifierStatsNode::rewrites_attempted)
.def_ro("rewrites_performed", &RewriteSimplifierStatsNode::rewrites_performed)
.def_ro("max_recursive_depth", &RewriteSimplifierStatsNode::max_recursive_depth)
.def_ro("num_recursive_rewrites", &RewriteSimplifierStatsNode::num_recursive_rewrites);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.RewriteSimplifierStats", RewriteSimplifierStatsNode,
Object);
};
struct RewriteSimplifierStats : ObjectRef {
explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) {
data_ = ffi::make_object<RewriteSimplifierStatsNode>(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RewriteSimplifierStats, ObjectRef,
RewriteSimplifierStatsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode);
};
/*!
* \brief Rewrite-based simplifier.
*
* This class can be inheritated for other simplifiers.
*/
class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}
PrimExpr VisitExpr(const PrimExpr& e) override;
void Update(const Var& var, const PrimExpr& info, bool override_info);
PrimExpr VisitExpr_(const AddNode* op) override;
PrimExpr VisitExpr_(const SubNode* op) override;
PrimExpr VisitExpr_(const MulNode* op) override;
PrimExpr VisitExpr_(const DivNode* op) override;
PrimExpr VisitExpr_(const ModNode* op) override;
PrimExpr VisitExpr_(const FloorDivNode* op) override;
PrimExpr VisitExpr_(const FloorModNode* op) override;
PrimExpr VisitExpr_(const MinNode* op) override;
PrimExpr VisitExpr_(const MaxNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
PrimExpr VisitExpr_(const LENode* op) override;
PrimExpr VisitExpr_(const GTNode* op) override;
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const AndNode* op) override;
PrimExpr VisitExpr_(const OrNode* op) override;
PrimExpr VisitExpr_(const NotNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const CastNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
std::function<void()> EnterConstraint(const PrimExpr& constraint);
/*! \brief Enable an optional extension or extensions
*
* \param flags A bitwise OR of all optional extensions that should
* be enabled.
*/
void SetEnabledExtensions(Extension flags);
/*! \brief Return the currently enabled extensions */
Extension GetEnabledExtensions() const;
RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); }
void ResetStatsCounters() { stats_ = {}; }
void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; }
protected:
int64_t maximum_rewrite_steps_{0};
RewriteSimplifierStatsNode stats_;
void RecordAttemptedRewrite() { stats_.rewrites_attempted++; }
void RecordRewrite() {
stats_.rewrites_performed++;
ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_)
<< "RewriteSimplifier exceeded maximum number of rewrites allowed ("
<< maximum_rewrite_steps_ << ")";
}
// counter to record recursive rewrite depth.
int64_t recur_depth_{0};
// internal variable map
std::unordered_map<Var, PrimExpr> var_map_;
std::vector<PrimExpr> literal_constraints_;
// Optionally enabled extensions
Extension enabled_extensions_{kNone};
/*! Whether the simplifier is current
*/
bool recursively_visiting_boolean_{false};
// maximum number of recursion allowed during a single pass.
static const constexpr int64_t kMaxRecurDepth = 5;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
* \param val The constant value.
* \return comparison result.
*/
CompareResult TryCompare(const PrimExpr& x, int64_t val);
/*! Try to compare x against y
*
* \param x The lhs of the comparison
* \param y The rhs of the comparison
* \return comparison result.
*/
CompareResult TryCompare(const PrimExpr& x, const PrimExpr& y);
/*!
* \brief Internal function to check whether or not to inline let.
* \param op The let expr.
* \return The inline decision.
*/
bool CanInlineLet(const LetNode* op);
/*! \brief Internal function to apply constraints
*
* Tests whether the expression is known to be true or false based
* on existing constraints. If the expression or its negation
* matches a constraint, return the boolean it should be replaced
* with. Otherwise, return false.
*/
ffi::Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const;
/*! \brief Rewrite rules for Less Than comparisons
*
* These are separate from the VisitExpr_(const LTNode*) method, as
* they may required from rewrites of LT or LE.
*/
PrimExpr ApplyRewriteRules(LT node);
/*! \brief Rewrite rules for Equal comparisons
*
* These are separate from the VisitExpr_(const EQNode*) method, as
* they may required from rewrites of LE or NE.
*/
PrimExpr ApplyRewriteRules(EQ node);
/*! \brief Rewrite rules for Equal comparisons
*
* These are separate from the VisitExpr_(const EQNode*) method, as
* they may required from rewrites of LT, LE, or NE.
*/
PrimExpr ApplyRewriteRules(Not node);
private:
CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y);
CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y);
CompareResult TryComparisonOfProductAndSum(const PrimExpr& x, const PrimExpr& y);
// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
return analyzer_->CanProveGreaterEqual(x, val);
}
// Whether x < val
bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); }
// Whether x == val
bool CanProveEqual(const PrimExpr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == CompareResult::kEQ;
}
// Whether x is true
bool CanProve(const PrimExpr& x) { return analyzer_->CanProve(x); }
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
PrimExpr RecursiveRewrite(const PrimExpr& x) {
stats_.num_recursive_rewrites++;
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth);
PrimExpr res = this->VisitExpr(x);
--recur_depth_;
return res;
}
template <typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
template <typename TA>
PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 1);
}
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_REWRITE_SIMPLIFY_H_