| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file storage_flatten.cc |
| * \brief Flattens storage from multi-dimensional array to 1D buffer access |
| */ |
| // The pass definition originates from Halide pipeline. |
| |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/runtime/device_api.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/target/target_info.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/tir/buffer.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/stmt.h> |
| #include <tvm/tir/stmt_functor.h> |
| #include <tvm/tir/transform.h> |
| |
| #include <unordered_map> |
| |
| #include "../../arith/ir_visitor_with_analyzer.h" |
| #include "../../runtime/thread_storage_scope.h" |
| #include "arg_binder.h" |
| #include "ir_util.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| using runtime::StorageRank; |
| using runtime::StorageScope; |
| using runtime::ThreadScope; |
| |
| class StorageFlattener : public StmtExprMutator { |
| public: |
| explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map, int cache_line_size, |
| bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) |
| : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { |
| for (auto kv : extern_buffer_map) { |
| BufferEntry e; |
| e.buffer = kv.second; |
| e.external = true; |
| buf_map_[kv.second] = e; |
| } |
| cache_line_size_ = cache_line_size; |
| } |
| |
| Stmt VisitStmt_(const StoreNode* op) final { |
| Stmt stmt = StmtExprMutator::VisitStmt_(op); |
| op = stmt.as<StoreNode>(); |
| auto it = var_remap_.find(op->buffer_var.get()); |
| if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { |
| CHECK(it->second.as<VarNode>()); |
| Var buf_var = Downcast<Var>(it->second); |
| return Store(buf_var, op->value, op->index, op->predicate); |
| } else { |
| return stmt; |
| } |
| } |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| if (op->attr_key == attr::realize_scope) { |
| storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value; |
| return this->VisitStmt(op->body); |
| } else if (op->attr_key == attr::double_buffer_scope && |
| op->node->IsInstance<tir::BufferNode>()) { |
| auto buffer = Downcast<tir::Buffer>(op->node); |
| Stmt body = this->VisitStmt(op->body); |
| auto it = buf_map_.find(buffer); |
| CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; |
| body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); |
| return body; |
| } else if (op->attr_key == attr::thread_extent) { |
| IterVar iv = Downcast<IterVar>(op->node); |
| ThreadScope ts = ThreadScope::Create(iv->thread_tag); |
| curr_thread_scope_.push_back(ts); |
| Stmt stmt = StmtExprMutator::VisitStmt_(op); |
| curr_thread_scope_.pop_back(); |
| return stmt; |
| } else if (op->attr_key == attr::buffer_bind_scope) { |
| return HandleBufferBindScope(op); |
| } else if (op->attr_key == attr::buffer_dim_align) { |
| auto buffer = Downcast<tir::Buffer>(op->node); |
| const CallNode* tuple = op->value.as<CallNode>(); |
| CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); |
| auto& vinfo = dim_align_[buffer]; |
| int dim = tuple->args[0].as<IntImmNode>()->value; |
| if (static_cast<size_t>(dim) >= vinfo.size()) { |
| vinfo.resize(dim + 1); |
| } |
| vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value; |
| vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value; |
| return this->VisitStmt(op->body); |
| } |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* op) final { |
| if (create_bound_attributes_) shape_collector_.clear(); |
| Stmt stmt = StmtExprMutator::VisitStmt_(op); |
| op = stmt.as<BufferStoreNode>(); |
| |
| const auto& key = op->buffer; |
| |
| auto it = buf_map_.find(key); |
| CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; |
| |
| const BufferEntry& e = it->second; |
| CHECK(!e.released) << "Read a buffer that is already out of scope"; |
| |
| Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); |
| if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { |
| shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); |
| } |
| // To create bound attribute collector should has at least one item. |
| if (create_bound_attributes_ && shape_collector_.size()) { |
| for (size_t i = 0; i < shape_collector_.size(); ++i) { |
| body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound, |
| MakeBound(e.buffer->dtype, shape_collector_[i].second), body); |
| } |
| } |
| return body; |
| } |
| |
| Stmt VisitStmt_(const BufferRealizeNode* op) final { |
| const auto& key = op->buffer; |
| |
| if (buf_map_.count(key)) { |
| CHECK(buf_map_.at(key).external); |
| return this->VisitStmt(op->body); |
| } else { |
| // create a buffer entry |
| BufferEntry e; |
| e.bounds = op->bounds; |
| Array<PrimExpr> shape; |
| for (auto r : e.bounds) { |
| shape.push_back(r->extent); |
| } |
| // deduce current storage scope. |
| auto it = storage_scope_.find(op->buffer.get()); |
| CHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; |
| StorageScope skey; |
| const std::string& strkey = it->second; |
| if (strkey.length() == 0) { |
| if (curr_thread_scope_.size() != 0) { |
| skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); |
| } |
| } else { |
| skey = StorageScope::Create(strkey); |
| } |
| |
| // use small alignment for small arrays |
| auto dtype = op->buffer->dtype; |
| int32_t const_size = AllocateNode::constant_allocation_size(shape); |
| int align = GetTempAllocaAlignment(dtype, const_size); |
| if (skey.tag.length() != 0) { |
| MemoryInfo info = GetMemoryInfo(skey.to_string()); |
| if (info.defined()) { |
| align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits(); |
| CHECK_LE(const_size * dtype.bits(), info->max_num_bits) |
| << "Allocation exceed bound of memory tag " << skey.to_string(); |
| } |
| } |
| Array<PrimExpr> strides; |
| if (dim_align_.count(key) != 0 && shape.size() != 0) { |
| std::vector<PrimExpr> rstrides; |
| const std::vector<DimAlignInfo>& avec = dim_align_[key]; |
| int first_dim = 0; |
| PrimExpr stride = make_const(shape[first_dim].dtype(), 1); |
| for (size_t i = shape.size(); i != 0; --i) { |
| size_t dim = i - 1; |
| if (dim < avec.size() && avec[dim].align_factor != 0) { |
| PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); |
| PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); |
| stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); |
| stride = bound_analyzer_->Simplify(stride); |
| } |
| rstrides.push_back(stride); |
| stride = stride * shape[dim]; |
| } |
| strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend()); |
| } |
| |
| e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), |
| op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, |
| skey.to_string(), align, 0, kDefault); |
| |
| buf_map_[key] = e; |
| Stmt body = this->VisitStmt(op->body); |
| buf_map_[key].released = true; |
| Stmt ret; |
| |
| DataType storage_type = e.buffer->dtype; |
| // specially handle bool, lower its storage |
| // type to beDataType::Int(8)(byte) |
| if (storage_type == DataType::Bool()) { |
| storage_type = DataType::Int(8); |
| } |
| if (strides.size() != 0) { |
| int first_dim = 0; |
| ret = Allocate(e.buffer->data, storage_type, |
| {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, |
| make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); |
| } else { |
| shape = e.buffer->shape; |
| if (shape.size() == 0) { |
| shape.push_back(make_const(DataType::Int(32), 1)); |
| } |
| ret = Allocate(e.buffer->data, storage_type, shape, |
| make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); |
| } |
| ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); |
| |
| if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { |
| ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, |
| MakeBound(e.buffer->dtype, e.buffer->shape), ret); |
| } |
| return ret; |
| } |
| } |
| |
| PrimExpr VisitExpr_(const LoadNode* op) final { |
| PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
| op = expr.as<LoadNode>(); |
| auto it = var_remap_.find(op->buffer_var.get()); |
| if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { |
| CHECK(it->second.as<VarNode>()); |
| Var buf_var = Downcast<Var>(it->second); |
| return Load(op->dtype, buf_var, op->index, op->predicate); |
| } else { |
| return expr; |
| } |
| } |
| |
| PrimExpr VisitExpr_(const VarNode* op) final { |
| auto it = var_remap_.find(op); |
| if (it != var_remap_.end()) { |
| return it->second; |
| } else { |
| return GetRef<PrimExpr>(op); |
| } |
| } |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
| op = expr.as<BufferLoadNode>(); |
| |
| const auto& key = op->buffer; |
| |
| auto it = buf_map_.find(key); |
| CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; |
| const BufferEntry& e = it->second; |
| CHECK(!e.released) << "Read a buffer that is already out of scope"; |
| |
| if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { |
| shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); |
| } |
| return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); |
| } |
| |
| Stmt VisitStmt_(const PrefetchNode* op) final { |
| Stmt stmt = StmtExprMutator::VisitStmt_(op); |
| op = stmt.as<PrefetchNode>(); |
| CHECK(op != nullptr); |
| |
| const auto& key = op->buffer; |
| auto it = buf_map_.find(key); |
| CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; |
| const BufferEntry& e = it->second; |
| |
| CHECK(!e.released) << "Read a buffer that is already out of scope"; |
| CHECK_EQ(e.buffer->shape.size(), op->bounds.size()) |
| << "Prefetch dim should be the same as buffer dim"; |
| |
| int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); |
| |
| int starts = op->bounds.size() - 1; |
| |
| while (starts > 0) { |
| auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>(); |
| if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break; |
| block_size *= static_cast<int>(shape_as_int->value); |
| starts--; |
| } |
| PrimExpr stride(elem_cnt / block_size); |
| |
| Array<PrimExpr> args; |
| std::vector<Var> vars; |
| |
| for (int i = op->bounds.size() - 1; i > starts; --i) { |
| args.push_back(op->bounds[i]->min); |
| } |
| auto& func_name = op->buffer->name; |
| vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); |
| args.push_back(op->bounds[starts]->min + stride * vars.back()); |
| for (int i = starts - 1; i >= 0; --i) { |
| vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); |
| args.push_back(vars.back() + op->bounds[i]->min); |
| } |
| for (int i = starts; i >= 0; --i) { |
| if (i < starts) { |
| stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); |
| } else { |
| PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); |
| PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); |
| PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); |
| stmt = Evaluate(prefetch); |
| PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; |
| stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); |
| } |
| } |
| return stmt; |
| } |
| |
| PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
| LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc."; |
| return PrimExpr(); |
| } |
| |
| Stmt VisitStmt_(const ProducerStoreNode* op) final { |
| LOG(FATAL) << "Cannot handle Provide " |
| << " please run SchedulePostProcToPrimFunc first"; |
| return Stmt(); |
| } |
| |
| Stmt VisitStmt_(const ProducerRealizeNode* op) final { |
| LOG(FATAL) << "Cannot handle Realize " |
| << " please run SchedulePostProcToPrimFunc first"; |
| return Stmt(); |
| } |
| |
| private: |
| // The specific tensor data layout is not determined before |
| // StorageFlatten pass. We use buffer_bind_scope |
| // to specify before hand we want to bind a subregion |
| // of tensor to a symbolic buffer, which get used in extern. |
| // |
| // Example: |
| // |
| // realize A in range [i*4, extent=10) { |
| // bind Ab to A in [i*4+1, extent=4) { |
| // call_func(Ab.ptr, Ab.shape[0]) |
| // } |
| // } |
| // |
| // After StorageFlatten |
| // |
| // alloc A[10] |
| // call(A + 1, 4) |
| // |
| // Buffer is a protocol to declare specific |
| // data layout and shape we expect. |
| // So this function need to check: |
| // - If the bind range is within the realize range |
| // - If we can match the requirement of buffer |
| // - Remap variables such as Ab.ptr to the actual value. |
| // |
| // Here are a few possible failure cases: |
| // - Buffer is declared to have constant shape, |
| // but we try to bind it to a different one. |
| // - Buffer is declared to be compact(no strides) |
| // but this binded region is a subregion of |
| // a matrix(tensor), which means it requires strides. |
| // |
| // We do support a few relaxed case, such as bindingx |
| // region with shape [1, 1, n, m] to buffer with shape [n, m] |
| Stmt HandleBufferBindScope(const AttrStmtNode* op) { |
| Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node); |
| CHECK_EQ(arr.size(), 2U); |
| const BufferNode* buffer = arr[0].as<BufferNode>(); |
| const BufferNode* target = arr[1].as<BufferNode>(); |
| const CallNode* tuple = op->value.as<CallNode>(); |
| CHECK(buffer && target); |
| CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); |
| auto key = GetRef<Buffer>(target); |
| |
| auto it = buf_map_.find(key); |
| CHECK(it != buf_map_.end()) << "Cannot find buffer of " << key; |
| const BufferEntry& be = it->second; |
| CHECK(!be.released); |
| CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); |
| Array<PrimExpr> begins, extents; |
| if (be.bounds.size() != 0) { |
| CHECK_EQ(tuple->args.size(), be.bounds.size() * 2); |
| for (size_t i = 0; i < be.buffer->shape.size(); ++i) { |
| begins.push_back(tuple->args[2 * i] - be.bounds[i]->min); |
| extents.push_back(tuple->args[2 * i + 1]); |
| } |
| } else { |
| for (size_t i = 0; i < tuple->args.size(); i += 2) { |
| begins.push_back(tuple->args[i]); |
| auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]); |
| extents.push_back(new_extent); |
| } |
| } |
| Buffer slice = be.buffer.MakeSlice(begins, extents); |
| if (buffer->strides.size() == 0) { |
| CHECK_EQ(slice->strides.size(), 0U) |
| << "Trying to bind compact buffer to strided one strides=" << slice->strides; |
| } else { |
| slice = slice.MakeStrideView(); |
| } |
| // start binding |
| ArgBinder binder(&var_remap_); |
| binder.BindBuffer(Downcast<Buffer>(arr[0]), slice, buffer->name, true); |
| // Apply the remaps |
| Stmt body = MergeNest(binder.asserts(), op->body); |
| body = MergeNest(binder.init_nest(), body); |
| body = this->VisitStmt(body); |
| // remove the binds |
| for (const Var& v : binder.defs()) { |
| var_remap_.erase(v.get()); |
| } |
| return body; |
| } |
| |
| // The buffer entry in the flatten map |
| struct DimAlignInfo { |
| int align_factor{0}; |
| int align_offset{0}; |
| }; |
| // The buffer entry in the flatten map |
| struct BufferEntry { |
| // the buffer of storage |
| Buffer buffer; |
| // the bounds of realization, can be null, means everything |
| Region bounds; |
| // Whether the buffer is external |
| bool external{false}; |
| // Whether we are out of allocation bounds and buffer get released. |
| bool released{false}; |
| // relative index |
| inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const { |
| if (bounds.size() != 0) { |
| Array<PrimExpr> index; |
| CHECK_EQ(bounds.size(), args.size()); |
| for (size_t i = 0; i < bounds.size(); ++i) { |
| index.push_back(args[i] - bounds[i]->min); |
| } |
| return index; |
| } else { |
| return args; |
| } |
| } |
| }; |
| |
| bool ShapeIsValid(const Array<PrimExpr>& shape) { |
| // Zero-dimensional tensor does not need boundary check. |
| if (!shape.size()) return false; |
| |
| for (size_t i = 0; i < shape.size(); ++i) { |
| if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| PrimExpr MakeBound(const DataType& type, const Array<PrimExpr>& shape) { |
| // We have already checked the shape size to be greater then 0. |
| PrimExpr bound = Mul(make_const(shape[0].dtype(), type.lanes()), shape[0]); |
| for (size_t i = 1; i < shape.size(); ++i) { |
| bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); |
| } |
| return bound; |
| } |
| |
| // The buffer assignment map |
| // Variable remap |
| std::unordered_map<const VarNode*, PrimExpr> var_remap_; |
| // Buffer map |
| std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_; |
| // Dimension alignment |
| std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_; |
| // Storage scope |
| std::unordered_map<const Object*, std::string> storage_scope_; |
| // The current thread scope. |
| std::vector<ThreadScope> curr_thread_scope_; |
| // Collects shapes. |
| std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_; |
| // bounds populator. We really need the analyzer from it. |
| // However |
| IRVisitorWithAnalyzer* bound_analyzer_; |
| // The size of cacheline |
| int cache_line_size_; |
| // Whether to mark load/store with theirs bounds. |
| bool create_bound_attributes_{false}; |
| }; |
| |
| PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { |
| auto fptr = func.CopyOnWrite(); |
| |
| IRVisitorWithAnalyzer bound_analyzer; |
| bound_analyzer(fptr->body); |
| fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, |
| &bound_analyzer)(std::move(fptr->body)); |
| return func; |
| } |
| |
| namespace transform { |
| |
| // TODO(tvm-team): consolidate configs to the PassContext |
| Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { |
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
| return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes); |
| }; |
| return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten); |
| |
| } // namespace transform |
| |
| } // namespace tir |
| } // namespace tvm |