| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file int_set.cc |
| * \brief The integer set functions |
| */ |
| #include <tvm/arith/int_set.h> |
| #include <tvm/arith/iter_affine_map.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/expr_functor.h> |
| |
| #include <algorithm> |
| #include <unordered_map> |
| #include <utility> |
| |
| #include "constraint_extract.h" |
| #include "interval_set.h" |
| #include "pattern_match.h" |
| |
| namespace tvm { |
| namespace arith { |
| |
| using tir::is_one; |
| using tir::is_zero; |
| using tir::make_const; |
| using tir::make_zero; |
| |
| PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); |
| PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); |
| |
| IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { |
| auto node = make_object<IntervalSetNode>(); |
| node->min_value = std::move(min_value); |
| node->max_value = std::move(max_value); |
| data_ = std::move(node); |
| } |
| |
| IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { |
| return IntervalSet(min_value, max_value); |
| } |
| |
| TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); |
| |
| IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { |
| PrimExpr max_value = min(a->max_value, b->max_value); |
| PrimExpr min_value = max(a->min_value, b->min_value); |
| if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && |
| (min_value.dtype().is_int() || min_value.dtype().is_uint()) && |
| analyzer->CanProve(max_value < min_value)) { |
| return IntervalSet::Empty(); |
| } else { |
| return IntervalSet(min_value, max_value); |
| } |
| } |
| |
| IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { |
| if (a->IsEmpty()) return b; |
| if (b->IsEmpty()) return a; |
| PrimExpr max_value = max(a->max_value, b->max_value); |
| PrimExpr min_value = min(a->min_value, b->min_value); |
| return IntervalSet(min_value, max_value); |
| } |
| |
| // type traits |
| template <typename OP> |
| struct is_logical_op { |
| static const bool value = false; |
| }; |
| |
| #define TVM_DECLARE_LOGICAL_OP(OP) \ |
| template <> \ |
| struct is_logical_op<tir::OP> { \ |
| static const bool value = true; \ |
| }; |
| |
| TVM_DECLARE_LOGICAL_OP(And); |
| TVM_DECLARE_LOGICAL_OP(Or); |
| TVM_DECLARE_LOGICAL_OP(EQ); |
| TVM_DECLARE_LOGICAL_OP(NE); |
| TVM_DECLARE_LOGICAL_OP(GE); |
| TVM_DECLARE_LOGICAL_OP(GT); |
| TVM_DECLARE_LOGICAL_OP(LE); |
| TVM_DECLARE_LOGICAL_OP(LT); |
| TVM_DECLARE_LOGICAL_OP(Not); |
| |
| /*! |
| * \brief Combine two interval set under arithmetic operations. |
| * \note this can possibly relax the set. |
| */ |
| template <typename Op> |
| inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| PrimExpr expr; |
| if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) { |
| expr = res.value(); |
| } else { |
| expr = Op(a->min_value, b->min_value); |
| } |
| return IntervalSet::SinglePoint(expr); |
| } |
| if (is_logical_op<Op>::value) { |
| return IntervalSet(make_const(dtype, 0), make_const(dtype, 1)); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| if (a->IsEverything()) return a; |
| if (b->IsEverything()) return b; |
| return IntervalSet::Everything(); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(a->min_value + b->min_value); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| PrimExpr min_value = |
| a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf(); |
| PrimExpr max_value = |
| a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(a->min_value - b->min_value); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| PrimExpr min_value = |
| a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf(); |
| PrimExpr max_value = |
| a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(a->min_value * b->min_value); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| if (a->IsSinglePoint()) { |
| std::swap(a, b); |
| } |
| if (b->IsSinglePoint()) { |
| if (is_zero(b->min_value)) return b; |
| if (is_one(b->min_value)) return a; |
| if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { |
| PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); |
| PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { |
| PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); |
| PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } else if (a->HasUpperBound() && a->HasLowerBound()) { |
| using tir::Select; |
| PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); |
| PrimExpr e1 = a->min_value * b->min_value; |
| PrimExpr e2 = a->max_value * b->min_value; |
| return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); |
| } |
| } |
| DLOG(WARNING) << "Return Everything in CombineInterval Mul"; |
| return IntervalSet::Everything(); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(a->min_value / b->min_value); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| if (b->IsSinglePoint()) { |
| if (is_zero(b->min_value)) { |
| LOG(FATAL) << "Divide by zero in CombineInterval Div"; |
| } |
| if (is_one(b->min_value)) return a; |
| // no relaxation is needed in here due to set is inclusive |
| if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { |
| PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); |
| PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { |
| PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); |
| PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } else if (a->HasUpperBound() && a->HasLowerBound()) { |
| using tir::Select; |
| PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); |
| PrimExpr e1 = a->min_value / b->min_value; |
| PrimExpr e2 = a->max_value / b->min_value; |
| return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); |
| } |
| } |
| DLOG(WARNING) << "Return Everything in CombineInterval Div"; |
| return IntervalSet::Everything(); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| |
| if (b->IsSinglePoint()) { |
| const PrimExpr& divisor = b->min_value; |
| if (is_zero(divisor)) { |
| LOG(FATAL) << "Modular by zero in CombineInterval Mod"; |
| } |
| // We need to add more bound constraints throughout the code. |
| // The logic below assumes a is non-negative, which usually |
| // is the case of our application. |
| // TODO(tqchen): add bound constraints for a. |
| if (analyzer->CanProveGreaterEqual(divisor, 0)) { |
| return IntervalSet(make_zero(divisor.dtype()), divisor - 1); |
| } else { |
| PrimExpr bound = abs(divisor) - 1; |
| return IntervalSet(-bound, bound); |
| } |
| } |
| DLOG(WARNING) << "Return Everything in CombineInterval Mod"; |
| return IntervalSet::Everything(); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| if (b->IsSinglePoint()) { |
| if (is_zero(b->min_value)) { |
| LOG(FATAL) << "Divide by zero in CombineInterval Div"; |
| } |
| if (is_one(b->min_value)) return a; |
| // no relaxation is needed in here due to set is inclusive |
| if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { |
| PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf(); |
| PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { |
| PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf(); |
| PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } else if (a->HasUpperBound() && a->HasLowerBound()) { |
| using tir::Select; |
| PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); |
| PrimExpr e1 = floordiv(a->min_value, b->min_value); |
| PrimExpr e2 = floordiv(a->max_value, b->min_value); |
| return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); |
| } |
| } |
| DLOG(WARNING) << "Return Everything in CombineInterval Div"; |
| return IntervalSet::Everything(); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| |
| if (b->IsSinglePoint()) { |
| const PrimExpr& divisor = b->min_value; |
| if (is_zero(divisor)) { |
| LOG(FATAL) << "Modular by zero in CombineInterval Mod"; |
| } |
| if (analyzer->CanProveGreaterEqual(divisor, 0)) { |
| if (divisor.as<tir::IntImmNode>()) { |
| // a mod b = a - (a / b) * b if a_max / b == a_min / b |
| auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); |
| auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); |
| // We can compare +/- inf against each other, but cannot use |
| // operator== between the symbolic limits and an integer. |
| bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle()); |
| if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) { |
| auto tmax = a->max_value - divisor * qmin; |
| auto tmin = a->min_value - divisor * qmin; |
| return IntervalSet(tmin, tmax); |
| } |
| } |
| return IntervalSet(make_zero(divisor.dtype()), divisor - 1); |
| } else { |
| PrimExpr bound = abs(divisor) - 1; |
| return IntervalSet(-bound, bound); |
| } |
| } |
| DLOG(WARNING) << "Return Everything in CombineInterval Mod"; |
| return IntervalSet::Everything(); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value)); |
| } |
| |
| template <> |
| inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, IntervalSet b, |
| DataType /* dtype */) { |
| if (a->IsSinglePoint() && b->IsSinglePoint()) { |
| return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); |
| } |
| if (a->IsEmpty()) return a; |
| if (b->IsEmpty()) return b; |
| return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value)); |
| } |
| |
| // internal helper function to get an interval set |
| IntervalSet ToIntervalSet(IntSet set) { |
| if (auto* node = set.as<IntervalSetNode>()) { |
| return GetRef<IntervalSet>(node); |
| } |
| DLOG(INFO) << "cannot resolve int set " << set; |
| return IntervalSet::Everything(); |
| } |
| |
| using namespace tir; |
| |
| // Simplified version of int set evaluator that operates on IntervalSet |
| // We might use better set analysis in the future to replace the intervalset. |
| class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> { |
| public: |
| IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, |
| const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr, |
| bool eval_vec = false) |
| : analyzer_(analyzer), |
| dom_map_(dom_map), |
| dom_constraints_(dom_constraints), |
| eval_vec_(eval_vec) {} |
| |
| IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } |
| // evaluate and relax the set |
| IntervalSet Eval(IntervalSet val) { |
| // avoid recursive indefinite recursive expansion. |
| if (static_cast<size_t>(recur_depth_) >= dom_map_.size()) return val; |
| ++recur_depth_; |
| IntervalSet min_set = this->Eval(val->min_value); |
| IntervalSet max_set = this->Eval(val->max_value); |
| --recur_depth_; |
| return IntervalSet(min_set->min_value, max_set->max_value); |
| } |
| |
| IntervalSet VisitExpr_(const IntImmNode* op) final { |
| return IntervalSet::SinglePoint(GetRef<PrimExpr>(op)); |
| } |
| |
| IntervalSet VisitExpr_(const VarNode* op) final { |
| Var var = GetRef<Var>(op); |
| |
| Array<IntSet> values; |
| if (dom_constraints_) { |
| for (const auto& constraint : *dom_constraints_) { |
| if (var.same_as(constraint.first)) { |
| values.push_back(constraint.second); |
| } |
| } |
| } |
| |
| auto it = dom_map_.find(var); |
| if (it != dom_map_.end()) { |
| values.push_back((*it).second); |
| } |
| |
| if (values.empty()) { |
| return IntervalSet::SinglePoint(var); |
| } |
| |
| IntSet intersection = [&]() { |
| if (values.size() == 1) { |
| return values.front(); |
| } else { |
| return Intersect(values); |
| } |
| }(); |
| |
| IntervalSet res = ToIntervalSet(intersection); |
| if (res->min_value.same_as(var) && res->max_value.same_as(var)) { |
| return res; |
| } |
| // recursively evaluate mapped result |
| // in case the domain contains variables to be relaxed. |
| return Eval(res); |
| } |
| |
| IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); } |
| |
| IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_<Sub>(op); } |
| |
| IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_<Mul>(op); } |
| |
| IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_<Div>(op); } |
| |
| IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_<Mod>(op); } |
| |
| IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_<FloorDiv>(op); } |
| |
| IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_<FloorMod>(op); } |
| |
| IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_<Min>(op); } |
| |
| IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_<Max>(op); } |
| |
| IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_<EQ>(op); } |
| |
| IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_<NE>(op); } |
| |
| IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_<LT>(op); } |
| |
| IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_<LE>(op); } |
| |
| IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_<GT>(op); } |
| |
| IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_<GE>(op); } |
| |
| IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_<And>(op); } |
| |
| IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_<Or>(op); } |
| |
| IntervalSet VisitExpr_(const RampNode* op) final { |
| ICHECK(eval_vec_); |
| IntervalSet base = Eval(op->base); |
| PVar<IntImm> stride; |
| if (stride.Match(op->stride)) { |
| DataType t = op->base.dtype(); |
| int64_t vstride = stride.Eval()->value; |
| if (vstride > 0) { |
| return Combine<Add>(analyzer_, base, |
| IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))), |
| op->dtype); |
| } else { |
| return Combine<Add>(analyzer_, base, |
| IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)), |
| op->dtype); |
| } |
| } |
| DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op); |
| return IntervalSet::Everything(); |
| } |
| |
| IntervalSet VisitExpr_(const BroadcastNode* op) final { |
| ICHECK(eval_vec_); |
| return VisitExpr(op->value); |
| } |
| |
| IntervalSet VisitExpr_(const SelectNode* op) final { |
| IntervalSet true_set = this->Eval(op->true_value); |
| IntervalSet false_set = this->Eval(op->false_value); |
| return Union(analyzer_, false_set, true_set); |
| } |
| |
| IntervalSet VisitExpr_(const CastNode* op) final { |
| IntervalSet value_set = this->Eval(op->value); |
| PrimExpr min_value = |
| value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); |
| PrimExpr max_value = |
| value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); |
| return IntervalSet(min_value, max_value); |
| } |
| |
| IntervalSet VisitExpr_(const BufferLoadNode* op) final { |
| if (!(op->dtype.is_int() || op->dtype.is_uint())) { |
| DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op->dtype |
| << " buffer"; |
| return IntervalSet::Everything(); |
| } |
| // If the indices do not contain any variables to be relaxed, return the BufferLoad itself. |
| // Otherwise return `IntervalSet::everything()` since we have no knowledge on the buffer data. |
| for (const PrimExpr& index : op->indices) { |
| if (UsesVar(index, [dom_map = &this->dom_map_](const VarNode* var) { |
| return dom_map->find(GetRef<Var>(var)) != dom_map->end(); |
| })) { |
| return IntervalSet::Everything(); |
| } |
| } |
| return IntervalSet::SinglePoint(GetRef<PrimExpr>(op)); |
| } |
| |
| IntervalSet VisitExprDefault_(const Object* op) final { |
| DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); |
| return IntervalSet::Everything(); |
| } |
| |
| private: |
| // whether set is exactly single point that equals value. |
| bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { |
| return set->min_value.same_as(value) && set->max_value.same_as(value); |
| } |
| |
| template <typename TOp, typename T> |
| inline IntervalSet VisitBinaryExpr_(const T* op) { |
| static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint"); |
| IntervalSet a = this->Eval(op->a); |
| IntervalSet b = this->Eval(op->b); |
| if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { |
| return IntervalSet::SinglePoint(GetRef<PrimExpr>(op)); |
| } |
| return Combine<TOp>(analyzer_, a, b, op->dtype); |
| } |
| |
| // recursive depth |
| int recur_depth_{0}; |
| // analyzer |
| Analyzer* analyzer_; |
| const Map<Var, IntSet>& dom_map_; |
| const std::vector<std::pair<Var, IntSet>>* dom_constraints_; |
| bool eval_vec_{false}; |
| }; |
| |
| class IntSetAnalyzer::Impl { |
| public: |
| explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} |
| |
| IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const { |
| return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); |
| } |
| |
| IntSet Eval(const PrimExpr& expr) const { |
| return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr); |
| } |
| |
| void Bind(const Var& var, const Range& range, bool allow_override) { |
| Update(var, IntSet::FromRange(range), allow_override); |
| } |
| |
| void Update(const Var& var, const IntSet& info, bool override_info); |
| void Bind(const Var& var, const PrimExpr& expr, bool override_info); |
| std::function<void()> EnterConstraint(const PrimExpr& constraint); |
| |
| private: |
| // Utility function to split a boolean condition into the domain |
| // bounds implied by that condition. |
| static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond); |
| |
| // The parent arith::Analyzer |
| Analyzer* analyzer_; |
| |
| // Map of variables to global variable bounds (e.g. loop iterator |
| // ranges) |
| Map<Var, IntSet> dom_map_; |
| |
| // List of implicit scope-dependent bounds (e.g. inside the body of |
| // an if-statement). Maintained as a list of constraints, rather |
| // than as a `Map<Var,IntSet>`, to avoid computing an Intersection |
| // until required. |
| std::vector<std::pair<Var, IntSet>> dom_constraints_; |
| }; |
| |
| IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} |
| |
| IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } |
| |
| IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) { |
| return impl_->Eval(expr, dom_map); |
| } |
| |
| IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return impl_->Eval(expr); } |
| |
| void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool allow_override) { |
| impl_->Update(var, info, allow_override); |
| } |
| |
| void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { |
| impl_->Bind(var, range, allow_override); |
| } |
| |
| void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_override) { |
| if (!can_override) { |
| auto it = dom_map_.find(var); |
| if (it != dom_map_.end()) { |
| const IntSet& old_info = (*it).second; |
| |
| ICHECK(ExprDeepEqual()(old_info.min(), info.min())) |
| << "Trying to update var \'" << var << "\'" |
| << " with a different minimum value: " |
| << "original=" << old_info.min() << ", new=" << info.min(); |
| |
| ICHECK(ExprDeepEqual()(old_info.max(), info.max())) |
| << "Trying to update var \'" << var << "\'" |
| << " with a different maximum value: " |
| << "original=" << old_info.max() << ", new=" << info.max(); |
| } |
| } |
| dom_map_.Set(var, info); |
| } |
| |
| void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_override) { |
| Update(var, Eval(expr), can_override); |
| } |
| |
| std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo( |
| const PrimExpr& constraint) { |
| PVar<Var> x; |
| PVar<PrimExpr> limit; |
| |
| std::vector<std::pair<Var, IntSet>> bounds; |
| for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) { |
| if ((x <= limit).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); |
| } else if ((x < limit).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); |
| } else if ((x >= limit).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); |
| } else if ((x > limit).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); |
| } else if ((x == limit).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); |
| } |
| |
| if ((limit >= x).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); |
| } else if ((limit > x).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); |
| } else if ((limit <= x).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); |
| } else if ((limit < x).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); |
| } else if ((limit == x).Match(subconstraint)) { |
| bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); |
| } |
| } |
| return bounds; |
| } |
| |
| std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { |
| return impl_->EnterConstraint(constraint); |
| } |
| |
| std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) { |
| auto bounds = DetectBoundInfo(constraint); |
| |
| if (bounds.size() == 0) return nullptr; |
| |
| size_t old_size = dom_constraints_.size(); |
| dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end()); |
| size_t new_size = dom_constraints_.size(); |
| auto frecover = [old_size, new_size, this]() { |
| ICHECK_EQ(dom_constraints_.size(), new_size); |
| dom_constraints_.resize(old_size); |
| }; |
| return frecover; |
| } |
| |
| // Quickly adapt to IntSet interface |
| // TODO(tqchen): revisit IntSet interface as well. |
| Range IntSet::CoverRange(Range max_range) const { |
| IntSet temp; |
| Analyzer analyzer; |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| ICHECK(s_int != nullptr); |
| if (s_int->HasUpperBound() && s_int->HasLowerBound()) { |
| return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), |
| analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); |
| } |
| return max_range; |
| } |
| |
| PrimExpr IntSet::min() const { |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| ICHECK(s_int); |
| return s_int->min_value; |
| } |
| |
| PrimExpr IntSet::max() const { |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| ICHECK(s_int); |
| return s_int->max_value; |
| } |
| |
| bool IntSet::IsNothing() const { |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| return (s_int && s_int->IsEmpty()); |
| } |
| |
| bool IntSet::IsEverything() const { |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| return (s_int && s_int->IsEverything()); |
| } |
| |
| bool IntSet::IsSinglePoint() const { |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| return (s_int && s_int->IsSinglePoint()); |
| } |
| |
| bool IntSet::CanProvePositive() const { |
| Analyzer analyzer; |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value))); |
| } |
| |
| bool IntSet::CanProveNegative() const { |
| Analyzer analyzer; |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value))); |
| } |
| |
| bool IntSet::CanProveNonPositive() const { |
| Analyzer analyzer; |
| if (const auto* s_int = (*this).as<IntervalSetNode>()) { |
| auto max = analyzer.Simplify(s_int->max_value); |
| return is_zero(max) || is_negative_const(max); |
| } |
| return false; |
| } |
| |
| bool IntSet::CanProveNonNegative() const { |
| Analyzer analyzer; |
| if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { |
| auto min = analyzer.Simplify(s_int->min_value); |
| return is_zero(min) || is_positive_const(min); |
| } |
| return false; |
| } |
| |
| bool IntSet::HasLowerBound() const { |
| if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { |
| return s_int->HasLowerBound(); |
| } |
| return false; |
| } |
| |
| bool IntSet::HasUpperBound() const { |
| if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { |
| return s_int->HasUpperBound(); |
| } |
| return false; |
| } |
| |
| SignType IntSet::GetSignType() const { |
| if (CanProvePositive()) { |
| return kPositive; |
| } else if (CanProveNegative()) { |
| return kNegative; |
| } else if (IsSinglePoint() && is_zero(PointValue())) { |
| return kZero; |
| } else { |
| return kUnknown; |
| } |
| } |
| PrimExpr IntSet::PointValue() const { |
| const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
| ICHECK(s_int && s_int->IsSinglePoint()); |
| return s_int->min_value; |
| } |
| |
| IntSet IntSet::Nothing() { return IntervalSet::Empty(); } |
| |
| IntSet IntSet::Everything() { return IntervalSet::Everything(); } |
| |
| IntSet IntSet::SinglePoint(PrimExpr x) { return IntervalSet::SinglePoint(x); } |
| |
| IntSet IntSet::Interval(PrimExpr min, PrimExpr max) { |
| if (min.same_as(max)) { |
| return IntSet::SinglePoint(min); |
| } |
| return IntervalSet(min, max); |
| } |
| |
| // Range related code |
| inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { |
| return is_zero(analyzer->Simplify(lhs - rhs)); |
| } |
| |
| IntSet IntSet::FromMinExtent(PrimExpr min, PrimExpr extent) { |
| if (is_one(extent)) { |
| return IntSet::SinglePoint(min); |
| } |
| return IntervalSet(min, extent + min - 1); |
| } |
| |
| IntSet IntSet::FromRange(Range r) { |
| // must make sure it can be matched back by MatchRange. |
| if (is_one(r->extent)) { |
| return IntSet::SinglePoint(r->min); |
| } |
| return IntervalSet(r->min, r->extent + r->min - 1); |
| } |
| |
| bool IntSet::MatchRange(const Range& b) const { |
| const IntSet& a = *this; |
| const IntervalSetNode* a_int = a.as<IntervalSetNode>(); |
| if (!a_int) return false; |
| if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false; |
| Analyzer ana; |
| return ProveEqual(&ana, a_int->min_value, b->min) && |
| ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); |
| } |
| |
| IntSet Union(const Array<IntSet>& sets) { |
| if (sets.size() == 0) return IntSet::Nothing(); |
| if (sets.size() == 1) return sets[0]; |
| Analyzer ana; |
| IntervalSet x = ToIntervalSet(sets[0]); |
| for (size_t i = 1; i < sets.size(); ++i) { |
| x = Union(&ana, x, ToIntervalSet(sets[i])); |
| } |
| return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); |
| } |
| |
| Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets) { |
| if (nd_int_sets.empty()) { |
| return {}; |
| } |
| int n = nd_int_sets.size(); |
| int ndim = nd_int_sets[0].size(); |
| Array<IntSet> result; |
| result.reserve(ndim); |
| for (int i = 0; i < ndim; ++i) { |
| Array<IntSet> candidates; |
| candidates.reserve(n); |
| for (int j = 0; j < n; ++j) { |
| candidates.push_back(nd_int_sets[j][i]); |
| } |
| result.push_back(Union(candidates)); |
| } |
| return result; |
| } |
| |
| IntSet UnionLowerBound(const Array<IntSet>& sets) { |
| if (sets.size() == 0) return IntSet::Nothing(); |
| if (sets.size() == 1) return sets[0]; |
| Analyzer analyzer; |
| bool is_first_interval = true; |
| PrimExpr min_inclusive{nullptr}; |
| PrimExpr max_inclusive(nullptr); |
| for (const IntSet& int_set : sets) { |
| if (const auto* interval_set = int_set.as<IntervalSetNode>()) { |
| PrimExpr new_min_inclusive = interval_set->min_value; |
| PrimExpr new_max_inclusive = interval_set->max_value; |
| if (is_first_interval) { |
| is_first_interval = false; |
| min_inclusive = std::move(new_min_inclusive); |
| max_inclusive = std::move(new_max_inclusive); |
| continue; |
| } |
| bool bound_1 = is_neg_inf(new_min_inclusive) || is_pos_inf(max_inclusive) || |
| analyzer.CanProve(new_min_inclusive <= max_inclusive + 1); |
| bool bound_2 = is_neg_inf(min_inclusive) || is_pos_inf(new_max_inclusive) || |
| analyzer.CanProve(min_inclusive <= new_max_inclusive + 1); |
| if (bound_1 && bound_2) { |
| min_inclusive = min(min_inclusive, new_min_inclusive); |
| max_inclusive = max(max_inclusive, new_max_inclusive); |
| } |
| } |
| } |
| if (is_first_interval) { |
| return IntSet::Nothing(); |
| } |
| return IntSet::Interval(min_inclusive, max_inclusive); |
| } |
| |
| Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets) { |
| if (nd_int_sets.empty()) { |
| return {}; |
| } |
| int n = nd_int_sets.size(); |
| int ndim = nd_int_sets[0].size(); |
| Array<IntSet> result; |
| result.reserve(ndim); |
| for (int i = 0; i < ndim; ++i) { |
| Array<IntSet> candidates; |
| candidates.reserve(n); |
| for (int j = 0; j < n; ++j) { |
| candidates.push_back(nd_int_sets[j][i]); |
| } |
| result.push_back(UnionLowerBound(candidates)); |
| } |
| return result; |
| } |
| |
| IntSet Intersect(const Array<IntSet>& sets) { |
| if (sets.size() == 0) return IntSet::Nothing(); |
| if (sets.size() == 1) return sets[0]; |
| Analyzer ana; |
| IntervalSet x = ToIntervalSet(sets[0]); |
| for (size_t i = 1; i < sets.size(); ++i) { |
| x = Intersect(&ana, x, ToIntervalSet(sets[i])); |
| } |
| return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); |
| } |
| |
| Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) { |
| Map<Var, IntSet> dmap; |
| for (auto kv : dom_map) { |
| dmap.Set(kv.first->var, kv.second); |
| } |
| return dmap; |
| } |
| |
| Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
| Map<Var, IntSet> dmap; |
| for (auto kv : dom_map) { |
| dmap.Set(GetRef<Var>(kv.first), kv.second); |
| } |
| return dmap; |
| } |
| |
| IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) { |
| Analyzer ana; |
| return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); |
| } |
| |
| IntSet IntSet::Vector(PrimExpr x) { |
| Analyzer ana; |
| Map<Var, IntSet> dmap; |
| return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); |
| } |
| |
| IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) { |
| return EvalSet(e, ConvertDomMap(dom_map)); |
| } |
| |
| IntSet EvalSet(PrimExpr e, const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
| return EvalSet(e, ConvertDomMap(dom_map)); |
| } |
| |
| IntSet EvalSet(Range r, const Map<Var, IntSet>& dom_map) { |
| Analyzer ana; |
| if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) { |
| return EvalSet(r->min, dom_map); |
| } |
| IntervalSetEvaluator m(&ana, dom_map); |
| // Simplifying first can give tighter bounds if r->min and r->extent share variables |
| PrimExpr sum = r->min + r->extent - 1; |
| auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); |
| return std::move(res); |
| } |
| |
| IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
| return EvalSet(r, ConvertDomMap(dom_map)); |
| } |
| |
| Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map) { |
| Analyzer ana; |
| IntervalSetEvaluator m(&ana, dom_map); |
| Array<IntSet> result; |
| result.reserve(region.size()); |
| for (const Range& r : region) { |
| PrimExpr sum = r->min + (r->extent - 1); |
| result.push_back(m.Eval(IntervalSet(r->min, ana.Simplify(sum)))); |
| } |
| return result; |
| } |
| |
| IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
| Analyzer ana; |
| auto dmap = ConvertDomMap(dom_map); |
| IntervalSetEvaluator m(&ana, dmap); |
| const IntervalSetNode* s_int = s.as<IntervalSetNode>(); |
| PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; |
| PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; |
| return IntervalSet(vmin, vmax); |
| } |
| |
| class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { |
| public: |
| explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map) |
| : IntervalSetEvaluator(analyzer, dom_map) {} |
| |
| IntervalSet VisitExpr(const PrimExpr& n) final { |
| IntervalSet ret = IntervalSetEvaluator::VisitExpr(n); |
| expr_map[n] = ret; |
| return ret; |
| } |
| |
| ExprIntSetMap expr_map; |
| }; |
| |
| ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, |
| const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
| Analyzer ana; |
| auto dmap = ConvertDomMap(dom_map); |
| SubExprIntervalSetEvaluator m(&ana, dmap); |
| m.Eval(e); |
| return m.expr_map; |
| } |
| |
| IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) { |
| return EvalSet(r, ConvertDomMap(dom_map)); |
| } |
| |
| Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) { |
| Map<Var, arith::IntSet> result; |
| for (auto kv : var_dom) { |
| const Var& var = kv.first; |
| const Range& range = kv.second; |
| result.Set(var, arith::IntSet::FromRange(range)); |
| } |
| return result; |
| } |
| |
| /*! \brief Helper function to convert IterSumExpr to the actual touched range. */ |
| static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, |
| Analyzer* analyzer) { |
| if (iter_min->args.empty()) { |
| return IntSet::FromMinExtent(iter_min->base, extent); |
| } |
| ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr"; |
| const IterSplitExpr& split = iter_min->args[0]; |
| if (!analyzer->CanProve(extent >= split->scale)) { |
| return NullOpt; |
| } |
| |
| const PrimExpr& base = iter_min->base; |
| // IterSplitExpr: (source // lower_factor) % extent * scale |
| // where `(source // lower_factor) % extent` is within [0, extent - 1] |
| if (analyzer->CanProve(split->scale < 0)) { |
| // If scale is negative, the var dom is [(extent - 1) * scale, 0] |
| // The total base is `base + (extent - 1) * scale`, |
| // while total extent is `dom_extent + (extent - 1) * (-scale)` |
| const PrimExpr& var_extent = (split->extent - 1) * split->scale; |
| return IntSet::FromMinExtent(base + var_extent, extent - var_extent); |
| } else { |
| // If scale is positive, the var dom is [0, (extent - 1) * scale] |
| // The total dom is [base, dom_extent + (extent - 1) * scale] |
| return IntSet::FromMinExtent(base, extent + (split->extent - 1) * split->scale); |
| } |
| } |
| |
| Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region, |
| const Map<Var, Range>& var_dom, |
| const PrimExpr& predicate, Analyzer* analyzer) { |
| int ndim = region.size(); |
| Array<IterSumExpr> iter_sum_exprs{nullptr}; |
| { |
| Array<PrimExpr> affine_indices; |
| affine_indices.reserve(ndim); |
| for (const Range& range : region) { |
| if (!is_const_number(range->extent)) { |
| // dynamic extent is not supported yet. |
| return NullOpt; |
| } |
| affine_indices.push_back(range->min); |
| } |
| auto res = DetectIterMap( |
| /*indices=*/affine_indices, /*input_iters=*/var_dom, |
| /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); |
| iter_sum_exprs = res->indices; |
| } |
| if (iter_sum_exprs.empty()) { |
| return NullOpt; |
| } |
| ICHECK_EQ(iter_sum_exprs.size(), ndim); |
| Array<IntSet> result; |
| result.reserve(ndim); |
| for (int i = 0; i < ndim; ++i) { |
| const IterSumExpr& sum_expr = iter_sum_exprs[i]; |
| const Range& range = region[i]; |
| Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer); |
| if (int_set.defined()) { |
| result.push_back(int_set.value()); |
| } else { |
| return NullOpt; |
| } |
| } |
| return result; |
| } |
| |
| Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region, |
| const Map<Var, Range>& var_dom, |
| const PrimExpr& predicate, |
| arith::Analyzer* analyzer) { |
| return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); |
| } |
| |
| Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region, const Map<Var, Range>& var_dom, |
| const PrimExpr& predicate, Analyzer* analyzer) { |
| if (Optional<Array<arith::IntSet>> result = EstimateRegionStrictBound( |
| /*region=*/region, |
| /*var_dom=*/var_dom, |
| /*predicate=*/predicate, /*analyzer=*/analyzer)) { |
| return result.value(); |
| } |
| Array<IntSet> result; |
| result.reserve(region.size()); |
| // try estimate each dimension independently |
| for (const Range& range : region) { |
| auto res = DetectIterMap( |
| /*indices=*/{range->min}, /*input_iters=*/var_dom, |
| /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); |
| if (!res->indices.empty()) { |
| ICHECK_EQ(res->indices.size(), 1U); |
| IterSumExpr sum_expr = res->indices[0]; |
| |
| // dynamic extent is not supported yet. |
| PrimExpr extent = range->extent; |
| if (!is_const_number(extent)) { |
| IntSet relaxed = EvalSet(extent, AsIntSet(var_dom)); |
| ICHECK(relaxed.HasUpperBound()); |
| extent = relaxed.max(); |
| } |
| |
| if (Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { |
| result.push_back(int_set.value()); |
| continue; |
| } |
| } |
| // fallback to coarse grained evalset |
| result.push_back(EvalSet(range, AsIntSet(var_dom))); |
| } |
| return result; |
| } |
| |
| TVM_REGISTER_NODE_TYPE(IntervalSetNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const IntervalSetNode*>(node.get()); |
| p->stream << "IntervalSet" |
| << "[" << op->min_value << ", " << op->max_value << ']'; |
| }); |
| |
| TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint); |
| |
| TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector); |
| |
| TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval); |
| |
| TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); |
| |
| TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); |
| |
| TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing); |
| |
| TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); |
| |
| TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") |
| .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom, |
| PrimExpr predicate) -> Optional<Array<IntSet>> { |
| Analyzer analyzer; |
| return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); |
| }); |
| TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") |
| .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom, |
| PrimExpr predicate) -> Optional<Array<IntSet>> { |
| Analyzer analyzer; |
| return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); |
| }); |
| TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") |
| .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom, |
| PrimExpr predicate) -> Optional<Array<IntSet>> { |
| Analyzer analyzer; |
| return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); |
| }); |
| |
| TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); |
| TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); |
| TVM_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); |
| |
| } // namespace arith |
| } // namespace tvm |