blob: bea6437b4c64500e4cad664ee512fa33ee2272b4 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file c_api.cc
* \brief C API of mxnet
*/
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <dmlc/recordio.h>
#include <dmlc/omp.h>
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/io.h>
#include <mxnet/c_api.h>
#include <mxnet/kvstore.h>
#include <mxnet/mxrtc.h>
#include <vector>
#include <sstream>
#include <string>
#include <mutex>
#include <memory>
#include <functional>
#include <utility>
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../engine/profiler.h"
using namespace mxnet;
// Internal function to get the information
// from function registry
// Used to implement MXSymbolGetAtomicSymbolInfo and MXFuncGetInfo
template<typename FunRegType>
inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
*name = e->name.c_str();
*description = e->description.c_str();
*num_args = static_cast<mx_uint>(e->arguments.size());
if (return_type) *return_type = e->return_type.c_str();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < e->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(e->arguments[i].name.c_str());
}
for (size_t i = 0; i < e->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(e->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < e->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(e->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2);
API_END();
}
// NOTE: return value is added in API_END
int MXRandomSeed(int seed) {
API_BEGIN();
mxnet::RandomSeed(seed);
API_END();
}
int MXNotifyShutdown() {
API_BEGIN();
Engine::Get()->NotifyShutdown();
API_END();
}
int MXSetProfilerConfig(int mode, const char* filename) {
// mode, kOnlySymbolic: 0, kAllOperator: 1
API_BEGIN();
#if MXNET_USE_PROFILER
engine::Profiler::Get()->SetConfig(engine::Profiler::ProfilerMode(mode), std::string(filename));
#else
LOG(FATAL) << "Need to compile with USE_PROFILER=1 for MXNet Profiler";
#endif
API_END();
}
int MXDumpProfile() {
API_BEGIN();
#if MXNET_USE_PROFILER
engine::Profiler *profiler = engine::Profiler::Get();
CHECK(profiler->IsEnableOutput())
<< "Profiler haven't been run. Config and start profiler first";
engine::Profiler::Get()->DumpProfile();
#else
LOG(FATAL) << "Need to compile with USE_PROFILER=1 for MXNet Profiler";
#endif
API_END()
}
int MXSetProfilerState(int state) {
// state, kNotRunning: 0, kRunning: 1
API_BEGIN();
#if MXNET_USE_PROFILER
engine::Profiler::Get()->SetState(engine::Profiler::ProfilerState(state));
#else
LOG(FATAL) << "Need to compile with USE_PROFILER=1 for MXNet Profiler";
#endif
API_END();
}
int MXSetNumOMPThreads(int thread_num) {
API_BEGIN();
omp_set_num_threads(thread_num);
API_END();
}
int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray();
API_END();
}
int MXNDArrayCreate(const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(
TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0);
API_END();
}
int MXNDArrayCreateEx(const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(
TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0,
dtype);
API_END();
}
int MXNDArrayLoadFromRawBytes(const void *buf,
size_t size,
NDArrayHandle *out) {
NDArray *ptr = nullptr;
API_BEGIN();
dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*)
ptr = new NDArray();
if (!ptr->Load(&strm)) {
throw dmlc::Error("Invalid NDArray serialization format");
}
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
int MXNDArraySaveRawBytes(NDArrayHandle handle,
size_t *out_size,
const char **out_buf) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_str.resize(0);
dmlc::MemoryStringStream strm(&ret->ret_str);
static_cast<NDArray*>(handle)->Save(&strm);
*out_size = ret->ret_str.length();
*out_buf = ret->ret_str.c_str();
API_END();
}
int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
const void *data,
size_t size) {
API_BEGIN();
static_cast<NDArray*>(handle)->SyncCopyFromCPU(data, size);
API_END();
}
int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
void *data,
size_t size) {
API_BEGIN();
static_cast<NDArray*>(handle)->SyncCopyToCPU(data, size);
API_END();
}
int MXNDArrayWaitToRead(NDArrayHandle handle) {
API_BEGIN();
static_cast<NDArray*>(handle)->WaitToRead();
API_END();
}
int MXNDArrayWaitToWrite(NDArrayHandle handle) {
API_BEGIN();
static_cast<NDArray*>(handle)->WaitToWrite();
API_END();
}
int MXNDArrayWaitAll() {
API_BEGIN();
Engine::Get()->WaitForAll();
API_END();
}
int MXNDArraySave(const char* fname,
mx_uint num_args,
NDArrayHandle* args,
const char** keys) {
API_BEGIN();
std::vector<NDArray> data(num_args);
std::vector<std::string> names;
for (mx_uint i = 0; i < num_args; ++i) {
data[i] = *static_cast<NDArray*>(args[i]);
}
if (keys != nullptr) {
names.resize(num_args);
for (mx_uint i = 0; i < num_args; ++i) {
names[i] = keys[i];
}
}
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
mxnet::NDArray::Save(fo.get(), data, names);
}
API_END();
}
int MXNDArrayLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
API_BEGIN();
std::vector<NDArray> data;
std::vector<std::string> &names = ret->ret_vec_str;
{
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
mxnet::NDArray::Load(fi.get(), &data, &names);
}
ret->ret_handles.resize(data.size());
for (size_t i = 0; i < data.size(); ++i) {
NDArray *ptr = new NDArray();
*ptr = data[i];
ret->ret_handles[i] = ptr;
}
ret->ret_vec_charp.resize(names.size());
for (size_t i = 0; i < names.size(); ++i) {
ret->ret_vec_charp[i] = names[i].c_str();
}
*out_size = static_cast<mx_uint>(data.size());
*out_arr = dmlc::BeginPtr(ret->ret_handles);
*out_name_size = static_cast<mx_uint>(names.size());
*out_names = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int MXNDArrayFree(NDArrayHandle handle) {
API_BEGIN();
delete static_cast<NDArray*>(handle);
API_END();
}
int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
*ptr = static_cast<NDArray*>(handle)->Slice(
slice_begin, slice_end);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
*ptr = static_cast<NDArray*>(handle)->At(idx);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
int ndim,
int *dims,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
TShape new_shape(dims, dims+ndim);
int size = 1;
int pos = -1;
for (int i = 0; i < ndim; ++i) {
int dim = dims[i];
if (dim == -1) {
CHECK_EQ(pos, -1)
<< "Invalid new shape " << new_shape
<< ": more than one dimensions are -1";
pos = i;
} else {
if (dim == 0) {
CHECK_LT(i, arr->shape().ndim())
<< "Invalid new shape " << new_shape
<< ": 0 dimension exceeds original shape " << arr->shape();
dim = arr->shape()[i];
}
size *= dim;
new_shape[i] = dim;
}
}
if (pos >= 0) {
new_shape[pos] = arr->shape().Size() / size;
}
*ptr = arr->Reshape(new_shape);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const TShape &s = arr->shape();
*out_dim = s.ndim();
std::vector<uint32_t>& buffer = ret->arg_shape_buffer;
buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
} else {
*out_dim = 0;
}
API_END();
}
int MXNDArrayGetData(NDArrayHandle handle,
void **out_pdata) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
CHECK(arr->ctx().dev_mask() == cpu::kDevMask)
<< "MXNDArrayGetData can only be called for NDArray on CPU";
const TBlob &b = arr->data();
CHECK(b.CheckContiguous());
MSHADOW_REAL_TYPE_SWITCH(arr->dtype(), DType, {
*out_pdata = b.FlatTo2D<cpu, DType>().dptr_;
});
} else {
*out_pdata = nullptr;
}
API_END();
}
int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
*out_dtype = arr->dtype();
} else {
*out_dtype = -1;
}
API_END();
}
int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_type,
int *out_dev_id) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const Context &ctx = arr->ctx();
*out_dev_type = ctx.dev_type;
*out_dev_id = ctx.dev_id;
} else {
*out_dev_type = 0;
*out_dev_id = 0;
}
API_END();
}
int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out = new NDArray(arr->Detach());
API_END();
}
int MXNDArraySetGradState(NDArrayHandle handle, int state) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
arr->set_fresh_out_grad(static_cast<bool>(state));
API_END();
}
int MXNDArrayGetGradState(NDArrayHandle handle, int *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out = arr->fresh_out_grad();
API_END();
}
int MXListFunctions(mx_uint *out_size,
FunctionHandle **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<NDArrayFunctionReg>::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
int MXGetFunction(const char *name,
FunctionHandle *out) {
API_BEGIN();
*out = dmlc::Registry<NDArrayFunctionReg>::Find(name);
API_END();
}
int MXFuncGetInfo(FunctionHandle fun,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
return MXAPIGetFunctionRegInfo(static_cast<const NDArrayFunctionReg *>(fun),
name, description, num_args,
arg_names, arg_type_infos, arg_descriptions,
return_type);
}
int MXFuncDescribe(FunctionHandle fun,
mx_uint *num_use_vars,
mx_uint *num_scalars,
mx_uint *num_mutate_vars,
int *type_mask) {
API_BEGIN();
auto *f = static_cast<const NDArrayFunctionReg*>(fun);
*num_use_vars = f->num_use_vars;
*num_scalars = f->num_scalars;
*num_mutate_vars = f->num_mutate_vars;
*type_mask = f->type_mask;
API_END();
}
int MXFuncInvoke(FunctionHandle fun,
NDArrayHandle *use_vars,
mx_float *scalar_args,
NDArrayHandle *mutate_vars) {
API_BEGIN();
auto *f = static_cast<const NDArrayFunctionReg*>(fun);
f->body((NDArray**)(use_vars), // NOLINT(*)
scalar_args,
(NDArray**)(mutate_vars), // NOLINT(*)
0,
NULL,
NULL);
API_END();
}
int MXFuncInvokeEx(FunctionHandle fun,
NDArrayHandle *use_vars,
mx_float *scalar_args,
NDArrayHandle *mutate_vars,
int num_params,
char **param_keys,
char **param_vals) {
API_BEGIN();
auto *f = static_cast<const NDArrayFunctionReg*>(fun);
f->body((NDArray**)(use_vars), // NOLINT(*)
scalar_args,
(NDArray**)(mutate_vars), // NOLINT(*)
num_params,
param_keys,
param_vals);
API_END();
}
//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
int MXListDataIters(mx_uint *out_size,
DataIterCreator **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<DataIteratorReg>::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (DataIterCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
int MXDataIterGetIterInfo(DataIterCreator creator,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions) {
DataIteratorReg *e = static_cast<DataIteratorReg *>(creator);
return MXAPIGetFunctionRegInfo(e, name, description, num_args,
arg_names, arg_type_infos, arg_descriptions,
NULL);
}
int MXDataIterCreateIter(DataIterCreator creator,
mx_uint num_param,
const char **keys,
const char **vals,
DataIterHandle *out) {
IIterator<DataBatch> *iter = nullptr;
API_BEGIN();
DataIteratorReg *e = static_cast<DataIteratorReg *>(creator);
iter = e->body();
std::vector<std::pair<std::string, std::string> > kwargs;
for (mx_uint i = 0; i < num_param; ++i) {
kwargs.push_back({std::string(keys[i]), std::string(vals[i])});
}
iter->Init(kwargs);
*out = iter;
API_END_HANDLE_ERROR(delete iter);
}
int MXDataIterFree(DataIterHandle handle) {
API_BEGIN();
delete static_cast<IIterator<DataBatch> *>(handle);
API_END();
}
int MXDataIterBeforeFirst(DataIterHandle handle) {
API_BEGIN();
static_cast<IIterator<DataBatch>* >(handle)->BeforeFirst();
API_END();
}
int MXDataIterNext(DataIterHandle handle, int *out) {
API_BEGIN();
*out = static_cast<IIterator<DataBatch>* >(handle)->Next();
API_END();
}
int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
NDArray* pndarray = new NDArray();
// temp hack to make label 1D
// TODO(tianjun) make label 1D when label_width=0
TShape shape = db.data[1].shape();
if (shape[1] == 1) {
*pndarray = db.data[1].Reshape(mshadow::Shape1(shape[0]));
} else {
*pndarray = db.data[1];
}
*out = pndarray;
API_END();
}
int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*out_size = db.index.size();
*out_index = const_cast<uint64_t*>(db.index.data());
API_END();
}
int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
NDArray* pndarray = new NDArray();
*pndarray = db.data[0];
*out = pndarray;
API_END();
}
int MXDataIterGetPadNum(DataIterHandle handle, int *pad) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*pad = db.num_batch_padd;
API_END();
}
int MXKVStoreCreate(const char *type,
KVStoreHandle *out) {
API_BEGIN();
*out = KVStore::Create(type);
API_END();
}
int MXKVStoreFree(KVStoreHandle handle) {
API_BEGIN();
delete static_cast<KVStore*>(handle);
API_END();
}
int MXKVStoreInit(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
API_END();
}
int MXKVStoreInitEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
API_END();
}
int MXKVStorePush(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
API_END();
}
int MXKVStorePushEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
API_END();
}
int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
API_END();
}
int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
API_END();
}
int MXKVStoreSetUpdater(KVStoreHandle handle,
MXKVStoreUpdater updater,
void* updater_handle) {
API_BEGIN();
MXKVStoreUpdater * updater_temp = updater;
void* updater_handle_temp = updater_handle;
std::function<void(int, const NDArray&, NDArray*)> updt
= [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) {
NDArray* recv_copy = new NDArray();
*recv_copy = recv;
NDArray* local_copy = new NDArray();
*local_copy = *local;
updater_temp(key, recv_copy, local_copy, updater_handle_temp);
};
static_cast<KVStore*>(handle)->set_updater(updt);
API_END();
}
int MXKVStoreGetRank(KVStoreHandle handle, int *rank) {
API_BEGIN();
*rank = static_cast<KVStore*>(handle)->get_rank();
API_END();
}
int MXKVStoreGetGroupSize(KVStoreHandle handle, int *size) {
API_BEGIN();
*size = static_cast<KVStore*>(handle)->get_group_size();
API_END();
}
int MXKVStoreBarrier(KVStoreHandle handle) {
API_BEGIN();
static_cast<KVStore*>(handle)->Barrier();
API_END();
}
int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle,
const int barrier_before_exit) {
API_BEGIN();
static_cast<KVStore*>(handle)->set_barrier_before_exit(barrier_before_exit);
API_END();
}
int MXInitPSEnv(mx_uint num_vars,
const char **keys,
const char **vals) {
API_BEGIN();
std::unordered_map<std::string, std::string> kwargs;
for (mx_uint i = 0; i < num_vars; ++i) {
kwargs[std::string(keys[i])] = std::string(vals[i]);
}
KVStore::InitPSEnv(kwargs);
API_END();
}
int MXKVStoreIsWorkerNode(int *ret) {
API_BEGIN();
*ret = KVStore::IsWorkerNode();
API_END();
}
int MXKVStoreIsServerNode(int *ret) {
API_BEGIN();
*ret = KVStore::IsServerNode();
API_END();
}
int MXKVStoreIsSchedulerNode(int *ret) {
API_BEGIN();
*ret = KVStore::IsSchedulerNode();
API_END();
}
int MXKVStoreRunServer(KVStoreHandle handle,
MXKVStoreServerController controller,
void *controller_handle) {
API_BEGIN();
MXKVStoreServerController *controller_temp = controller;
void *controller_handle_temp = controller_handle;
auto ctrl = [controller_temp, controller_handle_temp](int head, const std::string& body) {
controller_temp(head, body.c_str(), controller_handle_temp);
};
static_cast<KVStore*>(handle)->RunServer(ctrl);
API_END();
}
int MXKVStoreSendCommmandToServers(KVStoreHandle handle,
int cmd_id,
const char* cmd_body) {
API_BEGIN();
static_cast<KVStore*>(handle)->SendCommandToServers(
cmd_id, std::string(cmd_body));
API_END();
}
int MXKVStoreGetType(KVStoreHandle handle,
const char** type) {
API_BEGIN();
*CHECK_NOTNULL(type) = static_cast<KVStore*>(handle)->type().c_str();
API_END();
}
int MXKVStoreGetNumDeadNode(KVStoreHandle handle,
const int node_id,
int *number,
const int timeout_sec) {
API_BEGIN();
*number = static_cast<KVStore*>(handle)->get_num_dead_node(node_id, timeout_sec);
API_END();
}
struct MXRecordIOContext {
dmlc::RecordIOWriter *writer;
dmlc::RecordIOReader *reader;
dmlc::Stream *stream;
std::string *read_buff;
};
int MXRecordIOWriterCreate(const char *uri,
RecordIOHandle *out) {
API_BEGIN();
dmlc::Stream *stream = dmlc::Stream::Create(uri, "w");
MXRecordIOContext *context = new MXRecordIOContext;
context->writer = new dmlc::RecordIOWriter(stream);
context->reader = NULL;
context->stream = stream;
context->read_buff = NULL;
*out = reinterpret_cast<RecordIOHandle>(context);
API_END();
}
int MXRecordIOWriterFree(RecordIOHandle handle) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
delete context->writer;
delete context->stream;
delete context;
API_END();
}
int MXRecordIOWriterWriteRecord(RecordIOHandle handle,
const char *buf, size_t size) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
context->writer->WriteRecord(reinterpret_cast<const void*>(buf), size);
API_END();
}
int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
*pos = context->writer->Tell();
API_END();
}
int MXRecordIOReaderCreate(const char *uri,
RecordIOHandle *out) {
API_BEGIN();
dmlc::Stream *stream = dmlc::Stream::Create(uri, "r");
MXRecordIOContext *context = new MXRecordIOContext;
context->reader = new dmlc::RecordIOReader(stream);
context->writer = NULL;
context->stream = stream;
context->read_buff = new std::string();
*out = reinterpret_cast<RecordIOHandle>(context);
API_END();
}
int MXRecordIOReaderFree(RecordIOHandle handle) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
delete context->reader;
delete context->stream;
delete context->read_buff;
delete context;
API_END();
}
int MXRecordIOReaderReadRecord(RecordIOHandle handle,
char const **buf, size_t *size) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
if (context->reader->NextRecord(context->read_buff)) {
*buf = context->read_buff->c_str();
*size = context->read_buff->size();
} else {
*buf = NULL;
*size = 0;
}
API_END();
}
int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
context->reader->Seek(pos);
API_END();
}
int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
char** input_names, char** output_names,
NDArrayHandle* inputs, NDArrayHandle* outputs,
char* kernel, RtcHandle *out) {
API_BEGIN();
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
std::vector<std::pair<std::string, NDArray> > input, output;
for (mx_uint i = 0; i < num_input; ++i) {
input.push_back(std::pair<std::string, NDArray>(input_names[i],
*reinterpret_cast<NDArray*>(inputs[i])));
}
for (mx_uint i = 0; i < num_output; ++i) {
output.push_back(std::pair<std::string, NDArray>(output_names[i],
*reinterpret_cast<NDArray*>(outputs[i])));
}
MXRtc *rtc = new MXRtc(name, input, output, kernel);
*out = reinterpret_cast<RtcHandle>(rtc);
#else
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
API_END();
}
int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
NDArrayHandle* inputs, NDArrayHandle* outputs,
mx_uint gridDimX,
mx_uint gridDimY,
mx_uint gridDimZ,
mx_uint blockDimX,
mx_uint blockDimY,
mx_uint blockDimZ) {
API_BEGIN();
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
std::vector<NDArray> input, output;
for (mx_uint i = 0; i < num_input; ++i) {
input.push_back(*reinterpret_cast<NDArray*>(inputs[i]));
}
for (mx_uint i = 0; i < num_output; ++i) {
output.push_back(*reinterpret_cast<NDArray*>(outputs[i]));
}
reinterpret_cast<MXRtc*>(handle)->push(input, output,
gridDimX,
gridDimY,
gridDimZ,
blockDimX,
blockDimY,
blockDimZ);
#else
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
API_END();
}
int MXRtcFree(RtcHandle handle) {
API_BEGIN();
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
delete reinterpret_cast<MXRtc*>(handle);
#else
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
API_END();
}
int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
API_BEGIN();
mxnet::op::CustomOpProp::Register(op_type, creator);
API_END();
}