blob: 5d51e9a2b3a1f2aea9a3c1985be8c4db4ed8db4f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lib_api.cc
* \brief APIs to interact with libraries
* This API specifies function prototypes to
* register custom ops, partitioner, and passes
* for library authors
* See example/extension/lib_custom_op/README.md
* See example/extension/lib_subgraph/README.md
* See example/extension/lib_pass/README.md
*/
#include "mxnet/lib_api.h"
mxnet::ext::MXerrorMsgs* mxnet::ext::MXerrorMsgs::get() {
static MXerrorMsgs inst;
return &inst;
}
std::stringstream& mxnet::ext::MXerrorMsgs::add(const char* file, int line) {
messages.emplace_back();
messages.back() << file << "[" << line << "]: ";
return messages.back();
}
int mxnet::ext::MXerrorMsgs::size() {
return messages.size();
}
const std::string* mxnet::ext::MXerrorMsgs::get(int idx) {
return new std::string(messages.at(idx).str());
}
mxnet::ext::MXContext::MXContext() : dev_type("error"), dev_id(-1) {}
mxnet::ext::MXContext::MXContext(std::string dev_type_, int dev_id_)
: dev_type(std::move(dev_type_)), dev_id(dev_id_) {}
mxnet::ext::MXContext::MXContext(const char* dev_type_, int dev_id_)
: dev_type(dev_type_), dev_id(dev_id_) {}
mxnet::ext::MXContext mxnet::ext::MXContext::CPU() {
return MXContext("cpu", 0);
}
mxnet::ext::MXContext mxnet::ext::MXContext::GPU() {
return MXContext("gpu", 0);
}
mxnet::ext::MXContext mxnet::ext::MXContext::CPU(int dev_id) {
return MXContext("cpu", dev_id);
}
mxnet::ext::MXContext mxnet::ext::MXContext::GPU(int dev_id) {
return MXContext("gpu", dev_id);
}
void mxnet::ext::MXSparse::set(void* data_ptr,
const int64_t* dims,
int ndims,
void* idx,
int64_t num_idx,
void* idx_ptr,
int64_t num_idx_ptr) {
data = data_ptr;
// If CSR, num of non-zero elemets is num_idx,
// If row sparse, num of elements is num_idx * width.
data_len = num_idx;
if (!idx_ptr) {
for (int i = 1; i < ndims; ++i)
data_len *= dims[i];
}
indices = reinterpret_cast<int64_t*>(idx);
indices_len = num_idx;
if (idx_ptr) {
indptr = reinterpret_cast<int64_t*>(idx_ptr);
indptr_len = num_idx_ptr;
}
}
mxnet::ext::MXTensor::MXTensor()
: data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {}
mxnet::ext::MXTensor::MXTensor(const MXTensor& oth)
: data_ptr(oth.data_ptr),
shape(oth.shape),
dtype(oth.dtype),
verID(oth.verID),
ctx(oth.ctx),
stype(oth.stype) {
setDLTensor();
}
mxnet::ext::MXTensor::MXTensor(void* data_ptr,
std::vector<int64_t> shape,
MXDType dtype,
size_t vID,
MXContext mx_ctx,
MXStorageType stype)
: data_ptr(data_ptr),
shape(std::move(shape)),
dtype(dtype),
verID(vID),
ctx(std::move(mx_ctx)),
stype(stype) {
setDLTensor();
}
void mxnet::ext::MXTensor::setTensor(void* dptr,
MXDType type,
const int64_t* dims,
int ndims,
size_t vID,
MXContext mx_ctx,
MXStorageType storage_type) {
data_ptr = dptr;
dtype = type;
verID = vID;
ctx = mx_ctx;
stype = storage_type;
shape.clear();
for (int j = 0; j < ndims; j++) {
shape.push_back(dims[j]);
}
setDLTensor();
}
void mxnet::ext::MXTensor::setDLTensor() {
dltensor.data = data_ptr;
dltensor.ndim = shape.size();
dltensor.shape = const_cast<int64_t*>(shape.data());
dltensor.strides = nullptr;
dltensor.byte_offset = 0;
dltensor.dtype.lanes = 1;
dltensor.ctx.device_id = ctx.dev_id;
if (ctx.dev_type == "cpu")
dltensor.ctx.device_type = kDLCPU;
else if (ctx.dev_type == "gpu")
dltensor.ctx.device_type = kDLGPU;
else if (ctx.dev_type == "opencl")
dltensor.ctx.device_type = kDLOpenCL;
else if (ctx.dev_type == "vulcan")
dltensor.ctx.device_type = kDLVulkan;
else if (ctx.dev_type == "metal")
dltensor.ctx.device_type = kDLMetal;
else if (ctx.dev_type == "vpi")
dltensor.ctx.device_type = kDLVPI;
else if (ctx.dev_type == "rocm")
dltensor.ctx.device_type = kDLROCM;
else
dltensor.ctx.device_type = kDLExtDev;
switch (dtype) {
case kFloat32:
dltensor.dtype.code = kDLFloat;
dltensor.dtype.bits = 32;
break;
case kFloat64:
dltensor.dtype.code = kDLFloat;
dltensor.dtype.bits = 64;
break;
case kFloat16:
dltensor.dtype.code = kDLFloat;
dltensor.dtype.bits = 16;
break;
case kUint8:
dltensor.dtype.code = kDLUInt;
dltensor.dtype.bits = 8;
break;
case kInt32:
dltensor.dtype.code = kDLInt;
dltensor.dtype.bits = 32;
break;
case kInt8:
dltensor.dtype.code = kDLInt;
dltensor.dtype.bits = 8;
break;
case kInt64:
dltensor.dtype.code = kDLInt;
dltensor.dtype.bits = 64;
break;
default:
dltensor.dtype.code = 0;
dltensor.dtype.bits = 0;
throw std::runtime_error(
"Error! Invalid dtype flag: " + std::to_string(static_cast<int>(dtype)) +
" when constructing MXTensor");
}
}
int64_t mxnet::ext::MXTensor::size() const {
int64_t size = 1;
for (auto& s : shape)
size *= s;
return size;
}
bool mxnet::ext::MXTensor::isSame(const MXTensor& oth) const {
return data_ptr == oth.data_ptr && dtype == oth.dtype && verID == oth.verID &&
ctx.dev_type == oth.ctx.dev_type && ctx.dev_id == oth.ctx.dev_id && shape == oth.shape &&
stype == oth.stype;
}
mxnet::ext::PassResource::PassResource(std::unordered_map<std::string, MXTensor>* new_args,
std::unordered_map<std::string, MXTensor>* new_aux,
nd_malloc_t nd_malloc,
const void* nd_alloc)
: new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {}
mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_arg(const std::string& name,
const std::vector<int64_t>& shapes,
const mxnet::ext::MXContext& ctx,
mxnet::ext::MXDType dtype) const {
void* data;
nd_malloc_(nd_alloc_,
shapes.data(),
shapes.size(),
ctx.dev_type.c_str(),
ctx.dev_id,
dtype,
name.c_str(),
1,
&data);
MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
(*new_args_)[name] = tensor;
return &(new_args_->at(name));
}
mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_aux(const std::string& name,
const std::vector<int64_t>& shapes,
const mxnet::ext::MXContext& ctx,
mxnet::ext::MXDType dtype) const {
void* data;
nd_malloc_(nd_alloc_,
shapes.data(),
shapes.size(),
ctx.dev_type.c_str(),
ctx.dev_id,
dtype,
name.c_str(),
0,
&data);
MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
(*new_aux_)[name] = tensor;
return &(new_aux_->at(name));
}
mxnet::ext::OpResource::OpResource(xpu_malloc_t cpu_malloc_fp,
void* cpu_alloc_fp,
xpu_malloc_t gpu_malloc_fp,
void* gpu_alloc_fp,
void* stream,
sparse_malloc_t sparse_malloc_fp,
void* sparse_alloc_fp,
void* rng_cpu_states,
void* rng_gpu_states)
: cpu_malloc(cpu_malloc_fp),
gpu_malloc(gpu_malloc_fp),
cpu_alloc(cpu_alloc_fp),
gpu_alloc(gpu_alloc_fp),
cuda_stream(stream),
sparse_malloc(sparse_malloc_fp),
sparse_alloc(sparse_alloc_fp),
rand_cpu_states(rng_cpu_states),
rand_gpu_states(rng_gpu_states) {}
void* mxnet::ext::OpResource::alloc_cpu(int size) const {
return cpu_malloc(cpu_alloc, size);
}
void* mxnet::ext::OpResource::alloc_gpu(int size) const {
return gpu_malloc(gpu_alloc, size);
}
void mxnet::ext::OpResource::alloc_sparse(mxnet::ext::MXSparse* sparse,
int index,
int indices_len,
int indptr_len) const {
sparse_malloc(sparse_alloc,
index,
indices_len,
indptr_len,
&(sparse->data),
&(sparse->indices),
&(sparse->indptr));
}
mxnet::ext::mx_cpu_rand_t* mxnet::ext::OpResource::get_cpu_rand_states() const {
return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
}
std::string mxnet::ext::getShapeAt(const std::string& shape, unsigned index) {
int idx = 1; // start at 1 to skip the first square bracket [
// find the beginning of the output shape for the particular output index
for (unsigned x = 0; x < index; x++)
idx = shape.find('[', idx + 1);
int stop = shape.find(']', idx); // find stop index for this output shape
// add this shape to the list
return shape.substr(idx, stop - idx + 1);
}
std::string mxnet::ext::getDtypeAt(const std::string& dtype, unsigned index) {
// find the beginning of the output dtype for the particular output index
int idx = 0;
for (unsigned x = 0; x < index; x++)
idx = dtype.find(',', idx + 1);
int stop = dtype.find(',', idx + 1); // find stop index for this output dtype
if (stop == -1)
stop = dtype.find(']', idx + 1);
return dtype.substr(idx + 1, stop - idx - 1);
}
mxnet::ext::JsonVal::JsonVal() : type(ERR), num(-1), str("") {}
mxnet::ext::JsonVal::JsonVal(mxnet::ext::JsonType t) : type(t), num(-1), str("") {}
mxnet::ext::JsonVal::JsonVal(std::string s) : type(STR), num(-1), str(std::move(s)) {}
mxnet::ext::JsonVal::JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {}
mxnet::ext::JsonVal::JsonVal(JsonType t, int n, std::string s)
: type(t), num(n), str(std::move(s)) {}
bool mxnet::ext::JsonVal::operator<(const mxnet::ext::JsonVal& o) const {
// for string JSON objects compare the string
if (type == STR)
return type == o.type && str < o.str;
// for number JSON objects compare the number
if (type == NUM)
return type == o.type && num < o.num;
// for list JSON objects, compare the size of list, and then each object in the list
if (type == LIST) {
if (list.size() != o.list.size())
return false;
for (unsigned int i = 0; i < list.size(); i++)
if (list[i] < o.list[i])
return false; // if we find an object that doesnt match return
return true; // all objects in lists matched
}
// for map JSON objects, compare the size of map, and then each key/value in the maps
if (type == MAP) {
if (map.size() != o.map.size())
return false;
for (auto& item : map) {
// if one map is missing a key in another return
if (o.map.find(item.first) == o.map.end())
return false;
if (item.second < o.map.at(item.first))
return false;
}
return true;
}
return type < o.type;
}
std::string mxnet::ext::JsonVal::dump() const {
std::string ret;
switch (type) {
case ERR:
ret = "json(Error)";
break;
case STR:
ret = "\"" + str + "\"";
break;
case NUM:
ret = str;
break;
case LIST:
ret = "[";
for (unsigned i = 0; i < list.size(); i++) {
auto& item = list[i];
ret += item.dump();
if (i < list.size() - 1)
ret += ",";
}
ret += "]";
break;
case MAP:
ret = "{";
unsigned cnt = 0;
for (auto& item : map) {
ret += item.first.dump() + " : " + item.second.dump();
if (cnt++ < map.size() - 1)
ret += ",";
}
ret += "}";
break;
}
return ret;
}
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) {
unsigned int idx = 0;
return JsonVal::parse(json, &idx);
}
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) {
JsonVal ret(STR);
while (*idx < json.size()) {
if (json[*idx] == '"' &&
(ret.str.size() == 0 || (ret.str.size() > 0 && ret.str.back() != '\\'))) {
++(*idx);
return ret;
} else {
ret.str += json[*idx];
++(*idx);
}
}
MX_ERROR_MSG << "Error! Unable to parse string: '" << json.substr(*idx) << "'" << std::endl;
return JsonVal();
}
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_num(const std::string& json, unsigned int* idx) {
JsonVal ret(NUM);
while (*idx < json.size()) {
if (json[*idx] >= '0' && json[*idx] <= '9') {
ret.str += json[*idx];
++(*idx);
} else {
break;
}
}
ret.num = std::stoi(ret.str);
return ret;
}
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_list(const std::string& json, unsigned int* idx) {
JsonVal ret(LIST);
while (*idx < json.size()) {
if (json[*idx] == ']') {
++(*idx);
return ret;
} else {
JsonVal item = JsonVal::parse(json, idx);
if (item.type != ERR)
ret.list.push_back(item);
}
}
MX_ERROR_MSG << "Error! Unable to parse list: '" << json.substr(*idx) << "'" << std::endl;
return JsonVal();
}
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_map(const std::string& json, unsigned int* idx) {
JsonVal ret(MAP), key;
while (*idx < json.size()) {
if (json[*idx] == '}') {
++(*idx);
return ret;
} else {
JsonVal item = JsonVal::parse(json, idx);
if (key.type == ERR) {
key = item;
} else {
ret.map[key] = item;
key.type = ERR;
}
}
}
MX_ERROR_MSG << "Error! Unable to parse map: '" << json.substr(*idx) << "'" << std::endl;
return mxnet::ext::JsonVal();
}
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned int* idx) {
JsonVal ret;
while (*idx < json.size()) {
if (json[*idx] == '"') {
++(*idx);
ret = JsonVal::parse_string(json, idx);
} else if (json[*idx] >= '0' && json[*idx] <= '9') {
ret = JsonVal::parse_num(json, idx);
} else if (json[*idx] == '[') {
++(*idx);
ret = JsonVal::parse_list(json, idx);
} else if (json[*idx] == '{') {
++(*idx);
ret = JsonVal::parse_map(json, idx);
} else if (json[*idx] == ']' || json[*idx] == '}') {
return ret;
}
if (ret.type != ERR)
return ret;
++(*idx);
}
return ret;
}
std::string mxnet::ext::JsonVal::toString() const {
std::string ret;
switch (type) {
case ERR:
ret = "json(Error)";
break;
case STR:
ret = "json(STR:" + str + ")";
break;
case NUM:
ret = "json(INT:" + str + ")";
break;
case LIST:
ret = "json(LIST:[";
for (auto& item : list)
ret += item.toString() + ",";
ret += "])";
break;
case MAP:
ret = "json(MAP:{";
for (auto& item : map)
ret += item.first.toString() + " : " + item.second.toString() + ",";
ret += "})";
break;
}
return ret;
}
mxnet::ext::Node::Node() {
tensor = nullptr;
}
void mxnet::ext::Node::_setPassResource(mxnet::ext::PassResource* res_) {
res = res_;
}
void mxnet::ext::Node::alloc_arg(const std::vector<int64_t>& shapes,
const mxnet::ext::MXContext& ctx,
mxnet::ext::MXDType dtype) {
if (!res)
throw std::runtime_error("Node not initialized. Cannot use alloc_arg outside of graph passes.");
tensor = res->alloc_arg(name, shapes, ctx, dtype);
}
void mxnet::ext::Node::alloc_aux(const std::vector<int64_t>& shapes,
const mxnet::ext::MXContext& ctx,
mxnet::ext::MXDType dtype) {
if (!res)
throw std::runtime_error("Node not initialized. Cannot use alloc_aux outside of graph passes.");
tensor = res->alloc_aux(name, shapes, ctx, dtype);
}
mxnet::ext::Graph::Graph() : res(nullptr) {}
mxnet::ext::Graph::~Graph() {
for (auto& node : nodes)
delete node;
}
mxnet::ext::Graph* mxnet::ext::Graph::fromString(const std::string& json) {
JsonVal val = JsonVal::parse(json);
return fromJson(val);
}
mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) {
// get nodes list
JsonVal nodes = val.map[JsonVal("nodes")];
Graph* g = new Graph();
std::map<int, Node*> nodeMap;
// loop over nodes
for (int i = 0; i < nodes.list.size(); i++) {
Node* n = new Node();
g->nodes.push_back(n);
JsonVal node = nodes.list[i];
// set the op info
n->op = node.map[JsonVal("op")].str;
n->name = node.map[JsonVal("name")].str;
// if op is null it is an input to the graph
if (n->op.compare("null") == 0)
g->inputs.push_back(n);
// set attrs
JsonVal attributes = node.map[JsonVal("attrs")];
for (auto& kv : attributes.map) {
n->attrs[kv.first.str] = kv.second.str;
}
// set subgraphs, parsing each into a graph
if (node.map.count(JsonVal("subgraphs")) > 0) {
JsonVal subgraphs = node.map[JsonVal("subgraphs")];
for (auto& subgraph : subgraphs.list) {
n->subgraphs.push_back(fromJson(subgraph));
}
}
// set node inputs
JsonVal node_inputs = node.map[JsonVal("inputs")];
n->inputs.resize(node_inputs.list.size());
for (int j = 0; j < node_inputs.list.size(); j++) {
JsonVal input = node_inputs.list[j];
NodeEntry& entry = n->inputs[j];
// get pointer to other node
entry.node = nodeMap[input.list[0].num];
// get the other node's output index
entry.entry = input.list[1].num;
// set other nodes output as connected to this node
entry.node->outputs.push_back({n, j});
}
nodeMap[i] = n;
}
// set graph level outputs
JsonVal& heads = val.map[JsonVal("heads")];
g->outputs.resize(heads.list.size());
for (int i = 0; i < heads.list.size(); i++) {
JsonVal head = heads.list[i];
g->outputs[i].node = nodeMap[head.list[0].num];
g->outputs[i].entry = head.list[1].num;
}
// add all attributes to the graph
for (auto& kv : val.map) {
if (kv.first.str.compare("nodes") != 0 && kv.first.str.compare("heads") != 0 &&
kv.first.str.compare("node_row_ptr") != 0 && kv.first.str.compare("arg_nodes") != 0) {
g->attrs[kv.first.str] = kv.second;
}
}
return g;
}
/* \brief convert graph object back to JSON object */
mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const {
// top level object is a map
JsonVal val(MAP);
// add attributes
for (auto& kv : attrs) {
val.map[JsonVal(kv.first)] = kv.second;
}
// sort graph nodes in topological order, create mapping of node to index
std::map<Node*, int> nodeMap;
std::vector<Node*> sorted = topological_sort();
// nodes are in reverse topological order in the vector (back is first)
// so loop from end to front over the vector 'sorted'
for (int i = sorted.size() - 1; i >= 0; i--) {
nodeMap[sorted[i]] = sorted.size() - 1 - i;
}
// create node_row_ptr entry
val.map[JsonVal("node_row_ptr")] = JsonVal(LIST);
JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")];
for (int i = 0; i < nodes.size(); i++)
node_row_ptr.list.emplace_back(i);
// add all input nodes
val.map[JsonVal("arg_nodes")] = JsonVal(LIST);
JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")];
for (auto& input : inputs)
arg_nodes.list.emplace_back(nodeMap[input]);
// add all output nodes
val.map[JsonVal("heads")] = JsonVal(LIST);
JsonVal& heads = val.map[JsonVal("heads")];
for (int i = 0; i < outputs.size(); i++) {
heads.list.emplace_back(LIST);
JsonVal& out = heads.list[i];
out.list.emplace_back(nodeMap[outputs[i].node]);
out.list.emplace_back(outputs[i].entry);
out.list.emplace_back(0);
}
// add all graph nodes
val.map[JsonVal("nodes")] = JsonVal(LIST);
JsonVal& nodes_ = val.map[JsonVal("nodes")];
for (int i = sorted.size() - 1; i >= 0; i--) {
// each node is a map
nodes_.list.emplace_back(MAP);
Node* n = sorted[i];
JsonVal& n_ = nodes_.list[nodes_.list.size() - 1];
n_.map[JsonVal("op")] = JsonVal(n->op);
n_.map[JsonVal("name")] = JsonVal(n->name);
n_.map[JsonVal("inputs")] = JsonVal(LIST);
// add inputs for this node
JsonVal& inputs_ = n_.map[JsonVal("inputs")];
for (int j = 0; j < n->inputs.size(); j++) {
inputs_.list.emplace_back(LIST);
NodeEntry& entry = n->inputs[j];
JsonVal& in = inputs_.list[j];
in.list.emplace_back(nodeMap[entry.node]);
in.list.emplace_back(entry.entry);
in.list.emplace_back(0);
}
// add subgraphs for this node, convert each back to JSON
if (n->subgraphs.size() > 0) {
n_.map[JsonVal("subgraphs")] = JsonVal(LIST);
JsonVal& subgraphs_ = n_.map[JsonVal("subgraphs")];
for (Graph* subgraph : n->subgraphs) {
subgraphs_.list.push_back(subgraph->toJson());
}
}
// add attributes for this node
n_.map[JsonVal("attrs")] = JsonVal(MAP);
JsonVal& attrs_ = n_.map[JsonVal("attrs")];
for (auto& kv : n->attrs) {
attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second);
}
}
return val;
}
/* \brief convert graph object to JSON string */
std::string mxnet::ext::Graph::toString() const {
return toJson().dump();
}
/* \brief visits a node "n" */
void mxnet::ext::Graph::_dfs_util(Node* n,
std::unordered_set<mxnet::ext::Node*>* to_visit,
std::function<void(mxnet::ext::Node*)> handler) const {
to_visit->erase(n); // remove node now that we're visiting it
for (NodeEntry& e : n->outputs) {
Node* o = e.node;
if (to_visit->count(o) != 0) {
_dfs_util(o, to_visit, handler); // visit neighbor
}
}
handler(n); // post-order visit this node
}
/* \brief post-order DFS graph traversal */
void mxnet::ext::Graph::DFS(std::function<void(Node*)> handler) const {
std::unordered_set<Node*> to_visit;
// put all nodes in set to visit
for (auto& n : nodes)
to_visit.insert(n);
// visit all inputs first
for (auto& i : inputs)
if (to_visit.count(i) != 0)
_dfs_util(i, &to_visit, handler);
// visit any nodes left
while (to_visit.size() > 0)
_dfs_util(*(to_visit.begin()), &to_visit, handler);
}
/* \brief sort graph nodes in topological order */
std::vector<mxnet::ext::Node*> mxnet::ext::Graph::topological_sort() const {
std::vector<mxnet::ext::Node*> sorted;
auto handler = [&](mxnet::ext::Node* n) {
sorted.push_back(n); // when visiting each node, add it in order to the vector
};
DFS(handler);
return sorted;
}
/* \brief print out graph details */
void mxnet::ext::Graph::print(int indent) const {
std::string space = "";
for (int i = 0; i < indent; i++)
space += " ";
std::cout << space << "########### Graph #############" << std::endl;
std::cout << space << "attributes: " << std::endl;
for (auto& kv : attrs)
std::cout << space << "\t" << kv.first << " : " << kv.second.str << std::endl;
std::cout << space << "inputs: " << inputs.size() << std::endl;
std::cout << space << "outputs: " << outputs.size() << std::endl;
std::cout << space << "nodes: " << nodes.size() << std::endl;
std::vector<mxnet::ext::Node*> sorted = topological_sort();
// loop over each node and print out its inputs/outputs
for (int i = sorted.size() - 1; i >= 0; i--) {
std::cout << space << "Node: " << sorted[i]->name << std::endl;
for (auto& input : sorted[i]->inputs) {
std::cout << space << "\tInput: " << input.node->name << " " << input.entry << std::endl;
}
for (auto& output : sorted[i]->outputs) {
std::cout << space << "\tOutput: " << output.node->name << " " << output.entry << std::endl;
}
if (sorted[i]->subgraphs.size() > 0) {
for (auto& subgraph : sorted[i]->subgraphs) {
std::cout << space << "\tSubgraph:" << std::endl;
subgraph->print(indent + 2);
}
}
}
std::cout << space << "###############################" << std::endl;
}
/* \brief add a new node to this graph */
mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) {
Node* n = new Node();
nodes.push_back(n);
n->name = name;
n->op = op;
if (res)
n->_setPassResource(res);
return n;
}
/* \brief get node at index in graph */
mxnet::ext::Node* mxnet::ext::Graph::getNode(size_t idx) {
return nodes[idx];
}
/* \brief get const node at index in const graph */
const mxnet::ext::Node* mxnet::ext::Graph::getNode(size_t idx) const {
return nodes.at(idx);
}
/* \brief get attribute on graph */
const mxnet::ext::JsonVal& mxnet::ext::Graph::getAttr(const std::string& key) const {
return attrs.at(key);
}
/* \brief get number of nodes in the graph */
size_t mxnet::ext::Graph::size() const {
return nodes.size();
}
// internally set passResource to enable tensor allocation for graph passes
void mxnet::ext::Graph::_setPassResource(PassResource* res_) {
res = res_;
// set passResource for each node
for (Node* node : nodes) {
node->_setPassResource(res);
}
}
// internally set arg/aux params when available
void mxnet::ext::Graph::_setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
std::unordered_map<std::string, mxnet::ext::MXTensor>* aux) {
// set params for each input node
for (Node* node : inputs) {
std::string name = node->name;
if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0)
// mapping name back to original node name from subgraph input name
name = node->attrs["argName"];
if (args->count(name) > 0)
node->tensor = &args->at(name);
else if (aux->count(name) > 0)
node->tensor = &aux->at(name);
}
}
mxnet::ext::CustomOp::CustomOp(const char* op_name)
: name(op_name),
parse_attrs(nullptr),
infer_type(nullptr),
infer_storage_type(nullptr),
infer_shape(nullptr),
mutate_inputs(nullptr),
isSGop(false) {}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setForward(mxnet::ext::fcomp_t fcomp, const char* ctx) {
if (forward_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
forward_ctx_map[ctx] = fcomp;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setBackward(mxnet::ext::fcomp_t fgrad,
const char* ctx) {
if (backward_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
backward_ctx_map[ctx] = fgrad;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setParseAttrs(mxnet::ext::parseAttrs_t func) {
parse_attrs = func;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferType(mxnet::ext::inferType_t func) {
infer_type = func;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferSType(mxnet::ext::inferSType_t func) {
infer_storage_type = func;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferShape(mxnet::ext::inferShape_t func) {
infer_shape = func;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setMutateInputs(mxnet::ext::mutateInputs_t func) {
mutate_inputs = func;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setCreateOpState(mxnet::ext::createOpState_t func,
const char* ctx) {
if (create_op_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
create_op_ctx_map[ctx] = func;
return *this;
}
mxnet::ext::CustomOp& mxnet::ext::CustomOp::setIsSubgraphOp() {
isSGop = true;
return *this;
}
void mxnet::ext::CustomOp::mapToVector() {
for (auto kv : forward_ctx_map) {
forward_ctx_cstr.push_back(kv.first);
forward_fp.push_back(kv.second);
}
for (auto kv : backward_ctx_map) {
backward_ctx_cstr.push_back(kv.first);
backward_fp.push_back(kv.second);
}
for (auto kv : create_op_ctx_map) {
create_op_ctx_cstr.push_back(kv.first);
create_op_fp.push_back(kv.second);
}
}
void mxnet::ext::CustomOp::raiseDuplicateContextError() {
std::string op_name_str(name);
throw std::runtime_error(
"Error! Error! Cannot register multiple functions under same context for operator '" +
op_name_str + "'");
}
mxnet::ext::CustomStatefulOp::CustomStatefulOp() : ignore_warn(false), created(false) {}
mxnet::ext::CustomStatefulOp::~CustomStatefulOp() = default;
mxnet::ext::CustomStatefulOpWrapper::~CustomStatefulOpWrapper() {
destroy_(instance);
}
mxnet::ext::CustomPass::CustomPass() : name("ERROR") {}
mxnet::ext::CustomPass::CustomPass(const char* pass_name) : name(pass_name) {}
mxnet::ext::CustomPass& mxnet::ext::CustomPass::setBody(graphPass_t fn) {
pass = fn;
return *this;
}
mxnet::ext::CustomPartitioner::CustomPartitioner() : name("ERROR") {}
mxnet::ext::CustomPartitioner::CustomPartitioner(const char* backend_name) : name(backend_name) {}
mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::addStrategy(const char* prop_name,
const char* sg_name) {
strategies.push_back(prop_name);
op_names.push_back(sg_name);
return *this;
}
mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setSupportedOps(
const char* prop_name,
mxnet::ext::supportedOps_t fn) {
supported_map[std::string(prop_name)] = fn;
return *this;
}
mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setCreateSelector(
const char* prop_name,
mxnet::ext::createSelector_t fn) {
selector_map[std::string(prop_name)] = fn;
return *this;
}
mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setReviewSubgraph(
const char* prop_name,
mxnet::ext::reviewSubgraph_t fn) {
review_map[std::string(prop_name)] = fn;
return *this;
}
mxnet::ext::supportedOps_t mxnet::ext::CustomPartitioner::getSupportedOps(int stg_id) {
std::string prop(strategies[stg_id]);
if (supported_map.count(prop) > 0)
return supported_map[prop];
else
return nullptr;
}
mxnet::ext::createSelector_t mxnet::ext::CustomPartitioner::getCreateSelector(int stg_id) {
std::string prop(strategies[stg_id]);
if (selector_map.count(prop) > 0)
return selector_map[prop];
else
return nullptr;
}
mxnet::ext::reviewSubgraph_t mxnet::ext::CustomPartitioner::getReviewSubgraph(int stg_id) {
std::string prop(strategies[stg_id]);
if (review_map.count(prop) > 0)
return review_map[prop];
else
return nullptr;
}
/*! \brief returns MXNet library version */
MX_INT_RET _opVersion() {
return MX_LIBRARY_VERSION;
}
/*! \brief returns number of ops registered in this library */
MX_INT_RET _opRegSize() {
return mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->size();
}
/*! \brief returns operator registration at specified index */
MX_VOID_RET _opRegGet(int idx,
const char** name,
int* isSGop,
const char*** forward_ctx,
mxnet::ext::fcomp_t** forward_fp,
int* forward_count,
const char*** backward_ctx,
mxnet::ext::fcomp_t** backward_fp,
int* backward_count,
const char*** create_op_ctx,
mxnet::ext::createOpState_t** create_op_fp,
int* create_op_count,
mxnet::ext::parseAttrs_t* parse,
mxnet::ext::inferType_t* type,
mxnet::ext::inferSType_t* stype,
mxnet::ext::inferShape_t* shape,
mxnet::ext::mutateInputs_t* mutate) {
mxnet::ext::CustomOp& op = mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->get(idx);
*name = op.name;
*parse = op.parse_attrs;
*type = op.infer_type;
*stype = op.infer_storage_type;
*shape = op.infer_shape;
*mutate = op.mutate_inputs;
*isSGop = op.isSGop;
op.mapToVector();
*forward_ctx = op.forward_ctx_cstr.data();
*forward_fp = op.forward_fp.data();
*forward_count = op.forward_fp.size();
*backward_ctx = op.backward_ctx_cstr.data();
*backward_fp = op.backward_fp.data();
*backward_count = op.backward_fp.size();
*create_op_ctx = op.create_op_ctx_cstr.data();
*create_op_fp = op.create_op_fp.data();
*create_op_count = op.create_op_fp.size();
}
/*! \brief calls free from the external library for library allocated arrays */
MX_VOID_RET _opCallFree(void* ptr) {
free(ptr);
}
/*! \brief returns status of calling parse attributes function for operator from library */
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs,
const char* const* keys,
const char* const* vals,
int num,
int* num_in,
int* num_out) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
return parseAttrs(attrs, num_in, num_out);
}
/*! \brief returns status of calling inferShape function for operator from library */
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape,
const char* const* keys,
const char* const* vals,
int num,
unsigned int** inshapes,
int* indims,
int num_in,
unsigned int*** mod_inshapes,
int** mod_indims,
unsigned int*** outshapes,
int** outdims,
int num_out) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
// create a vector of shapes for inputs
std::vector<std::vector<unsigned int> > in_shapes(num_in);
for (int i = 0; i < num_in; i++) {
for (int j = 0; j < indims[i]; j++) {
in_shapes[i].push_back(inshapes[i][j]);
}
}
// create a vector of shapes for outputs
std::vector<std::vector<unsigned int> > out_shapes(num_out);
int retval = inferShape(attrs, &in_shapes, &out_shapes);
if (!retval)
return retval;
// allocate space for modified input dims, shape
*mod_indims = static_cast<int*>(malloc(num_in * sizeof(int)));
*mod_inshapes = static_cast<unsigned**>(malloc(num_in * sizeof(unsigned*)));
// copy modified input shapes
for (int i = 0; i < num_in; i++) {
(*mod_indims)[i] = in_shapes[i].size();
(*mod_inshapes)[i] = static_cast<unsigned*>(malloc((*mod_indims)[i] * sizeof(unsigned)));
for (int j = 0; j < (*mod_indims)[i]; j++) {
(*mod_inshapes)[i][j] = in_shapes[i][j];
}
}
// allocate space for output dims, shape
*outdims = static_cast<int*>(malloc(num_out * sizeof(int)));
*outshapes = static_cast<unsigned**>(malloc(num_out * sizeof(unsigned*)));
// copy output shapes
for (int i = 0; i < num_out; i++) {
(*outdims)[i] = out_shapes[i].size();
(*outshapes)[i] = static_cast<unsigned*>(malloc((*outdims)[i] * sizeof(unsigned)));
for (int j = 0; j < (*outdims)[i]; j++) {
(*outshapes)[i][j] = out_shapes[i][j];
}
}
return retval;
}
/*! \brief returns status of calling inferType function for operator from library */
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType,
const char* const* keys,
const char* const* vals,
int num,
int* intypes,
int num_in,
int* outtypes,
int num_out) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
// create a vector of types for inputs
std::vector<int> in_types(num_in);
for (int i = 0; i < num_in; i++) {
in_types[i] = intypes[i];
}
// create a vector of types for outputs
std::vector<int> out_types(num_out, -1);
int retval = inferType(attrs, &in_types, &out_types);
if (!retval)
return retval;
// copy modified input types
for (int i = 0; i < num_in; i++) {
intypes[i] = in_types[i];
}
// copy output types
for (int i = 0; i < num_out; i++) {
outtypes[i] = out_types[i];
}
return retval;
}
/*! \brief returns status of calling inferSType function for operator from library */
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType,
const char* const* keys,
const char* const* vals,
int num,
int* instypes,
int num_in,
int* outstypes,
int num_out) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
// create a vector of types for inputs
std::vector<int> in_stypes(num_in);
for (int i = 0; i < num_in; i++) {
in_stypes[i] = instypes[i];
}
// create a vector of types for outputs
std::vector<int> out_stypes(num_out, -1);
int retval = inferSType(attrs, &in_stypes, &out_stypes);
if (!retval)
return retval;
// copy modified input storage types
for (int i = 0; i < num_in; i++) {
instypes[i] = in_stypes[i];
}
// copy output storage types
for (int i = 0; i < num_out; i++) {
outstypes[i] = out_stypes[i];
}
return retval;
}
/*! \brief returns status of calling Forward/Backward function for operator from library */
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp,
const char* const* keys,
const char* const* vals,
int num,
const int64_t** inshapes,
int* indims,
void** indata,
int* intypes,
size_t* inIDs,
const char** indev_type,
int* indev_id,
int num_in,
const int64_t** outshapes,
int* outdims,
void** outdata,
int* outtypes,
size_t* outIDs,
const char** outdev_type,
int* outdev_id,
int num_out,
mxnet::ext::xpu_malloc_t cpu_malloc,
void* cpu_alloc,
mxnet::ext::xpu_malloc_t gpu_malloc,
void* gpu_alloc,
void* cuda_stream,
mxnet::ext::sparse_malloc_t sparse_malloc,
void* sparse_alloc,
int* instypes,
int* outstypes,
void** in_indices,
void** out_indices,
void** in_indptr,
void** out_indptr,
int64_t* in_indices_shapes,
int64_t* out_indices_shapes,
int64_t* in_indptr_shapes,
int64_t* out_indptr_shapes,
void* rng_cpu_states,
void* rng_gpu_states) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
// create a vector of tensors for inputs
std::vector<mxnet::ext::MXTensor> inputs(num_in);
// create a vector for sparse inputs
std::vector<mxnet::ext::MXSparse> in_sparse(num_in);
for (int i = 0; i < num_in; i++) {
// Dense representation.
if (instypes[i] == 0) {
inputs[i].setTensor(indata[i],
(mxnet::ext::MXDType)intypes[i],
inshapes[i],
indims[i],
inIDs[i],
mxnet::ext::MXContext(indev_type[i], indev_id[i]),
mxnet::ext::kDefaultStorage);
} else {
// Sparse representation.
mxnet::ext::MXStorageType type;
if (instypes[i] == 1) {
type = mxnet::ext::kRowSparseStorage;
in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
} else {
type = mxnet::ext::kCSRStorage;
in_sparse[i].set(indata[i],
inshapes[i],
indims[i],
in_indices[i],
in_indices_shapes[i],
in_indptr[i],
in_indptr_shapes[i]);
}
inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]),
(mxnet::ext::MXDType)intypes[i],
inshapes[i],
indims[i],
inIDs[i],
mxnet::ext::MXContext(indev_type[i], indev_id[i]),
type);
}
}
// create a vector of tensors for outputs
std::vector<mxnet::ext::MXTensor> outputs(num_out);
std::vector<mxnet::ext::MXSparse> out_sparse(num_out);
for (int i = 0; i < num_out; i++) {
// Dense representation.
if (outstypes[i] == 0) {
outputs[i].setTensor(outdata[i],
(mxnet::ext::MXDType)outtypes[i],
outshapes[i],
outdims[i],
outIDs[i],
mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
mxnet::ext::kDefaultStorage);
} else {
// Sparse representation.
mxnet::ext::MXStorageType type;
if (outstypes[i] == 1) {
type = mxnet::ext::kRowSparseStorage;
out_sparse[i].set(
outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]);
} else {
type = mxnet::ext::kCSRStorage;
out_sparse[i].set(outdata[i],
outshapes[i],
outdims[i],
out_indices[i],
out_indices_shapes[i],
out_indptr[i],
out_indptr_shapes[i]);
}
outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
(mxnet::ext::MXDType)outtypes[i],
outshapes[i],
outdims[i],
outIDs[i],
mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
type);
}
}
mxnet::ext::OpResource res(cpu_malloc,
cpu_alloc,
gpu_malloc,
gpu_alloc,
cuda_stream,
sparse_malloc,
sparse_alloc,
rng_cpu_states,
rng_gpu_states);
return fcomp(attrs, &inputs, &outputs, res);
}
/*! \brief returns status of calling mutateInputs function for operator from library */
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate,
const char* const* keys,
const char* const* vals,
int num,
int** mutate_indices,
int* indices_size) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
// create a vector of mutate input indices
std::vector<int> mut_ind;
int retval = mutate(attrs, &mut_ind);
if (!retval)
return retval;
// output the input indices
*indices_size = mut_ind.size();
*mutate_indices = static_cast<int*>(malloc(*indices_size * sizeof(int)));
for (int i = 0; i < *indices_size; i++) {
(*mutate_indices)[i] = mut_ind[i];
}
return retval;
}
/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op,
const char* const* keys,
const char* const* vals,
int num,
const char* dev_type,
int dev_id,
unsigned int** inshapes,
int* indims,
int num_in,
const int* intypes,
void** state_op) {
// create map of attributes from list
std::unordered_map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}
mxnet::ext::MXContext ctx(dev_type, dev_id);
// create a vector of shapes for inputs
std::vector<std::vector<unsigned int> > in_shapes(num_in);
for (int i = 0; i < num_in; i++) {
for (int j = 0; j < indims[i]; j++) {
in_shapes[i].push_back(inshapes[i][j]);
}
}
// create a vector of types for inputs
std::vector<int> in_types(num_in);
for (int i = 0; i < num_in; i++) {
in_types[i] = intypes[i];
}
// void pointer to hold custom state op instance created in custom library
// eventually state_op pointer is populated by instance from custom library
mxnet::ext::CustomStatefulOp** op_ptr =
reinterpret_cast<mxnet::ext::CustomStatefulOp**>(state_op);
return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
}
/*! \brief calls StatefulOp destructor for operator from library */
MX_VOID_RET _opCallDestroyOpState(void* state_op) {
mxnet::ext::CustomStatefulOp* op_ptr = reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
delete op_ptr;
}
/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward,
void* state_op,
const int64_t** inshapes,
int* indims,
void** indata,
int* intypes,
size_t* inIDs,
const char** indev_type,
int* indev_id,
int num_in,
const int64_t** outshapes,
int* outdims,
void** outdata,
int* outtypes,
size_t* outIDs,
const char** outdev_type,
int* outdev_id,
int num_out,
mxnet::ext::xpu_malloc_t cpu_malloc,
void* cpu_alloc,
mxnet::ext::xpu_malloc_t gpu_malloc,
void* gpu_alloc,
void* stream,
mxnet::ext::sparse_malloc_t sparse_malloc,
void* sparse_alloc,
int* instypes,
int* outstypes,
void** in_indices,
void** out_indices,
void** in_indptr,
void** out_indptr,
int64_t* in_indices_shapes,
int64_t* out_indices_shapes,
int64_t* in_indptr_shapes,
int64_t* out_indptr_shapes,
void* rng_cpu_states,
void* rng_gpu_states) {
// create a vector of tensors for inputs
std::vector<mxnet::ext::MXTensor> inputs(num_in);
// create a vector for sparse inputs
std::vector<mxnet::ext::MXSparse> in_sparse(num_in);
for (int i = 0; i < num_in; i++) {
if (instypes[i] == 0) {
// Dense representation.
inputs[i].setTensor(indata[i],
(mxnet::ext::MXDType)intypes[i],
inshapes[i],
indims[i],
inIDs[i],
mxnet::ext::MXContext(indev_type[i], indev_id[i]),
mxnet::ext::kDefaultStorage);
} else {
// Sparse representation.
mxnet::ext::MXStorageType type;
if (instypes[i] == 1) {
type = mxnet::ext::kRowSparseStorage;
in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
} else {
type = mxnet::ext::kCSRStorage;
in_sparse[i].set(indata[i],
inshapes[i],
indims[i],
in_indices[i],
in_indices_shapes[i],
in_indptr[i],
in_indptr_shapes[i]);
}
inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]),
(mxnet::ext::MXDType)intypes[i],
inshapes[i],
indims[i],
inIDs[i],
mxnet::ext::MXContext(indev_type[i], indev_id[i]),
type);
}
}
// create a vector of tensors for outputs
std::vector<mxnet::ext::MXTensor> outputs(num_out);
// create a vector for sparse outputs
std::vector<mxnet::ext::MXSparse> out_sparse(num_out);
for (int i = 0; i < num_out; i++) {
if (outstypes[i] == 0) {
// Dense representation.
outputs[i].setTensor(outdata[i],
(mxnet::ext::MXDType)outtypes[i],
outshapes[i],
outdims[i],
outIDs[i],
mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
mxnet::ext::kDefaultStorage);
} else {
// Sparse representation.
mxnet::ext::MXStorageType type;
if (outstypes[i] == 1) {
type = mxnet::ext::kRowSparseStorage;
out_sparse[i].set(
outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]);
} else {
type = mxnet::ext::kCSRStorage;
out_sparse[i].set(outdata[i],
outshapes[i],
outdims[i],
out_indices[i],
out_indices_shapes[i],
out_indptr[i],
out_indptr_shapes[i]);
}
outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
(mxnet::ext::MXDType)outtypes[i],
outshapes[i],
outdims[i],
outIDs[i],
mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
type);
}
}
mxnet::ext::OpResource res(cpu_malloc,
cpu_alloc,
gpu_malloc,
gpu_alloc,
stream,
sparse_malloc,
sparse_alloc,
rng_cpu_states,
rng_gpu_states);
mxnet::ext::CustomStatefulOp* op_ptr = reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
if (is_forward) {
return op_ptr->Forward(&inputs, &outputs, res);
}
return op_ptr->Backward(&inputs, &outputs, res);
}
/*! \brief returns number of partitioners registered in this library */
MX_INT_RET _partRegSize() {
return mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->size();
}
/* returns number of strategies registered for partitioner
* at specified index */
MX_INT_RET _partRegGetCount(int idx, const char** name) {
mxnet::ext::CustomPartitioner part =
mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->get(idx);
*name = part.name;
return part.strategies.size();
}
/*! \brief returns partitioner registration at specified index */
MX_VOID_RET _partRegGet(int part_idx,
int stg_idx,
const char** strategy,
mxnet::ext::supportedOps_t* supportedOps,
mxnet::ext::createSelector_t* createSelector,
mxnet::ext::reviewSubgraph_t* reviewSubgraph,
const char** op_name) {
mxnet::ext::CustomPartitioner part =
mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->get(part_idx);
*strategy = part.strategies[stg_idx];
*op_name = part.op_names[stg_idx];
*supportedOps = part.getSupportedOps(stg_idx);
*createSelector = part.getCreateSelector(stg_idx);
*reviewSubgraph = part.getReviewSubgraph(stg_idx);
}
/*! \brief returns status of calling supported ops function from library */
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps,
const char* json,
int num_ids,
int* ids,
const char* const* opt_keys,
const char* const* opt_vals,
int num_opts) {
mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json);
// create map of options from list
std::unordered_map<std::string, std::string> opts;
for (int i = 0; i < num_opts; i++)
opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
// create array of subgraph IDs for operator support
std::vector<int> _ids(num_ids, -2);
// call user's supportedOps function
mxnet::ext::MXReturnValue retval = supportedOps(graph, &_ids, opts);
if (!retval)
return retval;
// copy bools in ids to ints
for (int i = 0; i < num_ids; i++)
ids[i] = _ids[i];
return retval;
}
/*! \brief returns status of calling create selector function from library */
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector,
const char* json,
void** selector,
const char* const* opt_keys,
const char* const* opt_vals,
int num_opts) {
mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json);
// create map of options from list
std::unordered_map<std::string, std::string> opts;
for (int i = 0; i < num_opts; i++)
opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
// void pointer to hold selector instance created in custom library
// eventually pointer is populated by instance from custom library
mxnet::ext::CustomOpSelector** sel_ptr =
reinterpret_cast<mxnet::ext::CustomOpSelector**>(selector);
// call user's createSelector function
return createSelector(graph, sel_ptr, opts);
}
/*! \brief returns status of calling select function from library */
MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) {
mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
*selected = sel_ptr->Select(nodeID);
}
/*! \brief returns status of calling select input function from library */
MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, int input_nodeID, int* selected) {
mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
*selected = sel_ptr->SelectInput(nodeID, input_nodeID);
}
/*! \brief returns status of calling select output function from library */
MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, int output_nodeID, int* selected) {
mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
*selected = sel_ptr->SelectOutput(nodeID, output_nodeID);
}
/*! \brief returns status of calling filter function from library */
MX_VOID_RET _partCallFilter(void* sel_inst,
int* candidates,
int num_candidates,
int** keep,
int* num_keep) {
mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
std::vector<int> candidates_(num_candidates);
for (int i = 0; i < num_candidates; i++) {
candidates_[i] = candidates[i];
}
std::vector<int> keep_;
sel_ptr->Filter(candidates_, &keep_);
*num_keep = keep_.size();
*keep = static_cast<int*>(malloc(keep_.size() * sizeof(int)));
for (unsigned i = 0; i < keep_.size(); i++)
(*keep)[i] = keep_[i];
}
/*! \brief returns status of calling reset selector function from library */
MX_VOID_RET _partCallReset(void* sel_inst) {
mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
sel_ptr->Reset();
}
/*! \brief returns status of calling review subgraph function from library */
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph,
const char* json,
int subgraph_id,
int* accept,
const char* const* opt_keys,
const char* const* opt_vals,
int num_opts,
char*** attr_keys,
char*** attr_vals,
int* num_attrs,
const char* const* arg_names,
int num_args,
void* const* arg_data,
const int64_t* const* arg_shapes,
const int* arg_dims,
const int* arg_types,
const size_t* arg_IDs,
const char* const* arg_dev_type,
const int* arg_dev_id,
const char* const* aux_names,
int num_aux,
void* const* aux_data,
const int64_t* const* aux_shapes,
const int* aux_dims,
const int* aux_types,
const size_t* aux_IDs,
const char* const* aux_dev_type,
const int* aux_dev_id) {
mxnet::ext::Graph* subgraph = mxnet::ext::Graph::fromString(json);
bool accept_bool = false;
// create map of attributes from list
std::unordered_map<std::string, std::string> opts;
for (int i = 0; i < num_opts; i++)
opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
// create a map of named tensors for args
std::unordered_map<std::string, mxnet::ext::MXTensor> args;
for (int i = 0; i < num_args; i++) {
std::vector<int64_t> shapes;
shapes.reserve(arg_dims[i]);
for (int j = 0; j < arg_dims[i]; j++)
shapes.push_back(arg_shapes[i][j]);
mxnet::ext::MXTensor tensor(arg_data[i],
shapes,
(mxnet::ext::MXDType)arg_types[i],
arg_IDs[i],
mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i]));
args[arg_names[i]] = tensor;
}
// create a map of named tensors for aux
std::unordered_map<std::string, mxnet::ext::MXTensor> aux;
for (int i = 0; i < num_aux; i++) {
std::vector<int64_t> shapes;
shapes.reserve(aux_dims[i]);
for (int j = 0; j < aux_dims[i]; j++)
shapes.push_back(aux_shapes[i][j]);
mxnet::ext::MXTensor tensor(aux_data[i],
shapes,
(mxnet::ext::MXDType)aux_types[i],
aux_IDs[i],
mxnet::ext::MXContext(aux_dev_type[i], aux_dev_id[i]));
aux[aux_names[i]] = tensor;
}
subgraph->_setParams(&args, &aux);
std::unordered_map<std::string, std::string> attrs;
mxnet::ext::MXReturnValue retval =
reviewSubgraph(subgraph, subgraph_id, &accept_bool, opts, &attrs);
if (!retval)
return retval;
*accept = accept_bool;
if (attrs.size() > 0) {
*num_attrs = attrs.size();
// allocate space for attributes
*attr_keys = static_cast<char**>(malloc(*num_attrs * sizeof(char*)));
*attr_vals = static_cast<char**>(malloc(*num_attrs * sizeof(char*)));
// copy attributes
int i = 0;
for (auto kv : attrs) {
(*attr_keys)[i] = static_cast<char*>(malloc((kv.first.size() + 1) * sizeof(char))); // NOLINT
(*attr_vals)[i] =
static_cast<char*>(malloc((kv.second.size() + 1) * sizeof(char))); // NOLINT
snprintf((*attr_keys)[i], kv.first.size() + 1, "%s", kv.first.c_str());
snprintf((*attr_vals)[i], kv.second.size() + 1, "%s", kv.second.c_str());
i++;
}
}
return retval;
}
/*! \brief returns number of graph passes registered in this library */
MX_INT_RET _passRegSize() {
return mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->size();
}
/*! \brief returns pass registration at specified index */
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, const char** pass_name) {
mxnet::ext::CustomPass pass = mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->get(pass_idx);
*graphPass = pass.pass;
*pass_name = pass.name;
}
/*! \brief returns status of calling graph pass function from library */
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass,
const char* json,
char** out_graph,
const char* const* opt_keys,
const char* const* opt_vals,
int num_opts,
const char* pass_name,
const char* const* arg_names,
int num_args,
void* const* arg_data,
const int64_t* const* arg_shapes,
const int* arg_dims,
const int* arg_types,
const size_t* arg_IDs,
const char* const* arg_dev_type,
const int* arg_dev_id,
const char* const* aux_names,
int num_aux,
void* const* aux_data,
const int64_t* const* aux_shapes,
const int* aux_dims,
const int* aux_types,
const size_t* aux_IDs,
const char* const* aux_dev_type,
const int* aux_dev_id,
mxnet::ext::nd_malloc_t nd_malloc,
const void* nd_alloc) {
mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json);
// create map of attributes from list
std::unordered_map<std::string, std::string> opts;
for (int i = 0; i < num_opts; i++)
opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
// create a map of named tensors for args
std::unordered_map<std::string, mxnet::ext::MXTensor> args;
for (int i = 0; i < num_args; i++) {
std::vector<int64_t> shapes;
shapes.reserve(arg_dims[i]);
for (int j = 0; j < arg_dims[i]; j++)
shapes.push_back(arg_shapes[i][j]);
mxnet::ext::MXTensor tensor(arg_data[i],
shapes,
(mxnet::ext::MXDType)arg_types[i],
arg_IDs[i],
mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i]));
args[arg_names[i]] = tensor;
}
// create a map of named tensors for aux
std::unordered_map<std::string, mxnet::ext::MXTensor> aux;
for (int i = 0; i < num_aux; i++) {
std::vector<int64_t> shapes;
shapes.reserve(aux_dims[i]);
for (int j = 0; j < aux_dims[i]; j++)
shapes.push_back(aux_shapes[i][j]);
mxnet::ext::MXTensor tensor(aux_data[i],
shapes,
(mxnet::ext::MXDType)aux_types[i],
aux_IDs[i],
mxnet::ext::MXContext(aux_dev_type[i], aux_dev_id[i]));
aux[aux_names[i]] = tensor;
}
std::unordered_map<std::string, mxnet::ext::MXTensor> new_args, new_aux;
mxnet::ext::PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc);
graph->_setParams(&args, &aux);
graph->_setPassResource(&res);
mxnet::ext::MXReturnValue retval = graphPass(graph, opts);
if (!retval)
return retval;
std::string tmp = graph->toString();
*out_graph = static_cast<char*>(malloc((tmp.size() + 1) * sizeof(char))); // NOLINT
snprintf((*out_graph), tmp.size() + 1, "%s", tmp.c_str());
return retval;
}
/*!
* \brief Checks if the MXNet version is supported by the library.
* If supported, initializes the library.
* \param version MXNet version number passed to library and defined as:
* MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH)
* \return Non-zero value on error i.e. library incompatible with passed MXNet version
*/
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
__declspec(dllexport) mxnet::ext::MXReturnValue __cdecl
#else
mxnet::ext::MXReturnValue
#endif
initialize(int version);
MX_INT_RET _msgSize() {
return mxnet::ext::MXerrorMsgs::get()->size();
}
/*! \brief returns operator registration at specified index */
MX_VOID_RET _msgGet(int idx, const char** msg) {
*msg = mxnet::ext::MXerrorMsgs::get()->get(idx)->c_str();
}