| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file canonical_simplify.cc |
| * \brief Canonical form based simplification. |
| */ |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/op.h> |
| |
| #include "const_fold.h" |
| #include "pattern_match.h" |
| #include "rewrite_simplify.h" |
| |
| namespace tvm { |
| namespace arith { |
| |
| using namespace tir; |
| |
| class SumExpr; |
| class SplitExpr; |
| |
| /*! |
| * \brief Base class of all temporary expression introduced |
| * for canonicalization. |
| */ |
| class CanonicalExprNode : public PrimExprNode { |
| public: |
| virtual ~CanonicalExprNode() {} |
| /*! |
| * \brief Return the normal Expr that is equivalent to self. |
| * \note Can mutate the internal data structure. |
| * \return The normal expression. |
| */ |
| virtual PrimExpr Normalize() const = 0; |
| |
| // overrides |
| void VisitAttrs(tvm::AttrVisitor* v) {} |
| |
| static constexpr const char* _type_key = "arith.CanonicalExpr"; |
| static constexpr const uint32_t _type_child_slots = 2; |
| TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); |
| }; |
| |
| inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { |
| if (mode == kTruncDiv) { |
| return truncmod(a, b); |
| } else { |
| CHECK_EQ(mode, kFloorDiv); |
| return floormod(a, b); |
| } |
| } |
| |
| inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { |
| if (mode == kTruncDiv) { |
| return truncdiv(a, b); |
| } else { |
| CHECK_EQ(mode, kFloorDiv); |
| return floordiv(a, b); |
| } |
| } |
| |
| /*! |
| * \brief Internal "Split normal form" of expression. |
| * |
| * This is a special expression that represents |
| * a scaled value derived from a split of an index. |
| * |
| * result = ((index % upper_factor) / lower_factor) * scale |
| */ |
| class SplitExprNode : public CanonicalExprNode { |
| public: |
| /*! \brief The base index expression. */ |
| PrimExpr index; |
| /*! \brief The division factor ratio. */ |
| int64_t lower_factor{1}; |
| /*! |
| * \brief The upper factor. |
| * invariance: (upper_factor == kPosInf || upper_factor % lower_factor == 0) |
| */ |
| int64_t upper_factor{kPosInf}; |
| /*! \brief scale to the expression. */ |
| int64_t scale{1}; |
| /*! \brief Division mode. */ |
| DivMode div_mode{kTruncDiv}; |
| |
| /*! \brief verify that this is a valid entry. */ |
| void Verify() const { CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } |
| |
| PrimExpr NormalizeWithScale(int64_t sscale) const { |
| PrimExpr res = this->index; |
| DataType dtype = this->dtype; |
| if (this->scale == 0) { |
| return make_const(dtype, 0); |
| } |
| if (this->upper_factor != SplitExprNode::kPosInf) { |
| res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode); |
| } |
| if (this->lower_factor != 1) { |
| res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode); |
| } |
| sscale *= this->scale; |
| if (sscale != 1) { |
| CHECK(!dtype.is_uint() || sscale > 0); |
| res = res * make_const(dtype, sscale); |
| } |
| return res; |
| } |
| |
| PrimExpr Normalize() const final { return NormalizeWithScale(1); } |
| |
| void MulToSelf(int64_t scale) { this->scale *= scale; } |
| |
| inline bool IndexEqual(const SplitExpr& other) const; |
| inline bool DivModeCompatibleTo(DivMode mode) const; |
| |
| /*! \brief positive infty */ |
| static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; |
| static constexpr const char* _type_key = "arith.SplitExpr"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode); |
| }; |
| |
| class SplitExpr : public PrimExpr { |
| public: |
| TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode); |
| TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); |
| }; |
| |
| inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { |
| if (index.same_as(other->index)) return true; |
| return tir::ExprDeepEqual()(index, other->index); |
| } |
| |
| inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const { |
| if (this->div_mode == mode) return true; |
| if (lower_factor == 1 && upper_factor == kPosInf) return true; |
| return false; |
| } |
| |
| /*! |
| * \brief Normal form that represents sum of expressions. |
| * |
| * result = sum(args) + base. |
| */ |
| class SumExprNode : public CanonicalExprNode { |
| public: |
| /*! |
| * \brief arguments to be summed up. |
| * |
| * args are divided into segments with the same index. |
| * within each segment, the SplitExpr is ordered in descending order of lower_factor. |
| */ |
| std::vector<SplitExpr> args; |
| /*! \brief Base value in the summation. */ |
| int64_t base{0}; |
| /*! \brief The expression equals zero. */ |
| bool IsZero() const { return base == 0 && args.size() == 0; } |
| /*! |
| * \brief Return the normal Expr that is equivalent to self. |
| * \return The normal expression. |
| */ |
| PrimExpr Normalize() const final { |
| // quick path 1. |
| if (this->args.size() == 0) { |
| return make_const(this->dtype, this->base); |
| } |
| return Normalize_(this->dtype, SimplifySplitExprs(args), base); |
| } |
| /*! |
| * \brief Whether self is divisible by scale. |
| * \param scale The scale to be applied. |
| */ |
| bool DivisibleBy(int64_t scale) { |
| if (base % scale != 0) return false; |
| for (size_t i = 0; i < this->args.size(); ++i) { |
| if (args[i]->scale % scale != 0) return false; |
| } |
| return true; |
| } |
| /*! |
| * \brief mul scale to self. |
| * \param scale The scale to be applied. |
| */ |
| void MulToSelf(int64_t scale) { |
| this->base *= scale; |
| for (size_t i = 0; i < this->args.size(); ++i) { |
| args[i].CopyOnWrite()->scale *= scale; |
| } |
| } |
| /*! |
| * \brief divide by scale. |
| * \param scale The scale to be applied. |
| */ |
| void DivideBy(int64_t scale) { |
| CHECK_EQ(this->base % scale, 0); |
| this->base /= scale; |
| for (size_t i = 0; i < this->args.size(); ++i) { |
| CHECK_EQ(args[i]->scale % scale, 0); |
| args[i].CopyOnWrite()->scale /= scale; |
| } |
| } |
| /*! |
| * \brief add constant value to self. |
| * \param value to be added. |
| */ |
| void AddToSelf(int64_t value) { this->base += value; } |
| /*! |
| * \brief self += other * scale; |
| * \param other The expression to be added. |
| * \param scale The additional scale on value. |
| */ |
| void AddToSelf(SplitExpr other, int64_t scale) { |
| if (other->scale == 0) return; |
| // We need to maintain the segment invariance: |
| // Same index are stored close to each other. |
| // sorted from big lower_factor to small one. |
| size_t start = 0; |
| for (; start < args.size(); ++start) { |
| if (args[start]->IndexEqual(other)) break; |
| } |
| for (size_t j = start; j < args.size(); ++j) { |
| if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) { |
| other.CopyOnWrite()->scale *= scale; |
| this->args.insert(this->args.begin() + j, other); |
| return; |
| } |
| if (other->lower_factor == args[j]->lower_factor && |
| other->upper_factor == args[j]->upper_factor && |
| other->DivModeCompatibleTo(args[j]->div_mode)) { |
| args[j].CopyOnWrite()->scale += other->scale * scale; |
| return; |
| } |
| } |
| // Insert other in the end. |
| other.CopyOnWrite()->scale *= scale; |
| this->args.emplace_back(std::move(other)); |
| } |
| |
| void AddToSelf(const SumExpr& other, int64_t scale); |
| |
| static constexpr const char* _type_key = "arith.SumExpr"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); |
| |
| private: |
| /*! |
| * \brief Simplify the args by merging SplitExprs |
| * \param args The original list of arguments. |
| * \return simplified version. |
| */ |
| static std::vector<SplitExpr> SimplifySplitExprs(std::vector<SplitExpr> args) { |
| // NOTE: This algorithm relies on the factor that args are divided into segments |
| // and each segment is sorted in descending order of lower_factor. |
| for (size_t i = 0; i < args.size(); ++i) { |
| if (args[i]->scale == 0) continue; |
| for (size_t j = i + 1; j < args.size(); ++j) { |
| SplitExpr& lhs = args[i]; |
| SplitExpr& rhs = args[j]; |
| if (!lhs->IndexEqual(rhs)) break; |
| if (lhs->upper_factor < rhs->lower_factor) break; |
| if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor && |
| lhs->DivModeCompatibleTo(rhs->div_mode)) { |
| // folding same co-efficient. |
| rhs.CopyOnWrite()->scale += lhs->scale; |
| lhs.CopyOnWrite()->scale = 0; |
| } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 && |
| lhs->scale % rhs->scale == 0 && |
| lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor && |
| lhs->DivModeCompatibleTo(rhs->div_mode)) { |
| // Rules used in the proof: |
| // |
| // Rule 1: (x % (c * s)) / c = (x / c) % s |
| // Proof: |
| // x can always be decomposed into p * c * s + q * c + r |
| // where 0 <= q * c + r < c * s and 0 <= r < c. |
| // Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q |
| // rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q |
| // Thus, lhs = rhs |
| // |
| // The above proof is for the floordiv. |
| // The same rule also holds for truncdiv(division rule in C). |
| // Because both sides only involve mul, div and mod, |
| // we can take abs of x, c and s, apply the floordiv proof, |
| // and finally add the sign back. |
| // |
| // Rule 2: (x / s) * s + x % s = x (true for both trunc and floor div) |
| // |
| // General merge condition and proof: |
| // - x = lhs->index % lhs->upper_factor |
| // - s = lhs->scale / rhs->scale |
| // - c = rhs->lower_factor |
| // |
| // (x / (c * s)) * s + (x % (c * s)) / c |
| // => ((x / c) / s) * s + ((x / c) % s) |
| // => (x / c) |
| // |
| // Examples: |
| // |
| // (z / 6) * 6 + ((z % 6) / 3) * 3 |
| // => ((z / 6) * 2 + (z % 6) / 3) * 3 |
| // => (z / 3) * 3 |
| // note: x = z, c = 3, s = 2 |
| // |
| // ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3 |
| // => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3 |
| // => ((z % 12) / 3) * 3 |
| // note: x = z % 12, c = 3, s = 2 |
| // note also the invariance lhs->upper_factor % lhs->lower_factor == 0 |
| // |
| SplitExprNode* merged = rhs.CopyOnWrite(); |
| merged->upper_factor = lhs->upper_factor; |
| // reset args[i] to be zero. |
| lhs.CopyOnWrite()->scale = 0; |
| break; |
| } |
| } |
| } |
| // sort by the entry |
| // Here we simply sort by descending order of scales. |
| // For now, we do not compare by index because that comparison |
| // can be runtime dependent and create inderminism. |
| // we do not sort by index for now because it can be costly |
| // to deep compare Exprs, and address of Vars can be runtime dependent. |
| // |
| auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) { |
| // order by scale first |
| if (lhs->scale > rhs->scale) return true; |
| if (lhs->scale < rhs->scale) return false; |
| // then order by factor |
| if (lhs->lower_factor > rhs->lower_factor) return true; |
| if (lhs->lower_factor < rhs->lower_factor) return false; |
| // then order by upper factor |
| if (lhs->upper_factor > rhs->upper_factor) return true; |
| if (lhs->upper_factor < rhs->upper_factor) return false; |
| // then order by div mode |
| if (lhs->div_mode > rhs->div_mode) return true; |
| if (lhs->div_mode < rhs->div_mode) return false; |
| // tie. |
| // TODO(tvm-team) We might consider index as the last comparison point, |
| // after we make deep comparator more derministic. |
| // Specifically, we can consider comparing names of vars and break ties with address. |
| return false; |
| }; |
| std::stable_sort(args.begin(), args.end(), fcompare); |
| return args; |
| } |
| static PrimExpr Normalize_(DataType dtype, const std::vector<SplitExpr>& args, int64_t base) { |
| // Positive scales first |
| PrimExpr res = make_const(dtype, 0); |
| for (size_t i = 0; i < args.size(); ++i) { |
| if (args[i]->scale > 0) { |
| res = res + args[i]->Normalize(); |
| } |
| } |
| if (base > 0) { |
| res = res + make_const(dtype, base); |
| } |
| // negative scales follows using sub. |
| for (size_t i = 0; i < args.size(); ++i) { |
| if (args[i]->scale < 0) { |
| res = res - args[i]->NormalizeWithScale(-1); |
| } |
| } |
| if (base < 0) { |
| res = res - make_const(dtype, -base); |
| } |
| return res; |
| } |
| }; |
| |
| class SumExpr : public PrimExpr { |
| public: |
| TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode); |
| TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); |
| }; |
| |
| void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { |
| // NOTE: it is rare to have a balanced long expression, |
| // linear scan is fine for our case. |
| for (size_t i = 0; i < other->args.size(); ++i) { |
| this->AddToSelf(other->args[i], scale); |
| } |
| this->AddToSelf(other->base * scale); |
| } |
| |
| // Sub-class RewriteSimplifier::Impl to take benefit of |
| // rewriter for condition simplification etc. |
| class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { |
| public: |
| using Rewriter = RewriteSimplifier::Impl; |
| |
| explicit Impl(Analyzer* parent) : Rewriter(parent) {} |
| |
| PrimExpr CanonicalSimplify(PrimExpr expr) { |
| expr = operator()(expr); |
| return expr; |
| } |
| |
| // override the original mutate function. |
| PrimExpr VisitExpr(const PrimExpr& input_expr) final { |
| auto expr = Rewriter::VisitExpr(input_expr); |
| return Normalize(expr); |
| } |
| |
| // Normal mutation without normalization. |
| PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); } |
| |
| using Rewriter::VisitExpr_; |
| PrimExpr VisitExpr_(const AddNode* op) final; |
| PrimExpr VisitExpr_(const SubNode* op) final; |
| PrimExpr VisitExpr_(const MulNode* op) final; |
| PrimExpr VisitExpr_(const DivNode* op) final; |
| PrimExpr VisitExpr_(const ModNode* op) final; |
| PrimExpr VisitExpr_(const FloorDivNode* op) final; |
| PrimExpr VisitExpr_(const FloorModNode* op) final; |
| PrimExpr VisitExpr_(const ReduceNode* op) final; |
| |
| private: |
| /*! |
| * \brief compute lhs / cval |
| * \param lhs The left operand. |
| * \param cval The constant value. |
| * \param div_mode The division mode. |
| * \return The result expression; |
| */ |
| SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode); |
| /*! |
| * \brief compute lhs % cval |
| * \param lhs The left operand. |
| * \param cval The constant value. |
| * \param div_mode The division mode. |
| * \return The result expression; |
| */ |
| SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode); |
| /*! |
| * \brief Separate psum into divisible and non-divisible parts. |
| * \param psum The sum expression. |
| * \param coeff The co-efficient. |
| * \param out_divisible The result divisible component. |
| * \param out_non_divisible The non-divisible component. |
| */ |
| void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, |
| SumExpr* out_non_divisible); |
| /*! |
| * \brief Normalize expr to normal expr. |
| * \param expr The input expression. |
| * \return Normalized expr. |
| */ |
| PrimExpr Normalize(PrimExpr expr) { |
| if (const auto* op = expr.as<CanonicalExprNode>()) { |
| return op->Normalize(); |
| } else { |
| return expr; |
| } |
| } |
| /*! |
| * \brief Create a SplitExpr from expr. |
| * \param expr The input expr. |
| * \return The transformed SplitExpr. |
| */ |
| SplitExpr ToSplitExpr(PrimExpr expr) { |
| if (const auto* op = expr.as<SplitExprNode>()) { |
| return GetRef<SplitExpr>(op); |
| } |
| if (const auto* op = expr.as<SumExprNode>()) { |
| if (op->base == 0 && op->args.size() == 1) return op->args[0]; |
| } |
| if (const auto* op = expr.as<CanonicalExprNode>()) { |
| expr = op->Normalize(); |
| } |
| ObjectPtr<SplitExprNode> n = make_object<SplitExprNode>(); |
| n->dtype = expr.dtype(); |
| n->index = std::move(expr); |
| n->div_mode = kTruncDiv; |
| return SplitExpr(n); |
| } |
| /*! |
| * \brief Convert expr to an equivalent SplitExpr |
| * that has the specified div_mode. |
| * |
| * This function will return the same expr if its |
| * div_mode already satisfies the need. |
| * |
| * \param expr The input expr. |
| * \param div_mode The new div_mode. |
| * \return The transformed SplitExpr. |
| */ |
| SplitExpr ConvertDivMode(SplitExpr expr, DivMode div_mode) { |
| if (expr->div_mode == div_mode) return expr; |
| if (expr->DivModeCompatibleTo(div_mode)) { |
| expr.CopyOnWrite()->div_mode = div_mode; |
| return expr; |
| } |
| expr = ToSplitExpr(Normalize(expr)); |
| CHECK(expr->DivModeCompatibleTo(div_mode)); |
| expr.CopyOnWrite()->div_mode = div_mode; |
| return expr; |
| } |
| /*! |
| * \brief Create a SumExpr from expr. |
| * \param expr The input expr. |
| * \return The transformed SumExpr. |
| */ |
| SumExpr ToSumExpr(PrimExpr expr) { |
| if (const auto* op = expr.as<SumExprNode>()) { |
| return GetRef<SumExpr>(op); |
| } |
| ObjectPtr<SumExprNode> n = make_object<SumExprNode>(); |
| n->dtype = expr.dtype(); |
| if (const auto* op = expr.as<IntImmNode>()) { |
| n->base = op->value; |
| return SumExpr(n); |
| } else { |
| n->args.emplace_back(ToSplitExpr(expr)); |
| return SumExpr(n); |
| } |
| } |
| // Simplify the combiner used in reduce. |
| PrimExpr SimplifyReduceCombiner(const ReduceNode* op); |
| }; |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| // normalize |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<Add>(a, b); |
| if (const_res.defined()) return const_res; |
| |
| // canonical form simplification. |
| SumExpr ret = ToSumExpr(std::move(a)); |
| |
| if (const auto* op = b.as<IntImmNode>()) { |
| ret.CopyOnWrite()->AddToSelf(op->value); |
| } else if (const auto* op = b.as<SumExprNode>()) { |
| ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1); |
| } else { |
| ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1); |
| } |
| return std::move(ret); |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| // normalize |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<Sub>(a, b); |
| if (const_res.defined()) return const_res; |
| |
| // canonical form simplification. |
| SumExpr ret = ToSumExpr(std::move(a)); |
| |
| if (const auto* op = b.as<IntImmNode>()) { |
| ret.CopyOnWrite()->AddToSelf(-op->value); |
| } else if (const auto* op = b.as<SumExprNode>()) { |
| ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1); |
| } else { |
| ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1); |
| } |
| return std::move(ret); |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| // normalize |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<Mul>(a, b); |
| if (const_res.defined()) return const_res; |
| |
| // x * c |
| if (a.as<IntImmNode>()) { |
| std::swap(a, b); |
| } |
| if (const auto* bconst = b.as<IntImmNode>()) { |
| if (a.as<SumExprNode>()) { |
| SumExpr ret = Downcast<SumExpr>(std::move(a)); |
| ret.CopyOnWrite()->MulToSelf(bconst->value); |
| return std::move(ret); |
| } else { |
| SplitExpr ret = ToSplitExpr(std::move(a)); |
| ret.CopyOnWrite()->MulToSelf(bconst->value); |
| return std::move(ret); |
| } |
| } |
| |
| // normal path. |
| a = Normalize(a); |
| b = Normalize(b); |
| if (op->a.same_as(a) && op->b.same_as(b)) { |
| return GetRef<PrimExpr>(op); |
| } else { |
| return Mul(a, b); |
| } |
| } |
| |
| void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, |
| SumExpr* out_divisible, |
| SumExpr* out_non_divisible) { |
| auto divisible = make_object<SumExprNode>(); |
| auto non_divisible = make_object<SumExprNode>(); |
| divisible->dtype = psum->dtype; |
| non_divisible->dtype = psum->dtype; |
| |
| if (psum->base % coeff == 0) { |
| divisible->base = psum->base; |
| } else { |
| non_divisible->base = psum->base; |
| } |
| for (const auto& e : psum->args) { |
| if (e->scale % coeff == 0) { |
| divisible->args.push_back(e); |
| } else { |
| non_divisible->args.push_back(e); |
| } |
| } |
| *out_divisible = SumExpr(divisible); |
| *out_non_divisible = SumExpr(non_divisible); |
| } |
| |
| SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { |
| CHECK_GT(cval, 0); |
| lhs = ConvertDivMode(lhs, div_mode); |
| |
| // the following rule works for both floordiv and truncdiv |
| if (lhs->scale % cval == 0) { |
| lhs.CopyOnWrite()->scale /= cval; |
| return lhs; |
| } |
| |
| if (cval % lhs->scale == 0) { |
| int64_t scaled_cval = cval / lhs->scale; |
| if (lhs->upper_factor == SplitExprNode::kPosInf || |
| lhs->upper_factor % (lhs->lower_factor * scaled_cval) == 0) { |
| // directly fold division. |
| lhs.CopyOnWrite()->scale = 1; |
| lhs.CopyOnWrite()->lower_factor *= scaled_cval; |
| lhs->Verify(); |
| return lhs; |
| } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { |
| // (x % c1) / c2 => 0 when c2 >= c1 |
| return ToSplitExpr(make_zero(lhs.dtype())); |
| } else { |
| // move the upper_factor modular into index. |
| lhs.CopyOnWrite()->index = |
| ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode); |
| lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; |
| lhs.CopyOnWrite()->scale = 1; |
| lhs.CopyOnWrite()->lower_factor *= scaled_cval; |
| lhs->Verify(); |
| return lhs; |
| } |
| } |
| // directly return the split with cval == 1 |
| lhs = ToSplitExpr(Normalize(lhs)); |
| CHECK(lhs->DivModeCompatibleTo(div_mode)); |
| CHECK_EQ(lhs->scale, 1); |
| lhs.CopyOnWrite()->lower_factor *= cval; |
| lhs.CopyOnWrite()->div_mode = div_mode; |
| return lhs; |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<Div>(a, b); |
| if (const_res.defined()) return const_res; |
| PVar<IntImm> c1; |
| // x / c1 |
| if (c1.Match(b) && c1.Eval()->value > 0) { |
| int64_t cval = c1.Eval()->value; |
| if (cval == 1) return a; |
| |
| if (const auto* psum = a.as<SumExprNode>()) { |
| SumExpr lhs, extra; |
| SeparateDivisibleParts(psum, cval, &lhs, &extra); |
| // can be divided by cval |
| if (extra->IsZero()) { |
| lhs.CopyOnWrite()->DivideBy(cval); |
| return std::move(lhs); |
| } |
| // both lhs and extra are non-negative |
| if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && |
| analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { |
| lhs.CopyOnWrite()->DivideBy(cval); |
| PrimExpr temp = Normalize(extra); |
| if (const auto* pconst = temp.as<IntImmNode>()) { |
| lhs.CopyOnWrite()->AddToSelf(pconst->value / cval); |
| } else { |
| // if 0 <= extra < cval, it means the extra can be eliminated. |
| if (TryCompare(temp, cval) != kLT) { |
| lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); |
| } |
| } |
| return std::move(lhs); |
| } |
| } else { |
| // if a >= 0 && a < cval, then result == 0 |
| auto cbound = analyzer_->const_int_bound(Normalize(a)); |
| if (cbound->min_value >= 0 && cbound->max_value < cval) { |
| return make_zero(a.dtype()); |
| } |
| } |
| return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); |
| } |
| // normal path |
| a = Normalize(a); |
| b = Normalize(b); |
| if (op->a.same_as(a) && op->b.same_as(b)) { |
| return GetRef<PrimExpr>(op); |
| } else { |
| return Div(a, b); |
| } |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<FloorDiv>(a, b); |
| if (const_res.defined()) return const_res; |
| PVar<IntImm> c1; |
| // x / c1 |
| if (c1.Match(b) && c1.Eval()->value > 0) { |
| int64_t cval = c1.Eval()->value; |
| if (cval == 1) return a; |
| |
| if (const auto* psum = a.as<SumExprNode>()) { |
| SumExpr lhs, extra; |
| SeparateDivisibleParts(psum, cval, &lhs, &extra); |
| if (extra->IsZero()) { |
| lhs.CopyOnWrite()->DivideBy(cval); |
| return std::move(lhs); |
| } |
| // continue simplification. |
| lhs.CopyOnWrite()->DivideBy(cval); |
| PrimExpr temp = Normalize(extra); |
| if (const auto* pconst = temp.as<IntImmNode>()) { |
| lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval)); |
| } else { |
| // if 0 <= extra < cval, it means the extra can be eliminated. |
| if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) { |
| lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); |
| } |
| } |
| return std::move(lhs); |
| } else { |
| // if a >= 0 && a < cval, then result == 0 |
| auto cbound = analyzer_->const_int_bound(Normalize(a)); |
| if (cbound->min_value >= 0 && cbound->max_value < cval) { |
| return make_zero(a.dtype()); |
| } |
| } |
| return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv); |
| } |
| // normal path |
| a = Normalize(a); |
| b = Normalize(b); |
| if (op->a.same_as(a) && op->b.same_as(b)) { |
| return GetRef<PrimExpr>(op); |
| } else { |
| return FloorDiv(a, b); |
| } |
| } |
| |
| SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { |
| CHECK_GT(cval, 0); |
| lhs = ConvertDivMode(lhs, div_mode); |
| |
| if (lhs->scale % cval == 0) { |
| lhs.CopyOnWrite()->scale = 0; |
| return lhs; |
| } |
| if (cval % lhs->scale == 0) { |
| // (x * c1) % (c2 * c1) => (x % c2) * c1 |
| int64_t scaled_cval = cval / lhs->scale; |
| // (x / c1) % c2 => (x % (c1 * c2)) / c2 |
| int64_t new_upper_factor = lhs->lower_factor * scaled_cval; |
| // try to see if we can reduce the existing upper modular. |
| if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) { |
| // we gained a new upper factor that is smaller |
| // than the original one |
| // Perhaps there are more chances in simplifying the index |
| // Do a recursive call to simplify the mod with the new factor. |
| if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { |
| auto updated = ToSplitExpr(this->VisitExpr( |
| ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); |
| updated.CopyOnWrite()->scale = lhs->scale; |
| // re-apply the lower_factor |
| if (lhs->lower_factor != 1) { |
| return SplitDivConst(updated, lhs->lower_factor, div_mode); |
| } else { |
| return updated; |
| } |
| } else { |
| lhs.CopyOnWrite()->upper_factor = new_upper_factor; |
| return lhs; |
| } |
| } else if (new_upper_factor % lhs->upper_factor == 0) { |
| // (x % 2) % 4 => x % 2 |
| return lhs; |
| } |
| } |
| // Normalize the value. |
| lhs = ToSplitExpr(Normalize(lhs)); |
| CHECK(lhs->DivModeCompatibleTo(div_mode)); |
| CHECK_EQ(lhs->scale, 1); |
| CHECK_EQ(lhs->lower_factor, 1); |
| lhs.CopyOnWrite()->div_mode = div_mode; |
| lhs.CopyOnWrite()->upper_factor = cval; |
| return lhs; |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| // normalize |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<Mod>(a, b); |
| if (const_res.defined()) return const_res; |
| |
| PVar<IntImm> c1; |
| // x % c1 |
| if (c1.Match(b) && c1.Eval()->value > 0) { |
| int64_t cval = c1.Eval()->value; |
| if (const auto* psum = a.as<SumExprNode>()) { |
| SumExpr lhs, extra; |
| SeparateDivisibleParts(psum, cval, &lhs, &extra); |
| if (extra->IsZero()) { |
| return make_zero(a.dtype()); |
| } |
| // both lhs and extra are non-negative |
| if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && |
| analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { |
| PrimExpr temp = Normalize(extra); |
| if (temp.as<IntImmNode>()) { |
| return truncmod(temp, c1.Eval()); |
| } else { |
| // If temp < cval && temp >=0 then can remove the mod. |
| if (TryCompare(temp, cval) == kLT) { |
| return temp; |
| } else { |
| // contonue to use logic below. |
| a = extra; |
| psum = a.as<SumExprNode>(); |
| CHECK(psum != nullptr); |
| } |
| } |
| } |
| // Simplify the offset constant if necessary. |
| // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 |
| auto cbound = analyzer_->const_int_bound(Normalize(a)); |
| int64_t new_base = psum->base % cval; |
| if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { |
| SumExpr sum_expr = Downcast<SumExpr>(a); |
| sum_expr.CopyOnWrite()->base = new_base; |
| return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); |
| } |
| } else { |
| // if a >= 0 && a < cval, then result == 0 |
| auto cbound = analyzer_->const_int_bound(Normalize(a)); |
| if (cbound->min_value >= 0 && cbound->max_value < cval) { |
| return a; |
| } |
| } |
| return SplitModConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); |
| } |
| // normal path |
| a = Normalize(a); |
| b = Normalize(b); |
| if (op->a.same_as(a) && op->b.same_as(b)) { |
| return GetRef<PrimExpr>(op); |
| } else { |
| return Mod(a, b); |
| } |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { |
| if (!IsIndexType(op->dtype)) { |
| return Rewriter::VisitExpr_(op); |
| } |
| // normalize |
| PrimExpr a = this->CanonicalMutate(op->a); |
| PrimExpr b = this->CanonicalMutate(op->b); |
| |
| // const folding |
| PrimExpr const_res = TryConstFold<FloorMod>(a, b); |
| if (const_res.defined()) return const_res; |
| |
| PVar<IntImm> c1; |
| // x % c1 |
| if (c1.Match(b) && c1.Eval()->value > 0) { |
| int64_t cval = c1.Eval()->value; |
| if (const auto* psum = a.as<SumExprNode>()) { |
| SumExpr lhs, extra; |
| SeparateDivisibleParts(psum, cval, &lhs, &extra); |
| PrimExpr temp = Normalize(extra); |
| if (temp.as<IntImmNode>()) { |
| return floormod(temp, c1.Eval()); |
| } else { |
| // If temp < cval && temp >=0 then can remove the mod. |
| if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) { |
| return temp; |
| } else { |
| // contonue to use logic below. |
| a = extra; |
| psum = a.as<SumExprNode>(); |
| CHECK(psum != nullptr); |
| } |
| } |
| // Simplify the offset constant if necessary. |
| // floormod(x - 5, 3) => floormod(x + 1, 3) |
| int64_t new_base = floormod(psum->base, cval); |
| SumExpr sum_expr = Downcast<SumExpr>(std::move(a)); |
| sum_expr.CopyOnWrite()->base = new_base; |
| return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); |
| } else { |
| // if a >= 0 && a < cval, then result == a |
| auto cbound = analyzer_->const_int_bound(Normalize(a)); |
| if (cbound->min_value >= 0 && cbound->max_value < cval) { |
| return a; |
| } |
| } |
| return SplitModConst(ToSplitExpr(std::move(a)), cval, kFloorDiv); |
| } |
| // normal path |
| a = Normalize(a); |
| b = Normalize(b); |
| if (op->a.same_as(a) && op->b.same_as(b)) { |
| return GetRef<PrimExpr>(op); |
| } else { |
| return FloorMod(a, b); |
| } |
| } |
| |
| // Simplify reduce expression. |
| PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { |
| // First simplify the results |
| Array<PrimExpr> simplified_result; |
| for (const auto& res : op->combiner->result) { |
| PrimExpr new_res = this->VisitExpr(res); |
| simplified_result.push_back(new_res); |
| } |
| |
| // Which components to keep |
| std::vector<int> used(op->combiner->result.size(), false); |
| |
| // This function recursively marks the used components starting from |
| // the index idx |
| std::function<void(int)> mark_used; |
| mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) { |
| // if the idx-th component was marked as used before, do nothing |
| if (used[idx]) return; |
| used[idx] = true; |
| |
| // check if the idx-th result expr uses some lhs or rhs variables |
| // and recursively mark the corresponding components |
| for (size_t i = 0; i < simplified_result.size(); ++i) |
| if (!used[i]) { |
| if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || |
| ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) |
| mark_used(i); |
| } |
| }; |
| |
| // mark all used components starting from the value_index |
| mark_used(op->value_index); |
| |
| // components which have side effects should also be preserved |
| for (size_t i = 0; i < used.size(); ++i) { |
| if (SideEffect(op->source[i]) > CallEffectKind::kReadState || |
| SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState || |
| SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState || |
| (!op->init.empty() && SideEffect(op->init[i]) > CallEffectKind::kReadState)) { |
| mark_used(i); |
| } |
| } |
| |
| int new_value_index = op->value_index; |
| Array<PrimExpr> new_result; |
| Array<PrimExpr> new_identity; |
| Array<Var> new_lhs; |
| Array<Var> new_rhs; |
| Array<PrimExpr> new_source; |
| Array<PrimExpr> new_init; |
| |
| // new stuff is old stuff which is used |
| for (size_t i = 0; i < used.size(); ++i) { |
| if (used[i]) { |
| // We simplify the result and identity, but not the source |
| new_result.push_back(simplified_result[i]); |
| new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i])); |
| new_lhs.push_back(op->combiner->lhs[i]); |
| new_rhs.push_back(op->combiner->rhs[i]); |
| new_source.push_back(op->source[i]); |
| if (!op->init.empty()) new_init.push_back(op->init[i]); |
| } else if (static_cast<int>(i) < op->value_index) { |
| // value_index should also be adjusted |
| new_value_index--; |
| } |
| } |
| |
| CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity); |
| return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index, new_init); |
| } |
| |
| PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { |
| // Recursively call simplification when necessary. |
| PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op); |
| op = ret.as<ReduceNode>(); |
| // already been simplified by const reduction axis removal |
| if (op == nullptr) return ret; |
| if (op->axis.empty()) { |
| if (!op->init.empty()) { |
| return this->VisitExpr(Select(op->condition, |
| (*op->combiner.get())(op->init, op->source)[op->value_index], |
| op->init[op->value_index])); |
| } |
| // Note that here we assume that the identity element is indeed identity. Without this |
| // assumption we would have to perform a single iteration of the loop, i.e. use |
| // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` |
| // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. |
| return this->VisitExpr(Select(op->condition, op->source[op->value_index], |
| op->combiner->identity_element[op->value_index])); |
| } |
| // combiner simplification. |
| ret = SimplifyReduceCombiner(op); |
| return ret; |
| } |
| |
| PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { |
| return impl_->CanonicalSimplify(expr); |
| } |
| |
| void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { |
| impl_->Update(var, info, override); |
| } |
| |
| CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} |
| |
| CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } |
| |
| } // namespace arith |
| } // namespace tvm |