| # coding: utf-8 |
| # pylint: disable=invalid-name, protected-access, too-many-arguments |
| # pylint: disable=global-statement, unused-import |
| """NDArray configuration API.""" |
| from __future__ import absolute_import as _abs |
| |
| import ctypes |
| import sys as _sys |
| import numpy as np |
| |
| from ..base import _LIB |
| from ..base import c_array, py_str, c_str, mx_uint, _Null |
| from ..base import NDArrayHandle, OpHandle |
| from ..base import check_call |
| from ..ndarray_doc import _build_doc |
| from .common import CachedOp |
| |
| |
| class NDArrayBase(object): |
| """Base data structure for ndarray""" |
| __slots__ = ["handle", "writable"] |
| # pylint: disable= no-member |
| def __init__(self, handle, writable=True): |
| """initialize a new NDArray |
| |
| Parameters |
| ---------- |
| handle : NDArrayHandle |
| NDArray handle of C API |
| """ |
| if handle is not None: |
| assert isinstance(handle, NDArrayHandle) |
| self.handle = handle |
| self.writable = writable |
| |
| def __del__(self): |
| check_call(_LIB.MXNDArrayFree(self.handle)) |
| |
| def __reduce__(self): |
| return (_ndarray_cls, (None,), self.__getstate__()) |
| |
| |
| _ndarray_cls = None |
| |
| def _set_ndarray_class(cls): |
| """Set the symbolic class to be cls""" |
| global _ndarray_cls |
| _ndarray_cls = cls |
| |
| |
| def _imperative_invoke(handle, ndargs, keys, vals, out): |
| """ctypes implementation of imperative invoke wrapper""" |
| if out is not None: |
| original_output = out |
| if isinstance(out, NDArrayBase): |
| out = (out,) |
| num_output = ctypes.c_int(len(out)) |
| output_vars = c_array(NDArrayHandle, [i.handle for i in out]) |
| output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle)) |
| else: |
| original_output = None |
| output_vars = ctypes.POINTER(NDArrayHandle)() |
| num_output = ctypes.c_int(0) |
| |
| check_call(_LIB.MXImperativeInvoke( |
| ctypes.c_void_p(handle), |
| ctypes.c_int(len(ndargs)), |
| c_array(NDArrayHandle, [arr.handle for arr in ndargs]), |
| ctypes.byref(num_output), |
| ctypes.byref(output_vars), |
| ctypes.c_int(len(keys)), |
| c_array(ctypes.c_char_p, [c_str(key) for key in keys]), |
| c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]))) |
| |
| if original_output is not None: |
| return original_output |
| if num_output.value == 1: |
| return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) |
| else: |
| return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) |
| for i in range(num_output.value)] |
| |
| |
| def invoke(cached_op, args, out=None, name=None): # pylint: disable=unused-argument |
| """ctypes implementation of imperative invoke wrapper""" |
| if out is not None: |
| original_output = out |
| if isinstance(out, NDArrayBase): |
| out = (out,) |
| num_output = ctypes.c_int(len(out)) |
| output_vars = c_array(NDArrayHandle, [i.handle for i in out]) |
| output_vars = ctypes.cast(output_vars, ctypes.POINTER(NDArrayHandle)) |
| else: |
| original_output = None |
| output_vars = ctypes.POINTER(NDArrayHandle)() |
| num_output = ctypes.c_int(0) |
| |
| check_call(_LIB.MXCachedInvoke( |
| cached_op.handle, |
| ctypes.c_int(len(args)), |
| c_array(NDArrayHandle, [arr.handle for arr in args]), |
| ctypes.byref(num_output), |
| ctypes.byref(output_vars))) |
| |
| if original_output is not None: |
| return original_output |
| if num_output.value == 1: |
| return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) |
| else: |
| return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) |
| for i in range(num_output.value)] |