blob: d5a2b5353de41403eef12f045656f218f5501124 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/plugin/torch_codegen.cc
*/
#include "torch_codegen.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace contrib {
namespace msc {
void TorchPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) {
BasePluginCodeGen<TorchPluginCodeGenConfig>::CodeGenAttrDeclare(plugin);
const auto& attr_name = MetaAttrCls(plugin);
// serialize method for attr
stack_.comment("serialize method")
.func_def(attr_name + "_serialize", "std::vector<std::string>")
.func_arg("meta_attr", "const " + attr_name + "&");
// deserialize method for attr
stack_.comment("deserialize method")
.func_def(attr_name + "_deserialize")
.func_arg("attrs", "const std::vector<std::string>&")
.func_arg("meta_attr", attr_name + "&");
}
void TorchPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) {
const auto& attr_name = MetaAttrCls(plugin);
// serialize method for attr
stack_.func_def(attr_name + "_serialize", "std::vector<std::string>")
.func_arg("meta_attr", "const " + attr_name + "&")
.func_start()
.declare("std::vector<std::string>", "attrs");
for (const auto& a : plugin->attrs) {
stack_.func_call("push_back", "", "attrs")
.inplace_start("SerializeUtils::ToString")
.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name))
.inplace_end();
}
stack_.func_end("attrs");
// deserialize method for attr
stack_.func_def(attr_name + "_deserialize")
.func_arg("attrs", "const std::vector<std::string>&")
.func_arg("meta_attr", attr_name + "&")
.func_start();
for (size_t i = 0; i < plugin->attrs.size(); i++) {
stack_.func_call("SerializeUtils::FromString")
.call_arg(DocUtils::ToIndex("attrs", i))
.call_arg(DocUtils::ToAttrAccess("meta_attr", plugin->attrs[i]->name));
}
stack_.func_end();
}
void TorchPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) {
stack_.struct_start(plugin->name + " : torch::CustomClassHolder");
// constructor
stack_.constructor_def(plugin->name).constructor_arg("attrs", "const std::vector<std::string>&");
// serialize method
stack_.comment("serialize method").func_def("serialize", "const std::vector<std::string>");
// compute method
stack_.comment("main compute")
.func_def("compute", "std::vector<torch::Tensor>")
.func_arg("input_tensors", "const std::vector<torch::Tensor>&");
// members
stack_.comment("members")
.declare(MetaAttrCls(plugin), "meta_attr_")
.declare("std::vector<MetaLayout>", "layouts_")
.declare("std::string", "name_");
stack_.struct_end();
// entry method
stack_.comment("Entry method for plugin " + plugin->name)
.func_def(EntryName(plugin), "std::vector<torch::Tensor>")
.func_arg("instance", "const c10::intrusive_ptr<" + plugin->name + ">&");
for (const auto& input : plugin->inputs) {
stack_.func_arg(input->name, "const torch::Tensor&");
}
for (const auto& a : plugin->attrs) {
stack_.func_arg(a->name, "const " + ToTorchType(a->type) + "&");
}
stack_.func_arg("name", "const std::string&");
}
void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) {
const auto& attr_name = MetaAttrCls(plugin);
// define constructor
stack_.constructor_def(plugin->name + "::" + plugin->name)
.constructor_arg("attrs", "const std::vector<std::string>&")
.constructor_start()
.comment("get attributes")
.func_call(attr_name + "_deserialize")
.call_arg("attrs")
.call_arg("meta_attr_")
.comment("get extra info")
.assign("name_", DocUtils::ToIndex("attrs", plugin->attrs.size()))
.for_start("i", 1 + plugin->attrs.size(), 1 + plugin->attrs.size() + plugin->inputs.size())
.func_call("push_back", "", "layouts_")
.inplace_start("MetaLayout")
.call_arg(DocUtils::ToIndex("attrs", "i"))
.inplace_end()
.for_end()
.constructor_end();
// define serialize
stack_.func_def(plugin->name + "::serialize", "const std::vector<std::string>")
.func_start()
.assign("attrs", attr_name + "_serialize(meta_attr_)", "std::vector<std::string>")
.func_call("push_back", "", "attrs")
.call_arg("name_")
.for_start("i", 0, plugin->inputs.size())
.func_call("push_back", "", "attrs")
.call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("layouts_", "i"), "name()"))
.for_end()
.func_end("attrs");
// compute method
stack_.func_def(plugin->name + "::compute", "std::vector<torch::Tensor>")
.func_arg("input_tensors", "const std::vector<torch::Tensor>&")
.func_start()
.declare("std::vector<torch::Tensor>", "output_tensors");
if (plugin->externs.count("infer_buffer")) {
stack_.declare("std::vector<torch::Tensor>", "buffer_tensors");
}
stack_.line()
.comment("extract meta inputs")
.declare("std::vector<MetaTensor>", "input_metas")
.for_start("i", 0, plugin->inputs.size())
.func_call("push_back", "", "input_metas")
.inplace_start("TorchUtils::ToMetaTensor")
.call_arg(DocUtils::ToIndex("input_tensors", "i"))
.call_arg(DocUtils::ToIndex("layouts_", "i"))
.inplace_end()
.for_end();
// malloc outputs and buffers
ICHECK(plugin->externs.count("infer_output")) << "Can not find extern shape";
CodeGenMalloc(plugin, plugin->outputs, "output");
if (plugin->externs.count("infer_buffer")) {
CodeGenMalloc(plugin, plugin->buffers, "buffer");
}
// do the compute
ffi::String device_cond = "";
for (size_t i = 0; i < plugin->inputs.size(); i++) {
if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") {
device_cond = device_cond + "input_tensors[" + std::to_string(i) + "].is_cuda()";
} else {
device_cond = device_cond + "!input_tensors[" + std::to_string(i) + "].is_cuda()";
}
device_cond = device_cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
stack_.line().comment("do the compute").cond_if(device_cond);
CodeGenCompute(plugin, "cuda");
stack_.cond_else();
CodeGenCompute(plugin, "cpu");
stack_.cond_end();
stack_.func_end("output_tensors");
// register op
const auto& entry_name = EntryName(plugin);
stack_.func_def(entry_name, "std::vector<torch::Tensor>")
.func_arg("instance", "const c10::intrusive_ptr<" + plugin->name + ">&");
for (const auto& input : plugin->inputs) {
stack_.func_arg(input->name, "const torch::Tensor&");
}
for (const auto& a : plugin->attrs) {
stack_.func_arg(a->name, "const " + ToTorchType(a->type) + "&");
}
stack_.func_arg("name", "const std::string&");
stack_.func_start().declare("std::vector<torch::Tensor>", "inputs", 0, false);
for (const auto& input : plugin->inputs) {
stack_.declare_arg(input->name);
}
const auto& outputs_doc = DocUtils::ToDeclare("std::vector<torch::Tensor>", "outputs");
stack_.func_call("compute", outputs_doc, DocUtils::ToPtr("instance")).call_arg("inputs");
stack_.func_end("outputs");
stack_.comment("Bind plugin " + plugin->name + " to python")
.func_def("TORCH_LIBRARY", DocSymbol::Empty())
.func_arg(plugin->name, DocSymbol::Empty())
.func_arg("m", DocSymbol::Empty())
.func_start()
.lambda_def("serialize")
.lambda_arg("op", "const c10::intrusive_ptr<" + plugin->name + ">&")
.lambda_start()
.lambda_end(DocUtils::ToAttrAccess(DocUtils::ToPtr("op"), "serialize()"))
.lambda_def("deserialize")
.lambda_arg("state", "std::vector<std::string>")
.lambda_start()
.lambda_end("c10::make_intrusive<" + plugin->name + ">(std::move(state))")
.func_call("class_<" + plugin->name + ">", "", "m")
.call_arg(DocUtils::ToStr(plugin->name))
.method_call("def", true)
.call_arg("torch::init<const std::vector<std::string>>()")
.method_call("def", true)
.call_arg(DocUtils::ToStr("compute"))
.call_arg("&" + plugin->name + "::compute")
.method_call("def_pickle", true)
.call_arg("serialize")
.call_arg("deserialize")
.func_call("def", "", "m")
.call_arg(DocUtils::ToStr(entry_name))
.call_arg(entry_name)
.func_end();
}
void TorchPluginCodeGen::CodeGenCmake(const std::set<ffi::String>& devices) {
ffi::Map<ffi::String, ffi::String> flags;
flags.Set("PLUGIN_SUPPORT_TORCH", "");
CodeGenPreCmake(devices, flags);
stack_.line()
.line("set(CMAKE_CXX_STANDARD 17)")
.line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")")
.line("find_package(Torch REQUIRED)");
ffi::Array<ffi::String> includes, libs;
libs.push_back("${TORCH_LIBRARIES}");
CodeGenPostCmake(devices, includes, libs);
}
void TorchPluginCodeGen::CodeGenManagerDepends() {
BasePluginCodeGen<TorchPluginCodeGenConfig>::CodeGenManagerDepends();
stack_.line("import torch")
.line()
.func_def("to_string", "str")
.func_arg("value", "Any")
.func_start()
.switch_start("isinstance(value, (list, tuple))")
.assign("str_value", "\",\".join([str(len(value))] + [to_string(v) for v in value])")
.switch_case("isinstance(value, bool)")
.assign("str_value", "\"1\" if value else \"0\"")
.switch_case()
.assign("str_value", "str(value)")
.switch_end()
.func_end("str_value");
}
void TorchPluginCodeGen::CodeGenManagerMethods() {
BasePluginCodeGen<TorchPluginCodeGenConfig>::CodeGenManagerMethods();
// libs_loaded method
stack_.func_def("libs_loaded")
.func_arg("self", "object")
.func_start()
.assign("loaded_libs", "set()")
.assign("loaded", DocUtils::ToDoc(false))
.for_start("lib", "torch.classes.loaded_libraries")
.func_call("add", "", "loaded_libs")
.inplace_start("os.path.basename")
.call_arg("lib")
.inplace_end()
.for_end()
.for_start("lib", "os.listdir(self._lib_folder)")
.cond_if("lib in loaded_libs")
.assign("loaded", DocUtils::ToDoc(true))
.line("break")
.cond_end()
.for_end()
.func_end("loaded");
// setup method
stack_.func_def("setup")
.func_arg("self", "object")
.func_start()
.for_start("lib", "os.listdir(self._lib_folder)")
.assign("lib_file", "os.path.join(self._lib_folder, lib)")
.cond_if("\"" + config()->project_name + "\" in lib")
.func_call("load_library", "", "torch.classes")
.call_arg("lib_file")
.cond_else()
.func_call("CDLL", "", "ctypes")
.call_arg("lib_file")
.cond_end()
.for_end()
.func_end();
}
void TorchPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) {
const auto& entry_name = EntryName(plugin);
stack_.func_def(plugin->name).func_arg("self", "object");
for (const auto& attr : plugin->attrs) {
stack_.func_arg(attr->name, attr->type, attr->default_value);
}
stack_.func_arg("name", "str", "\"" + plugin->name + "\"")
.func_arg("layouts", "List[str]", "None")
.func_start()
.class_def(plugin->name + "(torch.nn.Module)")
.class_start();
// init method
stack_.func_def("__init__").func_arg("self", "torch.nn.Module");
for (const auto& attr : plugin->attrs) {
stack_.func_arg(attr->name, attr->type, attr->default_value);
}
stack_.func_arg("name", "str", "\"" + plugin->name + "\"")
.func_arg("layouts", "List[str]", "None")
.func_start()
.func_call("__init__", "", "super()");
for (const auto& attr : plugin->attrs) {
stack_.assign(DocUtils::ToAttrAccess("self", attr->name), attr->name);
}
stack_.assign(DocUtils::ToAttrAccess("self", "name"), "name")
.cond_if("layouts is None")
.assign(DocUtils::ToAttrAccess("self", "layouts"),
"[\"\"] * " + std::to_string(plugin->inputs.size()))
.cond_else()
.assign(DocUtils::ToAttrAccess("self", "layouts"), "layouts")
.cond_end()
.line()
.assign("attr_strs", "[]");
for (const auto& attr : plugin->attrs) {
stack_.func_call("append", "", "attr_strs")
.inplace_start("to_string")
.call_arg(attr->name)
.inplace_end();
}
stack_.func_call("append", "", "attr_strs")
.call_arg("name")
.func_call("extend", "", "attr_strs")
.call_arg(DocUtils::ToAttrAccess("self", "layouts"))
.line()
.func_call(plugin->name + "." + plugin->name, "self._inner_class", "torch.classes")
.call_arg("attr_strs")
.func_end();
// forward method
stack_.func_def("forward", "List[torch.Tensor]").func_arg("self", "torch.nn.Module");
for (const auto& t : plugin->inputs) {
stack_.func_arg(t->name, "torch.Tensor");
}
stack_.func_start()
.func_call(plugin->name + "." + entry_name, "outputs", "torch.ops")
.call_arg("self._inner_class");
for (const auto& t : plugin->inputs) {
stack_.call_arg(t->name);
}
for (const auto& a : plugin->attrs) {
stack_.call_arg(DocUtils::ToAttrAccess("self", a->name));
}
stack_.call_arg(DocUtils::ToAttrAccess("self", "name"));
if (plugin->outputs.size() == 1) {
stack_.func_end(DocUtils::ToIndex("outputs", 0));
} else {
stack_.func_end("outputs");
}
// end of inner class
stack_.class_end();
stack_.func_call(plugin->name, "op");
for (const auto& attr : plugin->attrs) {
stack_.call_arg(attr->name);
}
stack_.call_arg("name").call_arg("layouts").func_end("op").comment(GetPyComment(plugin), true);
}
void TorchPluginCodeGen::CodeGenConvertDepends() {
BasePluginCodeGen<TorchPluginCodeGenConfig>::CodeGenConvertDepends();
stack_.line("from torch import fx")
.line("from tvm.relax.frontend.torch.fx_translator import TorchFXImporter")
.line();
}
const ffi::String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) {
stack_.func_def(ConverterName(plugin), "relax.Var")
.func_arg("node", "fx.node.Node")
.func_arg("ctx", "TorchFXImporter")
.func_start()
.func_call("retrieve_args", "args", "ctx")
.call_arg("node");
ffi::Array<ffi::String> args;
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const auto& tensor = plugin->inputs[i];
stack_.assign(tensor->name, DocUtils::ToIndex("args", i + 1));
args.push_back(tensor->name);
}
for (size_t i = 0; i < plugin->attrs.size(); i++) {
const auto& attr = plugin->attrs[i];
stack_.func_call("plugin_utils.to_expr", attr->name)
.call_arg(DocUtils::ToIndex("args", i + plugin->inputs.size() + 1));
args.push_back(attr->name);
}
stack_.assign("name",
DocUtils::ToIndex("args", 1 + plugin->inputs.size() + plugin->attrs.size()));
stack_.func_call("relax.Tuple", "args")
.call_arg(DocUtils::ToList(args))
.func_call("InferStructInfo" + plugin->name, "out_sinfo", "_plugin_api");
for (const auto& t : plugin->inputs) {
stack_.call_arg(t->name);
}
for (const auto& a : plugin->attrs) {
stack_.call_arg(a->name);
}
stack_.func_call("call_dps_packed", "op")
.call_arg(DocUtils::ToStr(plugin->name))
.call_arg("args", "args")
.call_arg("list(out_sinfo)", "out_sinfo")
.func_call("msc_utils.set_expr_name", "op")
.call_arg("op")
.call_arg("name")
.func_call("emit", "var", "ctx.block_builder")
.call_arg("op")
.call_arg("name");
if (plugin->outputs.size() == 1) {
stack_.func_end(DocUtils::ToList(ffi::Array<ffi::String>{"var"}));
} else {
ffi::Array<ffi::String> outputs;
for (size_t i = 0; i < plugin->outputs.size(); i++) {
const auto& tensor = plugin->outputs[i];
stack_.func_call("relax.TupleGetItem", tensor->name).call_arg("var").call_arg(i);
outputs.push_back(tensor->name);
}
stack_.func_end(DocUtils::ToList(outputs));
}
return EntryName(plugin);
}
void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin,
const ffi::Array<PluginTensor>& tensors,
const ffi::String& collect) {
ffi::Array<ffi::String> call_args{"input_metas", "meta_attr_", "true"};
stack_.line().comment("malloc " + collect).declare("std::vector<MetaTensor>", collect + "_metas");
CodeGenSafeCall(plugin->externs["infer_" + collect], call_args, collect + "_metas");
for (size_t i = 0; i < tensors.size(); i++) {
stack_.func_call("push_back", "", collect + "_tensors")
.inplace_start("TorchUtils::MallocTorchTensor")
.call_arg(DocUtils::ToIndex(collect + "_metas", i));
int device_idx = plugin->FindDeviceRefIdx(tensors[i]);
if (device_idx >= 0) {
const auto& input_doc = DocUtils::ToIndex("input_tensors", device_idx);
stack_.inplace_start("device", std::nullopt, input_doc).inplace_end();
} else {
stack_.inplace_start("TorchUtils::ToTorchDevice")
.call_arg(DocUtils::ToStr(tensors[i]->device))
.inplace_end();
}
stack_.inplace_end();
}
}
void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) {
auto prepare_tensor = [this](const PluginTensor& tensor,
const ffi::Map<ffi::String, ffi::String>& dtypes, size_t idx,
const ffi::String& collect) {
const ffi::String& t_name = "d_" + tensor->name;
const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype;
const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">";
const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type;
stack_.func_call("TorchUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name))
.call_arg(DocUtils::ToIndex(collect + "_tensors", idx))
.call_arg(DocUtils::ToIndex(collect + "_metas", idx))
.call_arg(collect == "input");
return t_name;
};
if (plugin->externs.count(device + "_compute")) {
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes);
ffi::Array<ffi::String> compute_args;
ffi::String dtype_cond = "";
for (size_t i = 0; i < plugin->inputs.size(); i++) {
dtype_cond = dtype_cond + "input_metas[" + std::to_string(i) +
"].data_type() == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")";
dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
// prepare compute datas
stack_.cond_if(dtype_cond).comment("prepare compute datas");
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input");
compute_args.push_back(t_name);
}
for (size_t i = 0; i < plugin->outputs.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output");
compute_args.push_back(t_name);
}
for (size_t i = 0; i < plugin->buffers.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer");
compute_args.push_back(t_name);
}
compute_args.push_back("meta_attr_");
if (device == "cuda") {
stack_.func_call("at::cuda::getCurrentCUDAStream",
DocUtils::ToDeclare("cudaStream_t", "stream"));
compute_args.push_back("stream");
}
CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args);
stack_.cond_end();
}
} else {
stack_.comment("Skip compute on " + device);
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("msc.plugin.GetTorchPluginSources",
[](const ffi::String& codegen_config, const ffi::String& print_config,
const ffi::String& codegen_type) -> ffi::Map<ffi::String, ffi::String> {
TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config);
if (codegen_type == "build") {
return codegen.GetBuildSources(print_config);
}
if (codegen_type == "manager") {
return codegen.GetManagerSources(print_config);
}
return ffi::Map<ffi::String, ffi::String>();
});
}
} // namespace msc
} // namespace contrib
} // namespace tvm