blob: fe0e73d99e9986596f561bcd67cdc916de705b4a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
*
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relax/transform/run_codegen.cc
* \brief Run codegen for annotated relax functions.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <iostream>
#include "../../support/ordered_set.h"
#include "utils.h"
namespace tvm {
namespace relax {
class CodeGenRunner : ExprMutator {
public:
using OptionMap = Map<String, ObjectRef>;
explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
IRModule Run(Optional<Map<String, OptionMap>> target_options,
Array<String> entry_function_names) {
IRModule mod = builder_->GetContextIRModule();
support::OrderedSet<GlobalVar> entry_functions;
// Any user-provided functions are treated as entry functions.
for (const auto& name : entry_function_names) {
entry_functions.insert(mod->GetGlobalVar(name));
}
// In addtion, any externally-exposed function that does not
// belong to a specific codegen may be an entry function. These
// are added in alphabetical order, to ensure consistent order of
// evaluation for debug/test purposes.
{
std::vector<GlobalVar> attr_entry_functions;
for (const auto& [gv, func] : mod->functions) {
if (func->GetLinkageType() == LinkageType::kExternal &&
!func->GetAttr<String>(attr::kCodegen) && func->IsInstance<relax::FunctionNode>()) {
attr_entry_functions.push_back(gv);
}
}
std::sort(attr_entry_functions.begin(), attr_entry_functions.end(),
[](const auto& gvar_a, const auto& gvar_b) {
return gvar_a->name_hint > gvar_b->name_hint;
});
for (const auto& gvar : attr_entry_functions) {
entry_functions.insert(gvar);
}
}
for (const auto& gvar : entry_functions) {
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(mod->Lookup(gvar))));
}
auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
auto out_mod = builder_->GetContextIRModule();
if (ext_mods.size()) {
out_mod = WithAttr(out_mod, tvm::attr::kExternalMods, std::move(ext_mods));
}
if (constant_names.size()) {
// Some backends (e.g. TensorRT) expect constants to be passed when they are instantiated
Map<String, runtime::NDArray> constants;
for (const auto& [constant, name] : constant_names) {
ICHECK(!constants.count(name)) << "More than one constant with the name " << name;
constants.Set(name, constant->data);
}
out_mod = WithAttr(out_mod, tvm::attr::kConstNameToConstant, std::move(constants));
}
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this.
return DeadCodeElimination(out_mod, entry_function_names);
}
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call_node) override {
auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
if (auto const* gvar_node = call_node->op.as<GlobalVarNode>()) {
const GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
auto create_call_dps_packed = [call_node, this](Expr extern_func,
StructInfo ret_struct_info) {
Array<Expr> new_args({extern_func});
new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); })));
static const Op& call_op = Op::Get("relax.call_dps_packed");
return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
};
auto ret_sinfo = GetStructInfo(call);
if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
return create_call_dps_packed(it->second, ret_sinfo);
} else if (auto opt_func = builder_->GetContextIRModule()->Lookup(gvar).as<Function>()) {
// TODO(@sunggg): Is there any better way to get this func?
Function func = opt_func.value();
Expr new_func = VisitExpr(func);
if (new_func->IsInstance<ExternFuncNode>()) {
extern_funcs_[gvar_node] = new_func;
// Remove the global symbol and codegen attributes from the function so that it can be
// removed the module.
static const runtime::PackedFunc* RemoveFuncAttrFunc =
runtime::Registry::Get("ir.BaseFuncWithoutAttr");
ICHECK(RemoveFuncAttrFunc);
func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol);
func = (*RemoveFuncAttrFunc)(func, attr::kCodegen);
builder_->UpdateFunction(gvar, func);
return create_call_dps_packed(new_func, ret_sinfo);
}
}
}
Array<Expr> new_args;
for (const auto& arg : call_node->args) {
new_args.push_back(VisitExpr(arg));
}
return Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span);
}
Expr VisitExpr_(const FunctionNode* func_node) override {
Function func = GetRef<Function>(func_node);
auto opt_codegen = func->GetAttr<String>(attr::kCodegen);
if (opt_codegen) {
auto ext_symbol = GetExtSymbol(func);
size_t count = 0;
PostOrderVisit(func->body, [=, &count](Expr e) {
if (e->IsInstance<ConstantNode>()) {
// Make sure to pick a unique name
auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" + std::to_string(count++);
auto constant = Downcast<Constant>(e);
constant_names.Set(constant, name);
}
});
return ExternFunc(GetExtSymbol(func));
} else {
return ExprMutator::VisitExpr_(func_node);
}
}
private:
Array<runtime::Module> InvokeCodegen(IRModule mod, Map<String, OptionMap> target_options) {
std::unordered_map<std::string, Array<Function>> target_functions;
for (const auto& entry : mod->functions) {
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
continue;
}
PostOrderVisit(entry.second, [&target_functions](Expr e) {
if (e->IsInstance<FunctionNode>()) {
auto f = Downcast<Function>(e);
if (auto target_opt = f->GetAttr<String>(attr::kCodegen)) {
String target = target_opt.value();
target_functions[target].push_back(f);
}
}
});
}
Array<runtime::Module> ext_mods;
for (const auto& [target, functions] : target_functions) {
OptionMap options = target_options.Get(target).value_or({});
// Start the codegen process.
// Get the codegen with its ffi key.
String codegen_name = "relax.ext." + target;
auto codegen = runtime::Registry::Get(codegen_name);
ICHECK(codegen) << "Codegen is not found: " << codegen_name << "\n";
Array<runtime::Module> compiled_functions = (*codegen)(functions, options, constant_names);
ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end());
}
return ext_mods;
}
/*! \brief The names of all constants in the original module. */
Map<Constant, String> constant_names;
/*! \brief Extern funcs for each global variable. */
std::unordered_map<const GlobalVarNode*, Expr> extern_funcs_;
};
} // namespace relax
namespace transform {
Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_options,
Array<String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
PassContext pc) {
return relax::CodeGenRunner(m).Run(target_options, entry_functions);
};
return CreateModulePass(pass_func, 0, "RunCodegen", {});
}
TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen);
} // namespace transform
} // namespace tvm