blob: 8bca2746de9fab36667f0d432206a1dd48be969e [file] [log] [blame]
# coding: utf-8
# pylint: disable=invalid-name, no-member
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
import sys
import ctypes
import atexit
import warnings
import inspect
import numpy as np
from . import libinfo
warnings.filterwarnings('default', category=DeprecationWarning)
__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = str,
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = basestring,
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
return '_Null'
_Null = _NullType()
class MXNetError(Exception):
"""Error that will be throwed by all mxnet functions."""
pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
# DMatrix functions
lib.MXGetLastError.restype = ctypes.c_char_p
return lib
# version number
__version__ = libinfo.__version__
# library instance of mxnet
_LIB = _load_lib()
# type definitions
mx_uint = ctypes.c_uint
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
mx_real_t = np.float32
NDArrayHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
OpHandle = ctypes.c_void_p
CachedOpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
ExecutorHandle = ctypes.c_void_p
DataIterCreatorHandle = ctypes.c_void_p
DataIterHandle = ctypes.c_void_p
KVStoreHandle = ctypes.c_void_p
RecordIOHandle = ctypes.c_void_p
RtcHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
def check_call(ret):
"""Check the return value of C API call.
This function will raise an exception when an error occurs.
Wrap every API call with this function.
Parameters
----------
ret : int
return value from API calls.
"""
if ret != 0:
raise MXNetError(py_str(_LIB.MXGetLastError()))
if sys.version_info[0] < 3:
def c_str(string):
"""Create ctypes char * from a Python string.
Parameters
----------
string : string type
Python string.
Returns
-------
str : c_char_p
A char pointer that can be passed to C API.
Examples
--------
>>> x = mx.base.c_str("Hello, World")
>>> print x.value
Hello, World
"""
return ctypes.c_char_p(string)
else:
def c_str(string):
"""Create ctypes char * from a Python string.
Parameters
----------
string : string type
Python string.
Returns
-------
str : c_char_p
A char pointer that can be passed to C API.
Examples
--------
>>> x = mx.base.c_str("Hello, World")
>>> print x.value
Hello, World
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a Python array.
Parameters
----------
ctype : ctypes data type
Data type of the array we want to convert to, such as mx_float.
values : tuple or list
Data content.
Returns
-------
out : ctypes array
Created ctypes array.
Examples
--------
>>> x = mx.base.c_array(mx.base.mx_float, [1, 2, 3])
>>> print len(x)
3
>>> x[1]
2.0
"""
return (ctype * len(values))(*values)
def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type.
Parameters
----------
cptr : ctypes.POINTER(ctypes.c_char)
Pointer to the raw memory region.
length : int
The length of the buffer.
Returns
-------
buffer : bytearray
The raw byte memory buffer.
"""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise TypeError('expected char pointer')
res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed')
return res
def ctypes2numpy_shared(cptr, shape):
"""Convert a ctypes pointer to a numpy array.
The resulting NumPy array shares the memory with the pointer.
Parameters
----------
cptr : ctypes.POINTER(mx_float)
pointer to the memory region
shape : tuple
Shape of target `NDArray`.
Returns
-------
out : numpy_array
A numpy array : numpy array.
"""
if not isinstance(cptr, ctypes.POINTER(mx_float)):
raise RuntimeError('expected float pointer')
size = 1
for s in shape:
size *= s
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
def build_param_doc(arg_names, arg_types, arg_descs, remove_dup=True):
"""Build argument docs in python style.
arg_names : list of str
Argument names.
arg_types : list of str
Argument type information.
arg_descs : list of str
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for key, type_info, desc in zip(arg_names, arg_types, arg_descs):
if key in param_keys and remove_dup:
continue
if key == 'num_args':
continue
param_keys.add(key)
ret = '%s : %s' % (key, type_info)
if len(desc) != 0:
ret += '\n ' + desc
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
def _notify_shutdown():
"""Notify MXNet about a shutdown."""
check_call(_LIB.MXNotifyShutdown())
atexit.register(_notify_shutdown)
def add_fileline_to_docstring(module, incursive=True):
"""Append the definition position to each function contained in module.
Examples
--------
# Put the following codes at the end of a file
add_fileline_to_docstring(__name__)
"""
def _add_fileline(obj):
"""Add fileinto to a object.
"""
if obj.__doc__ is None or 'From:' in obj.__doc__:
return
fname = inspect.getsourcefile(obj)
if fname is None:
return
try:
line = inspect.getsourcelines(obj)[-1]
except IOError:
return
obj.__doc__ += '\n\nFrom:%s:%d' % (fname, line)
if isinstance(module, str):
module = sys.modules[module]
for _, obj in inspect.getmembers(module):
if inspect.isbuiltin(obj):
continue
if inspect.isfunction(obj):
_add_fileline(obj)
if inspect.ismethod(obj):
_add_fileline(obj.__func__)
if inspect.isclass(obj) and incursive:
add_fileline_to_docstring(obj, False)