blob: 9826769a5c6542fc30aa0240788d7515ec744b3d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file modular_set.cc
* \brief Modular set analysis
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
#include <limits>
#include <unordered_map>
#include <utility>
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace tir;
TVM_REGISTER_NODE_TYPE(ModularSetNode);
ModularSet::ModularSet(int64_t coeff, int64_t base) {
auto node = make_object<ModularSetNode>();
node->coeff = coeff;
node->base = base;
// finish construction.
data_ = std::move(node);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base=" << op->base << ')';
});
ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); }
TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet);
// internal entry for const int bound
struct ModularSetAnalyzer::Entry {
int64_t coeff{1};
int64_t base{0};
Entry() = default;
Entry(int64_t coeff, int64_t base) {
CHECK_GE(coeff, 0);
this->coeff = coeff;
if (coeff != 0) {
base = base % coeff;
if (base < 0) base += coeff;
}
this->base = base;
}
bool is_const() const { return coeff == 0; }
bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; }
bool operator==(const ModularSet& other) const {
return other.defined() && coeff == other->coeff && base == other->base;
}
};
class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> {
public:
explicit Impl(Analyzer* parent) : parent_(parent) {}
void Update(const Var& var, const ModularSet& info, bool allow_override) {
if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
<< " with a different const bound: "
<< "original=" << ModularSet(it->second.coeff, it->second.base)
<< ", new=" << info;
}
}
var_map_[var] = Entry(info->coeff, info->base);
}
// Detect useful constraints and use them in the analysis scope.
std::function<void()> EnterConstraint(const PrimExpr& constraint) {
PVar<Var> var;
PVar<IntImm> coeff, base;
// pattern match interesting constraints
if ((truncmod(var, coeff) == base).Match(constraint) ||
(floormod(var, coeff) == base).Match(constraint)) {
Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
return nullptr;
}
// Override visitor behaviors
Entry VisitExprDefault_(const Object* op) final { return Everything(); }
Entry VisitExpr_(const LetNode* op) final {
auto it = var_map_.find(op->var);
// if the var has not been binded, update the info.
if (it == var_map_.end()) {
var_map_[op->var] = this->VisitExpr(op->value);
Entry ret = VisitExpr(op->body);
var_map_.erase(op->var);
return ret;
} else {
return VisitExpr(op->body);
}
}
Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}
Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}
Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff;
int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
return Entry(coeff, a.base * b.base);
}
Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) {
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
if (a.base == 0) {
// a c x / c -> a x
return Entry(std::abs(a.coeff / val), 0);
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
return Entry(a.coeff / val, a.base / val);
}
}
return Everything();
}
Entry VisitExpr_(const DivNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, false);
}
return Everything();
}
Entry VisitExpr_(const FloorDivNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, true);
}
return Everything();
}
Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value);
return Union(a, b);
}
Entry VisitExpr_(const CallNode* op) final {
// only special handle >> which can be
// used for index calculation.
if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else {
return Everything();
}
}
Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
return it->second;
} else {
return Everything();
}
}
Entry VisitRightShift(const CallNode* op) {
Entry b = VisitExpr(op->args[1]);
// a c x / c -> a x
if (b.is_const()) {
return DivByConst(op->args[0], static_cast<int64_t>(1) << b.base, true);
}
return Everything();
}
private:
/*! \brief pointer to parent. */
Analyzer* parent_{nullptr};
// internal variable map
std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_;
/*!
* \brief Update var by intersecting entry with var's current set.
* \param var The variable.
* \param entry The entry to be updated.
* \return The recovery function of the scope.
*/
std::function<void()> UpdateByIntersect(const Var& var, Entry entry) {
Entry old = Everything();
auto it = var_map_.find(var);
if (it != var_map_.end()) {
old = it->second;
}
var_map_[var] = Intersect(old, entry);
// reover function.
return [this, old, var]() { var_map_[var] = old; };
}
/*!
* \brief Create union of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Union(Entry a, Entry b) {
// {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}}
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
if (coeff == 0) {
if (a.base == b.base) return a;
return Everything();
}
int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff;
if (base0 == base1) {
return Entry(coeff, base0);
} else {
return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
}
}
/*!
* \brief Create interect of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Intersect(Entry a, Entry b) {
int64_t x, y;
int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base;
// z = c1 * p + b1
// z = c2 * q + b2
// c1 * x + c2 * y = gcd(c1, c2)
// -> c1 * p - c2 * q = b2 - b1
// -> p = (b2 - b1) / gcd * x
// -> q = (b2 - b1) / gcd * (-y)
// -> z = LCM(x, y) * k + (c1 * p + b1)
int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y);
int64_t v = b2 - b1;
if (v % gcd == 0) {
x = v / gcd * x;
y = v / gcd * (-y);
int64_t coeff = c1 / gcd * c2;
return Entry(coeff, x * c1 + b1);
} else {
return Nothing();
}
}
/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything() { return Entry(1, 0); }
/*!
* \brief return an empty set
* \return Bound that represent everything dtype can represent.
*/
static Entry Nothing() { return Entry(0, 1); }
};
ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
Entry ret = impl_->VisitExpr(expr);
return ModularSet(ret.coeff, ret.base);
}
void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}
std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
return impl_->EnterConstraint(constraint);
}
ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; }
} // namespace arith
} // namespace tvm