blob: ed513c0d7785637dca93d79d844dbdcf0c869cec [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file c_api.cc
* \brief C API of mxnet
*/
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <dmlc/recordio.h>
#include <dmlc/omp.h>
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/io.h>
#include <mxnet/c_api.h>
#include <mxnet/kvstore.h>
#include <mxnet/rtc.h>
#include <mxnet/storage.h>
#include <vector>
#include <sstream>
#include <string>
#include <mutex>
#include <memory>
#include <functional>
#include <utility>
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
using namespace mxnet;
// Internal function to get the information
// from function registry
// Used to implement MXSymbolGetAtomicSymbolInfo and MXFuncGetInfo
template<typename FunRegType>
inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
*name = e->name.c_str();
*description = e->description.c_str();
*num_args = static_cast<mx_uint>(e->arguments.size());
if (return_type) *return_type = e->return_type.c_str();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < e->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(e->arguments[i].name.c_str());
}
for (size_t i = 0; i < e->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(e->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < e->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(e->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2);
API_END();
}
// NOTE: return value is added in API_END
int MXRandomSeed(int seed) {
API_BEGIN();
mxnet::RandomSeed(seed);
API_END();
}
int MXRandomSeedContext(int seed, int dev_type, int dev_id) {
API_BEGIN();
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
mxnet::RandomSeed(ctx, seed);
API_END();
}
int MXNotifyShutdown() {
API_BEGIN();
Engine::Get()->NotifyShutdown();
API_END();
}
int MXSetNumOMPThreads(int thread_num) {
API_BEGIN();
omp_set_num_threads(thread_num);
API_END();
}
int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) {
API_BEGIN();
*prev_bulk_size = Engine::Get()->set_bulk_size(bulk_size);
API_END();
}
int MXGetGPUCount(int* out) {
API_BEGIN();
*out = Context::GetGPUCount();
API_END();
}
int MXGetVersion(int *out) {
API_BEGIN();
*out = static_cast<int>(MXNET_VERSION);
API_END();
}
int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray();
API_END();
}
int MXNDArrayCreate(const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(
TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0);
API_END();
}
int MXNDArrayCreateEx(const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(
TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0,
dtype);
API_END();
}
int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out) {
API_BEGIN();
std::vector<int> aux_types;
std::vector<TShape> aux_shapes;
auto shape_start = aux_shape;
for (size_t i = 0; i < num_aux; i++) {
// types
aux_types.push_back(aux_type[i]);
// shapes
aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]);
shape_start += aux_ndims[i];
}
*out = new NDArray(
NDArrayStorageType(storage_type),
TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0,
dtype, aux_types, aux_shapes);
API_END();
}
int MXNDArrayLoadFromRawBytes(const void *buf,
size_t size,
NDArrayHandle *out) {
NDArray *ptr = nullptr;
API_BEGIN();
dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*)
ptr = new NDArray();
if (!ptr->Load(&strm)) {
throw dmlc::Error("Invalid NDArray serialization format");
}
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
int MXNDArraySaveRawBytes(NDArrayHandle handle,
size_t *out_size,
const char **out_buf) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_str.resize(0);
dmlc::MemoryStringStream strm(&ret->ret_str);
static_cast<NDArray*>(handle)->Save(&strm);
*out_size = ret->ret_str.length();
*out_buf = ret->ret_str.c_str();
API_END();
}
int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
const void *data,
size_t size) {
API_BEGIN();
static_cast<NDArray*>(handle)->SyncCopyFromCPU(data, size);
API_END();
}
int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
void *data,
size_t size) {
API_BEGIN();
static_cast<NDArray*>(handle)->SyncCopyToCPU(data, size);
API_END();
}
/*!
* \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0
* This function blocks. Do not use it in performance critical code.
* \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated
* \param handle_src handle of a src ndarray which has default storage type
* \param i dst data blob indicator
*/
int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
const NDArrayHandle handle_src,
const int i) {
API_BEGIN();
NDArray* dst = static_cast<NDArray*>(handle_dst);
NDArray* src = static_cast<NDArray*>(handle_src);
dst->SyncCopyFromNDArray(*src, -1, i);
API_END();
}
int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
arr->SyncCheckFormat(full_check);
API_END();
}
int MXNDArrayWaitToRead(NDArrayHandle handle) {
API_BEGIN();
static_cast<NDArray*>(handle)->WaitToRead();
API_END();
}
int MXNDArrayWaitToWrite(NDArrayHandle handle) {
API_BEGIN();
static_cast<NDArray*>(handle)->WaitToWrite();
API_END();
}
int MXNDArrayWaitAll() {
API_BEGIN();
Engine::Get()->WaitForAll();
API_END();
}
int MXNDArraySave(const char* fname,
mx_uint num_args,
NDArrayHandle* args,
const char** keys) {
API_BEGIN();
std::vector<NDArray> data(num_args);
std::vector<std::string> names;
for (mx_uint i = 0; i < num_args; ++i) {
data[i] = *static_cast<NDArray*>(args[i]);
}
if (keys != nullptr) {
names.resize(num_args);
for (mx_uint i = 0; i < num_args; ++i) {
names[i] = keys[i];
}
}
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
mxnet::NDArray::Save(fo.get(), data, names);
}
API_END();
}
int MXNDArrayLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
API_BEGIN();
std::vector<NDArray> data;
std::vector<std::string> &names = ret->ret_vec_str;
{
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
mxnet::NDArray::Load(fi.get(), &data, &names);
}
ret->ret_handles.resize(data.size());
for (size_t i = 0; i < data.size(); ++i) {
NDArray *ptr = new NDArray();
*ptr = data[i];
ret->ret_handles[i] = ptr;
}
ret->ret_vec_charp.resize(names.size());
for (size_t i = 0; i < names.size(); ++i) {
ret->ret_vec_charp[i] = names[i].c_str();
}
*out_size = static_cast<mx_uint>(data.size());
*out_arr = dmlc::BeginPtr(ret->ret_handles);
*out_name_size = static_cast<mx_uint>(names.size());
*out_names = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
size_t size,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
API_BEGIN();
CHECK_NOTNULL(ndarray_buffer);
std::vector<NDArray> data;
std::vector<std::string> &names = ret->ret_vec_str;
{
std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(new dmlc::MemoryFixedSizeStream(
const_cast<void*>(ndarray_buffer), size));
mxnet::NDArray::Load(fi.get(), &data, &names);
}
ret->ret_handles.resize(data.size());
for (size_t i = 0; i < data.size(); ++i) {
NDArray *ptr = new NDArray();
*ptr = data[i];
ret->ret_handles[i] = ptr;
}
ret->ret_vec_charp.resize(names.size());
for (size_t i = 0; i < names.size(); ++i) {
ret->ret_vec_charp[i] = names[i].c_str();
}
*out_size = static_cast<mx_uint>(data.size());
*out_arr = dmlc::BeginPtr(ret->ret_handles);
*out_name_size = static_cast<mx_uint>(names.size());
*out_names = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int MXNDArrayFree(NDArrayHandle handle) {
API_BEGIN();
delete static_cast<NDArray*>(handle);
API_END();
}
int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
*ptr = static_cast<NDArray*>(handle)->SliceWithRecord(
slice_begin, slice_end);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
*ptr = static_cast<NDArray*>(handle)->AtWithRecord(idx);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
int ndim,
int *dims,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
TShape new_shape(dims, dims+ndim);
int size = 1;
int pos = -1;
for (int i = 0; i < ndim; ++i) {
int dim = dims[i];
if (dim == -1) {
CHECK_EQ(pos, -1)
<< "Invalid new shape " << new_shape
<< ": more than one dimensions are -1";
pos = i;
} else {
if (dim == 0) {
CHECK_LT(i, arr->shape().ndim())
<< "Invalid new shape " << new_shape
<< ": 0 dimension exceeds original shape " << arr->shape();
dim = arr->shape()[i];
}
size *= dim;
new_shape[i] = dim;
}
}
if (pos >= 0) {
new_shape[pos] = arr->shape().Size() / size;
}
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
bool reverse,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: "
<< arr->shape();
TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
}
int MXNDArrayGetStorageType(NDArrayHandle handle,
int *out_storage_type) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
*out_storage_type = arr->storage_type();
} else {
*out_storage_type = kUndefinedStorage;
}
API_END();
}
int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const TShape &s = arr->shape();
*out_dim = s.ndim();
std::vector<uint32_t>& buffer = ret->arg_shape_buffer;
buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), buffer.data());
*out_pdata = buffer.data();
} else {
*out_dim = 0;
}
API_END();
}
int MXNDArrayGetData(NDArrayHandle handle,
void **out_pdata) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
*out_pdata = arr->data().dptr_;
} else {
*out_pdata = nullptr;
}
API_END();
}
int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
*out_dtype = arr->dtype();
} else {
*out_dtype = -1;
}
API_END();
}
int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out_type = arr->aux_type(i);
API_END();
}
/*!
* \brief Get a deep copy of the ith aux data blob
* in the form of an NDArray of default storage type.
* This function blocks. Do not use it in performance critical code.
*/
int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out = new NDArray(arr->aux_ndarray(i));
API_END();
}
/*!
* \brief Get a deep copy of the data blob
* in the form of an NDArray of default storage type.
* This function blocks. Do not use it in performance critical code.
*/
int MXNDArrayGetDataNDArray(NDArrayHandle handle,
NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out = new NDArray(arr->data_ndarray());
API_END();
}
int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_type,
int *out_dev_id) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const Context &ctx = arr->ctx();
*out_dev_type = ctx.dev_type;
*out_dev_id = ctx.dev_id;
} else {
*out_dev_type = 0;
*out_dev_id = 0;
}
API_END();
}
int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
NDArray ret = arr->grad();
if (ret.is_none()) {
*out = NULL;
} else {
*out = new NDArray(ret);
}
API_END();
}
int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out = new NDArray(arr->Detach());
API_END();
}
int MXNDArraySetGradState(NDArrayHandle handle, int state) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
arr->set_fresh_out_grad(static_cast<bool>(state));
API_END();
}
int MXNDArrayGetGradState(NDArrayHandle handle, int *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
*out = arr->fresh_out_grad();
API_END();
}
int MXListFunctions(mx_uint *out_size,
FunctionHandle **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<NDArrayFunctionReg>::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
int MXGetFunction(const char *name,
FunctionHandle *out) {
API_BEGIN();
*out = dmlc::Registry<NDArrayFunctionReg>::Find(name);
API_END();
}
int MXFuncGetInfo(FunctionHandle fun,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
return MXAPIGetFunctionRegInfo(static_cast<const NDArrayFunctionReg *>(fun),
name, description, num_args,
arg_names, arg_type_infos, arg_descriptions,
return_type);
}
int MXFuncDescribe(FunctionHandle fun,
mx_uint *num_use_vars,
mx_uint *num_scalars,
mx_uint *num_mutate_vars,
int *type_mask) {
API_BEGIN();
auto *f = static_cast<const NDArrayFunctionReg*>(fun);
*num_use_vars = f->num_use_vars;
*num_scalars = f->num_scalars;
*num_mutate_vars = f->num_mutate_vars;
*type_mask = f->type_mask;
API_END();
}
int MXFuncInvoke(FunctionHandle fun,
NDArrayHandle *use_vars,
mx_float *scalar_args,
NDArrayHandle *mutate_vars) {
API_BEGIN();
auto *f = static_cast<const NDArrayFunctionReg*>(fun);
f->body((NDArray**)(use_vars), // NOLINT(*)
scalar_args,
(NDArray**)(mutate_vars), // NOLINT(*)
0,
NULL,
NULL);
API_END();
}
int MXFuncInvokeEx(FunctionHandle fun,
NDArrayHandle *use_vars,
mx_float *scalar_args,
NDArrayHandle *mutate_vars,
int num_params,
char **param_keys,
char **param_vals) {
API_BEGIN();
auto *f = static_cast<const NDArrayFunctionReg*>(fun);
f->body((NDArray**)(use_vars), // NOLINT(*)
scalar_args,
(NDArray**)(mutate_vars), // NOLINT(*)
num_params,
param_keys,
param_vals);
API_END();
}
//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
int MXListDataIters(mx_uint *out_size,
DataIterCreator **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<DataIteratorReg>::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (DataIterCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
int MXDataIterGetIterInfo(DataIterCreator creator,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions) {
DataIteratorReg *e = static_cast<DataIteratorReg *>(creator);
return MXAPIGetFunctionRegInfo(e, name, description, num_args,
arg_names, arg_type_infos, arg_descriptions,
NULL);
}
int MXDataIterCreateIter(DataIterCreator creator,
mx_uint num_param,
const char **keys,
const char **vals,
DataIterHandle *out) {
IIterator<DataBatch> *iter = nullptr;
API_BEGIN();
DataIteratorReg *e = static_cast<DataIteratorReg *>(creator);
iter = e->body();
std::vector<std::pair<std::string, std::string> > kwargs;
for (mx_uint i = 0; i < num_param; ++i) {
kwargs.push_back({std::string(keys[i]), std::string(vals[i])});
}
iter->Init(kwargs);
*out = iter;
API_END_HANDLE_ERROR(delete iter);
}
int MXDataIterFree(DataIterHandle handle) {
API_BEGIN();
delete static_cast<IIterator<DataBatch> *>(handle);
API_END();
}
int MXDataIterBeforeFirst(DataIterHandle handle) {
API_BEGIN();
static_cast<IIterator<DataBatch>* >(handle)->BeforeFirst();
API_END();
}
int MXDataIterNext(DataIterHandle handle, int *out) {
API_BEGIN();
*out = static_cast<IIterator<DataBatch>* >(handle)->Next();
API_END();
}
int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
NDArray* pndarray = new NDArray();
// temp hack to make label 1D
// TODO(tianjun) make label 1D when label_width=0
TShape shape = db.data[1].shape();
if (shape[1] == 1) {
*pndarray = db.data[1].Reshape(mshadow::Shape1(shape[0]));
} else {
*pndarray = db.data[1];
}
*out = pndarray;
API_END();
}
int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*out_size = db.index.size();
*out_index = const_cast<uint64_t*>(db.index.data());
API_END();
}
int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
NDArray* pndarray = new NDArray();
*pndarray = db.data[0];
*out = pndarray;
API_END();
}
int MXDataIterGetPadNum(DataIterHandle handle, int *pad) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*pad = db.num_batch_padd;
API_END();
}
int MXKVStoreCreate(const char *type,
KVStoreHandle *out) {
API_BEGIN();
*out = KVStore::Create(type);
API_END();
}
int MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params,
const char** keys, const char** vals) {
API_BEGIN();
std::vector<std::pair<std::string, std::string> > params;
for (mx_uint i = 0; i < num_params; ++i) {
std::pair<std::string, std::string> p;
p.first = keys[i];
p.second = vals[i];
params.push_back(p);
}
static_cast<KVStore*>(handle)->SetGradientCompression(params);
API_END();
}
int MXKVStoreFree(KVStoreHandle handle) {
API_BEGIN();
delete static_cast<KVStore*>(handle);
API_END();
}
int MXKVStoreInit(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
API_END();
}
int MXKVStoreInitEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
API_END();
}
int MXKVStorePush(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
API_END();
}
int MXKVStorePushEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
API_END();
}
int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, true);
API_END();
}
int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, true);
API_END();
}
int MXKVStorePullWithSparse(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
API_END();
}
int MXKVStorePullWithSparseEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
API_END();
}
int MXKVStorePullRowSparse(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
const NDArrayHandle* row_ids,
int priority) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_val_rowids[i] = std::make_pair(static_cast<NDArray*>(vals[i]),
*static_cast<NDArray*>(row_ids[i]));
}
static_cast<KVStore*>(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
API_END();
}
int MXKVStorePullRowSparseEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
const NDArrayHandle* row_ids,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_val_rowids[i] = std::make_pair(static_cast<NDArray*>(vals[i]),
*static_cast<NDArray*>(row_ids[i]));
}
static_cast<KVStore*>(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
API_END();
}
void MXKVStoreSetUpdaterImpl(KVStoreHandle handle,
MXKVStoreUpdater updater,
void* updater_handle) {
MXKVStoreUpdater * updater_temp = updater;
void* updater_handle_temp = updater_handle;
std::function<void(int, const NDArray&, NDArray*)> updt
= [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) {
NDArray* recv_copy = new NDArray();
*recv_copy = recv;
NDArray* local_copy = new NDArray();
*local_copy = *local;
updater_temp(key, recv_copy, local_copy, updater_handle_temp);
};
static_cast<KVStore*>(handle)->set_updater(updt);
}
int MXKVStoreSetUpdater(KVStoreHandle handle,
MXKVStoreUpdater updater,
void* updater_handle) {
API_BEGIN();
MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
API_END();
}
int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
MXKVStoreUpdater updater,
MXKVStoreStrUpdater str_updater,
void* updater_handle) {
API_BEGIN();
// set updater with int keys
MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
// set updater with string keys
MXKVStoreStrUpdater * updater_temp = str_updater;
void* updater_handle_temp = updater_handle;
std::function<void(const std::string&, const NDArray&, NDArray*)> updt
= [updater_temp, updater_handle_temp]
(const std::string& key, const NDArray& recv, NDArray* local) {
NDArray* recv_copy = new NDArray();
*recv_copy = recv;
NDArray* local_copy = new NDArray();
*local_copy = *local;
updater_temp(key.c_str(), recv_copy, local_copy, updater_handle_temp);
};
static_cast<KVStore*>(handle)->set_updater(updt);
API_END();
}
int MXKVStoreGetRank(KVStoreHandle handle, int *rank) {
API_BEGIN();
*rank = static_cast<KVStore*>(handle)->get_rank();
API_END();
}
int MXKVStoreGetGroupSize(KVStoreHandle handle, int *size) {
API_BEGIN();
*size = static_cast<KVStore*>(handle)->get_group_size();
API_END();
}
int MXKVStoreBarrier(KVStoreHandle handle) {
API_BEGIN();
static_cast<KVStore*>(handle)->Barrier();
API_END();
}
int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle,
const int barrier_before_exit) {
API_BEGIN();
static_cast<KVStore*>(handle)->set_barrier_before_exit(barrier_before_exit);
API_END();
}
int MXInitPSEnv(mx_uint num_vars,
const char **keys,
const char **vals) {
API_BEGIN();
std::unordered_map<std::string, std::string> kwargs;
for (mx_uint i = 0; i < num_vars; ++i) {
kwargs[std::string(keys[i])] = std::string(vals[i]);
}
KVStore::InitPSEnv(kwargs);
API_END();
}
int MXKVStoreIsWorkerNode(int *ret) {
API_BEGIN();
*ret = KVStore::IsWorkerNode();
API_END();
}
int MXKVStoreIsServerNode(int *ret) {
API_BEGIN();
*ret = KVStore::IsServerNode();
API_END();
}
int MXKVStoreIsSchedulerNode(int *ret) {
API_BEGIN();
*ret = KVStore::IsSchedulerNode();
API_END();
}
int MXKVStoreRunServer(KVStoreHandle handle,
MXKVStoreServerController controller,
void *controller_handle) {
API_BEGIN();
MXKVStoreServerController *controller_temp = controller;
void *controller_handle_temp = controller_handle;
auto ctrl = [controller_temp, controller_handle_temp](int head, const std::string& body) {
controller_temp(head, body.c_str(), controller_handle_temp);
};
static_cast<KVStore*>(handle)->RunServer(ctrl);
API_END();
}
int MXKVStoreSendCommmandToServers(KVStoreHandle handle,
int cmd_id,
const char* cmd_body) {
API_BEGIN();
static_cast<KVStore*>(handle)->SendCommandToServers(
cmd_id, std::string(cmd_body));
API_END();
}
int MXKVStoreGetType(KVStoreHandle handle,
const char** type) {
API_BEGIN();
*CHECK_NOTNULL(type) = static_cast<KVStore*>(handle)->type().c_str();
API_END();
}
int MXKVStoreGetNumDeadNode(KVStoreHandle handle,
const int node_id,
int *number,
const int timeout_sec) {
API_BEGIN();
*number = static_cast<KVStore*>(handle)->get_num_dead_node(node_id, timeout_sec);
API_END();
}
struct MXRecordIOContext {
dmlc::RecordIOWriter *writer;
dmlc::RecordIOReader *reader;
dmlc::Stream *stream;
std::string *read_buff;
};
int MXRecordIOWriterCreate(const char *uri,
RecordIOHandle *out) {
API_BEGIN();
dmlc::Stream *stream = dmlc::Stream::Create(uri, "w");
MXRecordIOContext *context = new MXRecordIOContext;
context->writer = new dmlc::RecordIOWriter(stream);
context->reader = NULL;
context->stream = stream;
context->read_buff = NULL;
*out = reinterpret_cast<RecordIOHandle>(context);
API_END();
}
int MXRecordIOWriterFree(RecordIOHandle handle) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
delete context->writer;
delete context->stream;
delete context;
API_END();
}
int MXRecordIOWriterWriteRecord(RecordIOHandle handle,
const char *buf, size_t size) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
context->writer->WriteRecord(reinterpret_cast<const void*>(buf), size);
API_END();
}
int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
*pos = context->writer->Tell();
API_END();
}
int MXRecordIOReaderCreate(const char *uri,
RecordIOHandle *out) {
API_BEGIN();
dmlc::Stream *stream = dmlc::Stream::Create(uri, "r");
MXRecordIOContext *context = new MXRecordIOContext;
context->reader = new dmlc::RecordIOReader(stream);
context->writer = NULL;
context->stream = stream;
context->read_buff = new std::string();
*out = reinterpret_cast<RecordIOHandle>(context);
API_END();
}
int MXRecordIOReaderFree(RecordIOHandle handle) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
delete context->reader;
delete context->stream;
delete context->read_buff;
delete context;
API_END();
}
int MXRecordIOReaderReadRecord(RecordIOHandle handle,
char const **buf, size_t *size) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
if (context->reader->NextRecord(context->read_buff)) {
*buf = context->read_buff->c_str();
*size = context->read_buff->size();
} else {
*buf = NULL;
*size = 0;
}
API_END();
}
int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
context->reader->Seek(pos);
API_END();
}
int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos) {
API_BEGIN();
MXRecordIOContext *context =
reinterpret_cast<MXRecordIOContext*>(handle);
*pos = context->reader->Tell();
API_END();
}
int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
char** input_names, char** output_names,
NDArrayHandle* inputs, NDArrayHandle* outputs,
char* kernel, RtcHandle *out) {
API_BEGIN();
LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
API_END();
}
int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
NDArrayHandle* inputs, NDArrayHandle* outputs,
mx_uint gridDimX,
mx_uint gridDimY,
mx_uint gridDimZ,
mx_uint blockDimX,
mx_uint blockDimY,
mx_uint blockDimZ) {
API_BEGIN();
LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
API_END();
}
int MXRtcFree(RtcHandle handle) {
API_BEGIN();
LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
API_END();
}
int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
API_BEGIN();
mxnet::op::custom::CustomOperator::Get()->Register(op_type, creator);
API_END();
}
int MXRtcCudaModuleCreate(const char* source, int num_options,
const char** options, int num_exports,
const char** exports, CudaModuleHandle *out) {
API_BEGIN();
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
std::vector<std::string> str_opts;
for (int i = 0; i < num_options; ++i) str_opts.emplace_back(options[i]);
std::vector<std::string> str_exports;
for (int i = 0; i < num_exports; ++i) str_exports.emplace_back(exports[i]);
*out = new rtc::CudaModule(source, str_opts, str_exports);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
int MXRtcCudaModuleFree(CudaModuleHandle handle) {
API_BEGIN();
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
delete reinterpret_cast<rtc::CudaModule*>(handle);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_args,
int* is_ndarray, int* is_const, int* arg_types,
CudaKernelHandle *out) {
API_BEGIN();
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
auto module = reinterpret_cast<rtc::CudaModule*>(handle);
std::vector<rtc::CudaModule::ArgType> signature;
for (int i = 0; i < num_args; ++i) {
signature.push_back(rtc::CudaModule::ArgType{
static_cast<bool>(is_ndarray[i]), static_cast<bool>(is_const[i]),
static_cast<mshadow::TypeFlag>(arg_types[i])});
}
auto kernel = module->GetKernel(name, signature);
*out = new std::shared_ptr<rtc::CudaModule::Kernel>(kernel);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
int MXRtcCudaKernelFree(CudaKernelHandle handle) {
API_BEGIN();
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
delete reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
mx_uint grid_dim_x, mx_uint grid_dim_y,
mx_uint grid_dim_z, mx_uint block_dim_x,
mx_uint block_dim_y, mx_uint block_dim_z,
mx_uint shared_mem) {
API_BEGIN();
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
auto kernel = reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
const auto& signature = (*kernel)->signature();
std::vector<dmlc::any> any_args;
for (size_t i = 0; i < signature.size(); ++i) {
if (signature[i].is_ndarray) {
any_args.emplace_back(*static_cast<NDArray*>(args[i]));
} else {
MSHADOW_TYPE_SWITCH(signature[i].dtype, DType, {
any_args.emplace_back(*static_cast<DType*>(args[i]));
});
}
}
(*kernel)->Launch(Context::GPU(dev_id), any_args, grid_dim_x, grid_dim_y,
grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem);
#else
LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation.";
#endif
API_END();
}
int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shared_id) {
API_BEGIN();
NDArray* arr = reinterpret_cast<NDArray*>(handle);
Storage::Handle shandle;
if (arr->ctx().dev_type == Context::kCPUShared) {
arr->WaitToRead();
shandle = arr->storage_handle();
Storage::Get()->SharedIncrementRefCount(shandle);
} else {
NDArray new_arr(arr->shape(), Context::CPUShared(0), false, arr->dtype());
CopyFromTo(*arr, new_arr);
new_arr.WaitToRead();
shandle = new_arr.storage_handle();
Storage::Get()->SharedIncrementRefCount(shandle);
}
*shared_pid = shandle.shared_pid;
*shared_id = shandle.shared_id;
API_END();
}
int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(shared_pid, shared_id, TShape(shape, shape + ndim), dtype);
API_END();
}