blob: 838b3948bf01fe376ab4f91cbd4e1185861a82ce [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/core/ir/plugin.h
* \brief Plugin describe for msc.
*/
#ifndef TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_
#define TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_
#include <dmlc/json.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/data_layout.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "../../../../node/attr_registry.h"
#include "../utils.h"
namespace tvm {
namespace contrib {
namespace msc {
/*!
* \brief Json serialize and deserialize for Plugin Attribute.
*/
struct JsonPluginAttr {
std::string name;
std::string type;
std::string default_value;
std::string describe;
void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("type", type);
writer->WriteObjectKeyValue("default_value", default_value);
writer->WriteObjectKeyValue("describe", describe);
writer->EndObject();
}
void Load(dmlc::JSONReader* reader) {
int bitmask = 0;
std::string key;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
if (key == "name") {
reader->Read(&name);
bitmask |= 1;
} else if (key == "type") {
reader->Read(&type);
bitmask |= 2;
} else if (key == "default_value") {
reader->Read(&default_value);
} else if (key == "describe") {
reader->Read(&describe);
}
}
ICHECK_EQ(bitmask, 1 | 2) << "name and type should be given for plugin attr";
if (describe.size() == 0) {
describe = "Plugin attribute " + name + "(" + type + ")";
}
}
};
/*!
* \brief Json serialize and deserialize for Plugin Tensor.
*/
struct JsonPluginTensor {
std::string name;
std::string dtype;
int64_t ndim{-1};
std::string device{"default"};
std::string describe;
void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("dtype", dtype);
writer->WriteObjectKeyValue("ndim", ndim);
writer->WriteObjectKeyValue("device", device);
writer->WriteObjectKeyValue("describe", describe);
writer->EndObject();
}
void Load(dmlc::JSONReader* reader) {
int bitmask = 0;
std::string key;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
if (key == "name") {
reader->Read(&name);
bitmask |= 1;
} else if (key == "dtype") {
reader->Read(&dtype);
} else if (key == "ndim") {
reader->Read(&ndim);
} else if (key == "device") {
reader->Read(&device);
} else if (key == "describe") {
reader->Read(&describe);
}
}
ICHECK_EQ(bitmask, 1) << "name should be given for plugin tensor";
if (describe.size() == 0) {
describe = "Plugin tensor " + name + "(" + dtype + " on " + device + ")";
}
}
};
/*!
* \brief Json serialize and deserialize for Plugin Extern.
*/
struct JsonPluginExtern {
std::string name;
std::string header;
std::string source;
std::string lib;
std::string describe;
void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("header", header);
writer->WriteObjectKeyValue("source", source);
writer->WriteObjectKeyValue("lib", lib);
writer->WriteObjectKeyValue("describe", describe);
writer->EndObject();
}
void Load(dmlc::JSONReader* reader) {
int bitmask = 0;
std::string key;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
if (key == "name") {
reader->Read(&name);
bitmask |= 1;
} else if (key == "header") {
reader->Read(&header);
} else if (key == "source") {
reader->Read(&source);
} else if (key == "lib") {
reader->Read(&lib);
} else if (key == "describe") {
reader->Read(&describe);
}
}
ICHECK_EQ(bitmask, 1) << "name should be given for plugin extern";
if (describe.size() == 0) {
describe = "Plugin function " + name + "(from " + header + ")";
}
}
};
/*!
* \brief Json serialize and deserialize for Plugin.
*/
struct JsonPlugin {
std::string name;
std::string version;
std::string describe;
std::vector<JsonPluginAttr> attrs;
std::vector<JsonPluginTensor> inputs;
std::vector<JsonPluginTensor> outputs;
std::vector<JsonPluginTensor> buffers;
std::unordered_map<std::string, JsonPluginExtern> externs;
std::unordered_map<std::string, std::vector<std::string>> support_dtypes;
std::unordered_map<std::string, std::string> options;
void Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("version", version);
writer->WriteObjectKeyValue("describe", describe);
writer->WriteObjectKeyValue("attrs", attrs);
writer->WriteObjectKeyValue("inputs", inputs);
writer->WriteObjectKeyValue("outputs", outputs);
writer->WriteObjectKeyValue("buffers", buffers);
writer->WriteObjectKeyValue("externs", externs);
writer->WriteObjectKeyValue("support_dtypes", support_dtypes);
writer->WriteObjectKeyValue("options", options);
writer->EndObject();
}
void Load(dmlc::JSONReader* reader) {
int bitmask = 0;
std::string key;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
if (key == "name") {
reader->Read(&name);
bitmask |= 1;
} else if (key == "version") {
reader->Read(&version);
} else if (key == "describe") {
reader->Read(&describe);
} else if (key == "attrs") {
reader->Read(&attrs);
} else if (key == "inputs") {
reader->Read(&inputs);
bitmask |= 2;
} else if (key == "outputs") {
reader->Read(&outputs);
bitmask |= 4;
} else if (key == "buffers") {
reader->Read(&buffers);
} else if (key == "externs") {
reader->Read(&externs);
} else if (key == "support_dtypes") {
reader->Read(&support_dtypes);
} else if (key == "options") {
reader->Read(&options);
}
}
ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, inputs and outputs should be given for plugin";
if (externs.size() > 0) {
ICHECK(externs.count("infer_output")) << "infer_output should be given as extern";
bool has_compute = false;
for (const auto& pair : externs) {
if (StringUtils::EndsWith(pair.first, "_compute")) {
has_compute = true;
}
}
ICHECK(has_compute) << "No compute function found, please check";
}
if (describe.size() == 0) {
describe = "Plugin " + name + "(" + version + ")";
}
}
};
/*!
* \brief Attribute in Plugin.
*/
class PluginAttrNode : public Object {
public:
/*! \brief The name of attribute. */
ffi::String name;
/*! \brief The type of attribute. */
ffi::String type;
/*! \brief The default_value of attribute. */
ffi::String default_value;
/*! \brief The describe of attribute. */
ffi::String describe;
/*! \brief Export attribute to json. */
const JsonPluginAttr ToJson() const;
/*! \brief Load attribute from json struct. */
void FromJson(const JsonPluginAttr& j_attr);
/*! \brief Load attribute from json string. */
void FromJson(const std::string& json_str);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PluginAttrNode>()
.def_ro("name", &PluginAttrNode::name)
.def_ro("type", &PluginAttrNode::type)
.def_ro("default_value", &PluginAttrNode::default_value)
.def_ro("describe", &PluginAttrNode::describe);
}
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginAttr", PluginAttrNode, Object);
};
/*!
* \brief Managed reference to PluginAttrNode.
* \sa PluginAttrNode
*/
class PluginAttr : public ObjectRef {
public:
/*!
* \brief The constructor.
* \param name The name of the attribute.
* \param type The type of the attribute.
* \param default_value The default_value of the attribute.
* \param describe The describe of the attribute.
*/
TVM_DLL PluginAttr(const ffi::String& name, const ffi::String& type,
const ffi::String& default_value, const ffi::String& describe);
/*!
* \brief The json constructor.
* \param j_attr The json describe of the attribute.
*/
TVM_DLL PluginAttr(const JsonPluginAttr& j_attr);
/*!
* \brief The json constructor.
* \param json_str The json describe of the attribute.
*/
TVM_DLL PluginAttr(const std::string& json_str);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginAttr, ObjectRef, PluginAttrNode);
};
/*!
* \brief Tensor in Plugin.
*/
class PluginTensorNode : public Object {
public:
/*! \brief The name of tensor. */
ffi::String name;
/*! \brief The dtype of tensor. */
ffi::String dtype;
/*! \brief The ndim of tensor. */
Integer ndim;
/*! \brief The device of tensor. */
ffi::String device;
/*! \brief The describe of tensor. */
ffi::String describe;
/*! \brief Export tensor to json. */
const JsonPluginTensor ToJson() const;
/*! \brief Load tensor from json struct. */
void FromJson(const JsonPluginTensor& j_attr);
/*! \brief Load tensor from json string. */
void FromJson(const std::string& json_str);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PluginTensorNode>()
.def_ro("name", &PluginTensorNode::name)
.def_ro("dtype", &PluginTensorNode::dtype)
.def_ro("ndim", &PluginTensorNode::ndim)
.def_ro("device", &PluginTensorNode::device)
.def_ro("describe", &PluginTensorNode::describe);
}
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginTensor", PluginTensorNode, Object);
};
/*!
* \brief Managed reference to PluginTensorNode.
* \sa PluginTensorNode
*/
class PluginTensor : public ObjectRef {
public:
/*!
* \brief The constructor.
* \param name The name of the tensor.
* \param dtype The dtype of the tensor.
* \param ndim The ndim of the tensor.
* \param device The device of the tensor.
* \param describe The describe of the tensor.
*/
TVM_DLL PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim,
const ffi::String& device, const ffi::String& describe);
/*!
* \brief The json constructor.
* \param j_tensor The json describe of the tensor.
*/
TVM_DLL PluginTensor(const JsonPluginTensor& j_tensor);
/*!
* \brief The json constructor.
* \param json_str The json describe of the tensor.
*/
TVM_DLL PluginTensor(const std::string& json_str);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginTensor, ObjectRef, PluginTensorNode);
};
/*!
* \brief Extern symbol in Plugin.
*/
class PluginExternNode : public Object {
public:
/*! \brief The name of extern. */
ffi::String name;
/*! \brief The header of extern. */
ffi::String header;
/*! \brief The source of extern. */
ffi::String source;
/*! \brief The lib of extern. */
ffi::String lib;
/*! \brief The describe of extern. */
ffi::String describe;
/*! \brief Export extern to json. */
const JsonPluginExtern ToJson() const;
/*! \brief Load extern from json struct. */
void FromJson(const JsonPluginExtern& j_attr);
/*! \brief Load extern from json string. */
void FromJson(const std::string& json_str);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PluginExternNode>()
.def_ro("name", &PluginExternNode::name)
.def_ro("header", &PluginExternNode::header)
.def_ro("source", &PluginExternNode::source)
.def_ro("lib", &PluginExternNode::lib)
.def_ro("describe", &PluginExternNode::describe);
}
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginExtern", PluginExternNode, Object);
};
/*!
* \brief Managed reference to PluginExternNode.
* \sa PluginExternNode
*/
class PluginExtern : public ObjectRef {
public:
/*!
* \brief The constructor.
* \param name The name of the extern.
* \param header The header of the extern.
* \param source The source of the extern.
* \param lib The lib of the extern.
* \param describe The describe of the extern.
*/
TVM_DLL PluginExtern(const ffi::String& name, const ffi::String& header,
const ffi::String& source, const ffi::String& lib,
const ffi::String& describe);
/*!
* \brief The json constructor.
* \param j_extern The json describe of the extern.
*/
TVM_DLL PluginExtern(const JsonPluginExtern& j_extern);
/*!
* \brief The json constructor.
* \param json_str The json describe of the extern.
*/
TVM_DLL PluginExtern(const std::string& json_str);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginExtern, ObjectRef, PluginExternNode);
};
/*!
* \brief The Plugin in MSC.
*/
class PluginNode : public Object {
public:
/*! \brief The name of plugin. */
ffi::String name;
/*! \brief The version of plugin. */
ffi::String version;
/*! \brief The describe of plugin. */
ffi::String describe;
/*! \brief The attributes of plugin. */
ffi::Array<PluginAttr> attrs;
/*! \brief The inputs of plugin. */
ffi::Array<PluginTensor> inputs;
/*! \brief The outputs of plugin. */
ffi::Array<PluginTensor> outputs;
/*! \brief The buffers of plugin. */
ffi::Array<PluginTensor> buffers;
/*! \brief The externs of plugin. */
ffi::Map<ffi::String, PluginExtern> externs;
/*! \brief The support_dtypes of plugin. */
ffi::Map<ffi::String, ffi::Array<ffi::String>> support_dtypes;
/*! \brief The options of plugin. */
ffi::Map<ffi::String, ffi::String> options;
/*! \brief Export plugin to json. */
const JsonPlugin ToJson() const;
/*! \brief Load plugin from json struct. */
void FromJson(const JsonPlugin& j_attr);
/*! \brief Load plugin from json string. */
void FromJson(const std::string& json_str);
/*! \brief Find input ref index for dtype. */
int FindDtypeRefIdx(const PluginTensor& tensor) const;
/*! \brief Find input ref index for device. */
int FindDeviceRefIdx(const PluginTensor& tensor) const;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PluginNode>()
.def_ro("name", &PluginNode::name)
.def_ro("version", &PluginNode::version)
.def_ro("describe", &PluginNode::describe)
.def_ro("attrs", &PluginNode::attrs)
.def_ro("inputs", &PluginNode::inputs)
.def_ro("outputs", &PluginNode::outputs)
.def_ro("buffers", &PluginNode::buffers)
.def_ro("externs", &PluginNode::externs)
.def_ro("support_dtypes", &PluginNode::support_dtypes)
.def_ro("options", &PluginNode::options);
}
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.Plugin", PluginNode, Object);
};
/*!
* \brief Managed reference to PluginNode.
* \sa PluginNode
*/
class Plugin : public ObjectRef {
public:
/*!
* \brief The constructor.
* \param name The name of the plugin.
* \param version The version of the plugin.
* \param describe The describe of the plugin.
* \param attrs The attrs of the plugin.
* \param inputs The inputs of the plugin.
* \param outputs The outputs of the plugin.
* \param buffers The buffers of the plugin.
* \param externs The externs of the plugin.
* \param support_dtypes The support_dtypes of the plugin.
* \param options The options of the plugin.
*/
TVM_DLL Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe,
const ffi::Array<PluginAttr>& attrs, const ffi::Array<PluginTensor>& inputs,
const ffi::Array<PluginTensor>& outputs, const ffi::Array<PluginTensor>& buffers,
const ffi::Map<ffi::String, PluginExtern>& externs,
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& support_dtypes,
const ffi::Map<ffi::String, ffi::String>& options);
/*!
* \brief The json constructor.
* \param j_plugin The json describe of the plugin.
*/
TVM_DLL Plugin(const JsonPlugin& j_plugin);
/*!
* \brief The json constructor.
* \param json_str The json describe of the plugin.
*/
TVM_DLL Plugin(const std::string& json_str);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Plugin, ObjectRef, PluginNode);
};
class PluginRegistry {
public:
/*!
* \brief Register a new plugin.
* \param name The name of the item.
* \param json_str The json_str.
* \return The corresponding entry.
*/
bool Register(const ffi::String& name, const ffi::String& json_str) {
plugin_map_[name] = Plugin(json_str);
return true;
}
/*!
* \brief Check if an plugin is registered.
* \param name The name of the item.
* \return Whether the plugin is registered.
*/
bool Registered(const ffi::String& name) const {
auto it = plugin_map_.find(name);
return it != plugin_map_.end();
}
/*!
* \brief Get an plugin from the registry.
* \param name The name of the item.
* \return The corresponding plugin.
*/
const Plugin Get(const ffi::String& name) const {
auto it = plugin_map_.find(name);
ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name;
return it->second;
}
/*!
* \brief List all the plugin names in the registry.
* \return The plugin names.
*/
ffi::Array<ffi::String> ListAllNames() const {
ffi::Array<ffi::String> names;
for (const auto& kv : plugin_map_) {
names.push_back(kv.first);
}
return names;
}
/*!
* \return a global singleton of the registry.
*/
static PluginRegistry* Global() {
static PluginRegistry* inst = new PluginRegistry();
return inst;
}
private:
// map from name to plugins.
std::unordered_map<ffi::String, Plugin> plugin_map_;
};
/*!
* \brief List all plugin names.
* \return the corresponding plugin names.
*/
const ffi::Array<ffi::String> ListPluginNames();
/*!
* \brief Get the registered plugin.
* \param name The name of the Plugin.
* \return the corresponding plugin.
*/
const Plugin GetPlugin(const ffi::String& name);
/*!
* \brief Check if an plugin is registered.
* \param name The name of the item.
* \return Whether the plugin is registered.
*/
bool IsPlugin(const ffi::String& name);
} // namespace msc
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_