| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/arith/const_int_bound.cc |
| */ |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr_functor.h> |
| |
| #include <algorithm> |
| #include <optional> |
| |
| #include "constraint_extract.h" |
| #include "int_operator.h" |
| #include "pattern_match.h" |
| #include "scalable_expression.h" |
| |
| namespace tvm { |
| namespace arith { |
| |
| using namespace tir; |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { ConstIntBoundNode::RegisterReflection(); } |
| |
| ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { |
| auto node = ffi::make_object<ConstIntBoundNode>(); |
| node->min_value = min_value; |
| node->max_value = max_value; |
| data_ = std::move(node); |
| } |
| |
| ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { |
| return ConstIntBound(min_value, max_value); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("arith.ConstIntBound", MakeConstIntBound); |
| } |
| |
| inline void PrintBoundValue(std::ostream& os, int64_t val) { |
| if (val == ConstIntBound::kPosInf) { |
| os << "pos_inf"; |
| } else if (val == ConstIntBound::kNegInf) { |
| os << "neg_inf"; |
| } else { |
| os << val; |
| } |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const ConstIntBoundNode*>(node.get()); |
| p->stream << "ConstIntBound["; |
| PrintBoundValue(p->stream, op->min_value); |
| p->stream << ','; |
| PrintBoundValue(p->stream, op->max_value); |
| p->stream << ']'; |
| }); |
| |
| // internal entry for const int bound |
| struct ConstIntBoundAnalyzer::Entry { |
| int64_t min_value; |
| int64_t max_value; |
| |
| bool is_const(int64_t value) const { return min_value == max_value && min_value == value; } |
| |
| bool operator==(const Entry& other) const { |
| return min_value == other.min_value && max_value == other.max_value; |
| } |
| |
| friend std::ostream& operator<<(std::ostream& os, const Entry& entry) { |
| os << "Entry["; |
| PrintBoundValue(os, entry.min_value); |
| os << ", "; |
| PrintBoundValue(os, entry.max_value); |
| os << "]"; |
| |
| return os; |
| } |
| }; |
| |
| class ConstIntBoundAnalyzer::Impl |
| : public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> { |
| public: |
| /*! \brief additional bound info about expr in bound */ |
| struct BoundInfo { |
| /*! \brief The expr */ |
| PrimExpr expr; |
| /*! \brief The additional bound */ |
| Entry bound; |
| |
| BoundInfo() {} |
| BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} |
| }; |
| |
| bool IsBound(const Var& var) const { return var_map_.find(var) != var_map_.end(); } |
| |
| void Bind(const Var& var, const Range& range, bool allow_override) { |
| Entry a = VisitExpr(range->min); |
| Entry b = VisitExpr(range->extent); |
| Entry ret; |
| ret.min_value = a.min_value; |
| ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); |
| Update(var, ret, allow_override); |
| } |
| |
| void Update(const Var& var, const Entry& info, bool allow_override) { |
| if (!allow_override) { |
| auto it = var_map_.find(var); |
| if (it != var_map_.end()) { |
| ICHECK(it->second == info) |
| << "Trying to update var \'" << var << "\'" |
| << " with a different const bound: " |
| << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) |
| << ", new=" << ConstIntBound(info.min_value, info.max_value); |
| } |
| } |
| var_map_[var] = info; |
| } |
| |
| Entry VisitExpr_(const LetNode* op) final { |
| auto it = var_map_.find(op->var); |
| // if the var has not been binded, update the info. |
| if (it == var_map_.end()) { |
| var_map_[op->var] = this->VisitExpr(op->value); |
| Entry ret = VisitExpr(op->body); |
| var_map_.erase(op->var); |
| return ret; |
| } else { |
| return VisitExpr(op->body); |
| } |
| } |
| |
| void Update(const Var& var, const ConstIntBound& info, bool allow_override) { |
| Update(var, MakeBound(info->min_value, info->max_value), allow_override); |
| } |
| |
| // Override visitor behaviors |
| Entry VisitExprDefault_(const Object* op) final { |
| return Everything(static_cast<const PrimExprNode*>(op)->dtype); |
| } |
| |
| Entry VisitExpr(const PrimExpr& expr) final { |
| Entry res = ExprFunctor::VisitExpr(expr); |
| tir::ExprDeepEqual equal; |
| // a linear search over additional info |
| // assume we won't have a lot of conditions |
| for (const BoundInfo& info : additional_info_) { |
| if (equal(expr, info.expr)) { |
| res = Intersect(res, info.bound); |
| } |
| } |
| if (bound_) { |
| auto val = bound_->find(expr); |
| if (val != bound_->end()) { |
| auto everything = Everything(expr->dtype); |
| ICHECK( |
| (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || |
| (val->second->min_value == everything.min_value && |
| val->second->max_value == everything.max_value)) |
| << "Detected bound for " << expr << "conflicts with memorization"; |
| } |
| (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value); |
| } |
| return res; |
| } |
| |
| Entry VisitExpr_(const RampNode* op) final { |
| // op = {base + i * stride | 0 <= i < lanes} |
| // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes) |
| // Note that `base + i * stride` is linear w.r.t. `i` |
| // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1) |
| Entry a = VisitExpr(op->base); |
| Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride); |
| return Union(a, b); |
| } |
| |
| Entry VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } |
| |
| Entry VisitExpr_(const CastNode* op) final { |
| Entry a; |
| |
| // int(ceil(log2(cast(n,"float64")))) is used as the |
| // implementation of topi.math.ceil_log2, and appears in iteration |
| // bounds. |
| if (auto opt = FindCeilLog2Arg(op)) { |
| a = CeilLog2Bounds(opt.value()); |
| } else { |
| a = VisitExpr(op->value); |
| } |
| |
| Entry b = Everything(op->dtype); |
| return Intersect(a, b); |
| } |
| |
| /*! |
| * \brief Process the divisor by making assumption that divide by zero |
| * won't happen in a valid program. |
| * |
| * This is important for us to get a lot of symbolic shape bound right |
| * now that the shape n >= 0, but in cases |
| * when mod or divide of n occur, the intention is actually n > 0 |
| * |
| * \param divisor The input divsor entry |
| * \return The processed entry |
| */ |
| Entry AssumeNoZeroDivisor(Entry divisor) { |
| ICHECK(!divisor.is_const(0)) << "Find divide by zero"; |
| // NOTE: here we make the assumption that |
| // divide by zero won't happen in a valid program |
| // this is important for us to get a lot of symbolic shape bound right |
| // where most conditions know that the shape n >= 0, but in cases |
| // when mod or divide of n occur, the intention is actually n > 0 |
| if (divisor.min_value == 0) { |
| divisor.min_value = 1; |
| ICHECK_GE(divisor.max_value, 1); |
| } |
| return divisor; |
| } |
| |
| Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } |
| |
| Entry VisitExpr_(const AddNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = VisitExpr(op->b); |
| Entry ret; |
| ret.min_value = InfAwareAdd(a.min_value, b.min_value); |
| ret.max_value = InfAwareAdd(a.max_value, b.max_value); |
| |
| return ret; |
| } |
| |
| Entry VisitExpr_(const SubNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = VisitExpr(op->b); |
| Entry ret; |
| ret.min_value = InfAwareAdd(a.min_value, -b.max_value); |
| ret.max_value = InfAwareAdd(a.max_value, -b.min_value); |
| |
| return ret; |
| } |
| |
| Entry VisitExpr_(const MulNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = VisitExpr(op->b); |
| return BinaryOpBoundary(a, b, InfAwareMul); |
| } |
| |
| Entry VisitExpr_(const DivNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); |
| return HandleDivision(a, b, op->dtype, InfAwareDiv); |
| } |
| |
| Entry VisitExpr_(const ModNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); |
| |
| if (b.min_value > 0) { |
| int64_t b_max_cap = InfAwareAdd(b.max_value, -1); |
| if (a.min_value >= 0) { |
| // 0 <= [a_min, a_max] < b_min |
| if (a.max_value < b.min_value) return a; |
| // other case, we can get close to 0 |
| return MakeBound(0, std::min(a.max_value, b_max_cap)); |
| } else { |
| return MakeBound(std::max(a.min_value, -b_max_cap), |
| std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); |
| } |
| } else { |
| ICHECK(!b.is_const(0)) << "mod by zero"; |
| // mod by negative value is rare, |
| // and we just use the simpliest rule. |
| return Everything(op->dtype); |
| } |
| } |
| |
| Entry VisitExpr_(const FloorDivNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); |
| return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); |
| } |
| |
| Entry VisitExpr_(const FloorModNode* op) final { |
| /* let a / b = x + y, where x is integer, y \in [0, 1) |
| * floormod(a, b) = a - floordiv(a, b) * b |
| * floordiv(a, b) = x |
| * floormod(a, b) = a - floordiv(a, b) * b |
| * = a - x * b |
| * = a - (a / b - y) * b |
| * = a - a + y * b |
| * = y * b |
| * note that 0 <= y < 1 |
| * when b > 0, 0 <= b * y < b |
| * 0 <= b * y <= b - 1 |
| * when b < 0, b < b * y <= 0 |
| * b + 1 <= b * y <= 0 |
| * In all cases, min(0, b + 1) <= b * y <= max(0, b - 1) |
| * min(0, b_min + 1) <= b * y <= max(0, b_max - 1) |
| * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1) |
| */ |
| Entry a = VisitExpr(op->a); |
| Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); |
| |
| if (b.min_value > 0) { |
| int64_t b_max_cap = InfAwareAdd(b.max_value, -1); |
| if (a.min_value >= 0) { |
| // 0 <= [a_min, a_max] < b_min |
| if (a.max_value < b.min_value) return a; |
| // other case, we can get close to 0 |
| return MakeBound(0, std::min(a.max_value, b_max_cap)); |
| } else { |
| return MakeBound(0, b_max_cap); |
| } |
| } else { |
| ICHECK(!b.is_const(0)) << "floormod by zero"; |
| int64_t b_min_cap = InfAwareAdd(b.min_value, 1); |
| int64_t b_max_cap = InfAwareAdd(b.max_value, -1); |
| return Intersect(MakeBound(std::min(static_cast<int64_t>(0), b_min_cap), |
| std::max(static_cast<int64_t>(0), b_max_cap)), |
| Everything(op->dtype)); |
| } |
| } |
| |
| Entry VisitExpr_(const MinNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = VisitExpr(op->b); |
| Entry ret; |
| ret.min_value = std::min(a.min_value, b.min_value); |
| ret.max_value = std::min(a.max_value, b.max_value); |
| return ret; |
| } |
| |
| Entry VisitExpr_(const MaxNode* op) final { |
| Entry a = VisitExpr(op->a); |
| Entry b = VisitExpr(op->b); |
| Entry ret; |
| ret.min_value = std::max(a.min_value, b.min_value); |
| ret.max_value = std::max(a.max_value, b.max_value); |
| return ret; |
| } |
| |
| Entry VisitExpr_(const SelectNode* op) final { |
| Entry a = VisitExpr(op->true_value); |
| Entry b = VisitExpr(op->false_value); |
| return Union(a, b); |
| } |
| |
| Entry VisitExpr_(const CallNode* op) final { |
| // only special handle >> and & which can be |
| // used for index calculation. |
| |
| auto curr_target = Target::Current(); |
| if (op->op.same_as(tir::builtin::shift_right())) { |
| return VisitRightShift(op); |
| } else if (op->op.same_as(tir::builtin::shift_left())) { |
| return VisitLeftShift(op); |
| } else if (op->op.same_as(tir::builtin::bitwise_and())) { |
| return VisitBitwiseAnd(op); |
| } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) { |
| auto kVScaleValues = GetVScaleValues(curr_target); |
| unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); |
| return MakeBound(1, max_val); |
| } else { |
| return Everything(op->dtype); |
| } |
| } |
| |
| Entry VisitExpr_(const VarNode* op) final { |
| Var v = ffi::GetRef<Var>(op); |
| auto it = var_map_.find(v); |
| if (it != var_map_.end()) { |
| return it->second; |
| } else { |
| return Everything(op->dtype); |
| } |
| } |
| |
| Entry VisitExpr_(const SizeVarNode* op) final { |
| SizeVar v = ffi::GetRef<SizeVar>(op); |
| auto it = var_map_.find(v); |
| if (it != var_map_.end()) { |
| return it->second; |
| } else { |
| return MakeBound(0, kPosInf); |
| } |
| } |
| |
| Entry VisitLeftShift(const CallNode* op) { |
| Entry a = VisitExpr(op->args[0]); |
| Entry b = VisitExpr(op->args[1]); |
| |
| if (a.min_value < 0 || b.min_value < 0) { |
| // If either operand can negative, we may run into undefined |
| // behavior for some targets. In these cases, avoid making any |
| // assumptions about the result. |
| return Everything(op->dtype); |
| } |
| |
| return BinaryOpBoundary(a, b, InfAwareLeftShift); |
| } |
| |
| Entry VisitRightShift(const CallNode* op) { |
| Entry a = VisitExpr(op->args[0]); |
| Entry b = VisitExpr(op->args[1]); |
| return BinaryOpBoundary(a, b, InfAwareRightShift); |
| } |
| |
| Entry VisitBitwiseAnd(const CallNode* op) { |
| Entry a = VisitExpr(op->args[0]); |
| Entry b = VisitExpr(op->args[1]); |
| // handle positive index case. |
| if (a.min_value >= 0 && b.min_value >= 0) { |
| return MakeBound(0, std::min(a.max_value, b.max_value)); |
| } else { |
| if (b.min_value >= 0) { |
| return MakeBound(0, b.max_value); |
| } |
| if (a.min_value >= 0) { |
| return MakeBound(0, a.max_value); |
| } |
| return Everything(op->dtype); |
| } |
| } |
| |
| std::function<void()> EnterConstraint(const PrimExpr& constraint) { |
| std::vector<BoundInfo> info = DetectBoundInfo(constraint); |
| if (info.size() == 0) return nullptr; |
| size_t old_size = additional_info_.size(); |
| additional_info_.insert(additional_info_.end(), info.begin(), info.end()); |
| size_t new_size = old_size + info.size(); |
| auto frecover = [old_size, new_size, this]() { |
| ICHECK_EQ(additional_info_.size(), new_size); |
| additional_info_.resize(old_size); |
| }; |
| return frecover; |
| } |
| |
| private: |
| friend class ConstIntBoundAnalyzer; |
| // internal variable map |
| std::unordered_map<Var, Entry> var_map_; |
| // additional bound info |
| std::vector<BoundInfo> additional_info_; |
| // look up table for memorization |
| BoundMapType* bound_{nullptr}; |
| // constants: the limit value means umlimited |
| // NOTE: kNegInf/kPosInf are used to represent infinity. |
| static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; |
| static const constexpr int64_t kPosInf = ConstIntBound::kPosInf; |
| static_assert(-kNegInf == kPosInf, "invariant of inf"); |
| // internal helper functions |
| /*! |
| * \brief Get boundary of binary op who are monotonic wrt to one argument. |
| * \param a The entry of the left operand. |
| * \param b The entry of the right operand. |
| * \param op The operator. |
| * \tparam F the operator function type. |
| * \return The result. |
| */ |
| template <typename F> |
| static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) { |
| Entry ret; |
| // The boundary point must be shihft of the original boundary. |
| int64_t v1 = op(a.min_value, b.min_value); |
| int64_t v2 = op(a.max_value, b.max_value); |
| int64_t v3 = op(a.min_value, b.max_value); |
| int64_t v4 = op(a.max_value, b.min_value); |
| ret.min_value = std::min(std::min(std::min(v1, v2), v3), v4); |
| ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4); |
| return ret; |
| } |
| /*! |
| * \brief Get value boundaries of division (e.g. Div or FloorDiv). |
| * \param a The entry of the left operand. |
| * \param b The entry of the right operand. |
| * \param dt The data type of the division operator. |
| * \param op The division operator. |
| * \tparam F the operator function type. |
| * \return The result. |
| */ |
| template <typename F> |
| static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) { |
| // Here we have a / b. |
| // The largest value of the division will be for the smallest (with |
| // respect to the absolute value) value of b. If the range of b starts |
| // at a negative value and ends at a positive one, narrow it down to |
| // be closer to 0, because BinaryOpBoundary only checks end-points of |
| // the domain ranges. |
| // If the range of b contains 0, then some infinity will be involved |
| if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) { |
| Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt); |
| Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt); |
| |
| Entry e_neg = BinaryOpBoundary(a, b_neg, op); |
| Entry e_pos = BinaryOpBoundary(a, b_pos, op); |
| |
| return MakeBound(std::min(e_neg.min_value, e_pos.min_value), |
| std::max(e_neg.max_value, e_pos.max_value)); |
| } else if (b.min_value == 0 && dt.is_uint()) { |
| // uints only have one sided bounds |
| Entry assumed_b = MakeBound(1, b.max_value); |
| return BinaryOpBoundary(a, assumed_b, op); |
| } |
| // If the range of b does not have 0, use BinaryOpBoundary. |
| return BinaryOpBoundary(a, b, op); |
| } |
| /*! |
| * \brief Compute x + y, aware of inf. |
| * \param x The left operand. |
| * \param y The right operand. |
| * \return the result. |
| */ |
| static int64_t InfAwareAdd(int64_t x, int64_t y) { |
| if (x == kPosInf) { |
| ICHECK(y != kNegInf); |
| return kPosInf; |
| } |
| if (x == kNegInf) { |
| ICHECK(y != kPosInf); |
| return kNegInf; |
| } |
| if (y == kPosInf || y == kNegInf) return y; |
| if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) { |
| if (x > 0) return kPosInf; |
| return kNegInf; |
| } |
| return x + y; |
| } |
| /*! |
| * \brief Compute x * y, aware of inf. |
| * \param x The left operand. |
| * \param y The right operand. |
| * \return the result. |
| */ |
| static int64_t InfAwareMul(int64_t x, int64_t y) { |
| if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y; |
| if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf; |
| return kNegInf; |
| } |
| /*! |
| * \brief Compute x / y, aware of inf. |
| * \param x The left operand. |
| * \param y The right operand. |
| * \return the result. |
| */ |
| static int64_t InfAwareDiv(int64_t x, int64_t y) { |
| ICHECK_NE(y, 0); |
| if (x == kPosInf || x == kNegInf) { |
| if (y > 0) return x; |
| return -x; |
| } |
| return x / y; |
| } |
| /*! |
| * \brief Compute floodiv(x, y), aware of inf. |
| * \param x The left operand. |
| * \param y The right operand. |
| * \return the result. |
| */ |
| static int64_t InfAwareFloorDiv(int64_t x, int64_t y) { |
| ICHECK_NE(y, 0); |
| if (x == kPosInf || x == kNegInf) { |
| if (y > 0) return x; |
| return -x; |
| } |
| return floordiv(x, y); |
| } |
| /*! |
| * \brief Compute x << y, aware of inf. |
| * \param x The left operand. |
| * \param y The right operand. |
| * \return the result. |
| */ |
| static int64_t InfAwareLeftShift(int64_t x, int64_t y) { |
| if (x == kPosInf || x == kNegInf) return x; |
| |
| // Can be replaced with std::bit_width in C++20 |
| auto bit_width = [](int64_t as_signed) { |
| uint64_t val = std::abs(as_signed); |
| int num_bits = 0; |
| while (val) { |
| ++num_bits; |
| val >>= 1; |
| } |
| return num_bits; |
| }; |
| int x_bits = bit_width(x); |
| if (x_bits + y < 64) { |
| return x << y; |
| } else { |
| return kPosInf; |
| } |
| } |
| /*! |
| * \brief Compute x >> y, aware of inf. |
| * \param x The left operand. |
| * \param y The right operand. |
| * \return the result. |
| */ |
| static int64_t InfAwareRightShift(int64_t x, int64_t y) { |
| if (x == kPosInf || x == kNegInf) return x; |
| return x >> y; |
| } |
| /*! |
| * \brief Make a new bound entry. |
| */ |
| static Entry MakeBound(int64_t min_value, int64_t max_value) { |
| Entry e; |
| e.min_value = (min_value == kPosInf) ? min_value - 1 : min_value; |
| e.max_value = (max_value == kNegInf) ? max_value + 1 : max_value; |
| return e; |
| } |
| /*! |
| * \brief Create union of two sets. |
| * \param a The left operand. |
| * \param b the right operand. |
| */ |
| static Entry Union(Entry a, Entry b) { |
| Entry ret; |
| ret.min_value = std::min(a.min_value, b.min_value); |
| ret.max_value = std::max(a.max_value, b.max_value); |
| return ret; |
| } |
| /*! |
| * \brief Create intersect of two sets. |
| * \param a The left operand. |
| * \param b the right operand. |
| */ |
| static Entry Intersect(Entry a, Entry b) { |
| Entry ret; |
| ret.min_value = std::max(a.min_value, b.min_value); |
| ret.max_value = std::min(a.max_value, b.max_value); |
| return ret; |
| } |
| /*! |
| * \brief Flip the sign of a set. |
| * \param entry The set of values |
| */ |
| static Entry Negative(Entry entry) { |
| Entry ret; |
| if (entry.max_value == kPosInf) { |
| ret.min_value = kNegInf; |
| } else { |
| ret.min_value = -entry.max_value; |
| } |
| if (entry.min_value == kNegInf) { |
| ret.max_value = kPosInf; |
| } else { |
| ret.max_value = -entry.min_value; |
| } |
| |
| return ret; |
| } |
| /*! |
| * \brief return everything dtype can represent. |
| * \param dtype The data type. |
| * \return Bound that represent everything dtype can represent. |
| */ |
| static Entry Everything(DataType dtype) { |
| if (!dtype.is_int() && !dtype.is_uint()) { |
| return MakeBound(kNegInf, kPosInf); |
| } |
| Entry ret; |
| int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int()); |
| if (dtype.is_uint()) { |
| ret.min_value = 0; |
| } else { |
| if (vbits >= 63) { |
| ret.min_value = kNegInf; |
| } else { |
| ret.min_value = -(static_cast<int64_t>(1) << vbits); |
| } |
| } |
| if (vbits >= 63) { |
| ret.max_value = kPosInf; |
| } else { |
| ret.max_value = (static_cast<int64_t>(1) << vbits) - 1; |
| } |
| return ret; |
| } |
| |
| /*! |
| * \brief Detect additional constant bound from cond, if any |
| * \param cond The constraint condition. |
| * \return List of detected bounds. |
| */ |
| static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) { |
| PVar<PrimExpr> x, y; |
| PVar<IntImm> c; |
| |
| std::vector<BoundInfo> info; |
| auto add_info = [&](const PrimExpr& expr, int64_t min_value, int64_t max_value) { |
| // If the conditional is comparing two integers, do not assign a |
| // value to them. |
| if (!expr->IsInstance<IntImmNode>()) { |
| info.push_back(BoundInfo(expr, MakeBound(min_value, max_value))); |
| } |
| }; |
| |
| for (const auto& subexpr : ExtractConstraints(cond)) { |
| // NOTE: The canonical form always uses <= or <, but a |
| // user-supplied constraint from the python API might not be |
| // canonicalized. |
| if ((c <= x).Match(subexpr) || (x >= c).Match(subexpr)) { |
| add_info(x.Eval(), c.Eval()->value, kPosInf); |
| } else if ((c < x).Match(subexpr) || (x > c).Match(subexpr)) { |
| add_info(x.Eval(), c.Eval()->value + 1, kPosInf); |
| } else if ((x <= c).Match(subexpr) || (x >= c).Match(subexpr)) { |
| add_info(x.Eval(), kNegInf, c.Eval()->value); |
| } else if ((x < c).Match(subexpr) || (c > x).Match(subexpr)) { |
| add_info(x.Eval(), kNegInf, c.Eval()->value - 1); |
| } else if ((x == c).Match(subexpr) || (c == x).Match(subexpr)) { |
| add_info(x.Eval(), c.Eval()->value, c.Eval()->value); |
| } |
| } |
| |
| return info; |
| } |
| |
| /*! |
| * \brief Extract the argument from int(ceil(log2(arg))) |
| * |
| * This expression is used as the implementation of |
| * topi.math.ceil_log2, and can appear in iteration bounds. |
| */ |
| static ffi::Optional<PrimExpr> FindCeilLog2Arg(const CastNode* op) { |
| if (op->dtype.is_int()) { |
| if (auto as_call = op->value.as<CallNode>()) { |
| if (as_call->op.same_as(Op::Get("tir.ceil"))) { |
| PrimExpr ceil_arg = as_call->args[0]; |
| if (auto arg_call = ceil_arg.as<CallNode>()) { |
| if (arg_call->op.same_as(Op::Get("tir.log2"))) { |
| PrimExpr log_arg = arg_call->args[0]; |
| return log_arg; |
| } |
| } |
| } |
| } |
| } |
| return std::nullopt; |
| } |
| |
| /*! \brief Propagate constraints through ceil(log2(arg)) |
| * |
| * Helper function for CastNode visitor |
| */ |
| Entry CeilLog2Bounds(PrimExpr arg) { |
| if (auto as_float = arg.as<FloatImmNode>()) { |
| // A cast from int to float may have already been simplified |
| // out. Normally we don't inspect floating-point arguments, but here we can |
| int64_t val = std::ceil(std::log2(as_float->value)); |
| return MakeBound(val, val); |
| } else { |
| Entry arg_bounds = VisitExpr(arg); |
| return MakeBound(std::ceil(std::log2(arg_bounds.min_value)), |
| std::ceil(std::log2(arg_bounds.max_value))); |
| } |
| } |
| }; |
| |
| ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { |
| Entry ret = impl_->VisitExpr(expr); |
| return ConstIntBound(ret.min_value, ret.max_value); |
| } |
| |
| ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) { |
| impl_->bound_ = bound; |
| Entry ret = impl_->VisitExpr(expr); |
| impl_->bound_ = nullptr; |
| return ConstIntBound(ret.min_value, ret.max_value); |
| } |
| |
| void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool allow_override) { |
| impl_->Update(var, info, allow_override); |
| } |
| |
| void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { |
| impl_->Bind(var, range, allow_override); |
| } |
| |
| bool ConstIntBoundAnalyzer::IsBound(const Var& var) const { return impl_->IsBound(var); } |
| |
| std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { |
| return impl_->EnterConstraint(constraint); |
| } |
| |
| ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} |
| |
| ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } |
| |
| } // namespace arith |
| } // namespace tvm |