blob: 9df5985839e5f49fc59206817057fbcd0fd58f1a [file] [log] [blame]
from ..base import MXNetError
from libcpp.vector cimport vector
from libcpp.string cimport string
from cpython.version cimport PY_MAJOR_VERSION
ctypedef void* SymbolHandle
ctypedef void* NDArrayHandle
ctypedef void* OpHandle
ctypedef unsigned nn_uint
cdef py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef CALL(int ret):
if ret != 0:
raise MXNetError(NNGetLastError())
cdef const char** CBeginPtr(vector[const char*]& vec):
if (vec.size() != 0):
return &vec[0]
else:
return NULL
cdef vector[const char*] SVec2Ptr(vector[string]& vec):
cdef vector[const char*] svec
svec.resize(vec.size())
for i in range(vec.size()):
svec[i] = vec[i].c_str()
return svec
cdef extern from "nnvm/c_api.h":
const char* NNGetLastError();
int NNGetOpHandle(const char *op_name,
OpHandle *handle);
int NNGetOpInfo(OpHandle op,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNSymbolFree(SymbolHandle symbol);
int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
const char** keys,
SymbolHandle* args);
cdef extern from "mxnet/c_api.h":
int MXListAllOpNames(nn_uint *out_size,
const char ***out_array);
int MXSymbolGetAtomicSymbolInfo(OpHandle creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **key_var_args,
const char **return_type);
int MXSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
int MXSymbolSetAttr(SymbolHandle symbol,
const char* key,
const char* value);
int MXImperativeInvoke(OpHandle creator,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs,
int num_params,
const char **param_keys,
const char **param_vals);
int MXNDArrayFree(NDArrayHandle handle);