blob: 890b9a6df7b3c6038508d1edea09ec9bce7c6f6a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/plugin/tensorrt_codegen.cc
*/
#include "tensorrt_codegen.h"
#include <tvm/ffi/reflection/registry.h>
#include <set>
namespace tvm {
namespace contrib {
namespace msc {
void TensorRTPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) {
BasePluginCodeGen<TensorRTPluginCodeGenConfig>::CodeGenAttrDeclare(plugin);
const auto& attr_name = MetaAttrCls(plugin);
// serialize size for attr
stack_.comment("serialize size").func_def(attr_name + "_serialize_size", "size_t");
// serialize method for attr
stack_.comment("serialize method")
.func_def(attr_name + "_serialize", "char*")
.func_arg("meta_attr", "const " + attr_name + "&")
.func_arg("buffer", "char*");
// deserialize method for attr
stack_.comment("deserialize method")
.func_def(attr_name + "_deserialize", "const char*")
.func_arg("meta_attr", attr_name + "&")
.func_arg("buffer", "const char*");
// attr to field
stack_.comment("meta attr to field")
.func_def(attr_name + "_to_fields")
.func_arg("fields", "std::vector<PluginField>&");
// attr from field
stack_.comment("meta attr from field")
.func_def(attr_name + "_from_fields", "const " + attr_name)
.func_arg("fields", "const PluginField*");
}
void TensorRTPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) {
const auto& attr_name = MetaAttrCls(plugin);
// serialize size for attr
stack_.func_def(attr_name + "_serialize_size", "size_t").func_start().assign("size", 0, "size_t");
for (const auto& a : plugin->attrs) {
stack_.comment("attr " + a->name + "(" + a->type + ")");
if (IsListType(a->type)) {
LOG_FATAL << "attribute type " << a->type << " is not supported";
const auto& ele_type = GetEleType(a->type);
stack_.assign("size", "size + sizeof(size_t)")
.for_start("a", DocUtils::ToAttrAccess("meta_attr", a->name))
.assign("size", "size + sizeof(" + ToCppType(ele_type) + ")")
.for_end();
} else {
stack_.assign("size", "size + sizeof(" + ToCppType(a->type) + ")");
}
}
stack_.func_end("size");
// serialize method for attr
stack_.func_def(attr_name + "_serialize", "char*")
.func_arg("meta_attr", "const " + attr_name + "&")
.func_arg("buffer", "char*")
.func_start()
.assign("start", "buffer", "const char*");
for (const auto& a : plugin->attrs) {
stack_.func_call("TRTUtils::ValToBuffer")
.call_arg("buffer")
.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name));
}
stack_.func_call(attr_name + "_serialize_size", DocUtils::ToDeclare("size_t", "expected"))
.line("assert(buffer == start + expected);")
.func_end("buffer");
// deserialize method for attr
stack_.func_def(attr_name + "_deserialize", "const char*")
.func_arg("meta_attr", attr_name + "&")
.func_arg("buffer", "const char*")
.func_start()
.assign("start", "buffer", "const char*");
for (const auto& a : plugin->attrs) {
stack_.func_call("TRTUtils::ValFromBuffer")
.call_arg("buffer")
.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name));
}
stack_.func_call(attr_name + "_serialize_size", DocUtils::ToDeclare("size_t", "expected"))
.line("assert(buffer == start + expected);")
.func_end("buffer");
// attr to field
stack_.func_def(attr_name + "_to_fields")
.func_arg("fields", "std::vector<PluginField>&")
.func_start();
for (const auto& a : plugin->attrs) {
stack_.func_call("emplace_back", "", "fields")
.inplace_start("TRTUtils::ToField")
.call_arg(DocUtils::ToStr(a->name))
.call_arg(DocUtils::ToStr(a->type))
.inplace_end();
}
stack_.func_end();
// attr from field
stack_.func_def(attr_name + "_from_fields", "const " + attr_name)
.func_arg("fields", "const PluginField*")
.func_start()
.declare(attr_name, "meta_attr")
.for_start("i", 0, plugin->attrs.size());
for (size_t i = 0; i < plugin->attrs.size(); i++) {
const auto& attr = plugin->attrs[i];
const ffi::String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0";
if (i == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.func_call("TRTUtils::FromField")
.call_arg(DocUtils::ToIndex("fields", "i"))
.call_arg(DocUtils::ToAttrAccess("meta_attr", attr->name));
}
stack_.switch_end().for_end().func_end("meta_attr");
}
void TensorRTPluginCodeGen::CodeGenOpHeader(const Plugin& plugin) {
BasePluginCodeGen<TensorRTPluginCodeGenConfig>::CodeGenOpHeader(plugin);
stack_.line("using namespace nvinfer1;").line();
}
void TensorRTPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) {
if (!IsMixPrecision(plugin)) {
// static plugin op
const auto& op_static = OpCls(plugin, false);
stack_.class_def(op_static + " : public IPluginV2").class_start().scope_start("public:");
CodegenOpCommonMethods(plugin, false, true);
stack_.comment("special methods for " + op_static)
.func_def("getOutputDimensions", "Dims")
.func_decorator("noexcept override")
.func_arg("index", "int")
.func_arg("in_dims", "const Dims*")
.func_arg("n_inputs", "int")
.func_def("configureWithFormat")
.func_decorator("noexcept override")
.func_arg("in_dims", "const Dims*")
.func_arg("n_inputs", "int")
.func_arg("out_dims", "const Dims*")
.func_arg("n_outputs", "int")
.func_arg("dtype", "DataType")
.func_arg("format", "PluginFormat")
.func_arg("max_batch", "int")
.func_def("supportsFormat", "bool")
.func_decorator("const noexcept override")
.func_arg("dtype", "DataType")
.func_arg("format", "PluginFormat")
.func_def("getWorkspaceSize", "size_t")
.func_decorator("const noexcept override")
.func_arg("max_batch", "int")
.func_def("enqueue", "int")
.func_decorator("noexcept override")
.func_arg("batch_size", "int")
.func_arg("inputs", "const void* const*")
.func_arg("outputs", "void* const*")
.func_arg("workspace", "void*")
.func_arg("stream", "cudaStream_t")
.scope_end();
CodegenOpMembers(plugin, false);
stack_.class_end();
// static plugin creator
CodegenCreator(plugin, false, true);
}
// dynamic plugin op
const auto& op_dynamic = OpCls(plugin, true);
stack_.class_def(op_dynamic + " : public IPluginV2DynamicExt")
.class_start()
.scope_start("public:");
CodegenOpCommonMethods(plugin, true, true);
stack_.comment("special methods for " + op_dynamic)
.func_def("getOutputDataType", "DataType")
.func_decorator("const noexcept override")
.func_arg("index", "int")
.func_arg("in_types", "const DataType*")
.func_arg("n_inputs", "int")
.func_def("getOutputDimensions", "DimsExprs")
.func_decorator("noexcept override")
.func_arg("index", "int")
.func_arg("in_dims", "const DimsExprs*")
.func_arg("n_inputs", "int")
.func_arg("builder", "IExprBuilder&")
.func_def("configurePlugin")
.func_decorator("noexcept override")
.func_arg("in_descs", "const DynamicPluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("out_descs", "const DynamicPluginTensorDesc*")
.func_arg("n_outputs", "int")
.func_def("supportsFormatCombination", "bool")
.func_decorator("noexcept override")
.func_arg("pos", "int")
.func_arg("io_desc", "const PluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("n_outputs", "int")
.func_def("getWorkspaceSize", "size_t")
.func_decorator("const noexcept override")
.func_arg("in_descs", "const PluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("out_descs", "const PluginTensorDesc*")
.func_arg("n_outputs", "int")
.func_def("enqueue", "int")
.func_decorator("noexcept override")
.func_arg("input_descs", "const PluginTensorDesc*")
.func_arg("output_descs", "const PluginTensorDesc*")
.func_arg("inputs", "const void* const*")
.func_arg("outputs", "void* const*")
.func_arg("workspace", "void*")
.func_arg("stream", "cudaStream_t")
.scope_end();
CodegenOpMembers(plugin, true);
stack_.class_end();
// dynamic plugin creator
CodegenCreator(plugin, true, true);
}
void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) {
if (!IsMixPrecision(plugin)) {
// static op
const auto& op_static = OpCls(plugin, false);
CodegenOpCommonMethods(plugin, false, false);
// getOutputDimensions
stack_.func_def(op_static + "::getOutputDimensions", "Dims")
.func_decorator("noexcept")
.func_arg("index", "int")
.func_arg("in_dims", "const Dims*")
.func_arg("n_inputs", "int")
.func_start();
CodegenOutputInfer(plugin, false);
stack_
.func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"),
DocUtils::ToIndex("output_metas_", "index"))
.func_call("TRTUtils::ToDims", DocUtils::ToDeclare("Dims", "out_dims"))
.call_arg("out_shape")
.func_end("out_dims");
// configureWithFormat
stack_.func_def(op_static + "::configureWithFormat")
.func_decorator("noexcept")
.func_arg("in_dims", "const Dims*")
.func_arg("n_inputs", "int")
.func_arg("out_dims", "const Dims*")
.func_arg("n_outputs", "int")
.func_arg("dtype", "DataType")
.func_arg("format", "PluginFormat")
.func_arg("max_batch", "int")
.func_start()
.assign("dtype_", "dtype")
.line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");");
CodegenOutputInfer(plugin, false);
stack_.func_end();
// supportsFormat
stack_.func_def(op_static + "::supportsFormat", "bool")
.func_decorator("const noexcept")
.func_arg("dtype", "DataType")
.func_arg("format", "PluginFormat")
.func_start()
.declare("bool", "support");
size_t cnt = 0;
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
const ffi::String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")";
if (cnt == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.assign("support", true);
cnt++;
}
stack_.switch_case().assign("support", false).switch_end().func_end("support");
// getWorkspaceSize
stack_.func_def(op_static + "::getWorkspaceSize", "size_t")
.func_decorator("const noexcept")
.func_arg("max_batch", "int")
.func_start()
.assign("size", 0, "size_t");
if (plugin->externs.count("infer_buffer")) {
CodegenBufferInfer(plugin);
}
stack_.func_end("size");
// enqueue
stack_.func_def(op_static + "::enqueue", "int")
.func_decorator("noexcept")
.func_arg("batch_size", "int")
.func_arg("inputs", "const void* const*")
.func_arg("outputs", "void* const*")
.func_arg("workspace", "void*")
.func_arg("stream", "cudaStream_t")
.func_start();
CodegenEnqueue(plugin, false);
stack_.func_end(0);
// static creator
CodegenCreator(plugin, false, false);
}
// dynamic op
const auto& op_dynamic = OpCls(plugin, true);
CodegenOpCommonMethods(plugin, true, false);
// getOutputDataType
stack_.func_def(op_dynamic + "::getOutputDataType", "DataType")
.func_decorator("const noexcept")
.func_arg("index", "int")
.func_arg("in_types", "const DataType*")
.func_arg("n_inputs", "int")
.func_start()
.declare("DataType", "dtype");
for (size_t i = 0; i < plugin->outputs.size(); i++) {
if (i == 0) {
stack_.switch_start("index == " + std::to_string(i));
} else {
stack_.switch_case("index == " + std::to_string(i));
}
int ref = plugin->FindDtypeRefIdx(plugin->outputs[i]);
if (ref >= 0) {
stack_.assign("dtype", DocUtils::ToIndex("in_types", ref));
} else {
stack_.func_call("TRTUtils::ToDataType", "dtype")
.call_arg(DocUtils::ToStr(plugin->outputs[i]->dtype));
}
}
stack_.switch_end().func_end("dtype");
// getOutputDimensions
stack_.func_def(op_dynamic + "::getOutputDimensions", "DimsExprs")
.func_decorator("noexcept")
.func_arg("index", "int")
.func_arg("in_dims", "const DimsExprs*")
.func_arg("n_inputs", "int")
.func_arg("builder", "IExprBuilder&")
.func_start();
CodegenOutputInfer(plugin, false);
stack_
.func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"),
DocUtils::ToIndex("output_metas_", "index"))
.func_call("TRTUtils::ToDimsExprs", DocUtils::ToDeclare("DimsExprs", "out_dims"))
.call_arg("out_shape")
.call_arg("builder")
.func_end("out_dims");
// configurePlugin
stack_.func_def(op_dynamic + "::configurePlugin")
.func_decorator("noexcept")
.func_arg("in_descs", "const DynamicPluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("out_descs", "const DynamicPluginTensorDesc*")
.func_arg("n_outputs", "int")
.func_start()
.line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");");
CodegenOutputInfer(plugin, true);
stack_.func_end();
// supportsFormatCombination
stack_.func_def(op_dynamic + "::supportsFormatCombination", "bool")
.func_decorator("noexcept")
.func_arg("pos", "int")
.func_arg("io_desc", "const PluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("n_outputs", "int")
.func_start()
.declare("bool", "support");
size_t cnt = 0;
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
ffi::String cond;
for (size_t i = 0; i < plugin->inputs.size(); i++) {
cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" +
dtypes.at(i) + "\")";
cond = cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
if (cnt == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.assign("support", true);
cnt++;
}
stack_.switch_case().assign("support", false).switch_end().func_end("support");
// getWorkspaceSize
stack_.func_def(op_dynamic + "::getWorkspaceSize", "size_t")
.func_decorator("const noexcept")
.func_arg("in_descs", "const PluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("out_descs", "const PluginTensorDesc*")
.func_arg("n_outputs", "int")
.func_start()
.assign("size", 0, "size_t");
if (plugin->externs.count("infer_buffer")) {
CodegenBufferInfer(plugin);
}
stack_.func_end("size");
// enqueue
stack_.func_def(op_dynamic + "::enqueue", "int")
.func_decorator("noexcept")
.func_arg("input_descs", "const PluginTensorDesc*")
.func_arg("output_descs", "const PluginTensorDesc*")
.func_arg("inputs", "const void* const*")
.func_arg("outputs", "void* const*")
.func_arg("workspace", "void*")
.func_arg("stream", "cudaStream_t")
.func_start();
CodegenEnqueue(plugin, true);
stack_.func_end(0);
// dynamic creator
CodegenCreator(plugin, true, false);
}
void TensorRTPluginCodeGen::CodeGenCmake(const std::set<ffi::String>& devices) {
ffi::Map<ffi::String, ffi::String> flags;
flags.Set("PLUGIN_SUPPORT_TENSORRT", "");
flags.Set("TRT_MAJOR", std::to_string(config()->version[0]));
flags.Set("TRT_MINOR", std::to_string(config()->version[1]));
flags.Set("TRT_PATCH", std::to_string(config()->version[2]));
CodeGenPreCmake(devices, flags);
stack_
.line("find_path(TRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root +
" PATH_SUFFIXES include)")
.line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root +
" PATH_SUFFIXES lib)")
.line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-terminate\")");
ffi::Array<ffi::String> includes, libs;
includes.push_back("${TRT_INCLUDE_DIR}");
libs.push_back("${TRT_LIBS}");
CodeGenPostCmake(devices, includes, libs);
}
void TensorRTPluginCodeGen::CodeGenManagerMethods() {
BasePluginCodeGen<TensorRTPluginCodeGenConfig>::CodeGenManagerMethods();
stack_.func_def("setup")
.func_arg("self", "object")
.func_start()
.for_start("lib", "os.listdir(self._lib_folder)")
.assign("lib_file", "os.path.join(self._lib_folder, lib)")
.func_call("CDLL", "", "ctypes")
.call_arg("lib_file")
.for_end()
.func_end();
}
void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic,
bool in_declare) {
const auto& op_cls = OpCls(plugin, dynamic);
const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2";
if (in_declare) {
stack_.comment("common methods for " + op_cls);
stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&");
for (const auto& a : plugin->attrs) {
stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&");
}
stack_.constructor_arg("layouts", "const std::vector<std::string>&")
.constructor_def(op_cls)
.constructor_arg("name", "const std::string&")
.constructor_arg("buffer", "const void*")
.constructor_arg("length", "size_t")
.assign(op_cls + "()", "delete")
.line()
.constructor_def("~" + op_cls)
.func_def("getSerializationSize", "size_t")
.func_decorator("const noexcept override")
.func_def("serialize")
.func_decorator("const noexcept override")
.func_arg("buffer", "void*")
.func_def("getPluginType", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginVersion", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginNamespace", "const char*")
.func_decorator("const noexcept override")
.func_def("getNbOutputs", "int")
.func_decorator("const noexcept override")
.func_def("setPluginNamespace")
.func_decorator("noexcept override")
.func_arg("name_space", "const char*")
.func_def("initialize", "int")
.func_decorator("noexcept override")
.func_def("terminate")
.func_decorator("noexcept override")
.func_def("destroy")
.func_decorator("noexcept override")
.func_def("clone", plugin_cls + "*")
.func_decorator("const noexcept override");
} else {
const auto& attr_name = MetaAttrCls(plugin);
// constructor from attrs
stack_.constructor_def(op_cls + "::" + op_cls).constructor_arg("name", "const std::string&");
for (const auto& a : plugin->attrs) {
stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&");
}
stack_.constructor_arg("layouts", "const std::vector<std::string>&")
.constructor_start()
.assign("name_", "name");
for (const auto& a : plugin->attrs) {
stack_.assign(DocUtils::ToAttrAccess("meta_attr_", a->name), a->name);
}
stack_.line("assert(layouts.size() == " + std::to_string(plugin->inputs.size()) + ");")
.assign("layouts_", "layouts");
stack_.constructor_end();
// constructor from data
stack_.constructor_def(op_cls + "::" + op_cls)
.constructor_arg("name", "const std::string&")
.constructor_arg("buffer", "const void*")
.constructor_arg("length", "size_t")
.constructor_start()
.assign("name_", "name")
.func_call("static_cast<const char*>", DocUtils::ToDeclare("const char*", "char_buf"))
.call_arg("buffer")
.assign("start_buf", "char_buf", "const char*")
.func_call(attr_name + "_deserialize", "char_buf")
.call_arg("meta_attr_")
.call_arg("char_buf")
.func_call("TRTUtils::ValFromBuffer")
.call_arg("char_buf")
.call_arg("dtype_")
.func_call("TRTUtils::ValFromBuffer")
.call_arg("char_buf")
.call_arg("layouts_")
.line("assert(layouts_.size() == " + std::to_string(plugin->inputs.size()) + ");")
.line("assert(char_buf == (start_buf + length));")
.constructor_end();
// deconstructor
stack_.constructor_def(op_cls + "::~" + op_cls)
.constructor_start()
.comment("ignore deconstruct of " + op_cls)
.constructor_end();
// getSerializationSize
stack_.func_def(op_cls + "::getSerializationSize", "size_t")
.func_decorator("const noexcept")
.func_start()
.assign("size", attr_name + "_serialize_size()", "size_t")
.assign("size", "size + sizeof(dtype_)")
.assign("size", "size + sizeof(size_t)")
.for_start("layout", "layouts_")
.assign("size", "size + sizeof(size_t) + layout.size() * sizeof(char)")
.for_end()
.func_end("size");
// serialize
stack_.func_def(op_cls + "::serialize")
.func_decorator("const noexcept")
.func_arg("buffer", "void*")
.func_start()
.func_call("static_cast<char*>", DocUtils::ToDeclare("char*", "char_buf"))
.call_arg("buffer")
.assign("start_buf", "char_buf", "const char*")
.func_call(attr_name + "_serialize", "char_buf")
.call_arg("meta_attr_")
.call_arg("char_buf")
.func_call("TRTUtils::ValToBuffer")
.call_arg("char_buf")
.call_arg("dtype_")
.func_call("TRTUtils::ValToBuffer")
.call_arg("char_buf")
.call_arg("layouts_")
.line("assert(char_buf == (start_buf + getSerializationSize()));")
.func_end();
// getPluginType
const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : "");
stack_.func_def(op_cls + "::getPluginType", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr(plugin_type));
// getPluginVersion
stack_.func_def(op_cls + "::getPluginVersion", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr("1"));
// getPluginNamespace
stack_.func_def(op_cls + "::getPluginNamespace", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_call("c_str", DocUtils::ToDeclare("const char*", "name"),
DocUtils::ToDoc("name_space_"))
.func_end("name");
// getNbOutputs
stack_.func_def(op_cls + "::getNbOutputs", "int")
.func_decorator("const noexcept")
.func_start()
.func_end(plugin->outputs.size());
// setPluginNamespace
stack_.func_def(op_cls + "::setPluginNamespace")
.func_decorator("noexcept")
.func_arg("name_space", "const char*")
.func_start()
.assign("name_space_", "name_space")
.func_end();
// initialize
stack_.func_def(op_cls + "::initialize", "int")
.func_decorator("noexcept")
.func_start()
.func_end(0);
// terminate
stack_.func_def(op_cls + "::terminate")
.func_decorator("noexcept")
.func_start()
.comment("Ignore teminate for " + plugin->name)
.func_end();
// destroy
stack_.func_def(op_cls + "::destroy")
.func_decorator("noexcept")
.func_start()
.line("delete this;")
.func_end();
// clone
stack_.func_def(op_cls + "::clone", plugin_cls + "*")
.func_decorator("const noexcept")
.func_start()
.func_call("new " + op_cls, DocUtils::ToDeclare(plugin_cls + "*", "plugin"))
.call_arg("name_");
for (const auto& a : plugin->attrs) {
stack_.call_arg(DocUtils::ToAttrAccess("meta_attr_", a->name));
}
stack_.call_arg("layouts_").func_end("plugin");
}
}
void TensorRTPluginCodeGen::CodegenOpMembers(const Plugin& plugin, bool dynamic) {
stack_.scope_start("private:")
.declare("std::string", "name_")
.declare("std::string", "name_space_")
.declare("DataType", "dtype_", 0, false)
.declare_arg("DataType::kFLOAT")
.declare(MetaAttrCls(plugin), "meta_attr_")
.declare("std::vector<std::string>", "layouts_")
.declare("std::vector<MetaTensor>", "input_metas_")
.declare("std::vector<MetaTensor>", "output_metas_");
if (plugin->externs.count("infer_buffer")) {
stack_.declare("std::vector<MetaTensor>", "buffer_metas_");
}
stack_.scope_end().line();
}
void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) {
const auto& creator_cls = CreatorCls(plugin, dynamic);
const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2";
if (in_declare) {
stack_.class_def(creator_cls + " : public IPluginCreator")
.class_start()
.scope_start("public:")
.constructor_def(creator_cls)
.func_def("getPluginName", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginVersion", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginNamespace", "const char*")
.func_decorator("const noexcept override")
.func_def("getFieldNames", "const PluginFieldCollection*")
.func_decorator("noexcept override")
.func_def("setPluginNamespace")
.func_decorator("noexcept override")
.func_arg("name_space", "const char*")
.func_def("createPlugin", plugin_cls + "*")
.func_decorator("noexcept override")
.func_arg("name", "const char*")
.func_arg("collection", "const PluginFieldCollection*")
.func_def("deserializePlugin", plugin_cls + "*")
.func_decorator("noexcept override")
.func_arg("name", "const char*")
.func_arg("data", "const void*")
.func_arg("length", "size_t")
.scope_end()
.scope_start("private:")
.declare("static PluginFieldCollection", "collection_")
.declare("static std::vector<PluginField>", "fields_")
.declare("std::string", "name_space_")
.scope_end()
.line()
.class_end();
} else {
const ffi::String& attr_name = MetaAttrCls(plugin);
// static members
stack_.comment("static members and register for " + plugin->name)
.declare("PluginFieldCollection", creator_cls + "::collection_")
.declare("std::vector<PluginField>", creator_cls + "::fields_")
.func_call("REGISTER_TENSORRT_PLUGIN")
.call_arg(creator_cls)
.line();
// constructor
stack_.constructor_def(creator_cls + "::" + creator_cls)
.constructor_start()
.func_call(attr_name + "_to_fields")
.call_arg("fields_");
for (const auto& t : plugin->inputs) {
stack_.func_call("emplace_back", "", "fields_")
.inplace_start("TRTUtils::ToField")
.call_arg(DocUtils::ToStr("layout_" + t->name))
.call_arg(DocUtils::ToStr("string"))
.inplace_end();
}
const auto& nb_fields_doc = DocUtils::ToAttrAccess("collection_", "nbFields");
const auto& fields_doc = DocUtils::ToAttrAccess("collection_", "fields");
stack_.func_call("size", nb_fields_doc, DocUtils::ToDoc("fields_"))
.func_call("data", fields_doc, DocUtils::ToDoc("fields_"))
.constructor_end();
// getPluginName
const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : "");
stack_.func_def(creator_cls + "::getPluginName", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr(plugin_type));
// getPluginVersion
stack_.func_def(creator_cls + "::getPluginVersion", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr("1"));
// getPluginNamespace
stack_.func_def(creator_cls + "::getPluginNamespace", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_call("c_str", DocUtils::ToDeclare("const char*", "name"),
DocUtils::ToDoc("name_space_"))
.func_end("name");
// getFieldNames
stack_.func_def(creator_cls + "::getFieldNames", "const PluginFieldCollection*")
.func_decorator("noexcept")
.func_start()
.func_end("&collection_");
// setPluginNamespace
stack_.func_def(creator_cls + "::setPluginNamespace")
.func_decorator("noexcept")
.func_arg("name_space", "const char*")
.func_start()
.assign("name_space_", "name_space")
.func_end();
// createPlugin
size_t fields_size = plugin->attrs.size() + plugin->inputs.size();
const auto& op_cls = OpCls(plugin, dynamic);
stack_.func_def(creator_cls + "::createPlugin", plugin_cls + "*")
.func_decorator("noexcept")
.func_arg("name", "const char*")
.func_arg("collection", "const PluginFieldCollection*")
.func_start()
.line("assert(collection->nbFields == " + std::to_string(fields_size) + ");")
.assign("fields", DocUtils::ToAttrAccess(DocUtils::ToPtr("collection"), "fields"),
"const PluginField*")
.func_call(attr_name + "_from_fields", DocUtils::ToDeclare("const auto&", "meta_attr"))
.call_arg("fields")
.declare("std::vector<std::string>", "layouts")
.func_call("resize", "", "layouts")
.call_arg(plugin->inputs.size())
.for_start("i", plugin->attrs.size(), fields_size);
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const auto& tensor = plugin->inputs[i];
const ffi::String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0";
if (i == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.func_call("TRTUtils::FromField")
.call_arg(DocUtils::ToIndex("fields", "i"))
.call_arg(DocUtils::ToIndex("layouts", i));
}
stack_.switch_end()
.for_end()
.func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin"))
.call_arg("name");
for (const auto& a : plugin->attrs) {
stack_.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name));
}
stack_.call_arg("layouts")
.func_call("setPluginNamespace", std::nullopt, DocUtils::ToPtr("plugin"))
.inplace_start("c_str", std::nullopt, DocUtils::ToDoc("name_space_"))
.inplace_end()
.func_end("plugin");
// deserializePlugin
stack_.func_def(creator_cls + "::deserializePlugin", plugin_cls + "*")
.func_decorator("noexcept")
.func_arg("name", "const char*")
.func_arg("data", "const void*")
.func_arg("length", "size_t")
.func_start()
.func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin"))
.call_arg("name")
.call_arg("data")
.call_arg("length")
.func_call("setPluginNamespace", std::nullopt, DocUtils::ToPtr("plugin"))
.inplace_start("c_str", std::nullopt, DocUtils::ToDoc("name_space_"))
.inplace_end()
.func_end("plugin");
}
}
void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_desc) {
ffi::Array<ffi::String> infer_args{"input_metas_", "meta_attr_", "false"};
stack_.line("assert(n_inputs == " + std::to_string(plugin->inputs.size()) + ");")
.func_call("resize", "", "input_metas_")
.call_arg(plugin->inputs.size())
.for_start("i", 0, plugin->inputs.size())
.func_call("TRTUtils::ToMetaTensor", DocUtils::ToIndex("input_metas_", "i"));
if (as_desc) {
stack_.call_arg(DocUtils::ToIndex("in_descs", "i"));
} else {
stack_.call_arg(DocUtils::ToIndex("in_dims", "i")).call_arg("dtype_");
}
stack_.call_arg(DocUtils::ToIndex("layouts_", "i")).for_end();
CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas_");
}
void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) {
ffi::Array<ffi::String> infer_args{"input_metas_", "meta_attr_", "false"};
CodeGenSafeCall(plugin->externs["infer_buffer"], infer_args, "buffer_metas_");
stack_.for_start("b", "buffer_metas_")
.assign("size", "size + max_batch * b.size(false)")
.for_end();
}
void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) {
ICHECK(plugin->externs.count("cuda_compute")) << "cuda_compute is needed fo TensorRT plugin";
auto prepare_tensor = [this, &dynamic](const PluginTensor& tensor,
const ffi::Map<ffi::String, ffi::String>& dtypes,
size_t idx, const ffi::String& collect) {
const ffi::String& t_name = "d_" + tensor->name;
const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype;
const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">";
const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type;
stack_.func_call("TRTUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name));
const auto& t_meta = DocUtils::ToIndex(collect + "_metas_", idx);
if (dynamic) {
stack_.call_arg(t_meta).call_arg(DocUtils::ToIndex(collect + "_descs", idx));
} else {
stack_.call_arg(t_meta).call_arg("batch_size");
}
if (collect == "input") {
stack_.call_arg(DocUtils::ToIndex("inputs", idx));
} else if (collect == "output") {
stack_.call_arg(DocUtils::ToIndex("outputs", idx));
} else {
stack_.call_arg("workspace + offset");
}
return t_name;
};
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes);
ffi::Array<ffi::String> compute_args;
ffi::String dtype_cond = "";
if (dynamic) {
for (size_t i = 0; i < plugin->inputs.size(); i++) {
dtype_cond = dtype_cond + "input_descs[" + std::to_string(i) +
"].type == TRTUtils::ToDataType(\"" + dtypes.at(i) + "\")";
dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
} else {
dtype_cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")";
}
// prepare compute datas
stack_.cond_if(dtype_cond).comment("prepare compute datas");
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input");
compute_args.push_back(t_name);
}
for (size_t i = 0; i < plugin->outputs.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output");
compute_args.push_back(t_name);
}
if (plugin->buffers.size() > 0) {
stack_.assign("offset", 0, "size_t");
for (size_t i = 0; i < plugin->buffers.size(); i++) {
const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer");
compute_args.push_back(t_name);
const ffi::String& size_name = "size_" + plugin->buffers[i]->name;
stack_
.func_call("size", DocUtils::ToDeclare("size_t", size_name),
DocUtils::ToIndex("buffer_metas_", i))
.call_arg(false)
.assign("offset", "offset + batch_size * " + size_name);
}
}
compute_args.push_back("meta_attr_");
compute_args.push_back("stream");
CodeGenSafeCall(plugin->externs["cuda_compute"], compute_args);
stack_.cond_end();
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("msc.plugin.GetTensorRTPluginSources",
[](const ffi::String& codegen_config, const ffi::String& print_config,
const ffi::String& codegen_type) -> ffi::Map<ffi::String, ffi::String> {
TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config);
if (codegen_type == "build") {
return codegen.GetBuildSources(print_config);
}
if (codegen_type == "manager") {
return codegen.GetManagerSources(print_config);
}
return ffi::Map<ffi::String, ffi::String>();
});
}
} // namespace msc
} // namespace contrib
} // namespace tvm