blob: ccc0547e8f7dda731825e437f91231ec01ebf13b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_symbolic.cc
* \brief C API of mxnet
*/
#include "mxnet/base.h"
#include "mxnet/c_api.h"
#include "mxnet/imperative.h"
#include "nnvm/c_api.h"
#include "nnvm/pass.h"
#include "nnvm/pass_functions.h"
#include "nnvm/symbolic.h"
#include "./c_api_common.h"
#include "../common/exec_utils.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
#include "../operator/subgraph/subgraph_property.h"
namespace mxnet {
namespace op {
void RegisterLegacyOpProp();
void RegisterLegacyNDFunc();
}
const std::vector<std::string> kHiddenKeys = {
"ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage"
};
const std::vector<std::string> kReplacedHiddenKeys = {
"__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__"
};
const char *kNamespaceSeparator = "$";
DMLC_JSON_ENABLE_ANY(int, int);
// convert nnvm symbol to a nnvm graph.
nnvm::Graph Symbol2Graph(const nnvm::Symbol &s) {
nnvm::Graph g;
g.outputs = s.outputs;
g.attrs["mxnet_version"] = std::make_shared<nnvm::any>(static_cast<int>(MXNET_VERSION));
if (Imperative::Get()->is_np_shape()) {
g.attrs["is_np_shape"] = std::make_shared<nnvm::any>(
static_cast<int>(Imperative::Get()->is_np_shape()));
}
return g;
}
std::vector<uint32_t> ReadOnlyArgIndices(const nnvm::IndexedGraph& idx) {
std::vector<uint32_t> ret;
auto& arg_nodes = idx.input_nodes();
for (uint32_t i = 0; i < arg_nodes.size(); ++i) {
if (idx.mutable_input_nodes().count(arg_nodes[i]) == 0) {
ret.push_back(i);
}
}
return ret;
}
} // namespace mxnet
// symbolic configuration generation API.
// Redirect to NNVM's C API
int MXListAllOpNames(nn_uint *out_size,
const char ***out_array) {
mxnet::op::RegisterLegacyOpProp();
mxnet::op::RegisterLegacyNDFunc();
return NNListAllOpNames(out_size, out_array);
}
int MXSymbolListAtomicSymbolCreators(uint32_t *out_size,
AtomicSymbolCreator **out_array) {
mxnet::op::RegisterLegacyOpProp();
mxnet::op::RegisterLegacyNDFunc();
return NNListUniqueOps(out_size, out_array);
}
int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
uint32_t *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **key_var_num_args,
const char **return_type) {
static auto& map_key_var_args = nnvm::Op::GetAttr<std::string>("key_var_num_args");
const Op* op = static_cast<Op*>(creator);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_str.resize(0);
if (map_key_var_args.count(op) != 0) {
*key_var_num_args = map_key_var_args[op].c_str();
} else {
*key_var_num_args = ret->ret_str.c_str();
}
return NNGetOpInfo(
creator, name, description,
num_args, arg_names, arg_type_infos,
arg_descriptions, return_type);
}
int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
uint32_t num_param,
const char **keys,
const char **vals,
SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
const nnvm::Op* op = static_cast<const nnvm::Op*>(creator);
std::unordered_map<std::string, std::string> kwargs;
for (nn_uint i = 0; i < num_param; ++i) {
bool flag = false;
for (const auto &k : kHiddenKeys) {
std::string tmp(keys[i]);
size_t pos = tmp.rfind(k);
if (pos == 0) {
kwargs.insert({"__" + tmp + "__", std::string(vals[i])});
flag = true;
break;
} else if (pos != std::string::npos && pos == tmp.length() - k.length()) {
std::ostringstream os;
os << "setting variable attributes with " << keys[i] << " is deprecated. "
<< "please instead use\nw = Variable(" << k << "=" << vals[i] << ")\n"
<< "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)";
throw dmlc::Error(os.str());
}
}
if (!flag)
kwargs.insert({std::string(keys[i]), std::string(vals[i])});
}
*s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs));
*out = s;
API_END_HANDLE_ERROR(delete s;);
}
int MXSymbolCreateVariable(const char *name, SymbolHandle *out) {
return NNSymbolCreateVariable(name, out);
}
int MXSymbolCreateGroup(uint32_t num_symbols,
SymbolHandle *symbols,
SymbolHandle *out) {
return NNSymbolCreateGroup(num_symbols, symbols, out);
}
int MXSymbolGetOutput(SymbolHandle symbol,
uint32_t index,
SymbolHandle *out) {
return NNSymbolGetOutput(symbol, index, out);
}
int MXSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
*s = static_cast<nnvm::Symbol*>(symbol)->GetInternals();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int MXSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
*s = static_cast<nnvm::Symbol*>(symbol)->GetChildren();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int MXSymbolFree(SymbolHandle symbol) {
return NNSymbolFree(symbol);
}
int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
return NNSymbolCopy(symbol, out);
}
int MXSymbolPrint(SymbolHandle symbol, const char **out_str) {
return NNSymbolPrint(symbol, out_str);
}
int MXSymbolGetName(SymbolHandle symbol,
const char** out,
int* success) {
return NNSymbolGetAttr(symbol, "name", out, success);
}
int MXSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int* success) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
if (s->GetAttr(key, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
if (std::find(kHiddenKeys.begin(), kHiddenKeys.end(), key) != kHiddenKeys.end()) {
std::string skey = "__" + std::string(key) + "__";
if (s->GetAttr(skey, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
}
}
}
API_END();
}
int MXSymbolSetAttr(SymbolHandle symbol,
const char* key,
const char* value) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
API_BEGIN();
std::vector<std::pair<std::string, std::string> > kwargs;
std::string skey(key), sval(value);
for (const auto &k : kHiddenKeys) {
size_t pos = skey.rfind(k);
if (pos == 0 && k.length() == skey.length()) {
skey = "__" + skey + "__";
break;
} else if (pos != std::string::npos && pos + k.length() == skey.length()) {
std::ostringstream os;
os << "setting variable attributes with " << key << " is deprecated. "
<< "please instead use\nw = Variable(" << k << "=" << value << ")\n"
<< "sym = YourSymbolName(" << skey.substr(0, pos-1) << "=w)";
throw dmlc::Error(os.str());
}
}
kwargs.emplace_back(std::make_pair(std::move(skey), std::move(sval)));
s->SetAttrs(kwargs);
API_END();
}
int MXSymbolListAttr(SymbolHandle symbol,
uint32_t *out_size,
const char*** out) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
std::vector<std::tuple<std::string, std::string, std::string> > attr =
s->ListAttrsRecursive();
std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
for (const auto& tp : attr) {
attr_list.emplace_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp));
attr_list.emplace_back(std::get<2>(tp));
if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp))
!= kReplacedHiddenKeys.end()) {
attr_list.push_back(std::get<0>(tp) + kNamespaceSeparator +
std::get<1>(tp).substr(2, std::get<1>(tp).length() - 4));
attr_list.push_back(std::get<2>(tp));
}
}
*out_size = attr_list.size()/2;
ret->ret_vec_charp.clear();
for (const auto& attr : attr_list) {
ret->ret_vec_charp.push_back(attr.c_str());
}
*out = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int MXSymbolListAttrShallow(SymbolHandle symbol,
uint32_t *out_size,
const char*** out) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
std::unordered_map<std::string, std::string> attr =
s->ListAttrs(static_cast<nnvm::Symbol::ListAttrOption>(1)); // NOLINT(*)
std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
for (const auto& kv : attr) {
attr_list.push_back(kv.first);
attr_list.push_back(kv.second);
if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first)
!= kReplacedHiddenKeys.end()) {
attr_list.push_back(kv.first.substr(2, kv.first.length() - 4));
attr_list.push_back(kv.second);
}
}
*out_size = attr_list.size()/2;
ret->ret_vec_charp.clear();
for (auto &attr : attr_list) {
ret->ret_vec_charp.push_back(attr.c_str());
}
*out = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int MXSymbolListOutputs(SymbolHandle symbol,
uint32_t *out_size,
const char ***out_str_array) {
return NNSymbolListOutputNames(symbol, out_size, out_str_array);
}
int MXSymbolGetNumOutputs(SymbolHandle symbol,
uint32_t *output_count) {
return NNSymbolGetNumOutputs(symbol, output_count);
}
int MXSymbolCompose(SymbolHandle sym,
const char *name,
uint32_t num_args,
const char** keys,
SymbolHandle* args) {
return NNSymbolCompose(sym, name, num_args, keys, args);
}
// adapter functions that re-implements the functions.
int MXSymbolListArguments(SymbolHandle symbol,
uint32_t *out_size,
const char ***out_str_array) {
return NNSymbolListInputNames(symbol, 1, out_size, out_str_array);
}
int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
uint32_t *out_size,
const char ***out_str_array) {
return NNSymbolListInputNames(symbol, 2, out_size, out_str_array);
}
int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out) {
API_BEGIN();
Op *e = static_cast<Op *>(creator);
*out = e->name.c_str();
API_END();
}
namespace mxnet {
extern std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym);
extern bool CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries);
}
int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) {
API_BEGIN();
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
std::vector<nnvm::Symbol *> input_syms = mxnet::GetInputSymbols(*s);
*input_size = input_syms.size();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_handles.clear();
ret->ret_handles.reserve(*input_size);
for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
*input_arr = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
API_END_HANDLE_ERROR();
}
int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
int *input_size) {
// Given a graph, we want to fetch the nodes that have been marked as part of
// a subgraph.
API_BEGIN();
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
const std::string subg_attr = "__subgraph_name__";
auto out_node = s->outputs[0].node;
auto it = out_node->attrs.dict.find(subg_attr);
if (it != out_node->attrs.dict.end()) {
const std::string &subg_name = it->second;
std::vector<nnvm::NodeEntry *> input_entries;
DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
(nnvm::ObjectPtr n) {
// If the node itself isn't in the subgraph, we ignore it.
auto it = n->attrs.dict.find(subg_attr);
if (it == n->attrs.dict.end() || it->second != subg_name)
return;
// We search for nodes whose node entries aren't in the subgraph.
for (size_t j = 0; j < n->inputs.size(); j++) {
auto in_node = n->inputs[j].node;
auto it = in_node->attrs.dict.find(subg_attr);
if (it == in_node->attrs.dict.end() || it->second != subg_name)
input_entries.push_back(&n->inputs[j]);
}
});
std::vector<nnvm::NodeEntry> orig_entries;
CutGraphInputs(input_entries, false, &orig_entries);
std::vector<nnvm::Symbol *> input_syms(orig_entries.size());
for (size_t i = 0; i < input_syms.size(); i++) {
input_syms[i] = new nnvm::Symbol();
input_syms[i]->outputs.push_back(orig_entries[i]);
}
*input_size = input_syms.size();
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_handles.clear();
ret->ret_handles.reserve(*input_size);
for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
*input_symbols = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
} else {
*input_size = 0;
}
API_END_HANDLE_ERROR();
}
/*!
* \brief Convert shape attr in graph nodes to comply with NumPy semantics for
* legacy models (before 1.6.0) if global flag is_np_shape has been turned on,
* i.e., use -1 to indicate unknown number of dimensions and unknown dimension sizes.
*/
void ConvertShapeAttrToNumPyCompatible(nnvm::Graph* g) {
if (Imperative::Get()->is_np_shape()
&& (!g->HasAttr("is_np_shape") || !g->GetAttr<int>("is_np_shape"))) {
DFSVisit(g->outputs, [](nnvm::ObjectPtr n) {
if (n->is_variable()) {
auto it = n->attrs.dict.find("__shape__");
if (it != n->attrs.dict.end()) {
mxnet::TShape shape;
std::istringstream is(it->second);
is >> shape;
common::ConvertToNumpyShape(&shape);
std::ostringstream os;
os << shape;
it->second = os.str();
}
}
});
}
}
int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
dmlc::istream is(fi.get());
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(
std::string(std::istreambuf_iterator<char>(is), std::istreambuf_iterator<char>()));
g = nnvm::ApplyPass(g, "LoadLegacyJSON");
ConvertShapeAttrToNumPyCompatible(&g);
s->outputs = g.outputs;
*out = s;
is.set_stream(nullptr);
API_END_HANDLE_ERROR(delete s);
}
int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(std::string(json));
g = nnvm::ApplyPass(g, "LoadLegacyJSON");
ConvertShapeAttrToNumPyCompatible(&g);
s->outputs = g.outputs;
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) {
nnvm::Symbol* s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *source = static_cast<nnvm::Symbol*>(sym_handle);
*s = source->Copy();
s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
API_BEGIN();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
dmlc::ostream os(fo.get());
os << nnvm::pass::SaveJSON(Symbol2Graph(*s));
// reset file pointer, force flush
os.set_stream(nullptr);
API_END();
}
int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s));
*out_json = ret->ret_str.c_str();
API_END();
}
namespace mxnet {
template<typename AttrType>
void MatchArguments(
const nnvm::IndexedGraph& idx,
const std::unordered_map<std::string, AttrType>& known_arg_attrs,
std::vector<AttrType>* arg_attrs,
const char* source) {
auto& arg_nodes = idx.input_nodes();
CHECK_EQ(arg_attrs->size(), arg_nodes.size());
size_t nmatched = 0;
for (size_t i = 0; i < arg_nodes.size(); ++i) {
const std::string& name = idx[arg_nodes[i]].source->attrs.name;
auto it = known_arg_attrs.find(name);
if (it != known_arg_attrs.end()) {
arg_attrs->at(i) = it->second;
++nmatched;
}
}
if (nmatched != known_arg_attrs.size()) {
std::unordered_set<std::string> keys;
std::ostringstream head, msg;
msg << "\nCandidate arguments:\n";
for (size_t i = 0; i < arg_nodes.size(); ++i) {
std::string arg_name = idx[arg_nodes[i]].source->attrs.name;
keys.insert(arg_name);
msg << "\t[" << i << ']' << arg_name << '\n';
}
for (const auto& kv : known_arg_attrs) {
const std::string& key = kv.first;
if (keys.count(key) == 0) {
LOG(FATAL) << source
<< "Keyword argument name " << key << " not found."
<< msg.str();
}
}
}
}
} // namespace mxnet
int MXSymbolInferShape(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const uint32_t *arg_ind_ptr,
const uint32_t *arg_shape_data,
uint32_t *in_shape_size,
const uint32_t **in_shape_ndim,
const uint32_t ***in_shape_data,
uint32_t *out_shape_size,
const uint32_t **out_shape_ndim,
const uint32_t ***out_shape_data,
uint32_t *aux_shape_size,
const uint32_t **aux_shape_ndim,
const uint32_t ***aux_shape_data,
int *complete) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
nnvm::Graph g = Symbol2Graph(*s);
mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape());
if (keys == nullptr && num_args != 0) {
std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
CHECK_LE(num_args, read_only_args.size());
for (uint32_t i = 0; i < num_args; ++i) {
arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(
arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
}
} else {
std::unordered_map<std::string, mxnet::TShape> kwargs;
for (uint32_t i = 0; i < num_args; ++i) {
kwargs[keys[i]] = mxnet::ShapeTypeCast(
arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
}
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
}
try {
g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}
// if use legacy shape definition, need to convert numpy shape to legacy shape
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
if (!Imperative::Get()->is_np_shape()) {
common::ConvertToLegacyShape(&shapes);
}
// copy back
CopyAttr(g.indexed_graph(), shapes,
&(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
// copy data back
MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->arg_shapes,
&(ret->arg_shape_ndim), &(ret->arg_shape_data), &(ret->arg_shape_buffer));
MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->out_shapes,
&(ret->out_shape_ndim), &(ret->out_shape_data), &(ret->out_shape_buffer));
MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->aux_shapes,
&(ret->aux_shape_ndim), &(ret->aux_shape_data), &(ret->aux_shape_buffer));
*in_shape_size = static_cast<uint32_t>(ret->arg_shapes.size());
*in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim);
*in_shape_data = dmlc::BeginPtr(ret->arg_shape_data);
*out_shape_size = static_cast<uint32_t>(ret->out_shapes.size());
*out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim);
*out_shape_data = dmlc::BeginPtr(ret->out_shape_data);
*aux_shape_size = static_cast<uint32_t>(ret->aux_shapes.size());
*aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim);
*aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data);
// mark complete
*complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
API_END();
}
template<typename dtype, typename stype, typename itype>
inline void SymbolInferShape(const char** keys,
uint32_t num_args,
const dtype* arg_shape_data,
const itype* arg_ind_ptr,
const int** in_shape_ndim,
const dtype*** in_shape_data,
const int** out_shape_ndim,
const dtype*** out_shape_data,
const int** aux_shape_ndim,
const dtype*** aux_shape_data,
nnvm::Symbol* s,
MXAPIThreadLocalEntry<dtype>* ret,
stype* in_shape_size,
stype* out_shape_size,
stype* aux_shape_size,
int* complete) {
nnvm::Graph g = Symbol2Graph(*s);
mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape());
if (keys == nullptr && num_args != 0) {
std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
CHECK_LE(num_args, read_only_args.size());
for (uint32_t i = 0; i < num_args; ++i) {
arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
arg_shape_data + arg_ind_ptr[i + 1]);
}
} else {
std::unordered_map<std::string, mxnet::TShape> kwargs;
for (uint32_t i = 0; i < num_args; ++i) {
kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
arg_shape_data + arg_ind_ptr[i + 1]);
}
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
}
try {
g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
} catch (const mxnet::op::InferShapeError& err) {
throw dmlc::Error(err.msg);
}
// if use legacy shape definition, need to convert numpy shape to legacy shape
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
if (!Imperative::Get()->is_np_shape()) {
common::ConvertToLegacyShape(&shapes);
}
// copy back
CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
// copy data back
MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes,
&(ret->arg_shape_ndim_ex),
&(ret->arg_shape_data_ex),
&(ret->arg_shape_buffer_ex));
MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->out_shapes,
&(ret->out_shape_ndim_ex),
&(ret->out_shape_data_ex),
&(ret->out_shape_buffer_ex));
MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes,
&(ret->aux_shape_ndim_ex),
&(ret->aux_shape_data_ex),
&(ret->aux_shape_buffer_ex));
*in_shape_size = static_cast<stype>(ret->arg_shapes.size());
*in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex);
*in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex);
*out_shape_size = static_cast<stype>(ret->out_shapes.size());
*out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex);
*out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex);
*aux_shape_size = static_cast<stype>(ret->aux_shapes.size());
*aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex);
*aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex);
// mark complete
*complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
}
/*!
* \brief Executor for Symbol Shape Inference
* This api is available when MXNet is built with flag
* USE_INT64_TENSOR_SIZE=0 (by default)
* \param sym symbol handle
* \param num_args number of args
* \param keys keys
* \param arg_ind_ptr arg index pointer
* \param arg_shape_data arg shape data
* \param in_shape_size input shape size
* \param in_shape_ndim input shape number of dims
* \param in_shape_data input shape data
* \param out_shape_size ouput shape size
* \param out_shape_ndim output shape number of dims
* \param out_shape_data output shape data
* \param aux_shape_size shape size of auxiliary states
* \param aux_shape_ndim number of dims of auxiliary states shape
* \param aux_shape_data shape data of auxiliary states
* \param complete indicates completion of Shape Inference
* \return 0 when success, -1 when failure happens
*/
int MXSymbolInferShapeEx(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const uint32_t *arg_ind_ptr,
const int *arg_shape_data,
uint32_t *in_shape_size,
const int **in_shape_ndim,
const int ***in_shape_data,
uint32_t *out_shape_size,
const int **out_shape_ndim,
const int ***out_shape_data,
uint32_t *aux_shape_size,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
SymbolInferShape<int, uint32_t, uint32_t>(keys,
num_args,
arg_shape_data,
arg_ind_ptr,
in_shape_ndim,
in_shape_data,
out_shape_ndim,
out_shape_data,
aux_shape_ndim,
aux_shape_data,
s,
ret,
in_shape_size,
out_shape_size,
aux_shape_size,
complete);
API_END();
}
/*!
* \brief Executor for Symbol Shape Inference
* This api is available when MXNet is built with flag
* USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
* \param sym symbol handle
* \param num_args number of args
* \param keys keys
* \param arg_ind_ptr arg index pointer
* \param arg_shape_data arg shape data
* \param in_shape_size input shape size
* \param in_shape_ndim input shape number of dims
* \param in_shape_data input shape data
* \param out_shape_size ouput shape size
* \param out_shape_ndim output shape number of dims
* \param out_shape_data output shape data
* \param aux_shape_size shape size of auxiliary states
* \param aux_shape_ndim number of dims of auxiliary states shape
* \param aux_shape_data shape data of auxiliary states
* \param complete indicates completion of Shape Inference
* \return 0 when success, -1 when failure happens
*/
int MXSymbolInferShapeEx64(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const int64_t *arg_ind_ptr,
const int64_t *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const int64_t ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const int64_t ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const int64_t ***aux_shape_data,
int *complete) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
API_BEGIN();
SymbolInferShape<int64_t, size_t, int64_t>(keys,
num_args,
arg_shape_data,
arg_ind_ptr,
in_shape_ndim,
in_shape_data,
out_shape_ndim,
out_shape_data,
aux_shape_ndim,
aux_shape_data,
s,
ret,
in_shape_size,
out_shape_size,
aux_shape_size,
complete);
API_END();
}
int MXSymbolInferShapePartial(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const uint32_t *arg_ind_ptr,
const uint32_t *arg_shape_data,
uint32_t *in_shape_size,
const uint32_t **in_shape_ndim,
const uint32_t ***in_shape_data,
uint32_t *out_shape_size,
const uint32_t **out_shape_ndim,
const uint32_t ***out_shape_data,
uint32_t *aux_shape_size,
const uint32_t **aux_shape_ndim,
const uint32_t ***aux_shape_data,
int *complete) {
int succ = 0;
*complete = 1;
return MXSymbolInferShape(sym, num_args, keys,
arg_ind_ptr, arg_shape_data,
in_shape_size, in_shape_ndim, in_shape_data,
out_shape_size, out_shape_ndim, out_shape_data,
aux_shape_size, aux_shape_ndim, aux_shape_data,
&succ);
}
/*!
* \brief Executor for Symbol Partial Shape Inference
* This api is available when MXNet is built with flag
* USE_INT64_TENSOR_SIZE=0 (by default)
* \param sym symbol handle
* \param num_args number of args
* \param keys keys
* \param arg_ind_ptr arg index pointer
* \param arg_shape_data arg shape data
* \param in_shape_size input shape size
* \param in_shape_ndim input shape number of dims
* \param in_shape_data input shape data
* \param out_shape_size ouput shape size
* \param out_shape_ndim output shape number of dims
* \param out_shape_data output shape data
* \param aux_shape_size shape size of auxiliary states
* \param aux_shape_ndim number of dims of auxiliary states shape
* \param aux_shape_data shape data of auxiliary states
* \param complete indicates completion of Shape Inference
* \return 0 when success, -1 when failure happens
*/
int MXSymbolInferShapePartialEx(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const uint32_t *arg_ind_ptr,
const int *arg_shape_data,
uint32_t *in_shape_size,
const int **in_shape_ndim,
const int ***in_shape_data,
uint32_t *out_shape_size,
const int **out_shape_ndim,
const int ***out_shape_data,
uint32_t *aux_shape_size,
const int **aux_shape_ndim,
const int ***aux_shape_data,
int *complete) {
int succ = 0;
*complete = 1;
return MXSymbolInferShapeEx(sym, num_args, keys,
arg_ind_ptr, arg_shape_data,
in_shape_size, in_shape_ndim, in_shape_data,
out_shape_size, out_shape_ndim, out_shape_data,
aux_shape_size, aux_shape_ndim, aux_shape_data,
&succ);
}
/*!
* \brief Executor for Symbol Partial Shape Inference
* This api is available when MXNet is built with flag
* USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support
* \param sym symbol handle
* \param num_args number of args
* \param keys keys
* \param arg_ind_ptr arg index pointer
* \param arg_shape_data arg shape data
* \param in_shape_size input shape size
* \param in_shape_ndim input shape number of dims
* \param in_shape_data input shape data
* \param out_shape_size ouput shape size
* \param out_shape_ndim output shape number of dims
* \param out_shape_data output shape data
* \param aux_shape_size shape size of auxiliary states
* \param aux_shape_ndim number of dims of auxiliary states shape
* \param aux_shape_data shape data of auxiliary states
* \param complete indicates completion of Shape Inference
* \return 0 when success, -1 when failure happens
*/
int MXSymbolInferShapePartialEx64(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const int64_t *arg_ind_ptr,
const int64_t *arg_shape_data,
size_t *in_shape_size,
const int **in_shape_ndim,
const int64_t ***in_shape_data,
size_t *out_shape_size,
const int **out_shape_ndim,
const int64_t ***out_shape_data,
size_t *aux_shape_size,
const int **aux_shape_ndim,
const int64_t ***aux_shape_data,
int *complete) {
int succ = 0;
*complete = 1;
return MXSymbolInferShapeEx64(sym, num_args, keys,
arg_ind_ptr, arg_shape_data,
in_shape_size, in_shape_ndim, in_shape_data,
out_shape_size, out_shape_ndim, out_shape_data,
aux_shape_size, aux_shape_ndim, aux_shape_data,
&succ);
}
int MXSymbolInferType(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const int *arg_type_data,
uint32_t *in_type_size,
const int **in_type_data,
uint32_t *out_type_size,
const int **out_type_data,
uint32_t *aux_type_size,
const int **aux_type_data,
int *complete) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
nnvm::Graph g = Symbol2Graph(*s);
nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
if (keys == nullptr && num_args != 0) {
std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
CHECK_LE(num_args, read_only_args.size());
for (uint32_t i = 0; i < num_args; ++i) {
arg_types[read_only_args[i]] = arg_type_data[i];
}
} else {
std::unordered_map<std::string, int> kwargs;
for (uint32_t i = 0; i < num_args; ++i) {
kwargs[keys[i]] = arg_type_data[i];
}
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
}
g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "__dtype__");
// copy back
CopyAttr(g.indexed_graph(), g.GetAttr<nnvm::DTypeVector>("dtype"),
&(ret->arg_types), &(ret->out_types), &(ret->aux_types));
*in_type_size = static_cast<uint32_t>(ret->arg_types.size());
*in_type_data = dmlc::BeginPtr(ret->arg_types);
*out_type_size = static_cast<uint32_t>(ret->out_types.size());
*out_type_data = dmlc::BeginPtr(ret->out_types);
*aux_type_size = static_cast<uint32_t>(ret->aux_types.size());
*aux_type_data = dmlc::BeginPtr(ret->aux_types);
*complete = (g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0);
API_END();
}
int MXSymbolInferTypePartial(SymbolHandle sym,
uint32_t num_args,
const char** keys,
const int *arg_type_data,
uint32_t *in_type_size,
const int **in_type_data,
uint32_t *out_type_size,
const int **out_type_data,
uint32_t *aux_type_size,
const int **aux_type_data,
int *complete) {
int succ = 0;
*complete = 1;
return MXSymbolInferType(sym, num_args, keys,
arg_type_data,
in_type_size, in_type_data,
out_type_size, out_type_data,
aux_type_size, aux_type_data,
&succ);
}
int MXSymbolGrad(SymbolHandle sym, uint32_t num_wrt, const char** wrt, SymbolHandle* out) {
API_BEGIN();
LOG(FATAL) << "not implemented";
API_END();
}
int MXQuantizeSymbol(SymbolHandle sym_handle,
SymbolHandle *ret_sym_handle,
const int* dev_type,
const uint32_t num_excluded_sym_names,
const char **excluded_sym_names,
const uint32_t num_excluded_op_names,
const char **excluded_op_names,
const uint32_t num_offline,
const char **offline_params,
const char *quantized_dtype,
const bool calib_quantize,
const char *quantize_mode,
const char *quantize_granularity,
mx_uint* out_num_calib_names,
const char ***out_calib_names) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
nnvm::Graph g = Symbol2Graph(*sym);
int target_dev = *dev_type;
std::unordered_set<std::string> excluded_node_names;
for (size_t i = 0; i < num_excluded_sym_names; ++i) {
excluded_node_names.emplace(excluded_sym_names[i]);
}
std::unordered_set<std::string> excluded_op;
for (size_t i = 0; i < num_excluded_op_names; ++i) {
excluded_op.emplace(excluded_op_names[i]);
}
std::unordered_set<std::string> offline;
for (size_t i = 0; i < num_offline; ++i) {
offline.emplace(offline_params[i]);
}
std::string quantized_type(quantized_dtype);
std::string quantized_mode(quantize_mode);
std::string quantized_granularity(quantize_granularity);
g.attrs["excluded_nodes"] = std::make_shared<nnvm::any>(std::move(excluded_node_names));
g.attrs["excluded_ops"] = std::make_shared<nnvm::any>(std::move(excluded_op));
g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
g.attrs["quantized_dtype"] = std::make_shared<nnvm::any>(std::move(quantized_type));
g.attrs["target_ctx"] = std::make_shared<nnvm::any>(target_dev);
g.attrs["quantize_mode"] = std::make_shared<nnvm::any>(std::move(quantized_mode));
g.attrs["quantize_granularity"] = std::make_shared<nnvm::any>(std::move(quantized_granularity));
g = ApplyPass(std::move(g), "QuantizeGraph");
const auto& calib_nodes = g.GetAttr<std::vector<std::string>>("calib_nodes");
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
ret->ret_vec_str = std::move(calib_nodes);
*out_num_calib_names = ret->ret_vec_str.size();
ret->ret_vec_charp.clear();
ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
for (const auto &str : ret->ret_vec_str) {
ret->ret_vec_charp.push_back(str.c_str());
}
*out_calib_names = dmlc::BeginPtr(ret->ret_vec_charp);
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
// helper function to add mapping of node_name -> dtype map
// for the given indexed graph and inferred_dtypes
static void _SetInputDTypes(
const nnvm::IndexedGraph& idx,
const nnvm::DTypeVector& inferred_dtypes,
std::unordered_map<std::string, int>* node_name_dtype_map,
std::unordered_map<std::string, int>* node_without_dtype_map) {
const std::string dtype_keyword = "__dtype__";
for (uint32_t nid : idx.input_nodes()) {
const auto& node = idx[nid].source;
const auto& node_with_dtype = node->attrs.dict.find(dtype_keyword);
// input nodes classified into nodes_with_dtype, nodes_without_dtype
// This classification required because if param_names not provided
// we want to update dtypes of only those nodes which have dtypes set
// inferred_dtypes are obtained for the nodes, if unknown
// dtype is set to fp32
if (node_with_dtype != node->attrs.dict.end()) {
if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) {
(*node_name_dtype_map)[node->attrs.name] = 0;
} else {
(*node_name_dtype_map)[node->attrs.name] =
inferred_dtypes[idx.entry_id(nid, 0)];
}
} else {
if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) {
(*node_without_dtype_map)[node->attrs.name] = 0;
} else {
(*node_without_dtype_map)[node->attrs.name] =
inferred_dtypes[idx.entry_id(nid, 0)];
}
}
}
}
// helper function update the node dtype attrs for a vector of nodeptrs
// given the node name to dtype information and the names of model_params
// if model_params is provided the function will dtype of only model params.
// if model_params is empty, the function will dtype of all nodes which had
// a prior dtype set.
// args is a const_reference vector of ObjectPtrs. ObjectPtrs are immutable but
// the Nodes they are pointing will be mutated in this function
static void _UpdateSymDTypeAttrs(
const std::unordered_map<std::string, int>& node_name_dtype_map,
const std::unordered_map<std::string, int>& node_without_dtype_map,
const std::unordered_set<std::string>& model_params,
const std::vector<nnvm::ObjectPtr>& args) {
const std::string dtype_keyword = "__dtype__";
// Update args to have the right dtype attrs
if (model_params.size() > 0) {
// if model params provided, set dtype only for model params
for (size_t i = 0; i < args.size(); ++i) {
const std::string& node_name = args[i]->attrs.name;
auto it_model_params = model_params.find(node_name);
auto it_with_dtype = node_name_dtype_map.find(node_name);
auto it_without_dtype = node_without_dtype_map.find(node_name);
if (it_model_params != model_params.end()) {
// need to update __dtype__ attribute if already set, else set it
if (it_with_dtype != node_name_dtype_map.end()) {
args[i]->attrs.dict[dtype_keyword] =
std::to_string(it_with_dtype->second);
} else {
CHECK(it_without_dtype != node_without_dtype_map.end())
<< "make sure all nodes without dtype have properly been added "
"in node_without_dtype_map";
args[i]->attrs.dict[dtype_keyword] =
std::to_string(it_without_dtype->second);
}
}
}
} else {
// if model params not provided, update __dtype__ for all inputs,
// which already had it set, don't touch the rest
for (size_t i = 0; i < args.size(); ++i) {
auto it = node_name_dtype_map.find(args[i]->attrs.name);
if (it != node_name_dtype_map.end()) {
if (args[i]->attrs.dict.find(dtype_keyword) !=
args[i]->attrs.dict.end()) {
args[i]->attrs.dict[dtype_keyword] = std::to_string(it->second);
}
}
}
}
}
int MXReducePrecisionSymbol(SymbolHandle sym_handle,
SymbolHandle *ret_sym_handle,
uint32_t num_args,
const int *arg_type_data,
uint32_t num_ind_ptr,
const int* ind_ptr,
const int* target_dtype,
const int cast_optional_params,
const uint32_t num_target_dtype_op_names,
const uint32_t num_fp32_op_names,
const uint32_t num_widest_dtype_op_names,
const uint32_t num_conditional_fp32_op_names,
const uint32_t num_excluded_symbols,
const uint32_t num_model_params,
const char **target_dtype_op_names,
const char **fp32_op_names,
const char **widest_dtype_op_names,
const char **conditional_fp32_op_names,
const char **excluded_symbols,
const char **param_names,
const char **param_vals,
const char **model_param_names,
const char **arg_names) {
nnvm::Symbol *result_sym = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
nnvm::Graph g = Symbol2Graph(*sym);
std::unordered_set<std::string> target_dtype_ops;
std::unordered_set<std::string> fp32_ops;
std::unordered_set<std::string> widest_dtype_ops;
std::unordered_set<std::string> excluded_syms;
std::unordered_set<std::string> model_params;
// conditional_fp32_ops contains the mapping of op_name -> (map of param_name -> param_values)
// which need to be conditionally selected to be casted to FP32
std::unordered_map<std::string,
std::unordered_map<std::string,
std::vector<std::string>>> conditional_fp32_ops;
int target_dt = *target_dtype;
for (size_t i = 0; i < num_target_dtype_op_names; ++i) {
target_dtype_ops.emplace(target_dtype_op_names[i]);
}
for (size_t i = 0; i < num_fp32_op_names; ++i) {
fp32_ops.emplace(fp32_op_names[i]);
}
for (size_t i = 0; i < num_widest_dtype_op_names; ++i) {
widest_dtype_ops.emplace(widest_dtype_op_names[i]);
}
for (size_t i = 0; i < num_excluded_symbols; ++i) {
excluded_syms.emplace(excluded_symbols[i]);
}
for (size_t i = 0; i < num_model_params; ++i) {
model_params.emplace(model_param_names[i]);
}
for (size_t i = 0; i < num_ind_ptr - 1; ++i) {
for (int j = ind_ptr[i]; j < ind_ptr[i + 1]; ++j) {
conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]]
.emplace_back(std::string(param_vals[j]));
}
}
std::unordered_map<std::string, int> kwargs;
std::unordered_map<std::string, int> node_name_dtype_map, node_without_dtype_map;
nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
for (uint32_t i = 0; i < num_args; ++i) {
kwargs[arg_names[i]] = arg_type_data[i];
node_name_dtype_map[arg_names[i]] = arg_type_data[i];
}
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
g.attrs["target_dtype_ops"] =
std::make_shared<nnvm::any>(std::move(target_dtype_ops));
g.attrs["fp32_ops"] = std::make_shared<nnvm::any>(std::move(fp32_ops));
g.attrs["widest_dtype_ops"] =
std::make_shared<nnvm::any>(std::move(widest_dtype_ops));
g.attrs["conditional_fp32_ops"] =
std::make_shared<nnvm::any>(std::move(conditional_fp32_ops));
g.attrs["excluded_syms"] =
std::make_shared<nnvm::any>(std::move(excluded_syms));
g.attrs["target_dtype"] = std::make_shared<nnvm::any>(target_dt);
g.attrs["data_name_types"] = std::make_shared<nnvm::any>(kwargs);
g.attrs["cast_optional_params"] = std::make_shared<nnvm::any>(cast_optional_params);
g = ApplyPass(std::move(g), "ReducePrecision");
// Need to run type inference since it is possible that inferred
// type of some inputs has changed
g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "");
const nnvm::DTypeVector &inferred_dtypes =
g.GetAttr<nnvm::DTypeVector>("dtype");
g.attrs["inferred_dtypes"] = std::make_shared<dmlc::any>(std::move(inferred_dtypes));
g.attrs["target_dtype"] = std::make_shared<nnvm::any>(target_dt);
if (cast_optional_params) {
g = ApplyPass(std::move(g), "AMPInferUnknown");
const nnvm::DTypeVector &inferred_dtype_result =
g.GetAttr<nnvm::DTypeVector>("inferred_dtype_result");
const nnvm::IndexedGraph &idx = g.indexed_graph();
// set node name -> input dtype mapping using infer dtype
_SetInputDTypes(idx, inferred_dtype_result, &node_name_dtype_map, &node_without_dtype_map);
} else {
const nnvm::IndexedGraph &idx = g.indexed_graph();
// set node name -> input dtype mapping using infer dtype
_SetInputDTypes(idx, inferred_dtypes, &node_name_dtype_map, &node_without_dtype_map);
}
result_sym->outputs = g.outputs;
*ret_sym_handle = result_sym;
nnvm::Symbol *ret_sym = static_cast<nnvm::Symbol *>(*ret_sym_handle);
const std::vector<nnvm::ObjectPtr>& args = ret_sym->ListInputs(nnvm::Symbol::kAll);
// update symbol dtype attrs using the node name -> dtype mapping, if dtype is already set
// in the symbol, else set dtype for the model_params
_UpdateSymDTypeAttrs(node_name_dtype_map, node_without_dtype_map, model_params, args);
API_END_HANDLE_ERROR(delete result_sym);
}
int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
const uint32_t num_layers,
const char** layer_names,
const float* min_ranges,
const float* max_ranges,
SymbolHandle* ret_qsym_handle) {
nnvm::Symbol* s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(qsym_handle);
nnvm::Graph g = Symbol2Graph(*sym);
std::unordered_map<std::string, std::pair<float, float>> calib_table;
for (size_t i = 0; i < num_layers; ++i) {
calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], max_ranges[i]));
}
g.attrs["calib_table"] = std::make_shared<nnvm::any>(std::move(calib_table));
g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph");
s->outputs = g.outputs;
*ret_qsym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
SymbolHandle *ret_sym_handle) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
if (property->HasAttr("disable") && property->GetAttr<bool>("disable") == true) {
auto full_name = property->HasAttr("property_name")
? property->GetAttr<std::string>("property_name")
: std::string();
LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
<< " is disabled.";
continue;
}
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
property->RemoveAttr("graph");
g.attrs.erase("subgraph_property");
s->outputs = g.outputs;
}
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs.";
const auto &node = source->outputs[0].node;
for (const auto &other_node : source->outputs) {
if (node.get() != other_node.node.get()) {
LOG(FATAL)
<< "Generating atomic symbol from other symbol only works for nongrouped symbol.";
}
}
const auto *op = node->op();
const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
*s = nnvm::Symbol::CreateFunctor(op, attrs);
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) {
nnvm::Symbol* out_sym = new nnvm::Symbol;
API_BEGIN();
nnvm::Symbol* src_sym = static_cast<nnvm::Symbol*>(src);
*out_sym = *src_sym;
*out = out_sym;
API_END_HANDLE_ERROR(delete out_sym);
}
int MXOptimizeForBackend(SymbolHandle sym_handle,
const char* backend_name,
const int dev_type,
SymbolHandle* ret_sym_handle,
const mx_uint args_len,
NDArrayHandle* in_args_handle,
const mx_uint aux_len,
NDArrayHandle* in_aux_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
int* new_aux_cnt,
NDArrayHandle** new_aux_handle,
char*** new_aux_names_handle) {
// create copy of input symbol
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
const auto& indexed_graph = g.indexed_graph();
const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
std::vector<std::string> input_names = sym->ListInputNames(nnvm::Symbol::kAll);
size_t num_forward_inputs = input_names.size();
NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);
// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
}
}
args_top++;
}
}
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
std::vector<std::string> aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
} else {
// args/aux were not specified, so set nullptr/empty-lists
NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> arg_names;
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> aux_names;
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
}
// create a data structure from pointer array
std::unordered_map<std::string, std::string> options_map;
for (mx_uint i = 0; i < num_options; ++i)
options_map.emplace(keys[i], vals[i]);
// set dedup option as attribute on graph to enable dedup during partitioning
if (options_map.count("dedup_subgraph") > 0 &&
options_map.at("dedup_subgraph").compare("True") == 0)
g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
// use subgraph backend
const auto backend = mxnet::op::SubgraphBackendRegistry
::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
property->PrePartition(g, options_map);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
g.attrs.erase("subgraph_property");
property->PostPartition(g);
}
} else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
// use graph pass
g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
g = ApplyPass(std::move(g), backend_name);
std::vector<NDArray*> new_args = g.GetAttr<std::vector<NDArray*>>("new_args");
std::vector<NDArray*> new_aux = g.GetAttr<std::vector<NDArray*>>("new_aux");
std::vector<std::string> new_arg_names = g.GetAttr<std::vector<std::string>>("new_arg_names");
std::vector<std::string> new_aux_names = g.GetAttr<std::vector<std::string>>("new_aux_names");
g.attrs.erase("new_args");
g.attrs.erase("new_aux");
g.attrs.erase("new_arg_names");
g.attrs.erase("new_aux_names");
NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
char** new_arg_cstr = new char*[new_arg_names.size()];
char** new_aux_cstr = new char*[new_aux_names.size()];
for (unsigned i = 0; i < new_arg_names.size(); i++) {
new_arg_arr[i] = new_args[i];
std::string& s = new_arg_names[i];
char* tmp = new char[s.length()+1];
s.copy(tmp, s.length());
tmp[s.length()] = '\0';
new_arg_cstr[i] = tmp;
}
for (unsigned i = 0; i < new_aux_names.size(); i++) {
new_aux_arr[i] = new_aux[i];
std::string& s = new_aux_names[i];
char* tmp = new char[s.length()+1];
s.copy(tmp, s.length());
tmp[s.length()] = '\0';
new_aux_cstr[i] = tmp;
}
*new_args_cnt = new_arg_names.size();
*new_aux_cnt = new_aux_names.size();
*new_arg_names_handle = new_arg_cstr;
*new_aux_names_handle = new_aux_cstr;
*new_args_ptr = new_arg_arr;
*new_aux_ptr = new_aux_arr;
} else {
// cannot find graph pass or subgraph backend registered in this name
LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
}
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}