blob: d4a3101378b0b2ccf0197e299639f49df92fcf9f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <unordered_map>
#include <unordered_set>
#include "interval_set.h"
namespace tvm {
namespace arith {
using namespace tir;
// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder : public ExprVisitor {
public:
explicit VariablePathFinder(PrimExpr target) : target_(target) {}
void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
ExprVisitor::VisitExpr(node);
if (!found_) path_.pop_back();
}
std::vector<const Object*> path_;
private:
bool found_{false};
PrimExpr target_;
std::unordered_set<const Object*> visited_;
};
// get the path to the variable,
// return empty vector to represent failure
std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) {
VariablePathFinder v(target);
v(expr);
return v.path_;
}
enum CompareOp { kGreater, kLess, kEqual };
// a visitor to deduce the bound of a variable from a expression
class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(PrimExpr target, PrimExpr expr,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
void Deduce();
void VisitExpr(const PrimExpr& e) final {
if (!success_) return;
if (iter_ < path_.size() && e.get() == path_[iter_++]) {
ExprFunctor::VisitExpr(e);
} else {
success_ = false;
return;
}
}
void VisitExprDefault_(const Object* op) final { success_ = false; }
void VisitExpr_(const VarNode* op) final {}
void VisitExpr_(const AddNode* op) final {
bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a;
this->VisitExpr(left ? op->a : op->b);
}
void VisitExpr_(const SubNode* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result_ += op->b;
} else {
result_ -= op->a;
result_ = -result_;
comp_op = ReverseOp(comp_op);
}
this->VisitExpr(left ? op->a : op->b);
}
void VisitExpr_(const MulNode* op) final {
bool left = op->a.get() == path_[iter_];
PrimExpr operand = left ? op->b : op->a;
PrimExpr target_var = left ? op->a : op->b;
SignType sign_operand;
if (operand.dtype().is_uint()) {
sign_operand = kPositive;
} else {
sign_operand = expr_map_[operand].GetSignType();
}
if (sign_operand == SignType::kNegative) {
comp_op = ReverseOp(comp_op);
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success_ = false;
return;
}
// always use relax bound
bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
result_ = floordiv(result_, operand); // rounding down here
if (!divided) {
if (comp_op == kGreater) {
// System will round down in all the cases, so add one for result_ for kGreater
// (x >= 3/2 --> x >= 2)
// (x >= -3/2 --> x >= -1)
// (x >= 3/-2 --> x >= -1)
// (x >= -3/-2 --> x >= 2)
result_ += 1;
} else if (comp_op == kEqual) {
// condition unsatisfiable as with floor div, it will change the expression
success_ = false;
return;
} else {
// System rounds down in all cases, do nothing for kLess.
// ( x <= 3/2 --> x <= 1)
// ( x <= -3/2 --> x <= -2)
// ( x <= 3/-2 --> x <= -2)
// ( x <= -3/-2 --> x <= 1)
}
}
this->VisitExpr(left ? op->a : op->b);
}
PrimExpr result_;
CompareOp comp_op{kGreater};
bool success_{true};
private:
void Init();
void Transform();
void Relax();
CompareOp ReverseOp(CompareOp comp_op);
PrimExpr target_;
PrimExpr expr_;
const std::unordered_map<const VarNode*, IntSet>& hint_map_;
const std::unordered_map<const VarNode*, IntSet>& relax_map_;
ExprIntSetMap expr_map_;
std::vector<const Object*> path_;
size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
};
class BoundDeduceInputChecker : public ExprVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
this->VisitExpr(deducer_->expr_);
return target_count == 1;
}
void VisitExpr(const PrimExpr& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
ExprVisitor::VisitExpr(e);
}
private:
BoundDeducer* deducer_;
size_t target_count{0};
};
void BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success_ = false;
Transform();
}
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
switch (comp_op) {
case kEqual:
return kEqual; // IntSet can not represent range for `NE
case kGreater:
return kLess;
case kLess:
return kGreater;
default:
LOG(FATAL) << "Not a valid compare op";
}
}
void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_
if (const LTNode* op = expr_.as<LTNode>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
comp_op = kGreater;
expr_ = op->b;
result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
comp_op = kLess;
expr_ = op->a;
result_ = op->b - 1;
}
} else if (const LENode* op = expr_.as<LENode>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
comp_op = kGreater;
expr_ = op->b;
result_ = op->a;
} else {
comp_op = kLess;
expr_ = op->a;
result_ = op->b;
}
} else if (const GTNode* op = expr_.as<GTNode>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
comp_op = kLess;
expr_ = op->b;
result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
comp_op = kGreater;
expr_ = op->a;
result_ = op->b + 1;
}
} else if (const GENode* op = expr_.as<GENode>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
comp_op = kLess;
expr_ = op->b;
result_ = op->a;
} else {
comp_op = kGreater;
expr_ = op->a;
result_ = op->b;
}
} else if (const EQNode* op = expr_.as<EQNode>()) {
comp_op = kEqual;
if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b
expr_ = op->b;
result_ = op->a;
} else {
expr_ = op->a;
result_ = op->b;
}
} else {
success_ = false;
}
}
void BoundDeducer::Deduce() {
Init();
if (!success_) return;
Relax();
if (!success_) return;
// get the path
path_ = GetPath(target_, expr_);
if (!path_.size()) {
success_ = false;
return;
}
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
this->VisitExpr(expr_);
}
void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result_, relax_map_);
if (a.IsEverything() || b.IsEverything()) {
success_ = false;
return;
}
// Both LHS and RHS of the EQ should behave as constants e.g. i == j,
// can not be resolved when either `i` or `j` or both are variables with
// some Range OR `i` and `j` both should be a single point in IntSet
if (comp_op == kEqual &&
(!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) {
success_ = false;
return;
}
expr_ = (comp_op == kGreater) ? a.min() : a.max();
result_ = (comp_op == kGreater) ? b.max() : b.min();
}
IntSet DeduceBound(PrimExpr v, PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::Nothing();
PrimExpr min = neg_inf(), max = pos_inf();
if (d.comp_op == kEqual) {
min = d.result_;
max = d.result_;
} else if (d.comp_op == kGreater) {
min = d.result_;
} else {
max = d.result_;
}
return IntSet::Interval(min, max);
}
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
std::unordered_map<const VarNode*, IntSet> hmap;
for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second;
}
std::unordered_map<const VarNode*, IntSet> rmap;
for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second;
}
return DeduceBound(v, e, hmap, rmap);
}
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed([](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map) {
return DeduceBound(v, cond, hint_map, relax_map);
});
} // namespace arith
} // namespace tvm