blob: daf61441b46602b39896edf732d2b63c1becd324 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/analyzer.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this),
int_set(this) {}
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override);
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
}
void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
}
void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second, allow_override);
}
}
void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
// entering the scope.
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
// recovery function.
exit_ = [f0, f1, f2]() {
if (f2 != nullptr) f2();
if (f1 != nullptr) f1();
if (f0 != nullptr) f0();
};
}
void ConstraintContext::ExitWithScope() {
CHECK(exit_ != nullptr);
exit_();
}
bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
return false;
}
bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value < upper_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->max_value < upper_bound) return true;
return false;
}
bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<IntImmNode>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
if (const auto* ptr = res.as<IntImmNode>()) {
return ptr->value != 0;
}
return false;
}
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
if (tir::is_const_int(expr)) return expr;
PrimExpr res = expr;
for (int i = 0; i < steps; ++i) {
res = this->rewrite_simplify(res);
if (tir::is_const_int(res) || ++i == steps) return res;
res = this->canonical_simplify(res);
if (tir::is_const_int(res)) return res;
}
return res;
}
TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> PackedFunc {
if (name == "const_int_bound") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); });
} else if (name == "modular_set") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); });
} else if (name == "const_int_bound_update") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "Simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = self->Simplify(args[0]);
} else if (args.size() == 2) {
*ret = self->Simplify(args[0], args[1]);
} else {
LOG(FATAL) << "Invalid size of argument (" << args.size() << ")";
}
});
} else if (name == "rewrite_simplify") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
} else if (name == "canonical_simplify") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });
} else if (name == "int_set") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); });
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
if (args[1].IsObjectRef<Range>()) {
self->Bind(args[0], args[1].operator Range());
} else {
self->Bind(args[0], args[1].operator PrimExpr());
}
});
} else if (name == "enter_constraint_context") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx = std::shared_ptr<With<ConstraintContext> >(
new With<ConstraintContext>(self.get(), args[0]));
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); };
*ret = PackedFunc(fexit);
});
}
return PackedFunc();
};
*ret = TypedPackedFunc<PackedFunc(std::string)>(f);
});
} // namespace arith
} // namespace tvm