blob: 7a109a147280472c82309bcc27ee21547a3a99eb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/plugin/tvm_codegen.cc
*/
#include "tvm_codegen.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace contrib {
namespace msc {
void TVMPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) {
BasePluginCodeGen<TVMPluginCodeGenConfig>::CodeGenAttrDeclare(plugin);
const auto& attr_name = MetaAttrCls(plugin);
// exprs to meta_attr
stack_.comment("convert exprs to meta attrs method")
.func_def(attr_name + "_from_exprs", "const " + attr_name);
for (const auto& a : plugin->attrs) {
const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue";
stack_.func_arg(a->name, "const " + anno + "&");
}
// args to meta_attr
stack_.comment("convert args to meta attrs method")
.func_def(attr_name + "_from_args", "const " + attr_name)
.func_arg("args", "ffi::PackedArgs")
.func_arg("pos", "size_t&");
}
void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) {
const auto& attr_name = MetaAttrCls(plugin);
// exprs to meta_attr
stack_.func_def(attr_name + "_from_exprs", "const " + attr_name);
for (const auto& a : plugin->attrs) {
const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue";
stack_.func_arg(a->name, "const " + anno + "&");
}
stack_.func_start().declare(attr_name, "meta_attr");
for (const auto& a : plugin->attrs) {
const ffi::String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim";
stack_.func_call("TVMUtils::" + convert)
.call_arg(a->name)
.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name));
}
stack_.func_end("meta_attr");
// args to meta_attr
stack_.comment("convert args to meta attrs method")
.func_def(attr_name + "_from_args", "const " + attr_name)
.func_arg("args", "ffi::PackedArgs")
.func_arg("pos", "size_t&")
.func_start()
.declare(attr_name, "meta_attr");
for (const auto& a : plugin->attrs) {
if (IsListType(a->type)) {
// TODO(meng.tong): support list atribute
LOG_FATAL << "ListType argument is not supported for tvm runtime";
stack_.func_call("TVMUtils::AttrFromArg", a->name + "_size")
.call_arg(DocUtils::ToIndex("args", "pos"))
.func_call("TVMUtils::AttrFromArgs")
.call_arg("args")
.call_arg("pos")
.call_arg(a->name + "_size")
.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name))
.assign("pos", "pos + 1 + " + a->name + "_size");
} else {
stack_.func_call("TVMUtils::AttrFromArg")
.call_arg(DocUtils::ToIndex("args", "pos"))
.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name))
.assign("pos", "pos + 1");
}
}
stack_.func_end("meta_attr");
}
void TVMPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) {
// infer struct info
stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array<TensorStructInfo>");
for (const auto& t : plugin->inputs) {
stack_.func_arg(t->name, "const Expr&");
}
for (const auto& a : plugin->attrs) {
const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue";
stack_.func_arg(a->name, "const " + anno + "&");
}
// infer layout
stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput")
.func_arg("inputs", "const ffi::Array<Expr>&")
.func_arg("var_layout_map", "const VarLayoutMap&");
}
void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) {
const auto& attr_name = MetaAttrCls(plugin);
// infer struct info
ffi::Array<ffi::String> infer_args{"input_metas", "meta_attr", "false"};
stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array<TensorStructInfo>");
for (const auto& t : plugin->inputs) {
stack_.func_arg(t->name, "const Expr&");
}
for (const auto& a : plugin->attrs) {
const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue";
stack_.func_arg(a->name, "const " + anno + "&");
}
stack_.func_start()
.comment("extract meta attrs")
.func_call(attr_name + "_from_exprs", DocUtils::ToDeclare("const auto&", "meta_attr"));
for (const auto& a : plugin->attrs) {
stack_.call_arg(a->name);
}
stack_.comment("extract meta inputs").declare("std::vector<MetaTensor>", "input_metas");
for (const auto& t : plugin->inputs) {
stack_.func_call("push_back", "", "input_metas")
.inplace_start("TVMUtils::ToMetaTensor")
.call_arg(t->name)
.inplace_end();
}
stack_.declare("std::vector<MetaTensor>", "output_metas");
CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas");
stack_.declare("ffi::Array<TensorStructInfo>", "output_sinfo");
for (size_t i = 0; i < plugin->outputs.size(); i++) {
stack_.func_call("push_back", "", "output_sinfo")
.inplace_start("TVMUtils::ToTensorStructInfo")
.call_arg(DocUtils::ToIndex("output_metas", i));
int device_idx = plugin->FindDeviceRefIdx(plugin->outputs[i]);
if (device_idx >= 0) {
stack_.call_arg(plugin->inputs[device_idx]->name);
} else {
stack_.inplace_start("TVMUtils::ToTVMDevice")
.call_arg(plugin->outputs[i]->device)
.inplace_end();
}
stack_.inplace_end();
}
stack_.func_end("output_sinfo");
// infer layout
stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput")
.func_arg("inputs", "const ffi::Array<Expr>&")
.func_arg("var_layout_map", "const VarLayoutMap&")
.func_start()
.comment("define attrs");
for (size_t i = 0; i < plugin->attrs.size(); i++) {
const auto& attr = plugin->attrs[i];
const ffi::String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue";
stack_
.func_call("Downcast<" + anno + ">",
DocUtils::ToDeclare("const auto&", "attr_" + attr->name))
.call_arg(DocUtils::ToIndex("inputs", i + plugin->inputs.size()));
}
stack_.declare("ffi::Array<NLayout>", "arg_layouts")
.declare("ffi::Array<NLayout>", "output_layouts")
.comment("extract meta attrs")
.func_call(attr_name + "_from_exprs", "const " + attr_name + "& meta_attr");
for (const auto& a : plugin->attrs) {
stack_.call_arg("attr_" + a->name);
}
stack_.comment("extract meta inputs")
.declare("std::vector<MetaTensor>", "input_metas")
.for_start("i", 0, plugin->inputs.size())
.func_call("LayoutUtils::InferLayoutDecision",
DocUtils::ToDeclare("const auto&", "in_layout"))
.call_arg(DocUtils::ToIndex("inputs", "i"))
.call_arg("var_layout_map")
.func_call("push_back", "", "arg_layouts")
.call_arg("in_layout")
.func_call("push_back", "", "input_metas")
.inplace_start("TVMUtils::ToMetaTensor")
.call_arg(DocUtils::ToIndex("inputs", "i"))
.call_arg("in_layout")
.inplace_end()
.for_end()
.comment("add fake layout for attrs")
.for_start("i", 0, plugin->attrs.size())
.func_call("push_back", "", "arg_layouts")
.inplace_start("LayoutDecision")
.call_arg(DocUtils::ToStr(""))
.inplace_end()
.for_end();
stack_.declare("std::vector<MetaTensor>", "output_metas");
CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas");
stack_.for_start("i", 0, plugin->outputs.size())
.func_call("push_back", "", "output_layouts")
.inplace_start("LayoutDecision")
.call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("output_metas", "i"), "layout_name()"))
.inplace_end()
.for_end()
.declare("ffi::Array<NLayout>", "input_layouts")
.func_call("push_back", "", "input_layouts")
.inplace_start("LayoutDecision")
.call_arg(DocUtils::ToStr(""))
.inplace_end()
.func_call("push_back", "", "input_layouts")
.call_arg("arg_layouts")
.func_call("InferLayoutOutput", DocUtils::ToDeclare("const auto&", "infer_output"))
.call_arg("input_layouts")
.call_arg("output_layouts")
.call_arg("Attrs()");
stack_.func_end("infer_output");
// register funcs
stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF")
.call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name))
.call_arg("InferStructInfo" + plugin->name)
.line()
.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF")
.call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name))
.call_arg("InferLayout" + plugin->name)
.line();
}
void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) {
ICHECK(!plugin->externs.count("infer_buffer")) << "infer_buffer is not supported for tvm runtime";
const auto& attr_name = MetaAttrCls(plugin);
const auto& func_name = ComputeName(plugin);
ffi::String device_cond = "";
ffi::String device_index = "";
for (size_t i = 0; i < plugin->inputs.size(); i++) {
ffi::String device_type = "";
if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") {
device_type = "DLDeviceType::kDLCUDA";
} else {
device_type = "DLDeviceType::kDLCPU";
}
device_cond = device_cond + "TVMUtils::OnDevice(" + plugin->inputs[i]->name + ", " +
device_type + ")" + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
stack_.func_def(func_name).func_arg("args", "ffi::PackedArgs").func_arg("ret", "ffi::Any*");
stack_.func_start().comment("define tensors");
for (size_t i = 0; i < plugin->inputs.size(); i++) {
stack_.assign(plugin->inputs[i]->name, DocUtils::ToIndex("args", i), "DLTensor*");
}
stack_.comment("extract meta attrs")
.assign("pos", plugin->inputs.size(), "size_t")
.func_call(attr_name + "_from_args", "const " + attr_name + "& meta_attr")
.call_arg("args")
.call_arg("pos");
for (size_t i = 0; i < plugin->outputs.size(); i++) {
stack_.assign(plugin->outputs[i]->name, DocUtils::ToIndex("args", "pos + " + std::to_string(i)),
"DLTensor*");
}
stack_.comment("do the compute").cond_if(device_cond);
CodeGenCompute(plugin, "cuda");
stack_.cond_else();
CodeGenCompute(plugin, "cpu");
stack_.cond_end().func_end();
// register the compute
stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED")
.call_arg(DocUtils::ToStr(plugin->name))
.call_arg(func_name)
.line();
}
void TVMPluginCodeGen::CodeGenCmake(const std::set<ffi::String>& devices) {
ffi::Map<ffi::String, ffi::String> flags;
flags.Set("PLUGIN_SUPPORT_TVM", "");
CodeGenPreCmake(devices, flags);
stack_.line("set(CMAKE_CXX_STANDARD 17)")
.line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-macro-redefined\")")
.line()
.line("set(TVM_ROOT " + config()->tvm_root + ")")
.line("find_library(TVM_LIB NAMES tvm HINTS ${TVM_ROOT}/build NO_DEFAULT_PATH)");
ffi::Array<ffi::String> includes, libs;
includes.push_back("${TVM_ROOT}/include");
includes.push_back("${TVM_ROOT}/3rdparty/dmlc-core/include");
includes.push_back("${TVM_ROOT}/3rdparty/dlpack/include");
includes.push_back("${TVM_ROOT}/3rdparty/compiler-rt");
libs.push_back("${TVM_LIB}");
CodeGenPostCmake(devices, includes, libs);
}
void TVMPluginCodeGen::CodeGenManagerDepends() {
BasePluginCodeGen<TVMPluginCodeGenConfig>::CodeGenManagerDepends();
stack_.line("from tvm import relax")
.line("from tvm.relax import call_dps_packed")
.line("from tvm.contrib.msc.plugin import utils as plugin_utils")
.line("from tvm.contrib.msc.core import utils as msc_utils")
.line();
}
void TVMPluginCodeGen::CodeGenManagerMethods() {
BasePluginCodeGen<TVMPluginCodeGenConfig>::CodeGenManagerMethods();
stack_.func_def("setup")
.func_arg("self", "object")
.func_start()
.for_start("lib", "os.listdir(self._lib_folder)")
.assign("lib_file", "os.path.join(self._lib_folder, lib)")
.func_call("CDLL", "", "ctypes")
.call_arg("lib_file")
.for_end()
.line("from tvm.contrib.msc.plugin.op import _ffi_api")
.assign(DocUtils::ToAttrAccess("self", "_ffi_api"), "_ffi_api")
.func_end();
}
void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) {
stack_.func_def(plugin->name).func_arg("self", "object");
for (const auto& t : plugin->inputs) {
stack_.func_arg(t->name, "relax.Expr");
}
for (const auto& attr : plugin->attrs) {
stack_.func_arg(attr->name, ToPyType(attr->type), attr->default_value);
}
stack_.func_arg("name", "str", "\"" + plugin->name + "\"").func_start();
ffi::Array<ffi::String> args;
for (const auto& t : plugin->inputs) {
args.push_back(t->name);
}
for (const auto& a : plugin->attrs) {
stack_.func_call("plugin_utils.to_expr", a->name).call_arg(a->name);
args.push_back(a->name);
}
stack_.func_call("relax.Tuple", "args")
.call_arg(DocUtils::ToList(args))
.func_call("InferStructInfo" + plugin->name, "out_sinfo", "self._ffi_api");
for (const auto& t : plugin->inputs) {
stack_.call_arg(t->name);
}
for (const auto& a : plugin->attrs) {
stack_.call_arg(a->name);
}
stack_.func_call("call_dps_packed", "op")
.call_arg(DocUtils::ToStr(plugin->name))
.call_arg("args", "args")
.call_arg("list(out_sinfo)", "out_sinfo")
.func_call("msc_utils.set_expr_name", "op")
.call_arg("op")
.call_arg("name");
stack_.func_end("op").comment(GetPyComment(plugin), true);
}
void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) {
if (plugin->externs.count(device + "_compute")) {
// compute with dtype
auto prepare_tensor = [this](const PluginTensor& tensor,
const ffi::Map<ffi::String, ffi::String>& dtypes, size_t idx,
const ffi::String& collect) {
const ffi::String& t_name = "d_" + tensor->name;
const ffi::String& t_dtype =
dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype;
const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">";
const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type;
stack_.func_call("TVMUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name))
.call_arg(tensor->name)
.call_arg(collect == "input");
return t_name;
};
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes);
ffi::Array<ffi::String> compute_args;
ffi::String dtype_cond = "";
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const auto& t_name = plugin->inputs[i]->name;
dtype_cond = dtype_cond + "TVMUtils::ToMetaType(" + t_name +
"->dtype) == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")";
dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
// prepare compute datas
stack_.cond_if(dtype_cond).comment("prepare compute datas");
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input");
compute_args.push_back(t_name);
}
for (size_t i = 0; i < plugin->outputs.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output");
compute_args.push_back(t_name);
}
ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm";
compute_args.push_back("meta_attr");
if (device == "cuda") {
// TODO(tvm-team): update to support get stream from device id
stack_.assign("stream", "TVMFFIEnvGetStream(kDLCUDA, 0)", "auto");
compute_args.push_back("stream");
}
CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args);
stack_.cond_end();
}
} else {
stack_.comment("Skip compute on " + device);
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("msc.plugin.GetTVMPluginSources",
[](const ffi::String& codegen_config, const ffi::String& print_config,
const ffi::String& codegen_type) -> ffi::Map<ffi::String, ffi::String> {
TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config);
if (codegen_type == "build") {
return codegen.GetBuildSources(print_config);
}
if (codegen_type == "manager") {
return codegen.GetManagerSources(print_config);
}
return ffi::Map<ffi::String, ffi::String>();
});
}
} // namespace msc
} // namespace contrib
} // namespace tvm