| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file schedule_postproc_to_primfunc.cc |
| * |
| * \brief Translate the function body generated by ScheduleOps |
| * with te related dialects that incorporates Tensor |
| * into the Stmts to a PrimFunc. |
| * |
| * Perform this translation before running any TIR optimizations. |
| * |
| * Rationale: The body generated by ScheduleOps is not |
| * a formal PrimFunc and cannot be used for further optimization. |
| * This function canonicalize that body and creates a formal PrimFunc. |
| * |
| * List of actions taken by the function: |
| * - Remove occurrences of te::Tensor, te::Operation in the IR |
| * and replace them by corresponding IR nodes via tir::Buffer. |
| * - Add annotation of extern buffers using the buffer_map field |
| * in the PrimFunc type. |
| */ |
| #include <tvm/runtime/registry.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/function.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <functional> |
| #include <unordered_map> |
| #include <utility> |
| |
| namespace tvm { |
| namespace te { |
| |
| // create a buffer for tensor. |
| Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { |
| std::string name = tensor->op->name; |
| if (tensor->op->num_outputs() != 1) { |
| name += ".v" + std::to_string(tensor->value_index); |
| } |
| Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); |
| |
| return buffer; |
| } |
| |
| // A remapper that maps tensor to buffer |
| class TensorToBufferMapper : public StmtExprMutator { |
| public: |
| explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map) |
| : buffer_map_(buffer_map) {} |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| auto ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<AttrStmtNode>(); |
| if (op->attr_key == tir::attr::double_buffer_scope || |
| op->attr_key == tir::attr::rolling_buffer_scope) { |
| Stmt body = op->body; |
| Operation operation = Downcast<Operation>(op->node); |
| for (int i = operation->num_outputs(); i != 0; --i) { |
| Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); |
| body = AttrStmt(buffer, op->attr_key, op->value, body); |
| } |
| return body; |
| } else if (op->attr_key == tir::attr::buffer_bind_scope) { |
| Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node); |
| Tensor tensor = Downcast<Tensor>(tuple[1]); |
| return AttrStmt(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value, |
| op->body); |
| } else if (op->attr_key == tir::attr::buffer_dim_align || |
| op->attr_key == tir::attr::prefetch_scope) { |
| Tensor tensor = Downcast<Tensor>(op->node); |
| Buffer buffer = GetOrAllocBuffer(tensor); |
| return AttrStmt(buffer, op->attr_key, op->value, op->body); |
| } else if (op->attr_key == tir::attr::layout_transforms || |
| op->attr_key == tir::attr::axis_separators) { |
| auto arr = Downcast<Array<ObjectRef>>(op->node); |
| ICHECK_EQ(arr.size(), 2); |
| |
| Stmt body = op->body; |
| |
| Tensor tensor = Downcast<Tensor>(arr[0]); |
| Buffer buffer = GetBuffer(tensor); |
| |
| return AttrStmt(Array<ObjectRef>{buffer, arr[1]}, op->attr_key, 1, body); |
| } else { |
| return ret; |
| } |
| } |
| |
| Stmt VisitStmt_(const ProducerRealizeNode* op) final { |
| Tensor tensor = Downcast<Tensor>(op->producer); |
| Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); |
| |
| auto ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<ProducerRealizeNode>(); |
| |
| return BufferRealize(buffer, op->bounds, op->condition, op->body); |
| } |
| |
| Stmt VisitStmt_(const ProducerStoreNode* op) final { |
| Tensor tensor = Downcast<Tensor>(op->producer); |
| Buffer buffer = GetBuffer(tensor); |
| |
| auto ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<ProducerStoreNode>(); |
| |
| return BufferStore(buffer, op->value, GetIndices(op->indices, buffer->shape)); |
| } |
| |
| PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
| auto ret = StmtExprMutator::VisitExpr_(op); |
| op = ret.as<ProducerLoadNode>(); |
| Tensor tensor = Downcast<Tensor>(op->producer); |
| Buffer buffer = GetBuffer(tensor); |
| return tir::BufferLoad(buffer, GetIndices(op->indices, buffer->shape)); |
| } |
| |
| private: |
| Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { |
| return GetBuffer(tensor, storage_scope, true); |
| } |
| |
| Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { |
| auto it = buffer_map_.find(tensor); |
| if (it != buffer_map_.end()) return it->second; |
| ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; |
| |
| auto buffer = CreateBufferFor(tensor, storage_scope); |
| buffer_map_[tensor] = buffer; |
| return buffer; |
| } |
| |
| Array<PrimExpr> GetIndices(const Array<PrimExpr>& tensor_indices, |
| const Array<PrimExpr>& buffer_shape) { |
| if (tensor_indices.size() == buffer_shape.size()) { |
| return tensor_indices; |
| } else if (tensor_indices.size() == 1) { |
| // Workaround to support previous behavior of tensor indexing by |
| // a single index, treating the tensor as if were already |
| // flattened by a row-major traversal. |
| PrimExpr unravel = tensor_indices[0]; |
| Array<PrimExpr> rev_indices; |
| for (size_t i = buffer_shape.size(); i > 0; i--) { |
| PrimExpr dim = buffer_shape[i - 1]; |
| rev_indices.push_back(indexmod(unravel, dim)); |
| unravel = indexdiv(unravel, dim); |
| } |
| return Array<PrimExpr>(rev_indices.rbegin(), rev_indices.rend()); |
| } else { |
| LOG(FATAL) << "Cannot produce indices for " << buffer_shape.size() |
| << "-dimensional TIR buffer using " << tensor_indices.size() |
| << "-dimensional tensor indices."; |
| return {}; |
| } |
| } |
| |
| // Maps tensor to buffer. |
| std::unordered_map<Tensor, Buffer> buffer_map_; |
| }; |
| |
| /*! Collect the physical layout map of all tensors in the statement. */ |
| class LayoutTransformAttrUnwrapper : StmtExprMutator { |
| public: |
| static tir::PrimFunc Apply(tir::PrimFunc func) { |
| // Collect the physical layout annotations in the body, which may |
| // refer to input arguments. |
| auto layout_map = Collector::Collect(func->body); |
| |
| if (layout_map.size()) { |
| func = WithAttr(std::move(func), "layout_transform_map", layout_map); |
| |
| auto write_ptr = func.CopyOnWrite(); |
| write_ptr->body = LayoutTransformAttrUnwrapper()(func->body); |
| } |
| |
| return func; |
| } |
| |
| LayoutTransformAttrUnwrapper() {} |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| auto ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<AttrStmtNode>(); |
| |
| if (op->attr_key == tir::attr::layout_transforms) { |
| return op->body; |
| } else { |
| return ret; |
| } |
| } |
| |
| private: |
| /*! Collect the physical layout information of all tensors in the statement. |
| * |
| * Must be done before constructing the buffers, since the |
| * attributes could either apply to the external buffers or to |
| * internal allocations. |
| */ |
| class Collector : StmtExprVisitor { |
| public: |
| static Map<Buffer, Array<IndexMap>> Collect(Stmt stmt) { |
| Collector collector; |
| collector(std::move(stmt)); |
| return std::move(collector.layout_map_); |
| } |
| |
| Collector() {} |
| |
| void VisitStmt_(const AttrStmtNode* op) final { |
| if (op->attr_key == tir::attr::layout_transforms) { |
| auto arr = Downcast<Array<ObjectRef>>(op->node); |
| ICHECK_EQ(arr.size(), 2); |
| |
| auto buffer = Downcast<Buffer>(arr[0]); |
| auto layout_transforms = Downcast<Array<IndexMap>>(arr[1]); |
| layout_map_.Set(buffer, layout_transforms); |
| } |
| StmtExprVisitor::VisitStmt_(op); |
| } |
| |
| Map<Buffer, Array<IndexMap>> layout_map_; |
| }; |
| |
| std::unordered_map<const BufferNode*, Buffer> buffer_remap_; |
| |
| Map<Buffer, Array<IndexMap>> layout_map_; |
| }; |
| |
| /*! Move axis_separators from an attribute to a buffer property. */ |
| class AxisSeparatorsAttrUnwrapper : StmtExprMutator { |
| public: |
| static tir::PrimFunc Apply(tir::PrimFunc func) { |
| // Collect the physical layout annotations in the body, which may |
| // refer to input arguments. |
| auto axis_separators_map = Collector::Collect(func->body); |
| |
| if (axis_separators_map.size()) { |
| auto write_ptr = func.CopyOnWrite(); |
| auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map); |
| write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map); |
| write_ptr->body = pass(func->body); |
| if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map")) { |
| func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); |
| } |
| } |
| |
| return func; |
| } |
| |
| explicit AxisSeparatorsAttrUnwrapper(Map<Buffer, Array<IntImm>> axis_separators_map) |
| : axis_separators_map_(axis_separators_map) {} |
| |
| Map<Var, Buffer> UpdateExternBufferMap(const Map<Var, Buffer>& orig) { |
| Map<Var, Buffer> output; |
| for (const auto& kv : orig) { |
| output.Set(kv.first, GetRemappedBuffer(kv.second)); |
| } |
| return output; |
| } |
| |
| Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) { |
| Map<Buffer, Array<IndexMap>> output; |
| for (const auto& kv : orig) { |
| output.Set(GetRemappedBuffer(kv.first), kv.second); |
| } |
| return output; |
| } |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| auto ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<AttrStmtNode>(); |
| |
| if (op->attr_key == tir::attr::axis_separators) { |
| return op->body; |
| } else if (op->attr_key == tir::attr::buffer_bind_scope) { |
| Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node); |
| Buffer view_buffer = Downcast<Buffer>(tuple[0]); |
| Buffer source_buffer = Downcast<Buffer>(tuple[1]); |
| return AttrStmt( |
| Array<ObjectRef>{GetRemappedBuffer(view_buffer), GetRemappedBuffer(source_buffer)}, |
| op->attr_key, op->value, op->body); |
| } else { |
| return ret; |
| } |
| } |
| |
| Stmt VisitStmt_(const BufferRealizeNode* op) final { |
| auto node = Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* op) final { |
| auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| private: |
| template <typename Node> |
| Node VisitBufferAccess(Node node) { |
| Buffer new_buf = GetRemappedBuffer(node->buffer); |
| if (!node->buffer.same_as(new_buf)) { |
| auto writer = node.CopyOnWrite(); |
| writer->buffer = new_buf; |
| } |
| return node; |
| } |
| |
| Buffer GetRemappedBuffer(Buffer buf) { |
| // If this buffer has already been remapped, then return the |
| // previous value. |
| auto key = buf.get(); |
| { |
| auto it = buffer_remap_.find(key); |
| if (it != buffer_remap_.end()) { |
| return it->second; |
| } |
| } |
| |
| // Otherwise, check if we need to add axis_separators to this |
| // buffer. |
| auto lookup = axis_separators_map_.Get(buf); |
| if (lookup) { |
| Array<IntImm> axis_separators = lookup.value(); |
| if (axis_separators.size()) { |
| auto write_ptr = buf.CopyOnWrite(); |
| write_ptr->axis_separators = axis_separators; |
| } |
| } |
| |
| // And cache the result for next time. |
| buffer_remap_[key] = buf; |
| |
| return buf; |
| } |
| |
| /*! Collect the axis separator information of all tensors in the statement. |
| * |
| * Must be done before constructing the buffers, since the |
| * attributes could either apply to the external buffers or to |
| * internal allocations. |
| */ |
| class Collector : StmtExprVisitor { |
| public: |
| static Map<Buffer, Array<IntImm>> Collect(Stmt stmt) { |
| Collector collector; |
| collector(std::move(stmt)); |
| return std::move(collector.axis_separators_map_); |
| } |
| |
| Collector() {} |
| |
| void VisitStmt_(const AttrStmtNode* op) final { |
| if (op->attr_key == tir::attr::axis_separators) { |
| auto arr = Downcast<Array<ObjectRef>>(op->node); |
| ICHECK_EQ(arr.size(), 2); |
| |
| auto buffer = Downcast<Buffer>(arr[0]); |
| auto axis_separators = Downcast<Array<IntImm>>(arr[1]); |
| axis_separators_map_.Set(buffer, axis_separators); |
| } |
| StmtExprVisitor::VisitStmt_(op); |
| } |
| |
| Map<Buffer, Array<IntImm>> axis_separators_map_; |
| }; |
| |
| std::unordered_map<const BufferNode*, Buffer> buffer_remap_; |
| |
| Map<Buffer, Array<IntImm>> axis_separators_map_; |
| }; |
| |
| PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body, |
| Optional<Map<Tensor, Buffer>> extern_buffer_opt) { |
| std::unordered_map<Tensor, Buffer> extern_tensor_map; |
| |
| if (extern_buffer_opt.defined()) { |
| auto v = extern_buffer_opt.value(); |
| extern_tensor_map = std::unordered_map<Tensor, Buffer>(v.begin(), v.end()); |
| } |
| |
| Array<tir::Var> params; |
| Map<tir::Var, tir::Buffer> buffer_map; |
| |
| for (auto arg : arg_list) { |
| if (auto* n = arg.as<tir::VarNode>()) { |
| tir::Var var = GetRef<tir::Var>(n); |
| params.push_back(GetRef<tir::Var>(n)); |
| } else if (auto* n = arg.as<te::TensorNode>()) { |
| te::Tensor tensor = GetRef<te::Tensor>(n); |
| ICHECK(!extern_tensor_map.count(tensor)); |
| |
| tir::Buffer buffer = CreateBufferFor(tensor); |
| tir::Var bptr(buffer->name, PrimType(DataType::Handle())); |
| params.push_back(bptr); |
| buffer_map.Set(bptr, buffer); |
| extern_tensor_map[tensor] = buffer; |
| } else if (auto* n = arg.as<tir::BufferNode>()) { |
| tir::Buffer buffer = GetRef<tir::Buffer>(n); |
| tir::Var bptr(buffer->name, PrimType(DataType::Handle())); |
| params.push_back(bptr); |
| buffer_map.Set(bptr, buffer); |
| } else { |
| LOG(FATAL) << "Expected argument to be Var, Tensor, or Buffer, but received " |
| << arg->GetTypeKey(); |
| } |
| } |
| |
| body = TensorToBufferMapper(std::move(extern_tensor_map))(std::move(body)); |
| |
| PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map); |
| |
| func = LayoutTransformAttrUnwrapper::Apply(std::move(func)); |
| func = AxisSeparatorsAttrUnwrapper::Apply(std::move(func)); |
| |
| // We mark this PrimFunc as coming from a TE schedule |
| func = WithAttr(func, "from_legacy_te_schedule", Bool(true)); |
| |
| return func; |
| } |
| |
| TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") |
| .set_body_typed(SchedulePostProcToPrimFunc); |
| |
| } // namespace te |
| } // namespace tvm |