blob: a86ad76b0eb9a01678fd1f1fedeebf7be4070caa [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \file
* \brief Translate the function body generated by ScheduleOps
* with te related dialects that incorporates Tensor
* into the Stmts to a PrimFunc.
* Perform this translation before running any TIR optimizations.
* Rationale: The body generated by ScheduleOps is not
* a formal PrimFunc and cannot be used for further optimization.
* This function canonicalize that body and creates a formal PrimFunc.
* List of actions taken by the function:
* - Remove occurences of te::Tensor, te::Operation in the IR
* and replace them by corresponding IR nodes via tir::Buffer.
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include <utility>
namespace tvm {
namespace te {
// create a buffer for tensor.
Buffer CreateBufferFor(const Tensor& tensor) {
std::string name = tensor->op->name;
if (tensor->op->num_outputs() != 1) {
name += ".v" + std::to_string(tensor->value_index);
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name);
return buffer;
// A remapper that maps tensor to buffer
class TensorToBufferMapper : public StmtExprMutator {
explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map)
: buffer_map_(buffer_map) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
auto ret = StmtExprMutator::VisitStmt_(op);
op =<AttrStmtNode>();
// TODO(tvm-team): remove realize_scope, turn the info into
// Buffer's scope field in this pass.
if (op->attr_key == tir::attr::realize_scope ||
op->attr_key == tir::attr::double_buffer_scope) {
Stmt body = op->body;
Operation operation = Downcast<Operation>(op->node);
for (int i = operation->num_outputs(); i != 0; --i) {
Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
body = AttrStmt(buffer, op->attr_key, op->value, body);
return body;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
Tensor tensor = Downcast<Tensor>(tuple[1]);
return AttrStmt(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value,
} else if (op->attr_key == tir::attr::buffer_dim_align ||
op->attr_key == tir::attr::prefetch_scope) {
Tensor tensor = Downcast<Tensor>(op->node);
Buffer buffer = GetOrAllocBuffer(tensor);
return AttrStmt(buffer, op->attr_key, op->value, op->body);
} else {
return ret;
Stmt VisitStmt_(const ProducerRealizeNode* op) final {
Tensor tensor = Downcast<Tensor>(op->producer);
Buffer buffer = GetOrAllocBuffer(tensor);
auto ret = StmtExprMutator::VisitStmt_(op);
op =<ProducerRealizeNode>();
return BufferRealize(buffer, op->bounds, op->condition, op->body);
Stmt VisitStmt_(const ProducerStoreNode* op) final {
Tensor tensor = Downcast<Tensor>(op->producer);
Buffer buffer = GetBuffer(tensor);
auto ret = StmtExprMutator::VisitStmt_(op);
op =<ProducerStoreNode>();
return BufferStore(buffer, op->value, op->indices);
PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
auto ret = StmtExprMutator::VisitExpr_(op);
op =<ProducerLoadNode>();
Tensor tensor = Downcast<Tensor>(op->producer);
Buffer buffer = GetBuffer(tensor);
return tir::BufferLoad(buffer, op->indices);
Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); }
Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) {
auto it = buffer_map_.find(tensor);
if (it != buffer_map_.end()) return it->second;
CHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor;
auto buffer = CreateBufferFor(tensor);
buffer_map_[tensor] = buffer;
return buffer;
// maps tensor to buffer.
std::unordered_map<Tensor, Buffer> buffer_map_;
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
std::unordered_map<Tensor, Buffer> extern_buffer;
if (extern_buffer_opt.defined()) {
auto v = extern_buffer_opt.value();
extern_buffer = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : arg_list) {
if (auto* n =<tir::VarNode>()) {
} else if (auto* n =<te::TensorNode>()) {
te::Tensor tensor = GetRef<te::Tensor>(n);
tir::Buffer buffer = CreateBufferFor(tensor);
tir::Var bptr(buffer->name, DataType::Handle());
buffer_map.Set(bptr, buffer);
extern_buffer[tensor] = buffer;
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
buffer_map.Set(bptr, buffer);
body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
return tir::PrimFunc(params, body, VoidType(), buffer_map);
} // namespace te
} // namespace tvm