blob: f6f0b9f4d8df7266537cc8fe353bf0e345397287 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/analyzer.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include "./scalable_expression.h"
#include "const_fold.h"
#include "product_normal_form.h"
namespace tvm {
namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this),
int_set(this) {}
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override);
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
this->transitive_comparisons.Bind(var, expr, allow_override);
}
void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
ICHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
this->transitive_comparisons.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
}
void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
// decompose value as symbol * scale + offset
int64_t offset = 0;
PrimExpr symbol_scale = tir::make_const(value.dtype(), 0);
auto fcollect_sum = [&](PrimExpr val, int sign) {
if (const auto* intimm = val.as<IntImmNode>()) {
offset += intimm->value * sign;
} else {
if (sign > 0) {
symbol_scale = symbol_scale + val;
} else {
symbol_scale = symbol_scale - val;
}
}
};
UnpackSum(value, fcollect_sum);
// split out the symbol and non-symbolic part
int64_t cscale = 1;
PrimExpr symbol = tir::make_const(value.dtype(), 1);
auto fcollect_prod = [&](PrimExpr val) {
if (const auto* intimm = val.as<IntImmNode>()) {
cscale *= intimm->value;
} else {
symbol = symbol * val;
}
};
UnpackReduction<tir::MulNode>(symbol_scale, fcollect_prod);
if (cscale <= 0) return;
// override the constant int bound by marking it as non-negative
// NOTE: there might be future opportunities of more bound hint
// this is a simple step and covers all the current needs
//
// We may consider enhance the sub analyzer to directly take
// MarkPositiveVar so their bounds do not overlap
if (const auto* var_ptr = symbol.as<VarNode>()) {
Var var = ffi::GetRef<Var>(var_ptr);
// skip non-index type, keep it to be compatible
// with any_dim that do not represent any value
if (!IsIndexType(var.dtype())) return;
bool allow_override = true;
// mark the constant bound is sufficient
// we cannot mark interval set as that will cause relaxation of the var
// during bound proof which is not our intention
this->const_int_bound.Update(var, ConstIntBound(-offset, ConstIntBound::kPosInf),
allow_override);
}
}
void Analyzer::Bind(const ffi::Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second, allow_override);
}
}
void ConstraintContext::EnterWithScope() {
ICHECK(recovery_functions_.size() == 0);
// entering the scope.
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
}
void ConstraintContext::ExitWithScope() {
while (recovery_functions_.size()) {
auto& func = recovery_functions_.back();
if (func) {
func();
}
recovery_functions_.pop_back();
}
}
bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
return false;
}
bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value < upper_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->max_value < upper_bound) return true;
return false;
}
bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (clhs && crhs) return clhs->value == crhs->value;
if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) {
return lhs.same_as(rhs);
}
return CanProve(lhs - rhs == 0);
}
bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, const PrimExpr& shape) {
if (this->CanProve(lhs <= shape, ProofStrength::kSymbolicBound)) return true;
// no need to do further attempt if shape is already a constant.
if (tir::is_const_int(shape)) return false;
// collect constant scale and ignore symbolic part
// so 32 * n => cscale = 32
int64_t cscale = 1;
auto fcollect = [&](const PrimExpr& expr) {
if (auto* ptr = expr.as<IntImmNode>()) {
cscale *= ptr->value;
}
};
UnpackReduction<tir::MulNode>(shape, fcollect);
PrimExpr const_shape_bound = IntImm(shape.dtype(), std::abs(cscale));
if (this->CanProve(lhs <= const_shape_bound, ProofStrength::kSymbolicBound)) return true;
return false;
}
bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// Avoid potentially expensive simplification unless required.
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
}
PrimExpr simplified = Simplify(expr);
const int64_t* as_int = tir::as_const_int(simplified);
if (as_int && *as_int) return true;
if (strength >= ProofStrength::kSymbolicBound) {
// NOTE: we intentionally only pattern match common bound predicate i < bound
// and put this implementation at the top-level.
// This is to avoid repeatitive calling of this function
// that causes speed issues.
// This strategy can only be called from top-level and not from sub-analyzers.
ffi::Optional<PrimExpr> pos_diff;
int lower_bound = 0;
if (const auto* ptr_lt = expr.as<tir::LTNode>()) {
pos_diff = ptr_lt->b - ptr_lt->a;
lower_bound = 1;
}
if (const auto* ptr_le = expr.as<tir::LENode>()) {
pos_diff = ptr_le->b - ptr_le->a;
lower_bound = 0;
}
if (const auto* ptr_gt = expr.as<tir::GTNode>()) {
pos_diff = ptr_gt->a - ptr_gt->b;
lower_bound = 1;
}
if (const auto* ptr_ge = expr.as<tir::GENode>()) {
pos_diff = ptr_ge->a - ptr_ge->b;
lower_bound = 0;
}
if (pos_diff) {
IntSet iset = this->int_set(this->Simplify(pos_diff.value()));
if (iset.HasLowerBound()) {
ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min()));
if (relaxed_lower_bound->min_value >= lower_bound) return true;
}
}
}
// Current analysis may not be powerful enough to prove expressions containing
// the same symbolic value multiple times. However, when the symbolic values are
// "T.vscale" and the compile target uses a scalable architecture extension like
// VLA, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
Target curr_target = Target::Current();
if (ContainsVscaleCall(simplified)) {
if (TargetHasVLA(curr_target)) {
auto kVScaleValues = GetVScaleValues(curr_target);
return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by substituting "
"with known values of vscale was not performed. This proof currently only supports "
"VLA targets, but the target was "
<< curr_target;
}
return false;
}
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
PrimExpr res = expr;
// Always starts with a canonical simplification, as some structural property
// of an expression might be destroyed by rewrite simplification.
res = this->canonical_simplify(res);
for (int i = 0; i < steps; ++i) {
if (tir::is_const_int(res)) {
return res;
}
if (i % 2 == 0) {
res = this->rewrite_simplify(res);
} else {
res = this->canonical_simplify(res);
}
}
return res;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs args, ffi::Any* ret) {
using ffi::Function;
using ffi::TypedFunction;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> ffi::Function {
if (name == "const_int_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound(args[0].cast<PrimExpr>());
});
} else if (name == "modular_set") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->modular_set(args[0].cast<PrimExpr>());
});
} else if (name == "const_int_bound_update") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
self->const_int_bound.Update(args[0].cast<Var>(), args[1].cast<ConstIntBound>(),
args[2].cast<bool>());
});
} else if (name == "const_int_bound_is_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound.IsBound(args[0].cast<Var>());
});
} else if (name == "Simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (args.size() == 1) {
*ret = self->Simplify(args[0].cast<PrimExpr>());
} else if (args.size() == 2) {
*ret = self->Simplify(args[0].cast<PrimExpr>(), args[1].cast<int>());
} else {
LOG(FATAL) << "Invalid size of argument (" << args.size() << ")";
}
});
} else if (name == "rewrite_simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->rewrite_simplify(args[0].cast<PrimExpr>());
});
} else if (name == "get_rewrite_simplify_stats") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->rewrite_simplify.GetStatsCounters();
});
} else if (name == "reset_rewrite_simplify_stats") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
self->rewrite_simplify.ResetStatsCounters();
});
} else if (name == "canonical_simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->canonical_simplify(args[0].cast<PrimExpr>());
});
} else if (name == "int_set") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->int_set(args[0].cast<PrimExpr>(), args[1].cast<ffi::Map<Var, IntSet>>());
});
} else if (name == "bind") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (auto opt_range = args[1].try_cast<Range>()) {
self->Bind(args[0].cast<Var>(), opt_range.value());
} else {
self->Bind(args[0].cast<Var>(), args[1].cast<PrimExpr>());
}
});
} else if (name == "can_prove") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
int strength = args[1].cast<int>();
*ret = self->CanProve(args[0].cast<PrimExpr>(), static_cast<ProofStrength>(strength));
});
} else if (name == "enter_constraint_context") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx = std::shared_ptr<With<ConstraintContext>>(
new With<ConstraintContext>(self.get(), args[0].cast<PrimExpr>()));
auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); };
*ret = ffi::Function::FromPacked(fexit);
});
} else if (name == "can_prove_equal") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->CanProveEqual(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>());
});
} else if (name == "get_enabled_extensions") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions());
});
} else if (name == "set_enabled_extensions") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
int64_t flags = args[0].cast<int64_t>();
self->rewrite_simplify.SetEnabledExtensions(
static_cast<RewriteSimplifier::Extension>(flags));
});
}
return ffi::Function();
};
*ret = ffi::TypedFunction<ffi::Function(std::string)>(f);
});
}
} // namespace arith
} // namespace tvm