blob: 3b6af0644fc9d45b3ecb3b0435db79cc14d1b1fa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bounds_checker.cc
*/
// Instrument checkers for out of the bounds access.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <utility>
#include <vector>
namespace tvm {
namespace tir {
class BoundCollector : public StmtVisitor {
public:
BoundCollector() {}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::buffer_bound) {
if (const VarNode* key = op->node.as<VarNode>()) {
mem_to_shape[key] = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
// Hashtable which maps buffer_var to shape.
std::unordered_map<const VarNode*, PrimExpr> mem_to_shape;
};
class BoundChecker : public StmtExprMutator {
public:
explicit BoundChecker(const std::unordered_map<const VarNode*, PrimExpr>& mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt VisitStmt_(const AllocateNode* op) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->dtype);
}
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (process_store_ && op->op.same_as(builtin::if_then_else())) {
unsafe_rewritten_ = true;
}
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const StoreNode* op) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
StmtExprMutator::VisitStmt_(op);
process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
// The collector should has at least one item.
if (store_scope_bound_collector_.size()) {
PrimExpr condition = MakeCondition();
if (!condition.as<StringImmNode>()) {
Stmt nop = Evaluate(1);
Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate);
Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
Stmt body = IfThenElse(condition, then_case, else_case);
return body;
}
}
return GetRef<Stmt>(op);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
return StmtExprMutator::VisitExpr_(op);
}
private:
bool UpdateIsNeeded(const Var& buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
void Update(const Var& buffer_var, const Array<PrimExpr>& new_shape, const DataType& type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
}
for (size_t i = 0; i < new_shape.size(); ++i) {
if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() ||
is_negative_const(new_shape[i])) {
return;
}
}
// Scalarize the shape.
PrimExpr shape =
Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()),
Cast(DataType::UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}
bool IndexIsValid(const PrimExpr& index) const {
if (!index.defined()) {
return false;
}
if (const RampNode* ramp_index = index.as<RampNode>()) {
return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() &&
ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() &&
(ramp_index->lanes > 0);
}
return true;
}
bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) &&
!unsafe_rewritten_;
}
void Collect(PrimExpr index, Var buffer_var) {
store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()]));
}
PrimExpr MakeCondition() {
PrimExpr condition;
for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
std::pair<PrimExpr, PrimExpr> buffer_to_mem = store_scope_bound_collector_[i];
PrimExpr index = buffer_to_mem.first;
PrimExpr upper_bound = buffer_to_mem.second;
if (const RampNode* ramp_index = index.as<RampNode>()) {
// In case index is base + stride * i.
// Non inclusive range.
index = Add(ramp_index->base, Mul(ramp_index->stride, make_const(ramp_index->stride.dtype(),
ramp_index->lanes - 1)));
}
// Try to simplify index and bound.
index = analyzer_.Simplify(index);
upper_bound = analyzer_.Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
index = Cast(DataType::Int(64), index);
upper_bound = Cast(DataType::Int(64), upper_bound);
// Looks like a lower bound should always be zero after normalization.
PrimExpr lower_bound = make_zero(DataType::Int(64));
PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound));
condition = !i ? current_condition : And(condition, current_condition);
}
return condition;
}
// Whether we process store value recursively.
bool process_store_{false};
// Whether we face tvm_if_then_else intrinsic.
bool unsafe_rewritten_{false};
// Pool which collects the pair of index and shape for specific store/load.
std::vector<std::pair<PrimExpr, PrimExpr>> store_scope_bound_collector_;
// Error message.
const char* const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const VarNode*, PrimExpr> mem_to_shape_;
// internal analyzer
arith::Analyzer analyzer_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
}
namespace transform {
Pass InstrumentBoundCheckers() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector(n->body);
n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
.set_body_typed(InstrumentBoundCheckers);
} // namespace transform
} // namespace tir
} // namespace tvm