| /*! |
| * Copyright (c) 2018 by Contributors |
| * \file bounds_checker.cc |
| */ |
| // Instrument checkers for out of the bounds access. |
| |
| #include <tvm/ir.h> |
| #include <tvm/ir_mutator.h> |
| #include <tvm/ir_pass.h> |
| #include <tvm/ir_visitor.h> |
| #include <vector> |
| #include <unordered_map> |
| #include <utility> |
| |
| namespace tvm { |
| namespace ir { |
| |
| class BoundCollector : public IRVisitor { |
| public: |
| BoundCollector() {} |
| |
| void Visit_(const AttrStmt *op) { |
| if (op->attr_key == ir::attr::buffer_bound) { |
| if (const Variable *key = op->node.as<Variable>()) { |
| mem_to_shape[key] = op->value; |
| } |
| } |
| IRVisitor::Visit_(op); |
| } |
| // Hashtable which maps buffer_var to shape. |
| std::unordered_map<const Variable *, Expr> mem_to_shape; |
| }; |
| |
| class BoundChecker : public IRMutator { |
| public: |
| explicit BoundChecker( |
| const std::unordered_map<const Variable *, Expr> &mem_to_shape) |
| : mem_to_shape_(mem_to_shape) {} |
| |
| Stmt Mutate_(const Allocate *op, const Stmt &s) final { |
| // If the shape was updated we should update the hashtable. |
| if (UpdateIsNeeded(op->buffer_var)) { |
| Update(op->buffer_var, op->extents, op->type); |
| } |
| return IRMutator::Mutate_(op, s); |
| } |
| |
| Expr Mutate_(const Call *op, const Expr &ex) final { |
| if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { |
| unsafe_rewritten_ = true; |
| } |
| return IRMutator::Mutate_(op, ex); |
| } |
| |
| Stmt Mutate_(const Store *op, const Stmt &s) final { |
| store_scope_bound_collector_.clear(); |
| process_store_ = true; |
| unsafe_rewritten_ = false; |
| IRMutator::Mutate_(op, s); |
| process_store_ = false; |
| if (CanInstrument(op->index, op->buffer_var)) { |
| Collect(op->index, op->buffer_var); |
| } |
| // The collector should has at least one item. |
| if (store_scope_bound_collector_.size()) { |
| Expr condition = MakeCondition(); |
| if (!condition.as<StringImm>()) { |
| Stmt nop = Evaluate::make(1); |
| Stmt then_case = |
| Store::make(op->buffer_var, op->value, op->index, op->predicate); |
| Stmt else_case = |
| AssertStmt::make(condition, StringImm::make(error_message_), nop); |
| Stmt body = IfThenElse::make(condition, then_case, else_case); |
| return body; |
| } |
| } |
| return s; |
| } |
| |
| Expr Mutate_(const Load *op, const Expr &ex) final { |
| if (CanInstrument(op->index, op->buffer_var)) { |
| Collect(op->index, op->buffer_var); |
| } |
| return IRMutator::Mutate_(op, ex); |
| } |
| |
| private: |
| bool UpdateIsNeeded(const VarExpr &buffer_var) const { |
| return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); |
| } |
| |
| void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape, |
| const Type &type) { |
| // Sanity check at first. |
| if (!new_shape.size()) { |
| return; |
| } |
| |
| for (size_t i = 0; i < new_shape.size(); ++i) { |
| if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() || |
| is_negative_const(new_shape[i])) { |
| return; |
| } |
| } |
| |
| // Scalarize the shape. |
| Expr shape = Mul::make(make_const(UInt(64), type.lanes()), |
| Cast::make(UInt(64), new_shape[0])); |
| for (size_t i = 1; i < new_shape.size(); ++i) { |
| // Cast to unsigned to avoid integer overlow at frist. |
| shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()), |
| Cast::make(UInt(64), new_shape[i]))); |
| } |
| mem_to_shape_[buffer_var.get()] = shape; |
| } |
| |
| bool IndexIsValid(const Expr &index) const { |
| if (!index.defined()) { |
| return false; |
| } |
| |
| if (const Ramp *ramp_index = index.as<Ramp>()) { |
| return ramp_index->base.defined() && |
| ramp_index->base.type().is_scalar() && |
| ramp_index->stride.defined() && |
| ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0); |
| } |
| return true; |
| } |
| |
| bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const { |
| return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && |
| IndexIsValid(index) && !unsafe_rewritten_; |
| } |
| |
| void Collect(Expr index, VarExpr buffer_var) { |
| store_scope_bound_collector_.push_back( |
| std::make_pair(index, mem_to_shape_[buffer_var.get()])); |
| } |
| |
| Expr MakeCondition() { |
| Expr condition; |
| for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) { |
| std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i]; |
| Expr index = buffer_to_mem.first; |
| Expr upper_bound = buffer_to_mem.second; |
| |
| if (const Ramp *ramp_index = index.as<Ramp>()) { |
| // In case index is base + stride * i. |
| // Non inclusive range. |
| index = Add::make( |
| ramp_index->base, |
| Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(), |
| ramp_index->lanes - 1))); |
| } |
| |
| // Try to simplify index and bound. |
| index = ir::Simplify(index); |
| upper_bound = ir::Simplify(upper_bound); |
| |
| // Cast to the same type - signed, to be able to check lower bound. |
| index = Cast::make(Int(64), index); |
| upper_bound = Cast::make(Int(64), upper_bound); |
| |
| // Looks like a lower bound should always be zero after normalization. |
| Expr lower_bound = make_zero(Int(64)); |
| |
| Expr current_condition = |
| And::make(GE::make(index, lower_bound), LT::make(index, upper_bound)); |
| condition = |
| !i ? current_condition : And::make(condition, current_condition); |
| } |
| return condition; |
| } |
| |
| // Whether we process store value recursively. |
| bool process_store_{false}; |
| // Whether we face tvm_if_then_else intrinsic. |
| bool unsafe_rewritten_{false}; |
| // Pool which collects the pair of index and shape for specific store/load. |
| std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_; |
| // Error message. |
| const char *const error_message_ = "OUT OF THE BOUNDS"; |
| // Hashtable which maps buffer_var to shape. |
| std::unordered_map<const Variable *, Expr> mem_to_shape_; |
| }; |
| |
| Stmt InstrumentBoundCheckers(Stmt stmt) { |
| BoundCollector bound_collector; |
| // At first walk recursively and collect bound attributes. |
| bound_collector.Visit(stmt); |
| return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt); |
| } |
| } // namespace ir |
| } // namespace tvm |