blob: fc3f49d76fae6108938915c7b38fe07fc1aab2bc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tir_text_printer.cc
* \brief Printer to print out the IR text format
* that can be parsed by a parser.
*/
#include <tvm/ir/module.h>
#include <tvm/ir/type.h>
#include <tvm/ir/type_functor.h>
#include <tvm/node/serialization.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <algorithm>
#include <string>
#include "../tir/transforms/ir_utils.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
namespace tvm {
namespace tir {
Doc TIRTextPrinter::Print(const ObjectRef& node) {
if (!node.defined()) return Doc::Text("(nullptr)");
if (node->IsInstance<StmtNode>()) {
return VisitStmt(Downcast<Stmt>(node));
} else if (node->IsInstance<AnyNode>()) {
return Doc::Text("?");
} else if (node->IsInstance<PrimExprNode>()) {
return VisitExpr(Downcast<PrimExpr>(node));
} else if (node->IsInstance<TypeNode>()) {
return VisitType(Downcast<Type>(node));
} else if (node->IsInstance<PrimFuncNode>()) {
return PrintPrimFunc(Downcast<PrimFunc>(node));
} else if (node->IsInstance<IRModuleNode>()) {
return PrintIRModule(Downcast<IRModule>(node));
} else if (node->IsInstance<ArrayNode>()) {
return PrintArray(node.as<ArrayNode>());
} else if (node->IsInstance<IterVarNode>()) {
return PrintIterVar(node.as<IterVarNode>());
} else if (node->IsInstance<RangeNode>()) {
return PrintRange(node.as<RangeNode>());
} else if (node->IsInstance<BufferNode>()) {
return PrintBuffer(node.as<BufferNode>());
} else if (node->IsInstance<DataProducerNode>()) {
return PrintProducer(node.as<DataProducerNode>());
} else if (node->IsInstance<StringObj>()) {
return PrintString(node.as<StringObj>());
} else if (node->IsInstance<BufferRegionNode>()) {
return PrintBufferRegion(node.as<BufferRegionNode>());
} else if (node->IsInstance<TargetNode>()) {
return Doc::Text(node.as<TargetNode>()->ToDebugString());
} else {
return this->meta_->GetMetaNode(node);
}
}
Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
const auto* op = prim_func.operator->();
const auto& signature = op->func_type_annotation();
// collect Meta in DictAttr
if (prim_func->attrs.defined()) {
for (const auto& it : prim_func->attrs->dict) {
meta_collector_.Collect(it.second);
}
}
// collect buffers in buffer_map
memo_var_.clear();
memo_buf_.clear();
// ordered vars associated with buffers, for consistent printing
std::vector<Var> buffer_vars_ordered;
for (Var v : op->params) {
auto buffer_map_find = op->buffer_map.find(v);
if (buffer_map_find != op->buffer_map.end()) {
auto map_data = *buffer_map_find;
buffer_vars_ordered.push_back(map_data.first);
memo_buf_[map_data.second] = AllocBuf(map_data.second);
}
}
// print PrimFunc
Doc doc;
doc << "primfn"
<< "(";
// print params and its type annotation
std::vector<Doc> params;
for (const auto& param : op->params) {
params.push_back(Print(param));
}
Doc sep;
doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
// print return type
doc << " -> " << Print(signature->ret_type);
// print attr
Doc attr_doc;
std::vector<Doc> attr_docs;
if (prim_func->attrs.defined()) {
for (const auto& it : op->attrs->dict) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
doc << Doc::Indent(2, attr_doc);
}
// print all the buffers in the tree
if (memo_buf_.size() != 0) {
Doc buffer_doc;
std::vector<Doc> buffer_docs;
for (const Var& v : buffer_vars_ordered) {
const Buffer buf = op->buffer_map[v];
buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
}
buffer_doc << Doc::NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
doc << Doc::Indent(2, buffer_doc) << "}";
}
if (op->buffer_map.size() != 0) {
// print buffer_map
std::vector<Doc> buffer_map_doc;
for (const Var& v : buffer_vars_ordered) {
const Buffer buf = op->buffer_map[v];
buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
}
doc << Doc::Indent(
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
}
doc << PrintBody(op->body);
return doc;
}
Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
const auto* op = module.operator->();
Doc doc;
Doc body;
body << Doc::NewLine();
std::vector<Doc> functions;
for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
if ((*it).second.as<PrimFuncNode>()) {
functions.push_back(Print((*it).second));
}
}
body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
doc << Doc::Indent(0, body);
return doc;
}
Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
Doc doc;
doc << '[';
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
doc << ']';
return doc;
}
Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
Doc doc;
doc << "IterVar(" << Print(op->var);
if (op->dom.defined()) {
doc << ", [" << Print(op->dom) << "], ";
} else {
doc << ", " << Print(op->dom) << ", ";
}
doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
doc << Doc::StrLiteral(op->thread_tag) << ")";
return doc;
}
Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
return Print(op->min) << ":" << Print(op->min + op->extent);
}
Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
const Buffer& buffer = GetRef<Buffer>(op);
if (meta_->InMeta(buffer)) {
return meta_->GetMetaNode(buffer);
} else if (memo_buf_.count(buffer)) {
return memo_buf_[buffer];
} else {
memo_buf_[buffer] = AllocBuf(buffer);
return BufferNode2Doc(op, memo_buf_[buffer]);
}
}
Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) {
const DataProducer& prod = GetRef<DataProducer>(op);
if (meta_->InMeta(prod)) {
return meta_->GetMetaNode(prod);
} else if (memo_producer_.count(prod)) {
return memo_producer_[prod];
} else {
memo_producer_[prod] = AllocProducer(prod);
return DataProducerNode2Doc(op, memo_producer_[prod]);
}
}
Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", "
<< Print(buf->shape) << ", " << Print(buf->strides);
if (!is_zero(buf->elem_offset)) {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->axis_separators.size()) {
doc << ", axis_separators=" << Print(buf->axis_separators);
}
if (GetRef<Buffer>(buf).scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
}
if (buf->data_alignment != runtime::kAllocAlignment) {
doc << ", align=" << buf->data_alignment;
}
if (buf->offset_factor != 1) {
doc << ", offset_factor=" << buf->offset_factor;
}
if (buf->buffer_type != 1) {
doc << ", type=" << Doc::StrLiteral("auto");
}
return doc << ")";
}
Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) {
return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", "
<< PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")";
}
Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
doc << Print(op->buffer) << "[";
for (size_t i = 0; i < op->region.size(); ++i) {
if (i != 0) {
doc << ", ";
}
const auto& range = op->region[i];
if (!is_one(range->extent)) {
doc << Print(range->min) << ":" << Print(range->min + range->extent);
} else {
doc << Print(range->min);
}
}
doc << "]";
return doc;
}
Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
}
Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
}
Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
return PrintConstScalar<int64_t>(op->dtype, op->value);
}
Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
return PrintConstScalar<double>(op->dtype, op->value);
}
Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
Doc doc;
doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
const Var& var = GetRef<Var>(op);
return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
}
#define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \
Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
Doc doc; \
doc << "(" << Print(op->a) << OpString; \
doc << Print(op->b) << ")"; \
return doc; \
}
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AddNode, " + ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(SubNode, " - ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(MulNode, "*")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(DivNode, " / ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(ModNode, " % ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(EQNode, " == ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(NENode, " != ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LTNode, " < ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LENode, " <= ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GTNode, " > ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GENode, " >= ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AndNode, " && ")
TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OrNode, " || ")
Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
Doc doc;
doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
Doc doc;
doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
Doc doc;
doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
Doc doc;
doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
Doc doc;
doc << "!" << Print(op->a);
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
Doc doc;
doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
<< Print(op->false_value) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices);
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) {
// TODO(tvm-team): consider make a better text format for producer.
Doc doc;
doc << op->producer->GetNameHint() << Print(op->indices);
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
Doc doc;
doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index)
<< "]";
if (!is_one(op->predicate)) {
doc << " if " << Print(op->predicate);
}
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
Doc doc;
doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
Doc doc;
doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
Doc doc;
doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
std::vector<Doc> func_args;
if (auto* ptr_op = op->op.as<OpNode>()) {
doc << "@" << Doc::Text(ptr_op->name) << "(";
if (ptr_op->name == "tir.call_llvm_pure_intrin") {
auto f = tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name");
ICHECK(f != nullptr)
<< "Cannot find target.llvm_get_intrinsic_name. Compile with USE_LLVM=On";
func_args.push_back(Print((*f)(Downcast<IntImm>(op->args[0])->value)));
for (size_t i = 1; i < op->args.size(); i++) {
func_args.push_back(Print(op->args[i]));
}
} else {
for (const auto& arg : op->args) {
func_args.push_back(Print(arg));
}
}
} else {
// TODO(bohan): Print out the name by he global var in the module.
auto* op_gvar = op->op.as<GlobalVarNode>();
ICHECK(op_gvar != nullptr);
doc << "@" << Doc::Text(op_gvar->name_hint) << "(";
for (const auto& arg : op->args) {
func_args.push_back(Print(arg));
}
}
doc << PrintSep(func_args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) {
Doc doc;
doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
Doc doc;
doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
<< ", " << op->value_index << ", " << Print(op->init) << ")";
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
Doc doc;
meta_collector_.Collect(op->node);
doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = "
<< Print(op->value);
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
}
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine()
<< Print(op->body);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) {
Doc doc;
doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value);
if (!is_one(op->predicate)) {
doc << " if " << Print(op->predicate);
}
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) {
Doc doc;
doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
Doc doc;
doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", "
<< Print(op->condition) << PrintBody(op->body) << ")";
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) {
Doc doc;
doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", "
<< Print(op->condition) << ", " << PrintBody(op->body) << ")";
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
Doc doc;
auto scope = GetPtrStorageScope(op->buffer_var);
doc << "allocate(" << Print(op->buffer_var) << ", ";
doc << PrintDType(op->dtype) << ", ";
doc << Print(op->extents) << "), storage_scope = " << scope;
if (!op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
}
if (!is_one(op->condition)) {
doc << " if " << Print(op->condition);
}
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
}
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
Doc doc;
doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
<< Print(op->extents) << ")";
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
}
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc doc;
doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", "
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine();
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
}
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
if (!is_one(op->condition) && op->else_case) {
doc << " else" << PrintBody(op->else_case.value());
}
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
std::vector<Doc> stmts;
Doc seq_doc, doc;
for (Stmt stmt : op->seq) {
seq_doc << Doc::NewLine() << Print(stmt);
}
doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}";
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
Doc doc;
doc << Print(op->value);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
<< Print(op->min + op->extent) << ")";
if (op->kind != ForKind::kSerial) {
doc << " " << Doc::StrLiteral(ForKind2String(op->kind));
}
doc << PrintBody(op->body);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) {
Doc doc;
doc << "while (" << Print(op->condition) << ")";
doc << PrintBody(op->body);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
Doc doc;
doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
const auto* block_op = op->block.as<BlockNode>();
// print block name and block vars
Doc doc;
doc << "block([";
std::vector<Doc> block_var_docs;
for (const auto& iter_var : block_op->iter_vars) {
Doc block_var_doc;
if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) {
block_var_doc << Print(iter_var->dom->extent);
} else {
block_var_doc << "tir.";
switch (iter_var->iter_type) {
case kDataPar:
block_var_doc << "range";
break;
case kCommReduce:
block_var_doc << "reduce_axis";
break;
case kOrdered:
block_var_doc << "scan_axis";
break;
case kOpaque:
block_var_doc << "opaque_axis";
break;
default:
LOG(FATAL) << "Unknown block var iter type";
break;
}
block_var_doc << "(" << Print(iter_var->dom->min) << ", "
<< Print(iter_var->dom->min + iter_var->dom->extent) << ")";
}
block_var_docs.push_back(block_var_doc);
}
doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], ";
doc << Doc::StrLiteral(block_op->name_hint) << ")";
std::vector<Doc> block_var_names;
for (const auto& iter_var : block_op->iter_vars) {
Doc block_var_name;
AllocVar(iter_var->var);
block_var_names.push_back(Print(iter_var->var));
}
if (!block_var_names.empty()) {
doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]";
}
doc << " {";
Doc block_attr_doc;
// print predicate, binding, read/write tensor region, annotations
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << "where(" << Print(op->predicate) << ")";
}
for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
<< Print(op->iter_values[i]) << ")";
block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
if (!block_op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : block_op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
block_attr_doc << Doc::NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", "))
<< "})";
}
// print body
Doc body;
body << Doc::NewLine();
for (const auto& alloc_buf : block_op->alloc_buffers) {
body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype)
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
init_block << "with init()";
init_block << PrintBody(block_op->init.value());
body << init_block << Doc::NewLine();
}
body << Print(block_op->body);
doc << Doc::Indent(2, block_attr_doc << body);
return doc;
}
Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
doc << PrintDType(node->dtype);
return doc;
}
Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "Pointer(";
if (!node->storage_scope.empty()) {
doc << node->storage_scope << " ";
}
doc << Print(node->element_type) << ")";
return doc;
}
Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) {
std::vector<Doc> fields;
for (Type field : node->fields) {
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << Doc::Concat(fields);
// conform to python tuple format (1,)
if (node->fields.size() == 1) {
doc << ",";
}
return doc << ")";
}
Doc TIRTextPrinter::PrintDType(DataType dtype) {
return Doc::Text(runtime::DLDataType2String(dtype));
}
template <typename T>
Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) {
Doc doc;
std::ostringstream os;
os << data;
if (dtype == DataType::Int(32)) {
doc << Doc::Text(os.str());
} else {
if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) {
doc << ((data == 1) ? "True" : "False");
return doc;
}
doc << Doc::Text(os.str());
switch (dtype.code()) {
case kDLInt:
doc << "i";
break;
case kDLUInt:
doc << "u";
break;
case kDLFloat:
doc << "f";
break;
}
doc << Doc::Text(std::to_string(dtype.bits()));
if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes()));
}
return doc;
}
Doc TIRTextPrinter::GetUniqueName(std::string prefix) {
// std::replace(prefix.begin(), prefix.end(), '.', '_');
std::string unique_prefix = prefix;
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {
}
}
name_alloc_map_[unique_prefix] = 0;
return Doc::Text(unique_prefix);
}
Doc TIRTextPrinter::AllocVar(const Var& var) {
const auto& it = memo_var_.find(var);
if (it != memo_var_.end()) {
return it->second;
}
std::string name = var->name_hint.operator std::string();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName(name);
memo_var_[var] = val;
return val << ": " << Print(GetType(var));
}
Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) {
const auto& it = memo_buf_.find(buffer);
if (it != memo_buf_.end()) {
return it->second;
}
std::string name = buffer->name;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "buf_" + name;
}
Doc val = GetUniqueName(name);
memo_buf_[buffer] = val;
return val;
}
Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) {
const auto& it = memo_producer_.find(producer);
if (it != memo_producer_.end()) {
return it->second;
}
std::string name = producer->GetNameHint();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "tensor_" + name;
}
Doc val = GetUniqueName(name);
memo_producer_[producer] = val;
return val;
}
Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
seq = vec[0];
for (size_t i = 1; i < vec.size(); i++) {
seq << sep << vec[i];
}
}
return seq;
}
Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
Doc doc;
if (body->IsInstance<SeqStmtNode>()) return Print(body);
doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
return doc;
}
bool TIRTextPrinter::GetVarName(Var v, std::string* s) {
auto it = memo_var_.find(v);
if (it == memo_var_.end()) {
return false;
}
*s = it->second.str();
return true;
}
} // namespace tir
} // namespace tvm