blob: 258f833a7b21b7b44481eba4edc061bac6cac312 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rewrite_simplify.h
* \brief Rewrite-rule based simplification.
*/
#ifndef TVM_ARITH_REWRITE_SIMPLIFY_H_
#define TVM_ARITH_REWRITE_SIMPLIFY_H_
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <unordered_map>
#include <vector>
#include "const_fold.h"
#include "ir_mutator_with_analyzer.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace tir;
/*!
* \brief Rewrite-based simplifier.
*
* This class can be inheritated for other simplifiers.
*/
class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const PrimExpr& info, bool override_info);
PrimExpr VisitExpr_(const AddNode* op) override;
PrimExpr VisitExpr_(const SubNode* op) override;
PrimExpr VisitExpr_(const MulNode* op) override;
PrimExpr VisitExpr_(const DivNode* op) override;
PrimExpr VisitExpr_(const ModNode* op) override;
PrimExpr VisitExpr_(const FloorDivNode* op) override;
PrimExpr VisitExpr_(const FloorModNode* op) override;
PrimExpr VisitExpr_(const MinNode* op) override;
PrimExpr VisitExpr_(const MaxNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
PrimExpr VisitExpr_(const LENode* op) override;
PrimExpr VisitExpr_(const GTNode* op) override;
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const AndNode* op) override;
PrimExpr VisitExpr_(const OrNode* op) override;
PrimExpr VisitExpr_(const NotNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const CastNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
std::function<void()> EnterConstraint(const PrimExpr& constraint);
protected:
/*! \brief internal structure for comparison. */
enum CompareResult { kUnknown, kEQ, kGT, kGE, kLT, kLE, kNE };
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::vector<PrimExpr> literal_constraints_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
* \param val The constant value.
* \return comparison result.
*/
CompareResult TryCompare(const PrimExpr& x, int64_t val);
/*!
* \brief Internal function to check whether or not to inline let.
* \param op The let expr.
* \return The inline decision.
*/
bool CanInlineLet(const LetNode* op);
private:
// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
return analyzer_->CanProveGreaterEqual(x, val);
}
// Whether x == val
bool CanProveEqual(const PrimExpr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
PrimExpr RecursiveRewrite(const PrimExpr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
PrimExpr res = this->VisitExpr(x);
--recur_depth_;
return res;
}
template <typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}
template <typename TA>
PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 1);
}
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_REWRITE_SIMPLIFY_H_