blob: 189869bd64e7630bb9c838b9aeab2ed4a887a435 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_constraints.cc
* \brief The integer constraints data structures.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <algorithm>
#include <unordered_map>
#include <utility>
#include "../tir/transforms/ir_util.h"
namespace tvm {
namespace arith {
Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
const Array<PrimExpr>& relations) {
Array<PrimExpr> res;
// use variables to keep the order of iteration
// so as to get rid of any non-determinism.
CHECK_EQ(variables.size(), bounds.size());
for (const auto v : variables) {
CHECK(bounds.count(v));
const auto& bnds = bounds[v];
PrimExpr lhs = bnds->coef * v;
for (const PrimExpr& rhs : bnds->equal) {
res.push_back(tir::EQ(lhs, rhs));
}
for (const PrimExpr& rhs : bnds->lower) {
res.push_back(tir::GE(lhs, rhs));
}
for (const PrimExpr& rhs : bnds->upper) {
res.push_back(tir::LE(lhs, rhs));
}
}
for (const PrimExpr& e : relations) {
res.push_back(e);
}
return res;
}
IntGroupBounds::IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
Array<PrimExpr> upper) {
CHECK(coef.dtype().is_int() || coef.dtype().is_uint())
<< "Coefficient in IntGroupBounds must be integers";
ObjectPtr<IntGroupBoundsNode> node = make_object<IntGroupBoundsNode>();
node->coef = std::move(coef);
node->lower = std::move(lower);
node->equal = std::move(equal);
node->upper = std::move(upper);
data_ = std::move(node);
}
IntGroupBounds IntGroupBounds::FromRange(const Range& r) {
Analyzer analyzer;
PrimExpr coef = tir::make_const(r->min.dtype(), 1);
Array<PrimExpr> equal;
Array<PrimExpr> lower;
Array<PrimExpr> upper;
if (tir::is_one(r->extent)) {
equal.push_back(r->min);
} else {
lower.push_back(r->min);
upper.push_back(analyzer.Simplify(r->min + r->extent - 1));
}
return IntGroupBounds(coef, lower, equal, upper);
}
IntGroupBounds IntGroupBounds::operator+(const Range& r) {
Analyzer analyzer;
Array<PrimExpr> equal;
Array<PrimExpr> lower;
Array<PrimExpr> upper;
const PrimExpr& coef = operator->()->coef;
if (tir::is_one(r->extent)) {
equal.push_back(analyzer.Simplify(r->min * coef));
} else {
lower.push_back(analyzer.Simplify(r->min * coef));
upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef));
}
for (const auto& eq : operator->()->equal) equal.push_back(eq);
for (const auto& lb : operator->()->lower) lower.push_back(lb);
for (const auto& ub : operator->()->upper) upper.push_back(ub);
return IntGroupBounds(coef, lower, equal, upper);
}
IntGroupBounds IntGroupBounds::Substitute(const Map<Var, PrimExpr>& subst) const {
auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); };
return IntGroupBounds(tir::Substitute(operator->()->coef, subst),
tir::UpdateArray(operator->()->lower, apply_fun),
tir::UpdateArray(operator->()->equal, apply_fun),
tir::UpdateArray(operator->()->upper, apply_fun));
}
Range IntGroupBounds::FindBestRange(const Map<Var, Range>& vranges_addl) const {
Analyzer analyzer;
analyzer.Bind(vranges_addl);
std::unordered_map<const VarNode*, IntSet> var_intsets;
for (auto kv : vranges_addl) {
var_intsets[kv.first.get()] = IntSet::FromRange(kv.second);
}
const Array<PrimExpr>& equal = operator->()->equal;
const PrimExpr& coef = operator->()->coef;
std::vector<PrimExpr> lowers(equal.begin(), equal.end());
std::vector<PrimExpr> uppers(equal.begin(), equal.end());
for (const auto& expr : operator->()->lower) {
lowers.push_back(expr);
}
for (const auto& expr : operator->()->upper) {
uppers.push_back(expr);
}
if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) {
return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1));
}
// Here we will try all pairs of lower and upper bounds and find the best pair, that is, the
// pair with the minimal difference between the upper and the lower.
// Note that the bounds are for v, not for v*coef
// The lower bound of the best pair so far
PrimExpr best_lower;
// The difference between the upper and the lower of the best pair, maybe overapproximation
PrimExpr best_diff_over;
for (const PrimExpr& low : lowers) {
for (const PrimExpr& upp : uppers) {
PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
// Since diff may depend on some other variables, we compute its overapproximation
PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3);
// low is the lower bound for v*coef, but we need the lower bound for v.
// We use rounding-up division to compute it. Since we want to use a single formula
PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3);
// Compute another difference which may be more precise (or not).
PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3);
PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3);
PrimExpr diff_over =
analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1;
// If it is provable that the new one is strictly better than the current best one,
// then replace it. Note that we are biased towards earlier pairs which should be simpler.
if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) {
best_lower = low_divided;
best_diff_over = diff_over;
}
}
}
if (!best_lower.defined()) {
CHECK(!best_diff_over.defined());
return Range();
}
return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1));
}
TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode);
TVM_REGISTER_GLOBAL("arith.IntGroupBounds")
.set_body_typed([](PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
Array<PrimExpr> upper) {
return IntGroupBounds(coef, lower, equal, upper);
});
TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange);
TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args.size() == 1 || args.size() == 2);
IntGroupBounds bounds = args[0];
if (args.size() == 1) {
*ret = bounds.FindBestRange();
} else if (args.size() == 2) {
*ret = bounds.FindBestRange(args[1]);
}
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntGroupBoundsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntGroupBoundsNode*>(node.get());
p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower
<< ", equal=" << op->equal << ", upper=" << op->upper << ")";
});
IntConstraints::IntConstraints(Array<Var> variables, Map<Var, Range> ranges,
Array<PrimExpr> relations) {
ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
if (!variables.defined()) {
variables = Array<Var>();
}
if (!ranges.defined()) {
ranges = Map<Var, Range>();
}
CHECK(relations.defined());
for (const auto& var : variables) {
CHECK(var.dtype().is_int() || var.dtype().is_uint())
<< "Variables in IntConstraints must be integers";
}
node->variables = std::move(variables);
node->ranges = std::move(ranges);
node->relations = std::move(relations);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
TVM_REGISTER_GLOBAL("arith.IntConstraints")
.set_body_typed([](Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations) {
return IntConstraints(variables, ranges, relations);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsNode*>(node.get());
p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations
<< ")";
});
IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
node->src = std::move(src);
node->dst = std::move(dst);
node->src_to_dst = std::move(src_to_dst);
node->dst_to_src = std::move(dst_to_src);
data_ = std::move(node);
}
IntConstraintsTransform IntConstraintsTransform::operator+(
const IntConstraintsTransform& other) const {
CHECK(other->src.same_as(operator->()->dst));
Map<Var, PrimExpr> dst_to_src;
Map<Var, PrimExpr> src_to_dst;
Analyzer ana_first;
ana_first.Bind(operator->()->src->ranges);
for (auto p : other->dst_to_src) {
dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src)));
}
Analyzer ana_second;
ana_second.Bind(other->dst->ranges);
for (auto p : operator->()->src_to_dst) {
src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst)));
}
return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src);
}
TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform")
.set_body_typed([](IntConstraints src, IntConstraints dst, Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
p->stream << "IntConstraintsTransform("
<< "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t"
<< op->dst_to_src << "\n)";
});
} // namespace arith
} // namespace tvm