blob: 5add7c17b04c70a2d942a9242f3f96fd20b9230d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file printer/tvmscript_printer.cc
* \brief Printer class to print Tensor IR to python syntax script
*/
#include <tvm/ir/module.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <algorithm>
#include <utility>
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
namespace tvm {
namespace tir {
class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
public ExprFunctor<Doc(const PrimExpr&)>,
public TypeFunctor<Doc(const Type&)> {
public:
explicit TVMScriptPrinter(bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate = nullptr)
: show_meta_(show_meta), annotate_(std::move(annotate)), meta_collector_(&meta_) {}
/*! \brief Print the node */
TVM_DLL Doc Print(const ObjectRef& node);
private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Stmt)> annotate_;
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief map from Function to GlobalVar */
std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
/*! \brief var collector (var defined by For/Loop/Block) */
std::unordered_set<const VarNode*> var_not_in_headers;
/*! \brief buffer collector (buffer defined in BufferMap and BufferAllocation)*/
std::unordered_set<const BufferNode*> buf_not_in_headers;
/*! \brief Map from Var to thread env name */
std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Declaration Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_decl_;
/*! \brief Map from CommReducer to Doc */
std::unordered_map<const CommReducerNode*, Doc> memo_reducer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief number of children of current node's parent */
int num_child_;
/*! \brief the number of current node */
int current_num_;
Doc VisitExpr_(const CastNode* op) override;
Doc VisitExpr_(const VarNode* op) override;
Doc VisitExpr_(const AddNode* op) override;
Doc VisitExpr_(const SubNode* op) override;
Doc VisitExpr_(const MulNode* op) override;
Doc VisitExpr_(const DivNode* op) override;
Doc VisitExpr_(const ModNode* op) override;
Doc VisitExpr_(const FloorDivNode* op) override;
Doc VisitExpr_(const FloorModNode* op) override;
Doc VisitExpr_(const MinNode* op) override;
Doc VisitExpr_(const MaxNode* op) override;
Doc VisitExpr_(const EQNode* op) override;
Doc VisitExpr_(const NENode* op) override;
Doc VisitExpr_(const LTNode* op) override;
Doc VisitExpr_(const LENode* op) override;
Doc VisitExpr_(const GTNode* op) override;
Doc VisitExpr_(const GENode* op) override;
Doc VisitExpr_(const AndNode* op) override;
Doc VisitExpr_(const OrNode* op) override;
Doc VisitExpr_(const NotNode* op) override;
Doc VisitExpr_(const SelectNode* op) override;
Doc VisitExpr_(const IntImmNode* op) override;
Doc VisitExpr_(const FloatImmNode* op) override;
Doc VisitExpr_(const StringImmNode* op) override;
Doc VisitExpr_(const BufferLoadNode* op) override;
Doc VisitExpr_(const LoadNode* op) override;
Doc VisitExpr_(const RampNode* op) override;
Doc VisitExpr_(const BroadcastNode* op) override;
Doc VisitExpr_(const LetNode* op) override;
Doc VisitExpr_(const CallNode* op) override;
Doc VisitExpr_(const ShuffleNode* op) override;
Doc VisitExpr_(const ReduceNode* op) override;
Doc VisitExprDefault_(const Object* op) override;
Doc VisitStmt_(const LetStmtNode* op) override;
Doc VisitStmt_(const AttrStmtNode* op) override;
Doc VisitStmt_(const AssertStmtNode* op) override;
Doc VisitStmt_(const StoreNode* op) override;
Doc VisitStmt_(const BufferStoreNode* op) override;
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;
Doc VisitType_(const PrimTypeNode* node) override;
Doc VisitType_(const PointerTypeNode* node) override;
Doc VisitType_(const TupleTypeNode* node) override;
Doc PrintBody(const Stmt& body);
Doc PrintIRModule(const IRModule& module);
Doc PrintPrimFunc(const PrimFunc& primFunc);
Doc PrintIterVar(const IterVarNode* op);
Doc PrintRange(const RangeNode* op);
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc AllocBufferDeclaration(const Buffer& buf);
static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
Doc GetUniqueName(std::string prefix);
Doc AllocVar(const Var& var);
Doc AllocBuf(const Buffer& buffer);
/*!
* \brief Print additional info about expr in comment.
* \param expr The expression.
*/
Doc PrintOptionalInfo(const Stmt& stmt) {
Doc doc;
// default annotations
if (annotate_ != nullptr) {
std::string annotated_stmt = annotate_(stmt);
if (!annotated_stmt.empty()) {
doc << "# " << annotated_stmt << Doc::NewLine();
}
}
return doc;
}
/*!
* \brief special method to render vectors of docs with a separator
* \param vec vector of docs
* \param sep separator
*/
static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
seq = vec[0];
for (size_t i = 1; i < vec.size(); i++) {
seq << sep << vec[i];
}
}
return seq;
}
/*!
* \brief dump meta info
* \return Doc with meta info
*/
Doc DumpMeta() {
if (show_meta_) {
return Doc::Text("__tvm_meta__ = ")
<< (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
} else {
return Doc::Text("");
}
}
/*!
* \brief special method to print out data type
* \param dtype The data type
*/
static Doc PrintDType(DataType dtype) {
return Doc::StrLiteral(runtime::DLDataType2String(dtype));
}
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param data The pointer to hold the data.
*/
template <typename T>
static Doc PrintConstScalar(DataType dtype, const T* data) {
Doc doc;
std::ostringstream os;
os << data[0];
if (dtype == DataType::Int(32)) {
doc << Doc::Text(os.str());
} else if (dtype == DataType::Bool()) {
doc << Doc::Text(data[0] ? "True" : "False");
} else {
doc << "tir." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) << ")";
}
return doc;
}
};
Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
std::replace(prefix.begin(), prefix.end(), '.', '_');
std::string unique_prefix = prefix;
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {
}
}
name_alloc_map_[unique_prefix] = 0;
return Doc::Text(unique_prefix);
}
Doc TVMScriptPrinter::AllocVar(const Var& var) {
const auto& it = memo_var_.find(var);
if (it != memo_var_.end()) {
return it->second;
}
std::string name = var->name_hint.operator std::string();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName(name);
memo_var_[var] = val;
return val;
}
Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
Doc doc = Print(buf->shape);
if (!runtime::TypeEqual(buf->dtype, DataType::Float(32))) {
doc << ", dtype=" << PrintDType(buf->dtype);
}
if (memo_var_.find(buf->data) != memo_var_.end()) {
doc << ", data=" << Print(buf->data);
} else {
// implicitly define data
memo_var_[buf->data] = Doc::Text(memo_buf_[buf].str() + ".data");
var_not_in_headers.insert(buf->data.get());
}
if (!buf->strides.empty()) {
doc << ", strides=" << Print(buf->strides);
}
if (buf->offset_factor != 0 && buf->elem_offset->IsInstance<VarNode>()) {
Var elem_offset = Downcast<Var>(buf->elem_offset);
if (memo_var_.find(elem_offset) != memo_var_.end()) {
doc << ", elem_offset=" << Print(buf->elem_offset);
} else {
// implicitly define elem_offset
memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset");
var_not_in_headers.insert(elem_offset.get());
}
} else {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->scope != "global") {
doc << ", scope=" << Doc::StrLiteral(buf->scope);
}
if (buf->data_alignment != -1) {
doc << ", align=" << buf->data_alignment;
}
if (buf->offset_factor != 0) {
doc << ", offset_factor=" << buf->offset_factor;
}
if (buf->buffer_type != 1) {
doc << ", type=" << Doc::StrLiteral("auto");
}
return doc;
}
Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
const auto& it = memo_buf_.find(buffer);
if (it != memo_buf_.end()) {
return it->second;
}
std::string name = buffer->name;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "buf_" + name;
}
Doc val = GetUniqueName(name);
memo_buf_[buffer] = val;
memo_buf_decl_[buffer] = AllocBufferDeclaration(buffer);
return val;
}
Doc TVMScriptPrinter::Print(const ObjectRef& node) {
if (!node.defined()) return Doc::Text("None");
if (node->IsInstance<StmtNode>()) {
return PrintOptionalInfo(Downcast<Stmt>(node)) << VisitStmt(Downcast<Stmt>(node));
} else if (node->IsInstance<PrimExprNode>()) {
return VisitExpr(Downcast<PrimExpr>(node));
} else if (node->IsInstance<TypeNode>()) {
return VisitType(Downcast<Type>(node));
} else if (node->IsInstance<PrimFuncNode>()) {
return PrintPrimFunc(Downcast<PrimFunc>(node));
} else if (node->IsInstance<IRModuleNode>()) {
return PrintIRModule(Downcast<IRModule>(node));
} else if (node->IsInstance<ArrayNode>()) {
return PrintArray(node.as<ArrayNode>());
} else if (node->IsInstance<BufferNode>()) {
return PrintBuffer(node.as<BufferNode>());
} else if (node->IsInstance<StringObj>()) {
return PrintString(node.as<StringObj>());
} else if (node->IsInstance<IterVarNode>()) {
return PrintIterVar(node.as<IterVarNode>());
} else if (node->IsInstance<RangeNode>()) {
return PrintRange(node.as<RangeNode>());
} else {
meta_collector_.Collect(node);
return this->meta_.GetMetaNode(node);
}
}
Doc TVMScriptPrinter::VisitExprDefault_(const Object* op) {
meta_collector_.Collect(GetRef<ObjectRef>(op));
return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
}
Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) {
meta_collector_.Collect(GetRef<ObjectRef>(op));
return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
}
Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op) {
return PrintConstScalar<int64_t>(op->dtype, &(op->value));
}
Doc TVMScriptPrinter::VisitExpr_(const FloatImmNode* op) {
return PrintConstScalar<double>(op->dtype, &(op->value));
}
Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
Doc TVMScriptPrinter::VisitExpr_(const CastNode* op) {
Doc doc;
if (cast(op->dtype, op->value)->IsInstance<CastNode>()) {
doc << Print(op->value) << ".astype(" << PrintDType(op->dtype) << ")";
} else {
doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")";
}
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const VarNode* op) {
const Var& var = GetRef<Var>(op);
return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
}
#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString) \
Doc TVMScriptPrinter::VisitExpr_(const OpName* op) { \
Doc doc; \
doc << '(' << Print(op->a) << OpString << Print(op->b) << ")"; \
return doc; \
}
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, "*")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(ModNode, " % ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ")
TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ")
Doc TVMScriptPrinter::VisitExpr_(const FloorDivNode* op) {
Doc doc;
doc << "tir.floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const FloorModNode* op) {
Doc doc;
doc << "tir.floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const MinNode* op) {
Doc doc;
doc << "tir.min(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op) {
Doc doc;
doc << "tir.max(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const NotNode* op) {
Doc doc;
doc << "not (" << Print(op->a) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op) {
Doc doc;
doc << "tir.select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
<< Print(op->false_value) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices);
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op) {
Doc doc;
if (op->dtype == DataType::Float(32) && is_one(op->predicate) &&
op->buffer_var->dtype == DataType::Float(32)) {
doc << Print(op->buffer_var) << "[" << Print(op->index) << "]";
} else {
doc << "tir.load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", "
<< Print(op->index);
if (!is_one(op->predicate) || op->dtype.lanes() != 1) {
doc << ", " << Print(op->predicate);
}
doc << ")";
}
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const RampNode* op) {
Doc doc;
doc << "tir.ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op) {
Doc doc;
doc << "tir.broadcast(" << Print(op->value) << ", " << op->lanes << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const LetNode* op) {
Doc doc;
doc << "tir.let(" << Print(op->var) << ", " << Print(op->value) << ", " << Print(op->body) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
if (auto* ptr_op = op->op.as<OpNode>()) {
doc << Doc::Text(ptr_op->name) << "(";
} else {
auto* op_gvar = op->op.as<GlobalVarNode>();
CHECK(op_gvar != nullptr);
doc << Doc::Text(op_gvar->name_hint) << "(";
}
std::vector<Doc> args;
for (const auto& arg : op->args) {
args.push_back(Print(arg));
}
args.push_back(Doc::Text("dtype=") << PrintDType(op->dtype));
doc << PrintSep(args, Doc::Text(", ")) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op) {
Doc doc;
doc << "tir.shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op) {
Doc doc;
doc << "tir.reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", "
<< Print(op->axis) << ", " << op->value_index << ")";
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with tir.let(" << Print(op->var) << ", " << Print(op->value) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers.insert(op->var.get());
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
<< Doc::NewLine() << PrintBody(op->body);
}
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) {
Doc doc;
// merge attr with allocate when possible
if (op->node->IsInstance<VarNode>() && op->attr_key == "storage_scope" &&
op->body->IsInstance<AllocateNode>()) {
const auto* alloc = Downcast<Allocate>(op->body).get();
if (alloc->buffer_var.same_as(op->node)) {
var_not_in_headers.insert(alloc->buffer_var.get());
if (current_num_ != num_child_ - 1) {
doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype)
<< ", " << Print(op->value);
if (!is_one(alloc->condition)) {
doc << ", " << Print(alloc->condition);
}
doc << ") as " << Print(op->node) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body));
} else {
doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", "
<< PrintDType(alloc->dtype) << ", " << Print(op->value);
if (!is_one(alloc->condition)) {
doc << ", " << Print(alloc->condition);
}
doc << ")" << Doc::NewLine() << PrintBody(alloc->body);
}
return doc;
}
}
// merge attr with realize when possible
if (op->node->IsInstance<BufferNode>() && op->attr_key == "realize_scope" &&
op->body->IsInstance<BufferRealizeNode>()) {
const auto* realize = Downcast<BufferRealize>(op->body).get();
if (realize->buffer.same_as(op->node)) {
if (current_num_ != num_child_ - 1) {
doc << "with tir.realize(" << Print(realize->buffer) << Print(realize->bounds) << ", "
<< Print(op->value);
if (!is_one(realize->condition)) {
doc << ", " << Print(realize->condition);
}
doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body));
} else {
doc << "tir.realize(" << Print(realize->buffer) << Print(realize->bounds) << ", "
<< Print(op->value);
if (!is_one(realize->condition)) {
doc << ", " << Print(realize->condition);
}
doc << ")" << Doc::NewLine() << PrintBody(realize->body);
}
return doc;
}
}
// concise thread env
if (op->node->IsInstance<IterVarNode>() && op->attr_key == "thread_extent") {
const auto* iter_var = Downcast<IterVar>(op->node).get();
CHECK(!iter_var->dom.defined());
var_not_in_headers.insert(iter_var->var.get());
var_env_map_[iter_var->var] = iter_var->thread_tag;
if (current_num_ != num_child_ - 1) {
doc << "with tir.launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
doc << "tir.launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) << ")";
doc << Doc::NewLine() << PrintBody(op->body);
}
return doc;
}
// default
if (current_num_ != num_child_ - 1) {
doc << "with tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", "
<< Print(op->value) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
doc << "tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", "
<< Print(op->value) << ")";
doc << Doc::NewLine() << PrintBody(op->body);
}
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with tir.Assert(" << Print(op->condition) << ", " << Print(op->message) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
} else {
doc << "assert " << Print(op->condition) << ", " << Print(op->message);
doc << Doc::NewLine() << PrintBody(op->body);
}
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) {
Doc doc;
if (!is_one(op->predicate) || op->value.dtype().lanes() != 1) {
doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", "
<< Print(op->value) << ", " << Print(op->predicate) << ")";
} else {
doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value);
}
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
LOG(FATAL)
<< "TVM Script Printer Internal Error: All the BufferRealize should be folded with Attr";
return Doc();
}
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be folded with Attr";
return Doc();
}
Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
if (!is_one(op->condition) && op->else_case.defined()) {
doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case));
}
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
std::vector<Doc> stmts;
for (Stmt stmt : op->seq) {
stmts.push_back(Print(stmt));
}
return PrintSep(stmts, Doc::NewLine());
}
Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
Doc doc;
doc << "tir.evaluate(" << Print(op->value) << ")";
return doc;
}
inline const char* ForType2String(ForType t) {
switch (t) {
case ForType::Serial:
return "serial";
case ForType::Parallel:
return "parallel";
case ForType::Vectorized:
return "vectorized";
case ForType::Unrolled:
return "unroll";
}
LOG(FATAL) << "Unknown ForType";
return "Unknown";
}
Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
var_not_in_headers.insert(op->loop_var.get());
doc << "for " << Print(op->loop_var)
<< " in tir." + std::string(ForType2String(op->for_type)) + "(" << Print(op->min) << ", "
<< Print(op->min + op->extent)
<< "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
return doc;
}
Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) {
Doc doc;
doc << "tir.prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
doc << "ty." << runtime::DLDataType2String(node->dtype);
return doc;
}
Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "ty.Ptr[" << Print(node->element_type) << "]";
return doc;
}
Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {
if (node->fields.empty()) {
return Doc::Text("None");
} else {
std::vector<Doc> fields;
for (Type field : node->fields) {
fields.push_back(Print(field));
}
return Doc::Text("ty.Tuple[") << Doc::Concat(fields) << "]";
}
}
Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
return doc;
}
Doc TVMScriptPrinter::PrintBody(const Stmt& body) {
int memo_num_child, memo_current_num;
std::swap(memo_num_child, num_child_);
std::swap(memo_current_num, current_num_);
Doc doc;
if (body->IsInstance<SeqStmtNode>()) {
const auto& op = Downcast<SeqStmt>(body);
num_child_ = op->seq.size();
current_num_ = 0;
std::vector<Doc> stmts;
for (Stmt stmt : op->seq) {
stmts.push_back(Print(stmt));
current_num_++;
}
doc = PrintSep(stmts, Doc::NewLine());
} else {
num_child_ = 1;
current_num_ = 0;
doc = Print(body);
}
std::swap(memo_num_child, num_child_);
std::swap(memo_current_num, current_num_);
return doc;
}
Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) {
auto* op = module.operator->();
Doc doc;
doc << "class Module:";
for (const auto& x : op->functions) {
func2var_[x.second.operator->()] = x.first;
}
Doc body = Doc::NewLine();
std::vector<Doc> functions;
for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
if ((*it).second.as<PrimFuncNode>()) {
functions.push_back(Print((*it).second));
}
}
body << TVMScriptPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
body << Doc::NewLine() << DumpMeta();
doc << Doc::Indent(4, body);
return doc;
}
Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
auto* op = primFunc.operator->();
// clear renaming map
memo_var_.clear();
memo_buf_.clear();
memo_buf_decl_.clear();
memo_reducer_.clear();
var_not_in_headers.clear();
buf_not_in_headers.clear();
// print signature
Doc doc;
doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint)
<< "(";
std::vector<Doc> params;
for (const auto& param : op->params) {
var_not_in_headers.insert(param.get());
params.push_back(Print(param) << ": " << Print(GetType(param)));
}
doc << PrintSep(params, Doc::Text(", ")) << ") -> " << Print(primFunc->ret_type) << ":";
Doc body = Doc::NewLine();
// print buffer_bind
for (const auto& it : op->buffer_map) {
buf_not_in_headers.insert(it.second.get());
body << Print(it.second) << " = tir.match_buffer(";
body << Print(it.first) << ", " << memo_buf_decl_[it.second];
body << ")" << Doc::NewLine();
}
// print comm_reducer
for (const auto& it : memo_reducer_) {
body << it.second << " = tir.comm_reducer(";
var_not_in_headers.insert(it.first->lhs[0].get());
var_not_in_headers.insert(it.first->rhs[0].get());
body << "lambda " << Print(it.first->lhs[0]) << ", " << Print(it.first->rhs[0]) << ": "
<< Print(it.first->result[0]) << ", " << Print(it.first->identity_element[0]);
body << ")" << Doc::NewLine();
}
// print body
body << "# body" << Doc::NewLine() << PrintBody(op->body);
// print func attrs
Doc header_attr;
if (primFunc->attrs.defined()) {
header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << "tir.func_attr({";
std::vector<Doc> attrs;
for (const auto& it : op->attrs->dict) {
attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
header_attr << PrintSep(attrs, Doc::Text(", ")) << "})";
}
// print buffer declarations(buffers not defined by buffer_bind or buffer_allocate)
Doc header_buf;
std::vector<const BufferNode*> bufs;
for (const auto& it : memo_buf_) {
if (buf_not_in_headers.find(it.first.get()) == buf_not_in_headers.end()) {
bufs.push_back(it.first.get());
}
}
if (!bufs.empty()) {
header_buf << Doc::NewLine() << "# buffer definition";
std::sort(bufs.begin(), bufs.end(), [&](const BufferNode* a, const BufferNode* b) {
return memo_buf_[GetRef<Buffer>(a)].str() < memo_buf_[GetRef<Buffer>(b)].str();
});
for (const auto& buf : bufs) {
header_buf << Doc::NewLine() << Print(GetRef<Buffer>(buf)) << " = tir.buffer_decl(";
header_buf << memo_buf_decl_[GetRef<Buffer>(buf)] << ")";
}
}
// print var declaration
Doc header_var;
std::vector<const VarNode*> vars;
for (const auto& it : memo_var_) {
if (var_not_in_headers.find(it.first.get()) == var_not_in_headers.end()) {
vars.push_back(it.first.get());
}
}
if (!var_env_map_.empty()) {
header_var << Doc::NewLine() << "# var definition";
for (const auto& it : var_env_map_) {
header_var << Doc::NewLine() << Print(it.first) << " = tir.env_thread("
<< Doc::StrLiteral(it.second) << ")";
}
}
if (!vars.empty()) {
std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) {
return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str();
});
for (const auto& var : vars) {
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.var(";
header_var << PrintDType(var->dtype) << ")";
}
}
doc << Doc::Indent(4, header_attr << header_var << header_buf << body);
return doc;
}
Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) {
Doc doc;
doc << '[';
for (size_t i = 0; i < op->size(); ++i) {
if (i != 0) {
doc << ", ";
}
doc << Print(op->at(i));
}
doc << ']';
return doc;
}
Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) {
Doc doc;
doc << "tir.iter_var(" << Print(op->var);
if (op->dom.defined()) {
doc << ", [" << Print(op->dom) << "], ";
} else {
doc << ", None, ";
}
doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
doc << Doc::StrLiteral(op->thread_tag) << ")";
return doc;
}
Doc TVMScriptPrinter::PrintRange(const RangeNode* op) {
return Print(op->min) << ":" << Print(op->min + op->extent);
}
Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
const Buffer& buffer = GetRef<Buffer>(op);
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}
TVM_REGISTER_GLOBAL("script.AsTVMScript")
.set_body_typed<std::string(const ObjectRef&, bool)>([](const ObjectRef& functions,
bool show_meta) {
CHECK(functions.as<PrimFuncNode>() != nullptr || functions.as<IRModuleNode>() != nullptr);
return "@tvm.script.tir\n" + TVMScriptPrinter(show_meta).Print(functions).str() + "\n";
});
} // namespace tir
} // namespace tvm