blob: aac5656f91467ea9847c43ba891e88f3664e369e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir/type.h>
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
TVM_FFI_STATIC_INIT_BLOCK() { IRFrameNode::RegisterReflection(); }
struct SortableFunction {
int priority;
GlobalVar gv;
BaseFunc func;
explicit SortableFunction(const std::pair<GlobalVar, BaseFunc>& obj)
: priority(0), gv(obj.first), func(obj.second) {
if (gv->name_hint == "main") {
priority = 1000;
} else if (obj.second->GetTypeKey() == "tir.PrimFunc") {
priority = 1;
} else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc") {
priority = 2;
} else if (obj.second->GetTypeKey() == "relax.expr.Function") {
priority = 3;
} else {
LOG(FATAL) << "TypeError: TVMScript cannot print functions of type: "
<< obj.second->GetTypeKey();
}
}
bool operator<(const SortableFunction& other) const {
if (this->priority != other.priority) {
return this->priority < other.priority;
}
return this->gv->name_hint < other.gv->name_hint;
}
};
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<IRModule>("", [](IRModule mod, AccessPath p, IRDocsifier d) -> Doc {
std::vector<SortableFunction> functions;
for (const auto& kv : mod->functions) {
functions.push_back(SortableFunction(kv));
}
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
IdDoc module_doc = d->Define(mod, f(), GetBindingName(d).value_or("Module"));
(*f)->global_infos = &mod->global_infos;
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(IR(d, "module_attrs") //
->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
}
if (mod->global_infos.defined() && !mod->global_infos.empty()) {
(*f)->stmts.push_back(ExprStmtDoc(
IR(d, "module_global_infos") //
->Call({d->AsDoc<ExprDoc>(mod->global_infos, p->Attr("global_infos"))})));
}
// Declare GlobalVars first
IdDoc module_alias = d->cfg->module_alias.empty() ? module_doc : IdDoc(d->cfg->module_alias);
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
d->Define(gv, f(), [=]() {
return d->AsDoc<ExprDoc>(mod, p->Attr("global_vars"))->Attr(gv->name_hint);
});
}
// Print functions
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& base_func = entry.func;
d->cfg->binding_names.push_back(gv->name_hint);
Doc doc = d->AsDoc(base_func, p->Attr("functions")->MapItem(gv));
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
(*f)->stmts.back()->source_paths = std::move(doc->source_paths);
} else if (auto stmt = doc.as<StmtDoc>()) {
(*f)->stmts.push_back(stmt.value());
} else if (auto func = doc.as<FunctionDoc>()) {
(*f)->stmts.push_back(func.value());
} else if (auto expr = doc.as<ExprDoc>()) {
ExprDoc lhs = IdDoc(gv->name_hint);
AssignDoc assignment(lhs, expr.value(), std::nullopt);
(*f)->stmts.push_back(assignment);
} else {
LOG(FATAL) << "TypeError: "
<< "Expected IRModule to only contain functions, "
<< " but mod[" << gv->name_hint << "] with type " << base_func->GetTypeKey()
<< " produced Doc type of " << doc->GetTypeKey();
}
}
return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<DictAttrs>("", [](DictAttrs attrs, AccessPath p, IRDocsifier d) -> Doc {
return d->AsDoc(attrs->dict, p->Attr("dict"));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<GlobalVar>("", [](GlobalVar gv, AccessPath p, IRDocsifier d) -> Doc {
return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<DummyGlobalInfo>("", [](GlobalInfo ginfo, AccessPath p, IRDocsifier d) -> Doc {
return IR(d, "dummy_global_info")->Call({});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<VDevice>("", [](VDevice vdev, AccessPath p, IRDocsifier d) -> Doc {
d->AddGlobalInfo("vdevice", vdev);
ffi::Map<ffi::String, ffi::Any> config = vdev->target->Export();
return IR(d, "vdevice")
->Call({d->AsDoc<ExprDoc>(config, p),
LiteralDoc::Int(vdev->vdevice_id, p->Attr("vdevice_id")),
LiteralDoc::Str(vdev->memory_scope, p->Attr("memory_scope"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, AccessPath p, IRDocsifier d) -> Doc {
return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<FuncType>("", [](FuncType func_type, AccessPath p, IRDocsifier d) -> Doc {
return IR(d, "FuncType")
->Call({
d->AsDoc<ExprDoc>(func_type->arg_types, p->Attr("arg_types")),
d->AsDoc<ExprDoc>(func_type->ret_type, p->Attr("ret_type")),
});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Range>("ir", [](Range range, AccessPath p, IRDocsifier d) -> Doc {
return IR(d, "Range")
->Call({
d->AsDoc<ExprDoc>(range->min, p->Attr("min")),
d->AsDoc<ExprDoc>(range->extent + range->min, p->Attr("extent")),
});
});
std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) {
return ReprPrintIR(mod, cfg);
}
TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR);
TVM_SCRIPT_REPR(DictAttrsNode, ReprPrintIR);
TVM_SCRIPT_REPR(FuncTypeNode, ReprPrintIR);
TVM_SCRIPT_REPR(RangeNode, ReprPrintIR);
TVM_SCRIPT_REPR(IRModuleNode, ReprPrintIRModule);
} // namespace printer
} // namespace script
} // namespace tvm