blob: fcb1f3982f792223c5eacd2560d6e1080cb7de25 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/plugin/base_codegen.h
* \brief The codegen for Plugin.
*/
#ifndef TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_
#define TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_
#include <dmlc/json.h>
#include <tvm/script/printer/doc.h>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "../core/codegen/code_stack.h"
#include "../core/ir/plugin.h"
#include "../core/printer/cpp_printer.h"
#include "../core/printer/python_printer.h"
namespace tvm {
namespace contrib {
namespace msc {
using namespace tvm::script::printer;
/*!
* \brief CodeGen for Plugin
*/
template <typename ConfigType>
class BasePluginCodeGen {
public:
/*!
* \brief The constructor of BasePluginCodeGen
* \param config the options for codegen.
*/
explicit BasePluginCodeGen(const std::string& config = "") {
config_.reset(new ConfigType());
if (config.size() > 0) {
std::istringstream is(config);
dmlc::JSONReader reader(&is);
reader.Read(config_.get());
}
}
virtual ~BasePluginCodeGen() = default;
/*! \brief Get plugin sources*/
virtual const ffi::Map<ffi::String, ffi::String> GetBuildSources(
const std::string& print_options = "") {
ffi::Map<ffi::String, ffi::String> sources;
// plugin sources
for (const auto& name : ListPluginNames()) {
const auto& plugin = GetPlugin(name);
// attr declare
const ffi::String& attr_macro =
"TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_";
this->stack_.line("#ifndef " + attr_macro)
.line("#define " + attr_macro)
.line()
.line("#include \"plugin_utils.h\"")
.line();
StartNamespace();
CodeGenAttrDeclare(plugin);
EndNamespace();
this->stack_.line("#endif // " + attr_macro);
sources.Set(plugin->name + "_attr.h", ToCppSource(print_options));
// attr define
this->stack_.line("#include \"" + plugin->name + "_attr.h\"").line();
StartNamespace();
CodeGenAttrDefine(plugin);
EndNamespace();
sources.Set(plugin->name + "_attr.cc", ToCppSource(print_options));
// op decalre
const ffi::String& op_macro =
"TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_";
this->stack_.line("#ifndef " + op_macro).line("#define " + op_macro).line();
CodeGenOpHeader(plugin);
StartNamespace();
CodeGenOpDeclare(plugin);
EndNamespace();
this->stack_.line("#endif // " + op_macro);
sources.Set(plugin->name + "_op.h", ToCppSource(print_options));
// op define
this->stack_.line("#include \"" + plugin->name + "_op.h\"").line();
StartNamespace();
CodeGenOpDefine(plugin);
EndNamespace();
sources.Set(plugin->name + "_op.cc", ToCppSource(print_options));
// op runtime
if (this->config()->with_runtime) {
CodeGenOpHeader(plugin);
StartNamespace();
CodeGenOpRuntime(plugin);
EndNamespace();
sources.Set(plugin->name + "_runtime.cc", ToCppSource(print_options));
}
}
// cmakelists
std::set<ffi::String> devices;
for (const auto& name : ListPluginNames()) {
const auto& plugin = GetPlugin(name);
for (const auto& pair : plugin->externs) {
if (StringUtils::EndsWith(pair.first, "_compute")) {
devices.insert(StringUtils::Replace(pair.first, "_compute", ""));
}
}
}
CodeGenCmake(devices);
sources.Set("CMakeLists.txt", ToCppSource(print_options));
return sources;
}
/*! \brief Get manager sources*/
virtual const ffi::Map<ffi::String, ffi::String> GetManagerSources(
const std::string& print_options = "") {
ffi::Map<ffi::String, ffi::String> sources;
CodeGenManagerDepends();
this->stack_.class_def("PluginManager(object)").class_start();
CodeGenManagerMethods();
for (const auto& name : ListPluginNames()) {
CodeGenOpBuilder(GetPlugin(name));
}
if (this->config()->need_convert) {
ffi::Map<Plugin, ffi::String> symbols;
this->stack_.func_def("get_convert_map")
.func_decorator("classmethod")
.func_arg("cls", "object")
.func_start();
CodeGenConvertDepends();
for (const auto& name : ListPluginNames()) {
const auto& plugin = GetPlugin(name);
const auto& symbol = CodeGenOpConvert(plugin);
symbols.Set(plugin, symbol);
}
this->stack_.assign("converters", "{}");
for (const auto& pair : symbols) {
this->stack_.assign(DocUtils::ToIndex("converters", DocUtils::ToStr(pair.second)),
ConverterName(pair.first));
}
this->stack_.func_end("converters");
}
this->stack_.class_end();
sources.Set("manager.py", ToPySource(print_options));
return sources;
}
protected:
/*! \brief Header of plugin files*/
virtual void CodeGenOpHeader(const Plugin& plugin) {
this->stack_.line("#include \"" + plugin->name + "_attr.h\"");
std::set<ffi::String> include_headers;
for (const auto& pair : plugin->externs) {
if (pair.second->header.size() > 0 && !include_headers.count(pair.second->header)) {
this->stack_.line("#include \"" + pair.second->header + "\"");
include_headers.insert(pair.second->header);
}
}
this->stack_.line();
}
/*! \brief Start the namespace*/
void StartNamespace() {
this->stack_.line("namespace tvm {")
.line("namespace contrib {")
.line("namespace msc {")
.line("namespace plugin {")
.line();
}
/*! \brief End the namespace*/
void EndNamespace() {
this->stack_.line("} // namespace plugin")
.line("} // namespace msc")
.line("} // namespace contrib")
.line("} // namespace tvm");
}
/*! \brief Codegen safe call extern*/
void CodeGenSafeCall(const PluginExtern& extern_func,
const ffi::Array<ffi::String>& call_args = ffi::Array<ffi::String>(),
const ffi::String& ret = "") {
this->stack_.scope_start("try {").func_call(extern_func->name, ret);
for (const auto& arg : call_args) {
this->stack_.call_arg(arg);
}
this->stack_.scope_end()
.scope_start("} catch (const std::exception& exc) {")
.line("std::cerr << \"Failed to run extern " + extern_func->name +
" : \" << exc.what() << std::endl;")
.line("throw std::runtime_error(\"Failed to run extern " + extern_func->name + "\");")
.scope_end()
.line("}");
}
/*! \brief Codegen plugin attr declare*/
virtual void CodeGenAttrDeclare(const Plugin& plugin) {
this->stack_.struct_start(MetaAttrCls(plugin)).comment("define attributes");
for (const auto& attr : plugin->attrs) {
this->stack_.declare(ToCppType(attr->type), attr->name);
if (attr->default_value.size() > 0) {
this->stack_.declare_arg(attr->default_value);
}
}
this->stack_.line()
.comment("print method")
.func_def("operator<<", "friend std::ostream&")
.func_arg("out", "std::ostream&")
.func_arg("attrs", "const " + MetaAttrCls(plugin) + "&")
.func_start()
.line("out << \"[" + MetaAttrCls(plugin) + "] : \";");
for (const auto& attr : plugin->attrs) {
this->stack_.line("out << \"| " + attr->name + "(" + attr->type + ")=\" << attrs." +
attr->name + ";");
}
this->stack_.func_end("out").struct_end();
}
/*! \brief Codegen plugin attr define*/
virtual void CodeGenAttrDefine(const Plugin& plugin) {}
/*! \brief Codegen plugin op declare*/
virtual void CodeGenOpDeclare(const Plugin& plugin) = 0;
/*! \brief Codegen plugin op define*/
virtual void CodeGenOpDefine(const Plugin& plugin) = 0;
/*! \brief Codegen plugin runtime*/
virtual void CodeGenOpRuntime(const Plugin& plugin) {}
/*! \brief Codegen cmake file*/
virtual void CodeGenCmake(const std::set<ffi::String>& devices) {
CodeGenPreCmake(devices);
CodeGenPostCmake(devices);
}
/*! \brief Codegen cmake start*/
void CodeGenPreCmake(const std::set<ffi::String>& devices,
const ffi::Map<ffi::String, ffi::String>& extra_flags =
ffi::Map<ffi::String, ffi::String>()) {
const auto& p_name = this->config()->project_name;
stack_.line("cmake_minimum_required(VERSION " + this->config()->cmake_version + " FATAL_ERROR)")
.line("project(" + p_name + ")");
if (devices.count("cuda")) {
stack_.line("find_package(CUDA)").line("add_definitions(-DPLUGIN_ENABLE_CUDA)");
}
stack_.line();
for (const auto& pair : extra_flags) {
if (pair.second.size() == 0) {
stack_.line("add_definitions(-D" + pair.first + ")");
} else {
stack_.line("add_definitions(-D" + pair.first + "=" + pair.second + ")");
}
}
for (const auto& pair : this->config()->flags) {
if (pair.second.size() == 0) {
stack_.line("add_definitions(-D" + pair.first + ")");
} else {
stack_.line("add_definitions(-D" + pair.first + "=" + pair.second + ")");
}
}
stack_.line();
}
/*! \brief Codegen cmake end*/
void CodeGenPostCmake(const std::set<ffi::String>& devices,
const ffi::Array<ffi::String>& extra_includes = ffi::Array<ffi::String>(),
const ffi::Array<ffi::String>& extra_libs = ffi::Array<ffi::String>()) {
const auto& p_name = this->config()->project_name;
stack_.line()
.line("file(GLOB_RECURSE PLUGIN_HEADERS src/*.h)")
.line("file(GLOB_RECURSE PLUGIN_CC_SRCS src/*.cc)");
if (devices.count("cuda")) {
stack_.line("file(GLOB_RECURSE PLUGIN_CU_SRCS src/*.cu)");
}
if (devices.count("cuda")) {
stack_.line("cuda_add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS} ${PLUGIN_CU_SRCS})");
} else {
stack_.line("add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS})");
}
// define includes
ffi::String includes = StringUtils::Join(extra_includes, " ");
if (this->config()->includes.size() > 0) {
includes = includes + " " + StringUtils::Join(this->config()->includes, " ");
}
if (includes.size() > 0) {
stack_.line("target_include_directories(" + p_name + " PUBLIC " + includes + ")");
}
// define libs
ffi::String link_libs = StringUtils::Join(extra_libs, " ");
const auto& libs = StringUtils::Join(this->config()->libs, " ");
if (libs.size() > 0) {
link_libs = link_libs + " " + libs;
}
if (link_libs.size() > 0) {
stack_.line("target_link_libraries(" + p_name + " " + link_libs + ")");
}
const auto& install_dir = this->config()->install_dir;
if (install_dir.size() > 0) {
stack_.line()
.line("SET(LIBRARY_OUTPUT_PATH " + install_dir + "/lib)")
.line("file(COPY ${PLUGIN_HEADERS} DESTINATION " + install_dir + "/include)");
if (this->config()->libs.size() > 0) {
stack_.line("file(COPY " + libs + " DESTINATION " + install_dir + "/lib)");
}
}
}
/*! \brief Codegen manager depends*/
virtual void CodeGenManagerDepends() {
this->stack_.line("import os")
.line("import shutil")
.line("import ctypes")
.line("from typing import Any, List, Dict")
.line();
}
/*! \brief Codegen manager methods*/
virtual void CodeGenManagerMethods() {
// init method
stack_.func_def("__init__")
.func_arg("self", "object")
.func_arg("root", "str", "None")
.func_start()
.cond_if("root is None")
.assign("root", "os.path.dirname(__name__)")
.cond_end()
.assign(DocUtils::ToAttrAccess("self", "_lib_folder"), "os.path.join(root, \"lib\")")
.func_call("assert")
.inplace_start("os.path.isdir")
.call_arg(DocUtils::ToAttrAccess("self", "_lib_folder"))
.inplace_end()
.assign(DocUtils::ToAttrAccess("self", "_include_folder"),
"os.path.join(root, \"include\")")
.func_call("assert")
.inplace_start("os.path.isdir")
.call_arg(DocUtils::ToAttrAccess("self", "_include_folder"))
.inplace_end()
.assign(DocUtils::ToAttrAccess("self", "_manager_file"),
"os.path.join(root, \"manager.py\")")
.func_call("assert")
.inplace_start("os.path.isfile")
.call_arg(DocUtils::ToAttrAccess("self", "_manager_file"))
.inplace_end()
.func_call("setup", "", "self")
.func_end();
// list headers
this->stack_.func_def("list_includes")
.func_arg("self", "object")
.func_arg("as_abs", "bool", "False")
.func_start()
.assign("includes", "[]")
.for_start("f", "os.listdir(self._include_folder)")
.cond_if("as_abs")
.func_call("append", "", "includes")
.inplace_start("os.path.join")
.call_arg(DocUtils::ToAttrAccess("self", "_include_folder"))
.call_arg("f")
.inplace_end()
.cond_else()
.func_call("append", "", "includes")
.call_arg("f")
.cond_end()
.for_end()
.func_end("includes");
// copy the headers
this->stack_.func_def("copy_includes")
.func_arg("self", "object")
.func_arg("dst", "str")
.func_start()
.cond_if("not os.path.isdir(dst)")
.func_call("makedirs", "", "os")
.call_arg("dst")
.cond_end()
.for_start("header", "os.listdir(self._include_folder)")
.func_call("shutil.copyfile")
.inplace_start("os.path.join")
.call_arg(DocUtils::ToAttrAccess("self", "_include_folder"))
.call_arg("header")
.inplace_end()
.inplace_start("os.path.join")
.call_arg("dst")
.call_arg("header")
.inplace_end()
.for_end()
.func_end();
// list libs
this->stack_.func_def("list_libs")
.func_arg("self", "object")
.func_arg("as_abs", "bool", "False")
.func_start()
.assign("libs", "[]")
.for_start("f", "os.listdir(self._lib_folder)")
.cond_if("as_abs")
.func_call("append", "", "libs")
.inplace_start("os.path.join")
.call_arg(DocUtils::ToAttrAccess("self", "_lib_folder"))
.call_arg("f")
.inplace_end()
.cond_else()
.func_call("append", "", "libs")
.call_arg("f")
.cond_end()
.for_end()
.func_end("libs");
// copy the libs
this->stack_.func_def("copy_libs")
.func_arg("self", "object")
.func_arg("dst", "str")
.func_start()
.cond_if("not os.path.isdir(dst)")
.func_call("makedirs", "", "os")
.call_arg("dst")
.cond_end()
.for_start("lib", "os.listdir(self._lib_folder)")
.func_call("shutil.copyfile")
.inplace_start("os.path.join")
.call_arg(DocUtils::ToAttrAccess("self", "_lib_folder"))
.call_arg("lib")
.inplace_end()
.inplace_start("os.path.join")
.call_arg("dst")
.call_arg("lib")
.inplace_end()
.for_end()
.func_end();
// export method
this->stack_.func_def("export")
.func_arg("self", "object")
.func_arg("dst", "str")
.func_start()
.func_call("copy_includes", "", "self")
.inplace_start("os.path.join")
.call_arg("dst")
.call_arg(DocUtils::ToStr("include"))
.inplace_end()
.func_call("copy_libs", "", "self")
.inplace_start("os.path.join")
.call_arg("dst")
.call_arg(DocUtils::ToStr("lib"))
.inplace_end()
.func_call("shutil.copyfile")
.call_arg(DocUtils::ToAttrAccess("self", "_manager_file"))
.inplace_start("os.path.join")
.call_arg("dst")
.call_arg(DocUtils::ToStr("manager.py"))
.inplace_end()
.func_end();
// get op names
this->stack_.func_def("get_op_names", "List[str]")
.func_arg("self", "object")
.func_start()
.assign("names", "[]");
for (const auto& name : ListPluginNames()) {
this->stack_.func_call("append", "", "names").call_arg(DocUtils::ToStr(name));
}
this->stack_.func_end("names");
// get ops info
this->stack_.func_def("get_ops_info", "dict")
.func_arg("self", "object")
.func_start()
.assign("info", "{}");
for (const auto& name : ListPluginNames()) {
ICHECK(this->config()->ops_info.count(name)) << "Can not find op info for " << name;
const auto& info = this->config()->ops_info[name];
this->stack_.assign(DocUtils::ToIndex("info", DocUtils::ToStr(name)), info);
}
this->stack_.func_end("info");
}
/*! \brief Codegen manager for plugin*/
virtual void CodeGenOpBuilder(const Plugin& plugin) {}
/*! \brief Codegen convert depends*/
virtual void CodeGenConvertDepends() {
this->stack_.line("from tvm import relax")
.line("from tvm.relax import call_dps_packed")
.line("from tvm.contrib.msc.plugin import utils as plugin_utils")
.line("from tvm.contrib.msc.plugin.op import _ffi_api as _plugin_api")
.line("from tvm.contrib.msc.core import utils as msc_utils")
.line();
}
/*! \brief Codegen convert function for plugin*/
virtual const ffi::String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; }
/*! \brief Change code stack to cpp source*/
const ffi::String ToCppSource(const std::string& print_options = "") {
CppPrinter printer(print_options);
for (const auto& d : this->stack_.GetDocs()) {
printer.Append(d);
}
this->stack_.Reset();
return printer.GetString();
}
/*! \brief Change code stack to python source*/
const ffi::String ToPySource(const std::string& print_options = "") {
PythonPrinter printer(print_options);
for (const auto& d : this->stack_.GetDocs()) {
printer.Append(d);
}
this->stack_.Reset();
return printer.GetString();
}
std::vector<std::unordered_map<int, ffi::String>> GetDtypeMatrix(const Plugin& plugin) {
std::vector<std::unordered_map<int, ffi::String>> matrix;
if (plugin->support_dtypes.size() == 0) {
std::unordered_map<int, ffi::String> dtypes;
for (size_t i = 0; i < plugin->inputs.size(); i++) {
dtypes[i] = plugin->inputs[i]->dtype;
}
matrix.push_back(dtypes);
} else {
ffi::Array<ffi::String> templates;
ffi::Array<ffi::Array<ffi::String>> condidates;
for (const auto& pair : plugin->support_dtypes) {
templates.push_back(pair.first);
condidates.push_back(pair.second);
}
for (const auto& t_dtypes : ArrayUtils::Product(condidates)) {
std::unordered_map<int, ffi::String> dtypes;
for (size_t i = 0; i < templates.size(); i++) {
for (size_t in_idx = 0; in_idx < plugin->inputs.size(); in_idx++) {
if (plugin->inputs[in_idx]->dtype == templates[i]) {
dtypes[in_idx] = t_dtypes[i];
}
}
}
for (size_t i = 0; i < plugin->inputs.size(); i++) {
if (dtypes.count(i)) {
continue;
}
dtypes[i] = plugin->inputs[i]->dtype;
}
matrix.push_back(dtypes);
}
}
return matrix;
}
const ffi::Map<ffi::String, ffi::String> GetTensorDtypes(
const Plugin& plugin, const std::unordered_map<int, ffi::String>& dtypes) {
ffi::Map<ffi::String, ffi::String> tensor_dtypes;
for (const auto& pair : dtypes) {
const ffi::String& ref_dtype = plugin->inputs[pair.first]->dtype;
for (const auto& t : plugin->inputs) {
if (t->dtype == ref_dtype) {
tensor_dtypes.Set(t->name, pair.second);
}
}
for (const auto& t : plugin->outputs) {
if (t->dtype == ref_dtype) {
tensor_dtypes.Set(t->name, pair.second);
}
}
for (const auto& t : plugin->buffers) {
if (t->dtype == ref_dtype) {
tensor_dtypes.Set(t->name, pair.second);
}
}
}
return tensor_dtypes;
}
/*! \brief Change plugin comment in python*/
const ffi::String GetPyComment(const Plugin& plugin) {
ffi::String comment = "Python wrapper for " + plugin->name + "\nInputs\n------";
for (const auto& t : plugin->inputs) {
comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe;
}
comment = comment + "\nOutputs\n-------";
for (const auto& t : plugin->outputs) {
comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe;
}
if (plugin->attrs.size() > 0) {
comment = comment + "\nAttributes\n-----------";
for (const auto& a : plugin->attrs) {
comment = comment + "\n" + a->name + ": " + ToPyType(a->type) + "\n " + a->describe;
}
}
return comment;
}
/*! \brief Get class name for meta attrs*/
const ffi::String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; }
/*! \brief Get converter name for plugin*/
const ffi::String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; }
/*! \brief Check if the type is list type. */
bool IsListType(const ffi::String& type) { return StringUtils::StartsWith(type, "list"); }
/*! \brief Get type of element. */
const ffi::String GetEleType(const ffi::String& type) {
if (!IsListType(type)) {
return "";
}
return StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", "");
}
/*! \brief Type name in cpp*/
virtual const ffi::String ToCppType(const ffi::String& type) {
if (IsListType(type)) {
const auto& ele_type = GetEleType(type);
return "std::vector<" + ToCppType(ele_type) + ">";
}
if (type == "int64") {
return "int64_t";
}
if (type == "int32" || type == "int") {
return "int32_t";
}
if (type == "int8") {
return "int8_t";
}
if (type == "string") {
return "std::string";
}
return type;
}
/*! \brief Type name in python*/
virtual const ffi::String ToPyType(const ffi::String& type) {
if (IsListType(type)) {
const auto& ele_type = GetEleType(type);
return "List[" + ToPyType(ele_type) + "]";
}
if (type == "int64" || type == "int32" || type == "int" || type == "int8") {
return "int";
}
if (type == "string") {
return "str";
}
return type;
}
/*!
* \brief Compare version with version in config
* 0 for same version, 1 for greater version, -1 for less version
*/
int CompareVersion(size_t major, size_t minor, size_t patch) {
return CommonUtils::CompareVersion(this->config()->version, {major, minor, patch});
}
/*! \brief The config of plugin codegen*/
const std::shared_ptr<ConfigType> config() { return config_; }
/*! \brief The stack of codes*/
CodeStack stack_;
private:
std::shared_ptr<ConfigType> config_;
};
} // namespace msc
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_