blob: 96005abffa9c2790e0a76280c95bdbcbcedaf103 [file] [log] [blame]
from __future__ import absolute_import as _abs
import sys as _sys
import ctypes as _ctypes
import numpy as np
from ..ndarray_doc import _build_doc
include "./base.pyi"
cdef class NDArrayBase:
"""Symbol is symbolic graph."""
# handle for symbolic operator.
cdef NDArrayHandle chandle
cdef int cwritable
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <SymbolHandle>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
property writable:
def __get__(self):
return bool(self.cwritable)
def __init__(self, handle, writable=True):
self._set_handle(handle)
self.cwritable = writable
def __dealloc__(self):
CALL(MXNDArrayFree(self.chandle))
def __reduce__(self):
return (_ndarray_cls, (None,), self.__getstate__())
_ndarray_cls = NDArrayBase
cdef _set_ndarray_class(cls):
global _ndarray_cls
_ndarray_cls = cls
cdef NewArray(NDArrayHandle handle):
"""Create a new array given handle"""
nd = _ndarray_cls(None)
(<NDArrayBase>nd).chandle = handle
(<NDArrayBase>nd).cwritable = True
return nd
cdef _make_ndarray_function(OpHandle handle, string name):
"""Create a NDArray function from the FunctionHandle."""
cdef const char *real_name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
cdef const char** arg_types
cdef const char** arg_descs
cdef const char* return_type
cdef const char* key_var_num_args
CALL(MXSymbolGetAtomicSymbolInfo(
handle, &real_name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&key_var_num_args, &return_type))
func_name = py_str(name.c_str())
key_vargs = py_str(key_var_num_args)
num_args = int(num_args)
doc_str = _build_doc(func_name,
py_str(desc),
[py_str(arg_names[i]) for i in range(num_args)],
[py_str(arg_types[i]) for i in range(num_args)],
[py_str(arg_descs[i]) for i in range(num_args)],
key_vargs,
py_str(return_type) if return_type != NULL else '')
func_hint = func_name.lower()
arguments = []
for i in range(num_args):
dtype = py_str(arg_types[i])
if not (dtype.startswith('NDArray') or dtype.startswith('Symbol')):
arguments.append(py_str(arg_names[i]))
num_param_args = len(arguments)
# Definition of internal functions.
def generic_ndarray_function(*args, **kwargs):
"""Invoke this function by passing in parameters
Parameters
----------
*args
Positional arguments of input scalars and NDArray
out : NDArray or tuple of NDArray, optional
Output NDArray, used to hold the output result.
Returns
-------
out : NDArray
The result NDArray(tuple) of result of computation.
"""
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef vector[NDArrayHandle] nd_args
cdef vector[NDArrayHandle] output_vars
cdef NDArrayHandle* p_output_vars
cdef NDArrayHandle ret_handle
cdef int pos_param_arg
cdef int num_output
pos_param_arg = 0
for v in args:
if isinstance(v, NDArrayBase):
nd_args.push_back((<NDArrayBase>v).chandle)
else:
if pos_param_arg >= num_param_args:
raise ValueError("Too many positional arguments")
if arguments[pos_param_arg] == 'dtype':
sparam_vals.push_back(c_str(np.dtype(v).name))
else:
sparam_vals.push_back(c_str(str(v)))
sparam_keys.push_back(c_str(arguments[pos_param_arg]))
pos_param_arg = pos_param_arg + 1
original_output = None
for k, v in kwargs.iteritems():
if k == "out":
original_output = v
if isinstance(v, NDArrayBase):
output_vars.push_back((<NDArrayBase>v).chandle)
else:
for item in v:
if not isinstance(item, NDArrayBase):
raise ValueError("out need to be of type NDArray")
output_vars.push_back((<NDArrayBase>v).chandle)
elif k == 'dtype':
sparam_vals.push_back(c_str(np.dtype(v).name))
sparam_keys.push_back(c_str(k))
else:
sparam_vals.push_back(c_str(str(v)))
sparam_keys.push_back(c_str(k))
num_output = output_vars.size()
if output_vars.size() == 0:
output_vars.resize(1)
p_output_vars = NULL
else:
p_output_vars = &output_vars[0]
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
CALL(MXImperativeInvoke(
handle,
<int>nd_args.size(),
&nd_args[0] if nd_args.size() != 0 else NULL,
&num_output,
&p_output_vars,
<int>param_keys.size(),
CBeginPtr(param_keys),
CBeginPtr(param_vals)))
if original_output is not None:
return original_output
if num_output == 1:
return NewArray(p_output_vars[0])
else:
return tuple(NewArray(p_output_vars[i]) for i in range(num_output))
# End of function declaration
generic_ndarray_function.__name__ = func_name
generic_ndarray_function.__doc__ = doc_str
generic_ndarray_function.__module__ = 'mxnet.ndarray'
return generic_ndarray_function
def _init_ndarray_module(nd_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
cdef const char** op_name_ptrs
cdef nn_uint size
cdef vector[string] op_names
cdef OpHandle handle
_set_ndarray_class(nd_class)
CALL(MXListAllOpNames(&size, &op_name_ptrs))
for i in range(size):
op_names.push_back(string(op_name_ptrs[i]))
module_obj = _sys.modules["%s.ndarray" % root_namespace]
module_internal = _sys.modules["%s._ndarray_internal" % root_namespace]
for i in range(op_names.size()):
CALL(NNGetOpHandle(op_names[i].c_str(), &handle))
function = _make_ndarray_function(handle, op_names[i])
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)