blob: 1dd784ba2249a62c0a2cbaf6e13b6f168b162e14 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file c_predict_api.cc
* \brief C predict API of mxnet
*/
#include <dmlc/base.h>
#include <dmlc/memory_io.h>
#include <mxnet/c_predict_api.h>
#include <mxnet/executor.h>
#include <mxnet/ndarray.h>
#include <nnvm/pass_functions.h>
#include <memory>
#include <unordered_set>
#include <unordered_map>
#include "./c_api_common.h"
#include "../operator/operator_common.h"
using namespace mxnet;
// predictor interface
struct MXAPIPredictor {
// output arrays
std::vector<NDArray> out_arrays;
// argument arrays
std::vector<NDArray> arg_arrays;
// output shapes
std::vector<TShape> out_shapes;
// uint32_t buffer for output shapes
std::vector<uint32_t> out_shapes_buffer;
// key to arguments
std::unordered_map<std::string, size_t> key2arg;
// executor
std::unique_ptr<Executor> exec;
};
struct MXAPINDList {
std::vector<std::string> keys;
std::vector<TShape> shapes;
std::vector<uint32_t> shapes_buffer;
std::vector<size_t> indptr;
std::vector<mx_float> data;
};
int MXPredCreate(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
PredictorHandle* out) {
return MXPredCreatePartialOut(
symbol_json_str,
param_bytes,
param_size,
dev_type,
dev_id,
num_input_nodes,
input_keys,
input_shape_indptr,
input_shape_data,
0,
NULL,
out);
}
namespace mxnet {
} // namespace mxnet
int MXPredCreatePartialOut(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
mx_uint num_output_nodes,
const char** output_keys,
PredictorHandle* out) {
using nnvm::Symbol;
MXAPIPredictor* ret = new MXAPIPredictor();
API_BEGIN();
Symbol sym;
// make sure symbols are registered
{
mx_uint outSize;
const char **outArray;
MXListAllOpNames(&outSize, &outArray);
}
// load in the symbol.
{
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
}
// looks likely to output the internal results
if (num_output_nodes != 0) {
Symbol internal = sym.GetInternals();
std::vector<std::string> all_out = internal.ListOutputNames();
std::vector<Symbol> out_syms(num_output_nodes);
for (mx_uint i = 0; i < num_output_nodes; ++i) {
std::string out_key(output_keys[i]);
out_key += "_output";
for (size_t j = 0; j < all_out.size(); ++j) {
if (all_out[j] == out_key) {
out_syms[i] = internal[j];
break;
}
CHECK_NE(j, all_out.size() - 1) << "didn't find node name: " << out_key;
}
}
sym = nnvm::Symbol::CreateGroup(out_syms);
}
// load the parameters
std::unordered_map<std::string, NDArray> arg_params, aux_params;
{
std::unordered_set<std::string> arg_names, aux_names;
std::vector<std::string> arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names_vec = sym.ListInputNames(Symbol::kAuxiliaryStates);
for (size_t i = 0; i < arg_names_vec.size(); ++i) {
arg_names.insert(arg_names_vec[i]);
}
for (size_t i = 0; i < aux_names_vec.size(); ++i) {
aux_names.insert(aux_names_vec[i]);
}
std::vector<NDArray> data;
std::vector<std::string> names;
dmlc::MemoryFixedSizeStream fi((void*)param_bytes, param_size); // NOLINT(*)
NDArray::Load(&fi, &data, &names);
CHECK_EQ(names.size(), data.size())
<< "Invalid param file format";
for (size_t i = 0; i < names.size(); ++i) {
if (!strncmp(names[i].c_str(), "aux:", 4)) {
std::string name(names[i].c_str() + 4);
if (aux_names.count(name) != 0) {
aux_params[name] = data[i];
}
}
if (!strncmp(names[i].c_str(), "arg:", 4)) {
std::string name(names[i].c_str() + 4);
if (arg_names.count(name) != 0) {
arg_params[name] = data[i];
}
}
}
}
// shape inference and bind
std::unordered_map<std::string, TShape> known_shape;
for (mx_uint i = 0; i < num_input_nodes; ++i) {
known_shape[std::string(input_keys[i])] =
TShape(input_shape_data + input_shape_indptr[i],
input_shape_data + input_shape_indptr[i + 1]);
}
std::vector<std::string> arg_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
std::vector<TShape> out_shapes(sym.ListOutputNames().size());
std::vector<TShape> aux_shapes(aux_names.size());
std::vector<TShape> arg_shapes;
for (size_t i = 0; i < arg_names.size(); ++i) {
std::string key = arg_names[i];
ret->key2arg[key] = i;
}
try {
std::vector<TShape> in_shapes;
for (std::string key : sym.ListInputNames(Symbol::kAll)) {
if (known_shape.count(key) != 0) {
in_shapes.push_back(known_shape[key]);
} else {
in_shapes.push_back(TShape());
}
}
nnvm::Graph g; g.outputs = sym.outputs;
g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__");
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
CHECK(infer_complete)
<< "The shape information of is not enough to get the shapes";
CopyAttr(g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"),
&arg_shapes, &out_shapes, &aux_shapes);
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
std::vector<NDArray> arg_arrays, aux_arrays;
for (size_t i = 0; i < arg_shapes.size(); ++i) {
NDArray nd = NDArray(arg_shapes[i], ctx);
if (arg_params.count(arg_names[i]) != 0) {
CopyFromTo(arg_params[arg_names[i]], &nd);
}
arg_arrays.push_back(nd);
}
for (size_t i = 0; i < aux_shapes.size(); ++i) {
NDArray nd = NDArray(aux_shapes[i], ctx);
if (aux_params.count(aux_names[i]) != 0) {
CopyFromTo(aux_params[aux_names[i]], &nd);
}
aux_arrays.push_back(nd);
}
ret->arg_arrays = arg_arrays;
// bind
{
std::map<std::string, Context> ctx_map;
std::vector<NDArray> grad_store(arg_arrays.size());
std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
ret->exec.reset(Executor::Bind(sym, ctx, ctx_map,
arg_arrays,
grad_store, grad_req,
aux_arrays));
ret->out_shapes = out_shapes;
ret->out_arrays = ret->exec->outputs();
}
*out = ret;
API_END_HANDLE_ERROR(delete ret);
}
int MXPredGetOutputShape(PredictorHandle handle,
mx_uint out_index,
mx_uint** shape_data,
mx_uint* shape_ndim) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
CHECK_LT(out_index, p->out_arrays.size())
<< "Index exceed number of outputs";
const TShape& s = p->out_shapes[out_index];
p->out_shapes_buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data());
*shape_data = p->out_shapes_buffer.data();
*shape_ndim = p->out_shapes[out_index].ndim();
API_END();
}
int MXPredSetInput(PredictorHandle handle,
const char* key,
const mx_float* data,
mx_uint size) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
auto it = p->key2arg.find(key);
if (it == p->key2arg.end()) {
LOG(FATAL) << "cannot find input key " << key;
}
NDArray& nd = p->arg_arrays[it->second];
nd.SyncCopyFromCPU(data, size);
API_END();
}
int MXPredForward(PredictorHandle handle) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
p->exec->Forward(false);
API_END();
}
int MXPredPartialForward(PredictorHandle handle, int step, int* step_left) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
p->exec->PartialForward(false, step, step_left);
API_END();
}
int MXPredGetOutput(PredictorHandle handle,
mx_uint index,
mx_float* data,
mx_uint size) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
CHECK_LT(index, p->out_arrays.size())
<< "Output index out of range";
const NDArray& nd = p->out_arrays[index];
nd.SyncCopyToCPU(data, size);
API_END();
}
int MXPredFree(PredictorHandle handle) {
API_BEGIN();
delete static_cast<MXAPIPredictor*>(handle);
API_END();
}
int MXNDListCreate(const char* nd_file_bytes,
int nd_file_size,
NDListHandle *out,
mx_uint* out_length) {
MXAPINDList* ret = new MXAPINDList();
API_BEGIN();
std::vector<NDArray> arrays;
dmlc::MemoryFixedSizeStream fi((void*)nd_file_bytes, nd_file_size); // NOLINT(*)
NDArray::Load(&fi,
&(arrays),
&(ret->keys));
if (ret->keys.size() == 0) {
ret->keys.resize(arrays.size());
}
ret->indptr.push_back(0);
for (size_t i = 0; i < arrays.size(); ++i) {
TShape shape = arrays[i].shape();
size_t begin = ret->data.size();
size_t size = shape.Size();
ret->shapes.push_back(shape);
ret->data.resize(begin + size);
arrays[i].SyncCopyToCPU(dmlc::BeginPtr(ret->data) + begin, size);
ret->indptr.push_back(begin + size);
}
*out = ret;
*out_length = static_cast<mx_uint>(arrays.size());
API_END();
}
int MXNDListGet(NDListHandle handle,
mx_uint index,
const char** out_key,
const mx_float** out_data,
const mx_uint** out_shape,
mx_uint* out_ndim) {
MXAPINDList* p = static_cast<MXAPINDList*>(handle);
API_BEGIN();
CHECK_LT(index, p->shapes.size())
<< "Index out of range";
*out_key = p->keys[index].c_str();
*out_data = dmlc::BeginPtr(p->data) + p->indptr[index];
const TShape& s = p->shapes[index];
p->shapes_buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), p->shapes_buffer.data());
*out_shape = p->shapes_buffer.data();
*out_ndim = p->shapes[index].ndim();
API_END();
}
int MXNDListFree(NDListHandle handle) {
API_BEGIN();
delete static_cast<MXAPINDList*>(handle);
API_END();
}