blob: aa8775db531170674bcf7e40ffcb59a93b3c00d0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file relay_text_printer.cc
* \brief Printer to print out the IR text format
* that can be parsed by a parser.
*
* Supports ANF, GNF in relay and metadata.
*
* Inlining heuristics:
* - Always inline:
* - GlobalVar
* - Constant
* - Op
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <tvm/ir/module.h>
#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/tir/function.h>
#include "../ir/attr_functor.h"
#include "../parser/meta_ref.h"
#include "../relay/analysis/dependency_graph.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
namespace tvm {
namespace relay {
/*!
* \brief Print additional info about expr in comment.
* \param expr The expression.
*/
Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
Doc doc;
// default annotations
if (annotate_ == nullptr) {
if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
doc << " /* ty=" << Print(expr->checked_type()) << " */";
}
} else {
std::string annotated_expr = annotate_(expr);
if (annotated_expr != "") {
doc << annotated_expr;
}
}
return doc;
}
// indent a new body
Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) {
Doc doc;
Doc body;
doc << "{";
doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
doc << "}";
return doc;
}
// create a new scope by creating a new printer object. This allows temp var
// numbers to be reused and prevents hoisted vars from escaping too far
Doc RelayTextPrinter::PrintScope(const ObjectRef& node) {
// print in a new scope
doc_stack_.push_back(Doc());
// must print first so doc_stack_.back() reference doesn't become stale
Doc doc = Print(node, false, true);
doc = doc_stack_.back() << doc;
doc_stack_.pop_back();
return doc;
}
Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
if (node.defined() && node->IsInstance<BaseFuncNode>() &&
!node->IsInstance<relay::FunctionNode>()) {
// Temporarily skip non-relay functions.
// TODO(tvm-team) enhance the code to work for all functions
} else if (node.as<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}
Doc doc;
doc << PrintScope(node);
return doc;
}
Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
bool is_non_relay_func = node.defined() && node->IsInstance<BaseFuncNode>() &&
!node->IsInstance<relay::FunctionNode>();
if (node.as<ExprNode>() && !is_non_relay_func) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node));
} else {
// default module.
std::ostringstream os;
os << node;
return Doc::RawText(os.str());
}
}
Doc RelayTextPrinter::TempVar(int n) {
Doc doc;
return doc << "%" << n;
}
Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); }
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) {
std::string unique_prefix = prefix;
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (name_alloc_map_.count(name) == 0) {
unique_prefix = name;
break;
}
}
}
name_alloc_map_[unique_prefix] = 0;
return Doc::Text(unique_prefix);
}
Doc RelayTextPrinter::Print(Kind k) {
switch (k) {
case kType:
return Doc::Text("Type");
case kShapeVar:
return Doc::Text("Shape");
case kBaseType:
return Doc::Text("BaseType");
case kConstraint:
return Doc::Text("Constraint");
case kAdtHandle:
return Doc::Text("AdtHandle");
case kTypeData:
return Doc::Text("TypeData");
default:
LOG(ERROR) << "Unknown Kind";
throw;
}
}
/*!
* \brief Allocate name to a type variable.
* \param var The input type variable.
* \return The corresponding name.
*/
Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) {
if (memo_type_.count(var)) {
Doc val = memo_type_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName(name);
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
}
return val;
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
* \return The corresponding name.
*/
Doc RelayTextPrinter::AllocVar(const Var& var) {
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
Doc val = memo_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
}
return val;
}
bool RelayTextPrinter::IsUnique(const Expr& expr) {
auto it = dg_.expr_node.find(expr);
if (it == dg_.expr_node.end()) {
return true;
} else {
return !(it->second->parents.head && it->second->parents.head->next);
}
}
bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
expr.as<VarNode>() || expr.as<ConstructorNode>();
}
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info) {
// Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality.
// determine whether to inline
bool inline_expr = AlwaysInline(expr);
if (try_inline) {
inline_expr |= IsUnique(expr);
}
auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;
Doc printed_expr;
if (meta) {
printed_expr = meta_->GetMetaNode(GetRef<ObjectRef>(expr.get()));
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "(";
printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
printed_expr << ")";
} else {
printed_expr = VisitExpr(expr);
}
if (optional_info) {
printed_expr << PrintOptionalInfo(expr);
}
// add expr to doc
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
memo_[expr] = printed_expr;
return printed_expr;
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
return temp_var;
}
}
// Should only be triggered when op is a free variable being visited for the
// first time.
Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var>(op)); }
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
template <typename T>
Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
os << value;
} else if (dtype == DataType::Float(32)) {
os << value << 'f';
} else if (dtype == DataType::Float(64)) {
os << value;
} else if (dtype == DataType::Bool()) {
return Doc::PyBoolLiteral(value != 0);
} else {
os << value;
}
return Doc::Text(os.str());
}
Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
// Print out simple scalars directly.
if (op->is_scalar()) {
std::ostringstream os;
DataType dtype = DataType(op->data->dtype);
CHECK_EQ(op->data->ctx.device_type, kDLCPU);
if (dtype == DataType::Int(32)) {
return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
} else if (dtype == DataType::Int(64)) {
return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
} else if (dtype == DataType::Float(32)) {
return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
} else if (dtype == DataType::Float(64)) {
return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
} else if (dtype == DataType::Bool()) {
return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
}
}
// default fall-back, record it as meta node.
Doc doc;
// Don't append optional_info. Because the entry function is Print,
// and it will append the optional_info afterwards.
return doc << PrintExpr(GetRef<Expr>(op), true, false, false);
}
Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
std::vector<Doc> fields;
for (Expr field : op->fields) {
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << Doc::Concat(fields);
// conform to python tuple format (1,)
if (op->fields.size() == 1) {
doc << ",";
}
return doc << ")";
}
Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
Doc doc;
return doc << Print(op->tuple) << "." << op->index;
}
Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
Doc doc;
doc << "if (" << Print(op->cond) << ") ";
doc << PrintBody(op->true_branch);
doc << " else ";
doc << PrintBody(op->false_branch);
return doc;
}
Doc RelayTextPrinter::VisitExpr_(const LetNode* op) {
int n = 0;
Expr let = GetRef<Let>(op);
while (auto let_node = let.as<LetNode>()) {
Doc doc;
doc << "let " << AllocVar(let_node->var) << " = " << Print(let_node->value, false, true) << ";"
<< Doc::NewLine();
doc_stack_.push_back(doc);
let = let_node->body;
++n;
}
Doc doc = PrintScope(let);
for (int i = 0; i < n; ++i) {
doc = doc_stack_.back() << doc;
doc_stack_.pop_back();
}
return doc;
}
Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "[";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(Doc::Text(tv->name_hint));
}
doc << Doc::Concat(type_params);
doc << "]";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << Doc::Concat(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << PrintBody(fn->body);
return doc;
}
Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
if (auto* n = base_func.as<relay::FunctionNode>()) {
return PrintFunc(prefix, GetRef<relay::Function>(n));
} else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
std::ostringstream os;
os << GetRef<tir::PrimFunc>(n);
return Doc::RawText(os.str());
} else {
// def @xyz = meta['ExternalFunc'][id]
Doc doc;
doc << prefix << " = " << meta_->GetMetaNode(base_func);
return doc;
}
}
Doc RelayTextPrinter::PrintMod(const IRModule& mod) {
Doc doc;
int counter = 0;
// type definitions
for (const auto& kv : mod->type_definitions) {
if (counter++ != 0) {
doc << Doc::NewLine();
}
doc << Print(kv.second);
doc << Doc::NewLine();
}
// functions
for (const auto& kv : mod->functions) {
if (kv.second.as<relay::FunctionNode>()) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
}
if (counter++ != 0) {
doc << Doc::NewLine();
}
std::ostringstream os;
os << "def @" << kv.first->name_hint;
doc << PrintFunc(Doc::Text(os.str()), kv.second);
doc << Doc::NewLine();
}
return doc;
}
Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }
Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }
Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
// visit args first so they are lifted before the op
// this places op closer to its call site
std::vector<Doc> args;
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
const auto* cons_node = op->op.as<ConstructorNode>();
if (cons_node) {
doc << cons_node->name_hint;
} else {
doc << Print(op->op);
}
if (cons_node && cons_node->inputs.size() == 0) {
// don't print as a call if it's a 0-arity cons
return doc;
} else {
return doc << "(" << Doc::Concat(args) << ")";
}
}
Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) {
Doc doc;
return doc << "ref(" << Print(op->value) << ")";
}
Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) {
Doc doc;
return doc << Print(op->ref) << "^";
}
Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) {
Doc doc;
return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
}
Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) {
// TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
Doc doc;
Doc body;
doc << "match";
if (!op->complete) {
doc << "?";
}
doc << " (" << Print(op->data) << ") {";
std::vector<Doc> clause_docs;
for (const auto& clause : op->clauses) {
Doc clause_doc;
clause_doc << PrintPattern(clause->lhs, false) << " => ";
Doc rhs_doc = PrintScope(clause->rhs);
if (clause->rhs.as<LetNode>()) {
// only add braces if there are multiple lines on the rhs
rhs_doc = Doc::Brace("{", rhs_doc, "}");
}
clause_doc << rhs_doc << ",";
clause_docs.push_back(clause_doc);
}
doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
<< Doc::NewLine() << "}";
return doc;
}
Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) {
auto it = memo_pattern_.find(pattern);
if (it != memo_pattern_.end()) return it->second;
Doc printed_pattern;
if (meta) {
printed_pattern = meta_->GetMetaNode(GetRef<ObjectRef>(pattern.get()));
} else {
printed_pattern = VisitPattern(pattern);
}
memo_pattern_[pattern] = printed_pattern;
return printed_pattern;
}
Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) {
Doc doc;
doc << p->constructor->name_hint;
if (!p->patterns.empty()) {
doc << "(";
std::vector<Doc> pats;
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
doc << Doc::Concat(pats) << ")";
}
return doc;
}
Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) {
Doc doc;
doc << "(";
std::vector<Doc> pats;
for (const auto& pat : pt->patterns) {
pats.push_back(Print(pat));
}
doc << Doc::Concat(pats) << ")";
return doc;
}
Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_"); }
Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); }
Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) {
Doc doc;
doc << n->name_hint;
if (in_adt_def_ && n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
inputs.push_back(Print(input));
}
doc << Doc::Concat(inputs) << ")";
}
return doc;
}
//------------------------------------
// Overload of Type printing functions
//------------------------------------
Doc RelayTextPrinter::PrintType(const Type& type, bool meta) {
auto it = memo_type_.find(type);
if (it != memo_type_.end()) return it->second;
Doc printed_type;
if (meta) {
printed_type = meta_->GetMetaNode(GetRef<ObjectRef>(type.get()));
} else {
printed_type = VisitType(type);
}
memo_type_[type] = printed_type;
return printed_type;
}
Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) {
// by default always print as meta data
return Print(GetRef<ObjectRef>(node), true);
}
Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); }
Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) {
return Doc::Text(node->name_hint);
}
Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) {
Doc doc = PrintType(node->func, false);
std::vector<Doc> args;
for (const Type& t : node->args) {
args.push_back(PrintType(t, false));
}
doc << "[";
doc << Doc::Concat(args);
doc << "]";
return doc;
}
Doc RelayTextPrinter::PrintDType(DataType dtype) {
return Doc::Text(runtime::DLDataType2String(dtype));
}
Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
// scalar type
if (node->shape.size() == 0) {
return PrintDType(node->dtype);
}
Doc doc;
doc << "Tensor[(";
std::vector<Doc> shapes;
for (ObjectRef shape : node->shape) {
shapes.push_back(PrintAttr(shape));
}
doc << Doc::Concat(shapes);
return doc << "), " << PrintDType(node->dtype) << "]";
}
Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) {
std::vector<Doc> fields;
for (Type field : node->fields) {
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << Doc::Concat(fields);
// conform to python tuple format (1,)
if (node->fields.size() == 1) {
doc << ",";
}
return doc << ")";
}
Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) {
Doc doc;
doc << "fn ";
if (node->type_params.size() != 0) {
doc << "[";
std::vector<Doc> type_params;
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << Doc::Concat(type_params);
doc << "]";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type));
}
return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
}
Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) {
Doc doc;
return doc << "ref(" << Print(node->value) << ")";
}
Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
in_adt_def_ = true;
Doc doc;
doc << "type " << Print(node->header);
// type vars
if (node->type_vars.size() != 0) {
doc << "[";
std::vector<Doc> type_vars;
for (Type type_var : node->type_vars) {
type_vars.push_back(Print(type_var));
}
doc << Doc::Concat(type_vars) << "]";
}
doc << " ";
std::vector<Doc> constructor_docs;
for (Constructor constructor : node->constructors) {
constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
}
Doc separator;
separator << "," << Doc::NewLine();
Doc adt_body;
adt_body << Doc::Concat(constructor_docs, separator);
// add trailing comma if there are any constructors
if (!constructor_docs.empty()) {
adt_body << ",";
}
doc << Doc::Brace("{", adt_body, "}");
in_adt_def_ = false;
return doc;
}
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
if (value.defined()) {
Doc printed_attr;
if (value.as<tvm::tir::AnyNode>()) {
printed_attr << "?";
} else if (auto str_obj = value.as<tvm::StringObj>()) {
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else {
printed_attr = VisitAttr(value);
}
return printed_attr;
} else {
return Doc::Text("None");
}
}
Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
return PrintAttr(GetRef<ObjectRef>(op), true);
}
Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
Doc doc;
doc << "[";
std::vector<Doc> arr_vals;
for (auto val : *op) {
arr_vals.push_back(PrintAttr(val));
}
doc << Doc::Concat(arr_vals);
doc << "]";
return doc;
}
Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
return ScalarLiteral(op->dtype, op->value);
}
Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
return ScalarLiteral(op->dtype, op->value);
}
Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
return Doc::StrLiteral(op->value);
}
/*!
* \brief Attribute printer which prints the attributes in the call.
*/
class RelayTextPrinter::AttrPrinter : public AttrVisitor {
public:
AttrPrinter(std::vector<Doc>* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {}
template <typename T>
void PrintKV(const char* key, const T& value) {
Doc doc;
doc << key << "=" << value;
docs->push_back(doc);
}
void Visit(const char* key, double* value) final {
Doc doc;
doc << key << "=" << *value << "f";
docs->push_back(doc);
}
void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); }
void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); }
void Visit(const char* key, int* value) final { PrintKV(key, *value); }
void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); }
void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); }
void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; }
void Visit(const char* key, DataType* value) final {
PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value)));
}
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::ObjectRef* obj) final {
PrintKV(key, parent_->PrintAttr(*obj));
}
private:
std::vector<Doc>* docs;
RelayTextPrinter* parent_;
};
std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* op_node = op.as<OpNode>();
if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
// fallback
Doc doc;
doc << meta_->GetMetaNode(attrs);
docs.push_back(doc);
return docs;
} else {
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
return docs;
}
}
std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* dict_attrs = attrs.as<DictAttrsNode>();
CHECK(dict_attrs);
for (const auto& k : dict_attrs->dict) {
Doc doc;
doc << k.first << "=" << Print(k.second);
docs.push_back(doc);
}
return docs;
}
TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
auto text = AsText(node, false, nullptr);
return text;
});
} // namespace relay
} // namespace tvm