blob: ca6b71b4d8a292fed3f3069b76a885a45009699d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.cc
*/
#include "codegen_c_host.h"
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
namespace tvm {
namespace codegen {
CodeGenCHost::CodeGenCHost() {
module_name_ = name_supply_->FreshName(ffi::symbol::tvm_ffi_library_ctx);
}
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str, const std::unordered_set<std::string>& devices) {
emit_asserts_ = emit_asserts;
emit_fwd_func_decl_ = emit_fwd_func_decl;
declared_globals_.clear();
decl_stream << "// tvm target: " << target_str << "\n";
decl_stream << "#define TVM_EXPORTS\n";
decl_stream << "#include \"tvm/runtime/base.h\"\n";
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
decl_stream << "#include \"tvm/ffi/c_api.h\"\n";
decl_stream << "#include <math.h>\n";
decl_stream << "#include <stdbool.h>\n";
CodeGenCHost::InitGlobalContext();
CodeGenC::Init(output_ssa);
}
void CodeGenCHost::InitGlobalContext() {
decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n";
}
void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }
void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false);
}
void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
bool emit_fwd_func_decl) {
auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
if (global_symbol) {
function_names_.push_back(global_symbol.value());
}
emit_fwd_func_decl_ = emit_fwd_func_decl;
CodeGenC::AddFunction(gvar, func);
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) {
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenCHost: The entry func must have the global_symbol attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
function_names_.push_back(ffi::symbol::tvm_ffi_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(func->ret_type, stream);
stream << " " << ffi::symbol::tvm_ffi_main
<< "(void* self, void* args,int num_args, void* result) {\n";
stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n";
stream << "}\n";
}
}
void CodeGenCHost::GenerateForwardFunctionDeclarations(ffi::String global_symbol,
const ffi::Array<Type>& arg_types,
const Type& ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
for (auto& func_already_defined : GetFunctionNames()) {
if (global_symbol == func_already_defined) {
return;
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL ";
}
void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
TVM_FFI_ICHECK_EQ(lanes, 1) << "does not support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == DataType::Bool()) {
os << "bool";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
os << "half";
break;
case 32:
os << "float";
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
case 8:
os << "int8_t";
break;
case 16:
os << "int16_t";
break;
case 32:
os << "int32_t";
break;
case 64:
os << "int64_t";
break;
case 1:
os << "int32_t";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to C type";
}
void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << "))";
}
void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name,
const std::string& packed_func_name) {
this->PrintIndent();
this->stream << "if (" << packed_func_name << " == NULL) {\n";
int packed_func_if_scope = this->BeginScope();
this->PrintIndent();
this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\""
<< ", &" << packed_func_name << ") != 0) {\n";
int get_func_env_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(get_func_env_scope);
this->PrintIndent();
this->stream << "}\n";
this->EndScope(packed_func_if_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenCHost::PrintCallPacked(const CallNode* op) {
const StringImmNode* func_name = op->args[0].as<StringImmNode>();
TVM_FFI_ICHECK(func_name != nullptr)
<< "tvm_call_[c]packed_lowered expects first argument as function name";
int64_t begin = op->args[2].as<IntImmNode>()->value;
int64_t end = op->args[3].as<IntImmNode>()->value;
int64_t num_args = end - begin;
TVM_FFI_ICHECK_GE(num_args, 0);
std::string packed_func_name;
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
packed_func_name = GetPackedName(op);
this->PrintGetFuncFromBackend(func_name->value, packed_func_name);
} else {
// directly use the original symbol
TVM_FFI_ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered()));
packed_func_name = ffi::symbol::tvm_ffi_symbol_prefix + func_name->value;
}
std::string args_stack = PrintExpr(op->args[1]);
this->PrintIndent();
std::string result = name_supply_->FreshName("result");
this->stream << "TVMFFIAny " << result << ";\n";
this->PrintIndent();
// must make sure type_index is set to none
this->stream << result << ".type_index = kTVMFFINone;\n";
this->PrintIndent();
this->stream << result << ".zero_padding = 0;\n";
this->PrintIndent();
this->stream << result << ".v_int64 = 0;\n";
this->PrintIndent();
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", ";
} else {
this->stream << "if (" << packed_func_name << "(NULL, ";
}
this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", "
<< "&" << result << ") != 0) {\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(func_call_scope);
this->PrintIndent();
this->stream << "}\n";
}
std::string CodeGenCHost::GetPackedName(const CallNode* op) {
const StringImmNode* s = op->args[0].as<StringImmNode>();
TVM_FFI_ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
std::string func_name = s->value;
std::string packed_func_name = func_name + "_packed";
std::string unique_name;
auto it = declared_globals_.find(packed_func_name);
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
unique_name = name_supply_->FreshName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
return unique_name;
}
void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->op.same_as(builtin::tvm_stack_alloca())) {
std::string stack_name = name_supply_->FreshName("stack");
const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>();
TVM_FFI_ICHECK(num != nullptr);
static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMFFIAny);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit;
} else if (type == "tvm_ffi_any") {
size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else {
TVM_FFI_THROW(InternalError) << "Unknown stack alloca type " << type;
}
this->PrintIndent();
this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*)
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \""
<< op->message.as<StringImmNode>()->value << "\", NULL);\n";
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
PrintIndent();
stream << "}\n";
}
}
void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
void CodeGenCHost::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
template <typename T>
inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare,
std::ostream& os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
std::ostringstream temp_b;
VisitExpr(op->b, temp_b);
std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
os << "((" << a_id << ") " << compare << " (" << b_id << ") "
<< "? (" << a_id << ") : (" << b_id << "))";
}
ffi::Module BuildCHost(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = true;
std::unordered_set<std::string> devices;
if (mod->GetAttr<ffi::Map<GlobalVar, ffi::String>>("device_contexts") != nullptr) {
ffi::Map<GlobalVar, ffi::String> device_contexts =
mod->GetAttr<ffi::Map<GlobalVar, ffi::String>>("device_contexts").value();
for (auto const& context : device_contexts) {
devices.insert(context.second.data());
}
}
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
auto is_aot_executor_fn = [](const PrimFunc& func) -> bool {
return func->GetAttr<Bool>("runner_function", Bool(false)).value();
};
std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
TVM_FFI_ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
funcs.push_back({gvar, prim_func});
}
// Sort functions
auto sort_key = [&is_aot_executor_fn](const auto& kv) {
return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
};
std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) {
return sort_key(kv_a) < sort_key(kv_b);
});
for (const auto& [gvar, prim_func] : funcs) {
cg.DeclareFunction(gvar, prim_func);
}
// Codegen all functions. Passing emit_fwd_func_decl=true adds a
// forward declaration for any `builtin::call_extern`, based on the
// arguments provided to it.
for (const auto& [gvar, prim_func] : funcs) {
cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
std::string code = cg.Finish();
return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.c", BuildCHost);
}
} // namespace codegen
} // namespace tvm