| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include "./utils.h" |
| |
| namespace tvm { |
| namespace script { |
| namespace printer { |
| |
| bool AtTopLevelFunction(const IRDocsifier& d) { |
| // fewer than 2 frames: not in a function at all |
| if (d->frames.size() < 2) { |
| return false; |
| } |
| // if the first frame is a RelaxFrame, then this is not inside a module. |
| // 2 frames => we are at a function (more than 2 => nested function) |
| if (d->frames[0]->IsInstance<RelaxFrameNode>()) { |
| return d->frames.size() == 2; |
| } |
| // otherwise the first two frames pertain to an IR module, |
| // so 3 frames => we are at a top-level function (more than 3 => nested function) |
| return d->frames.size() == 3; |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { RelaxFrameNode::RegisterReflection(); } |
| |
| TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
| .set_dispatch<relax::Function>("", [](relax::Function n, AccessPath n_p, IRDocsifier d) -> Doc { |
| std::unordered_set<const tir::VarNode*> func_vars; |
| With<RelaxFrame> f(d); |
| |
| IdDoc func_name(""); |
| // if we are binding a local definition, then calling d->Define |
| // will result in a repeated definition and an incorrect displayed name |
| if (ffi::Optional<ffi::String> name = GetBindingName(d)) { |
| func_name = IdDoc(name.value()); |
| } else { |
| func_name = IdDoc(FindFunctionName(d, n).value_or("main")); |
| } |
| (*f)->AddDispatchToken(d, "relax"); |
| (*f)->is_func = true; |
| (*f)->func_vars = &func_vars; |
| // Step 1. Print the return type |
| ffi::Optional<ExprDoc> ret_type = std::nullopt; |
| if (const auto& func_sinfo = relax::MatchStructInfo<relax::FuncStructInfo>(n)) { |
| ret_type = d->AsDoc<ExprDoc>(func_sinfo.value()->ret, // |
| n_p->Attr("struct_info_")->Attr("ret")); |
| } |
| // Step 2. Print params |
| ffi::Array<AssignDoc> params; |
| { |
| AccessPath params_p = n_p->Attr("params"); |
| for (int i = 0, l = n->params.size(); i < l; ++i) { |
| params.push_back(AssignDoc( |
| /*lhs=*/DefineVar(n->params[i], *f, d), |
| /*rhs=*/std::nullopt, |
| StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); |
| } |
| } |
| // Step 3. Clean up func variables |
| (*f)->func_vars = nullptr; |
| // Step 4. Print attributes |
| if (n->attrs.defined() && !n->attrs->dict.empty()) { |
| // If the function is a global function and has a global symbol, |
| // then don't print the global symbol (it will be implicit from not being private). |
| // For a function without an IR module whose global symbol |
| // doesn't match the function name, we should still print the global symbol attribute. |
| if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && |
| Downcast<ffi::String>(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { |
| ffi::Map<ffi::String, Any> new_attrs; |
| for (auto kv : n->attrs->dict) { |
| if (kv.first != tvm::attr::kGlobalSymbol) { |
| new_attrs.Set(kv.first, kv.second); |
| } |
| } |
| if (!new_attrs.empty()) { |
| (*f)->stmts.push_back(ExprStmtDoc( |
| Relax(d, "func_attr") // |
| ->Call({d->AsDoc<ExprDoc>(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); |
| } |
| } else { |
| (*f)->stmts.push_back( |
| ExprStmtDoc(Relax(d, "func_attr") // |
| ->Call({d->AsDoc<ExprDoc>(n->attrs, n_p->Attr("attrs"))}))); |
| } |
| } |
| // Step 5. Prepare the decorator (include purity if it's impure) |
| ExprDoc decorator = Relax(d, "function"); |
| ffi::Array<ExprDoc, void> pos_args = {}; |
| ffi::Array<ffi::String, void> dec_keys; |
| ffi::Array<ExprDoc, void> dec_values; |
| if (!n->is_pure) { |
| dec_keys.push_back("pure"); |
| dec_values.push_back(LiteralDoc::Boolean(false, ffi::Optional<AccessPath>())); |
| } |
| // if the function is global or is not in a module and does not have a global symbol, |
| // indicate that it's private |
| if (AtTopLevelFunction(d) && |
| (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { |
| dec_keys.push_back("private"); |
| dec_values.push_back(LiteralDoc::Boolean(true, ffi::Optional<AccessPath>())); |
| } |
| if (dec_keys.size()) { |
| decorator = decorator->Call(pos_args, dec_keys, dec_values); |
| } |
| |
| // Step 6. Print body |
| ffi::Array<StmtDoc> body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); |
| (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); |
| return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
| .set_dispatch<relax::ExternFunc>( // |
| "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc { |
| // TODO(@junrushao): print more information out of extern function. |
| return Relax(d, "ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)}); |
| }); |
| |
| TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax); |
| TVM_SCRIPT_REPR(relax::ExternFuncNode, ReprPrintRelax); |
| |
| } // namespace printer |
| } // namespace script |
| } // namespace tvm |