| #!/usr/bin/env python |
| |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: disable=too-many-lines, unused-argument |
| """numpy ndarray and util functions.""" |
| |
| |
| try: |
| from __builtin__ import all as py_all |
| from __builtin__ import slice as py_slice |
| except ImportError: |
| from builtins import all as py_all |
| from builtins import slice as py_slice |
| |
| from array import array as native_array |
| import functools |
| import ctypes |
| import sys |
| import datetime |
| import warnings |
| import numpy as _np |
| from .. import _deferred_compute as dc |
| from ..autograd import is_recording |
| from ..ndarray import NDArray, dtype_np_to_mx, _GRAD_REQ_MAP |
| from ..ndarray import indexing_key_expand_implicit_axes, get_indexing_dispatch_code,\ |
| get_oshape_of_gather_nd_op |
| from ..ndarray._internal import _set_np_ndarray_class |
| from . import _op as _mx_np_op |
| from ..base import check_call, _LIB, NDArrayHandle, c_array, mx_int, mx_int64 |
| from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types |
| from ..runtime import Features |
| from ..device import Device |
| from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\ |
| is_np_default_dtype, wrap_ctx_to_device_func,\ |
| dtype_from_number, wrap_data_api_statical_func,\ |
| wrap_sort_functions |
| from ..device import current_device |
| from ..ndarray import numpy as _mx_nd_np |
| from ..ndarray.numpy import _internal as _npi |
| from ..ndarray.ndarray import _storage_type |
| from ..dlpack import ndarray_from_numpy, ndarray_to_dlpack_for_write, DLDeviceType,\ |
| ndarray_from_dlpack |
| from .utils import _get_np_op |
| from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import |
| from . import fallback |
| |
| |
| __all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape', 'median', |
| 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'all', 'any', 'broadcast_to', |
| 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'pow', 'power', 'bitwise_not', |
| 'delete', 'trace', 'transpose', 'copy', 'moveaxis', 'reshape', 'dot', |
| 'arctan2', 'atan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'bitwise_invert', 'invert', |
| 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'asin', 'arccos', 'acos', 'arctan', |
| 'atan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', |
| 'negative', 'histogram', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'asinh', |
| 'arccosh', 'acosh', 'arctanh', 'atanh', 'append', 'argsort', 'sort', 'tensordot', 'eye', 'linspace', |
| 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', |
| 'dsplit', 'flatnonzero', 'tril_indices', 'concatenate', 'concat', 'stack', 'vstack', 'row_stack', |
| 'column_stack', 'hstack', 'dstack', 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', |
| 'amax', 'amin', 'max', 'min', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert', |
| 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', |
| 'logical_and', 'logical_or', 'logical_xor', |
| 'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot', |
| 'triu_indices_from', 'triu_indices', 'tri', |
| 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', |
| 'unique', 'lcm', 'gcd', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', |
| 'cross', 'kron', 'equal', 'not_equal', 'interp', |
| 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', |
| 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul', |
| 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', |
| 'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze', |
| 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal', |
| 'positive', 'logaddexp', 'floor_divide', 'permute_dims', 'bitwise_left_shift', 'bitwise_right_shift', |
| 'asarray', 'from_dlpack'] |
| |
| __all__ += fallback.__all__ |
| |
| # Return code for dispatching indexing function call |
| _NDARRAY_UNSUPPORTED_INDEXING = -1 |
| _NDARRAY_BASIC_INDEXING = 0 |
| _NDARRAY_ADVANCED_INDEXING = 1 |
| _NDARRAY_EMPTY_TUPLE_INDEXING = 2 |
| |
| # Return code for 0-d boolean array handler |
| _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1 |
| _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0 |
| _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1 |
| _SIGNED_INT32_UPPER_LIMIT = (2**31 - 1) |
| |
| # Caching whether MXNet was built with INT64 support or not |
| _INT64_TENSOR_SIZE_ENABLED = None |
| |
| def _int64_enabled(): |
| global _INT64_TENSOR_SIZE_ENABLED |
| if _INT64_TENSOR_SIZE_ENABLED is None: |
| _INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE') |
| return _INT64_TENSOR_SIZE_ENABLED |
| |
| # This function is copied from ndarray.py since pylint |
| # keeps giving false alarm error of undefined-all-variable |
| def _new_alloc_handle(shape, device, delay_alloc, dtype=mx_real_t): # pylint: disable=redefined-outer-name |
| """Return a new handle with specified shape and device. |
| |
| Empty handle is only used to hold results. |
| |
| Returns |
| ------- |
| handle |
| A new empty `ndarray` handle. |
| """ |
| hdl = NDArrayHandle() |
| if _int64_enabled(): |
| check_call(_LIB.MXNDArrayCreate64( |
| c_array_buf(mx_int64, native_array('q', shape)), |
| ctypes.c_int(len(shape)), |
| ctypes.c_int(device.device_typeid), |
| ctypes.c_int(device.device_id), |
| ctypes.c_int(int(delay_alloc)), |
| ctypes.c_int(int(dtype_np_to_mx(dtype))), |
| ctypes.byref(hdl))) |
| else: |
| # When shape is larger than uint32 then there is an overflow error at python end itself. |
| # It needs to be caught here since the call doesn't even reach backend. |
| array_size = 1 |
| for idx in shape: |
| array_size = array_size * idx |
| if array_size > _SIGNED_INT32_UPPER_LIMIT: |
| raise Exception("[_new_alloc_handle] Size of tensor you are trying to allocate is " + |
| "larger than 2^31 elements. Please build with flag " + |
| "USE_INT64_TENSOR_SIZE=1") |
| check_call(_LIB.MXNDArrayCreate( |
| c_array_buf(mx_uint, native_array('I', shape)), |
| mx_uint(len(shape)), |
| ctypes.c_int(device.device_typeid), |
| ctypes.c_int(device.device_id), |
| ctypes.c_int(int(delay_alloc)), |
| ctypes.c_int(int(dtype_np_to_mx(dtype))), |
| ctypes.byref(hdl))) |
| return hdl |
| |
| |
| def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name |
| """Returns a **view** of this array with a new shape without altering any data. |
| |
| Parameters |
| ---------- |
| shape : tuple of int, or n ints |
| The new shape should not change the array size, namely |
| ``np.prod(new_shape)`` should be equal to ``np.prod(a.shape)``. |
| Some dimensions of the shape can take special value -1, which |
| infers the dimension of the output shape by using the remainder of the |
| input dimensions keeping the size of the new array same as that of the input array. |
| At most one dimension of shape can be -1. |
| |
| Returns |
| ------- |
| ndarray |
| An array with desired shape that shares data with this array. |
| """ |
| if len(shape) == 1 and isinstance(shape[0], (list, tuple)): |
| shape = shape[0] |
| handle = NDArrayHandle() |
| check_call(_LIB.MXNDArrayReshape64(a.handle, |
| len(shape), |
| c_array(ctypes.c_int64, shape), |
| False, |
| ctypes.byref(handle))) |
| return ndarray(handle=handle, writable=a.writable) |
| |
| def _as_mx_np_array(object, device=None, zero_copy=False): |
| """Convert arrays or any array member of container to mxnet.numpy.ndarray on device.""" |
| if object is None or isinstance(object, ndarray): |
| return object |
| elif isinstance(object, _np.ndarray): |
| from_numpy = ndarray_from_numpy(ndarray, array) |
| return from_numpy(object, zero_copy and object.flags['C_CONTIGUOUS']) |
| elif isinstance(object, (integer_types, numeric_types)): |
| return object |
| elif isinstance(object, (_np.bool_, _np.bool)): |
| return array(object, dtype=_np.bool_, device=device) |
| elif isinstance(object, (list, tuple)): |
| tmp = [_as_mx_np_array(arr, device=device, zero_copy=zero_copy) for arr in object] |
| return object.__class__(tmp) |
| else: |
| raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object)))) |
| |
| |
| def _as_onp_array(object, cur_device=None): |
| """Convert object to numpy.ndarray.""" |
| def _update_device(cur_device, tmp_device): |
| if cur_device is None: |
| cur_device = tmp_device |
| elif tmp_device is not None and cur_device != tmp_device: |
| raise ValueError('Ambiguous to set the device for the output ndarray since' # pylint: disable=too-few-format-args |
| ' input ndarrays are allocated on different devices: {} and {}' |
| .format(str(cur_device, tmp_device))) |
| return cur_device |
| |
| if isinstance(object, ndarray): |
| return object.asnumpy(), object.device |
| elif isinstance(object, (list, tuple)): |
| tmp = [] |
| for arr in object: |
| arr, tmp_device = _as_onp_array(arr, cur_device) |
| tmp.append(arr) |
| cur_device = _update_device(cur_device, tmp_device) |
| return object.__class__(tmp), cur_device |
| elif isinstance(object, dict): |
| tmp = dict() |
| for key, value in object.items(): |
| value, tmp_device = _as_onp_array(value, cur_device) |
| tmp[key] = value |
| cur_device = _update_device(cur_device, tmp_device) |
| return object.__class__(tmp), cur_device |
| else: |
| return object, cur_device |
| |
| |
| # Have to use 0 as default value for stype since pylint does not allow |
| # importing _STORAGE_TYPE_DEFAULT from ndarray.py. |
| def _np_ndarray_cls(handle, writable=True, stype=0): |
| if stype == -1: |
| stype = _storage_type(handle) |
| if stype != 0: |
| raise ValueError('_np_ndarray_cls currently only supports default storage ' |
| 'type, while received stype = {}'.format(stype)) |
| return ndarray(handle, writable=writable) |
| |
| |
| _set_np_ndarray_class(_np_ndarray_cls) |
| |
| _NUMPY_ARRAY_FUNCTION_DICT = {} |
| _NUMPY_ARRAY_UFUNC_DICT = {} |
| _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD = {} |
| _FALLBACK_ARRAY_UFUNC_WARNED_RECORD = {} |
| |
| def wrap_mxnp_np_ufunc(func): |
| """ |
| A convenience decorator for wrapping for python overload-able ops to provide type |
| casting for mixed use of mx_np and onp inputs. |
| |
| Parameters |
| ---------- |
| func : a python overload-able binary function to be wrapped for type casting. |
| |
| Returns |
| ------- |
| Function |
| A function wrapped with type casted. |
| """ |
| @functools.wraps(func) |
| def _wrap_mxnp_np_ufunc(x1, x2): |
| if isinstance(x2, _np.ndarray): |
| x2 = _as_mx_np_array(x2, device=x1.device) |
| return func(x1, x2) |
| return _wrap_mxnp_np_ufunc |
| |
| @set_module('mxnet.numpy') |
| class ndarray(NDArray): # pylint: disable=invalid-name |
| """ |
| ndarray(handle, writable=True): |
| |
| An array object represents a multidimensional, homogeneous array of fixed-size items. |
| An associated data-type object describes the format of each element in the array |
| (its byte-order, how many bytes it occupies in memory, whether it is an integer, a |
| floating point number, or something else, etc.). Arrays should be constructed using |
| `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported. |
| |
| Arrays should be constructed using `array`, `zeros` or `empty` (refer |
| to the See Also section below). The parameters given here refer to |
| a low-level method (`ndarray(...)`) for instantiating an array. |
| |
| For more information, refer to the `mxnet.numpy` module and examine the |
| methods and attributes of an array. |
| |
| Parameters |
| ---------- |
| handle: int |
| The ndarray handle in backend (C++). |
| writable: bool |
| Indicates whether inplace-assignment is allowed for the array. |
| |
| Attributes |
| ---------- |
| T : ndarray |
| Transpose of the array. |
| dtype : dtype object |
| Describes the format of the elements in the array. |
| size : int |
| Number of elements in the array. |
| ndim : int |
| The array's number of dimensions. |
| shape : tuple of ints |
| Shape of the array. |
| |
| See Also |
| -------- |
| array : Construct an array. |
| zeros : Create an array, each element of which is zero. |
| empty : Create an array, but leave its allocated memory unchanged (i.e., |
| it contains "garbage"). |
| """ |
| |
| @staticmethod |
| def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pylint: disable=bad-staticmethod-argument |
| """ |
| Dispatch official NumPy unary/binary operator calls on mxnet.numpy.ndarray |
| to this function. The operators must comply with the ufunc definition in NumPy. |
| The following code is adapted from CuPy. |
| Casting rules for operator with mx_np and onp (inplace op will keep its type) |
| | Expression | a type | b type | out type| |
| | --- | --- | --- | --- | |
| | `a += b` | onp | mx_np | onp | |
| | `a += b` | mx_np | onp | mx_np | |
| | `c = a + b` | onp | mx_np | mx_np | |
| | `c = a + b` | mx_np | onp | mx_np | |
| """ |
| ufunc_list = ["add", "subtract", "multiply", "divide", "true_divide", "floor_divide", "power", |
| "remainder", "bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift", |
| "greater", "greater_equal", "less", "less_equal", "not_equal", "equal", "matmul"] |
| if 'out' in kwargs: |
| # need to unfold tuple argument in kwargs |
| out = kwargs['out'] |
| if len(out) != 1: |
| raise ValueError('The `out` parameter must have exactly one ndarray') |
| kwargs['out'] = out[0] |
| |
| if method == '__call__': |
| name = ufunc.__name__ |
| mx_ufunc = _NUMPY_ARRAY_UFUNC_DICT.get(name, None) |
| onp_op = _get_np_op(name) |
| if mx_ufunc is None: |
| # try to fallback to official NumPy op |
| if is_recording(): |
| raise ValueError("Falling back to NumPy operator {} with autograd active is not supported." |
| "Please consider moving the operator to the outside of the autograd scope.")\ |
| .format(name) |
| new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs] |
| if onp_op not in _FALLBACK_ARRAY_UFUNC_WARNED_RECORD: |
| import logging |
| logging.warning("np.%s is a fallback operator, " |
| "which is actually using official numpy's implementation", name) |
| _FALLBACK_ARRAY_UFUNC_WARNED_RECORD[onp_op] = True |
| out = onp_op(*new_inputs, **kwargs) |
| return _as_mx_np_array(out, device=inputs[0].device) |
| # ops with np mx_np |
| elif name in ufunc_list and isinstance(inputs[0], _np.ndarray): |
| # inplace |
| if 'out' in kwargs: |
| new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs] |
| return onp_op(*new_inputs, **kwargs) |
| else: |
| new_inputs = [_as_mx_np_array(arg, device=inputs[1].device) |
| if isinstance(arg, _np.ndarray) else arg for arg in inputs] |
| return mx_ufunc(*new_inputs, **kwargs) |
| else: |
| return mx_ufunc(*inputs, **kwargs) |
| else: |
| return NotImplemented |
| |
| @staticmethod |
| def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-staticmethod-argument |
| """ |
| Dispatch official NumPy operators that comply with the array function protocol to |
| this function. |
| """ |
| mx_np_func = _NUMPY_ARRAY_FUNCTION_DICT.get(func, None) |
| func_name = func.__name__ |
| if mx_np_func is None: |
| # try to fallback to official NumPy op |
| if is_recording(): |
| raise ValueError("Falling back to NumPy operator {} with autograd active is not supported." |
| "Please consider moving the operator to the outside of the autograd scope.")\ |
| .format(func) |
| cur_device = None |
| new_args, cur_device = _as_onp_array(args, cur_device) |
| new_kwargs, cur_device = _as_onp_array(kwargs, cur_device) |
| if cur_device is None: |
| raise ValueError('Unknown device for the input ndarrays. It is probably a bug. Please' |
| ' create an issue on GitHub.') |
| if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD: |
| import logging |
| logging.warning("np.%s is a fallback operator, " |
| "which is actually using official numpy's implementation.", func_name) |
| _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD[func] = True |
| out = func(*new_args, **new_kwargs) |
| return _as_mx_np_array(out, device=cur_device) |
| else: |
| if py_all(issubclass(t, ndarray) for t in types): |
| return mx_np_func(*args, **kwargs) |
| else: |
| try: |
| cur_device = next(a.device for a in args if hasattr(a, 'device')) |
| except StopIteration: |
| cur_device = next(a.device for a in kwargs.values() if hasattr(a, 'device')) |
| new_args = _as_mx_np_array(args, device=cur_device, |
| zero_copy=func_name in {'may_share_memory', 'shares_memory'}) |
| new_kwargs = {k: _as_mx_np_array(v, cur_device) for k, v in kwargs.items()} |
| return mx_np_func(*new_args, **new_kwargs) |
| |
| |
| def __array_namespace__(self, api_version=None): |
| """ |
| Returns an object that has all the array API functions on it. |
| |
| Notes |
| ----- |
| This is a standard API in |
| https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none. |
| |
| Parameters |
| ---------- |
| self : ndarray |
| The indexing key. |
| api_version : Optional, string |
| string representing the version of the array API specification to be returned, in `YYYY.MM` form. |
| If it is None, it should return the namespace corresponding to latest version of the array API |
| specification. |
| """ |
| if api_version is not None: |
| try: |
| date = datetime.datetime.strptime(api_version, '%Y.%m') |
| if date.year != 2021: |
| raise ValueError |
| except ValueError: |
| raise ValueError(f"Unrecognized array API version: {api_version!r}") |
| return sys.modules[self.__module__] |
| |
| |
| def __dlpack__(self, stream=None): |
| """Exports the array for consumption by from_dlpack() as a DLPack capsule. |
| |
| Parameters |
| ---------- |
| stream : int, optional |
| A Python integer representing a pointer to a stream (CUDA or ROCm). |
| Stream is provided by the consumer to the producer to instruct the producer |
| to ensure that operations can safely be performed on the array. The pointer must |
| be positive integer or -1. If stream is -1, the value must be used by the consumer |
| to signal "producer must not perform any synchronization". |
| |
| Returns |
| ------- |
| capsule : PyCapsule |
| A DLPack capsule for the array, containing a DLPackManagedTensor. |
| """ |
| if stream is not None: |
| if type(stream) is not int: |
| raise TypeError('The input stream must be int or None') |
| if self.device.device_type != "gpu": |
| raise ValueError('Stream {} is not supported in current device {}'\ |
| .format(stream, self.device.device_type)) |
| if stream != -1: |
| check_call(_LIB.MXPushStreamDep(self.handle, ctypes.c_int64(stream))) |
| to_dlpack_write = ndarray_to_dlpack_for_write() |
| return to_dlpack_write(self) |
| |
| |
| def __dlpack_device__(self): |
| """Returns device type and device ID in DLPack format""" |
| devtype_map = {'cpu': DLDeviceType.DLCPU, |
| 'gpu': DLDeviceType.DLGPU, |
| 'cpu_pinned': DLDeviceType.DLCPUPINNED} |
| if self.device.device_type not in devtype_map: |
| raise ValueError('Unkown device type {} for DLPack'.format(self.device.device_type)) |
| return (devtype_map[self.device.device_type], self.device.device_id) |
| |
| |
| def _get_np_basic_indexing(self, key): |
| """ |
| This function indexes ``self`` with a tuple of `slice` objects only. |
| """ |
| key_nd = tuple(idx for idx in key if idx is not None) |
| if len(key_nd) < self.ndim: |
| raise RuntimeError( |
| 'too few indices after normalization: expected `ndim` ({}) ' |
| 'but got {}. This is a bug, please report it!' |
| ''.format(self.ndim, len(key_nd)) |
| ) |
| if len(key_nd) > self.ndim: |
| raise IndexError( |
| 'too many indices ({}) for array with {} dimensions' |
| ''.format(len(key_nd), self.ndim) |
| ) |
| |
| none_axes = [ax for ax in range(len(key)) if key[ax] is None] # pylint: disable=invalid-name |
| slc_key, int_axes = self._basic_indexing_key_int_to_slice(key_nd) |
| new_axes = self._new_axes_after_basic_indexing(none_axes, key) |
| |
| # Check bounds for integer axes |
| for ax in int_axes: # pylint: disable=invalid-name |
| if not -self.shape[ax] <= key_nd[ax] < self.shape[ax]: |
| raise IndexError( |
| 'index {} is out of bounds for axis {} with size {}' |
| ''.format(key_nd[ax], ax, self.shape[ax])) |
| |
| if self._basic_indexing_slice_is_contiguous(slc_key, self.shape): |
| # Create a shared-memory view by using low-level flat slicing |
| flat_begin, flat_end = self._basic_indexing_contiguous_flat_begin_end( |
| slc_key, self.shape |
| ) |
| handle = NDArrayHandle() |
| flat_self = self.reshape_view(-1) |
| if _int64_enabled(): |
| check_call( |
| _LIB.MXNDArraySlice64( |
| flat_self.handle, |
| ctypes.c_int64(flat_begin), |
| ctypes.c_int64(flat_end), |
| ctypes.byref(handle), |
| ) |
| ) |
| else: |
| check_call( |
| _LIB.MXNDArraySlice( |
| flat_self.handle, |
| ctypes.c_uint32(flat_begin), |
| ctypes.c_uint32(flat_end), |
| ctypes.byref(handle), |
| ) |
| ) |
| sliced_shape = self._basic_indexing_sliced_shape(slc_key, self.shape) |
| sliced = self.__class__(handle=handle, writable=self.writable) |
| if 0 in sliced_shape: |
| sliced = sliced.reshape(sliced_shape) |
| else: |
| sliced = sliced.reshape_view(sliced_shape) |
| |
| else: |
| begin, end, step = self._basic_indexing_key_to_begin_end_step( |
| slc_key, self.shape, keep_none=True |
| ) |
| sliced = _npi.slice(self, begin, end, step) |
| |
| # Reshape to final shape due to integer and `None` entries in `key`. |
| final_shape = [sliced.shape[i] for i in range(sliced.ndim) if i not in int_axes] |
| for ax in new_axes: # pylint: disable=invalid-name |
| final_shape.insert(ax, 1) |
| |
| if sliced.size == 0: |
| return sliced.reshape(tuple(final_shape)) |
| else: |
| return sliced.reshape_view(tuple(final_shape)) |
| |
| def _get_np_empty_tuple_indexing(self, key): |
| new_shape = [] |
| num_none = 0 |
| for i, idx in enumerate(key): |
| if idx is None: |
| new_shape.append(1) # expand dimension |
| num_none += 1 |
| elif idx == (): |
| new_shape.append(0) # 0 shape |
| elif idx == slice(None, None, None): |
| new_shape.append(self.shape[i - num_none]) |
| return empty(new_shape, dtype=self.dtype) |
| |
| def _get_np_advanced_indexing(self, key): |
| idcs, new_axes = self._get_index_nd(key) |
| if type(idcs) == NDArray: # pylint: disable=unidiomatic-typecheck |
| idcs = idcs.as_np_ndarray() |
| else: |
| idcs = _mx_nd_np.stack([i if isinstance(i, self.__class__) else i.as_np_ndarray() for i in idcs]) |
| sliced = _npi.gather_nd(self, idcs) |
| # Reshape due to `None` entries in `key`. |
| if new_axes: |
| final_shape = [sliced.shape[i] for i in range(sliced.ndim)] |
| for ax in new_axes: # pylint: disable=invalid-name |
| final_shape.insert(ax, 1) |
| return sliced.reshape(tuple(final_shape)) |
| else: |
| return sliced |
| |
| def _set_np_advanced_indexing(self, key, value): |
| """This function is called by __setitem__ when key is an advanced index.""" |
| idcs, new_axes = self._get_index_nd(key) |
| if type(idcs) == NDArray: # pylint: disable=unidiomatic-typecheck |
| idcs = idcs.as_np_ndarray() |
| else: |
| idcs = _mx_nd_np.stack([i if isinstance(i, self.__class__) else i.as_np_ndarray() for i in idcs]) |
| vshape = get_oshape_of_gather_nd_op(self.shape, idcs.shape) |
| value_nd = self._prepare_value_nd(value, bcast_shape=vshape, squeeze_axes=new_axes) |
| self._scatter_set_nd(value_nd, idcs) |
| |
| # pylint: disable=redefined-outer-name |
| def _get_np_boolean_indexing(self, key, ndim, shape): |
| """ |
| There are two types of boolean indices (which are equivalent, |
| for the most part though). This function will handle single |
| boolean indexing for higher speed. |
| If this is not the case, it is instead expanded into (multiple) |
| integer array indices and will be handled by advanced indexing. |
| """ |
| key_shape = key.shape |
| key_ndim = len(key_shape) |
| if ndim < key_ndim: |
| raise IndexError('too many indices, whose ndim = {}, for array with ndim = {}' |
| .format(key_ndim, ndim)) |
| for i in range(key_ndim): |
| if key_shape[i] != shape[i]: |
| raise IndexError('boolean index did not match indexed array along dimension {};' |
| ' dimension is {} but corresponding boolean dimension is {}' |
| .format(i, shape[i], key_shape[i])) |
| remaining_dims = shape[key_ndim:] |
| data = _reshape_view(self, -1, *remaining_dims) |
| key = _reshape_view(key, -1) |
| if data.size == 0 and key.size == 0: |
| return data |
| return _reshape_view(_npi.boolean_mask(data, key), -1, *remaining_dims) |
| |
| def _set_np_boolean_indexing(self, key, value): |
| """ |
| There are two types of boolean indices (which are equivalent, |
| for the most part though). This function will handle single boolean assign for higher speed. |
| If this is not the case, it is instead expanded into (multiple) |
| integer array indices and will be handled by advanced assign. |
| """ |
| if isinstance(value, numeric_types): |
| _npi.boolean_mask_assign_scalar(data=self, mask=key, |
| value=int(value) if isinstance(value, bool) else value, |
| start_axis=0, out=self) |
| elif isinstance(value, ndarray): |
| _npi.boolean_mask_assign_tensor(data=self, mask=key, value=value, start_axis=0, out=self) |
| else: |
| raise NotImplementedError(f'type {type(value)} is not supported.') |
| |
| # pylint: disable=too-many-return-statements |
| def __getitem__(self, key): |
| """Return self[key]. |
| |
| Returns a sliced view of this array if the elements fetched are contiguous in memory; |
| otherwise, returns a newly created NDArray. |
| This functions supports advanced indexing defined in the following reference with |
| some restrictions. Boolean indexing is supported only for a single boolean ndarray |
| as a key. Mixing boolean ndarray with other index types is not supported in ``advanced`` |
| indexing. |
| |
| For basic indexing, i.e., if ``key`` consists only of integers, |
| ``slice``, ``Ellipsis`` (``...``) and ``None``, a mutable view is |
| returned that shares memory with this array if the accessed portion is |
| contiguous in memory. |
| Otherwise, a newly created ``ndarray`` is returned. |
| |
| This functions supports advanced indexing as defined in `the NumPy |
| advanced indexing documentation |
| <https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_. |
| |
| Parameters |
| ---------- |
| key : int, slice, list, np.ndarray, mx.np.ndarray, or tuple of all previous types |
| Indexing key. |
| |
| Examples |
| -------- |
| The default is to give explicit indices for all axes: |
| |
| >>> x = np.arange(6).reshape(2, 3) |
| >>> x |
| array([[0., 1., 2.], |
| [3., 4., 5.]]) |
| >>> x[0, :2] |
| array([0., 1.]) |
| >>> x[:, :-1] |
| array([[0., 1.], |
| [3., 4.]]) |
| |
| If fewer indices are given, they are automatically supplemented by an |
| appropriate number of ``slice(None)`` ("``:``") to the right. For |
| instance, a single integer indexes along the first axis: |
| |
| >>> x[0] |
| array([0., 1., 2.]) |
| >>> x[1:] |
| array([[3., 4., 5.]]) |
| |
| To omit a range of axes that should be kept as-is, an `Ellipsis` |
| ("``...``") can be used: |
| |
| >>> x = np.arange(16).reshape(2, 2, 2, 2) |
| >>> x[0, ..., 1] |
| array([[1., 3.], |
| [5., 7.]]) |
| >>> x[0, :, :, 1] # equivalent |
| array([[1., 3.], |
| [5., 7.]]) |
| |
| New axes of length 1 can be created by inserting ``None`` |
| (`numpy.newaxis`) in the index: |
| |
| >>> x = np.arange(6).reshape(2, 3) |
| >>> x[None, :, :] |
| array([[[0., 1., 2.], |
| [3., 4., 5.]]]) |
| >>> x[None, :, :].shape |
| (1, 2, 3) |
| |
| If the indexed portion of the array is contiguous in memory, no data |
| is copied. Instead, a shared-memory view of the original array is |
| returned, and changes to that view affect the original array: |
| |
| >>> x = np.arange(8).reshape(2, 2, 2) |
| >>> y = x[0] # contiguous |
| >>> y |
| array([[0., 1.], |
| [2., 3.]]) |
| >>> y[:] = -1 |
| >>> x |
| array([[[-1., -1.], |
| [-1., -1.]], |
| [[ 4., 5.], |
| [ 6., 7.]]]) |
| >>> x = np.arange(8).reshape(2, 2, 2) |
| >>> y = x[1, :1, :] # contiguous |
| >>> y |
| array([[4., 5.]]) |
| >>> y[:] = -1 |
| >>> x |
| array([[[ 0., 1.], |
| [ 2., 3.]], |
| [[-1., -1.], |
| [ 6., 7.]]]) |
| >>> x = np.arange(0, 8).reshape(2, 2, 2) |
| >>> y = x[:, :, 1] # not contiguous |
| >>> y |
| array([[1., 3.], |
| [5., 7.]]) |
| >>> y[:] = -1 |
| >>> x |
| array([[[0., 1.], |
| [2., 3.]], |
| [[4., 5.], |
| [6., 7.]]]) |
| |
| If the indexing key contains `list`, `numpy.ndarray` or `NDArray` |
| objects, advanced indexing is triggered, which always returns a |
| copy: |
| |
| >>> x = np.arange(8).reshape(2, 2, 2) |
| >>> x[[0, 1]] |
| array([[[0., 1.], |
| [2., 3.]], |
| [[4., 5.], |
| [6., 7.]]]) |
| >>> x[[0, 1], :] # equivalent |
| array([[[0., 1.], |
| [2., 3.]], |
| [[4., 5.], |
| [6., 7.]]]) |
| >>> y = np.array([0, 1], dtype='int32') |
| >>> x[1:, y] |
| array([[[4., 5.], |
| [6., 7.]]]) |
| >>> y = np.array([0, 1], dtype='int32') |
| >>> x[1:, y] |
| array([[[4., 5.], |
| [6., 7.]]]) |
| |
| Get negative elements in an ndarray through boolean array indexing |
| >>> x = np.array([1., -1., -2., 3]) |
| >>> x[x < 0] |
| array([-1., -2.]) |
| |
| For more imformation related to boolean indexing, please refer to |
| https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html. |
| """ |
| ndim = self.ndim # pylint: disable=redefined-outer-name |
| shape = self.shape # pylint: disable=redefined-outer-name |
| if isinstance(key, bool): # otherwise will be treated as 0 and 1 |
| key = array(key, dtype=_np.bool, device=self.device) |
| if isinstance(key, list): |
| try: |
| new_key = _np.array(key) |
| if new_key.dtype == _np.bool_: |
| key = new_key |
| except Exception as err: |
| raise TypeError('{}'.format(str(err))) |
| if isinstance(key, _np.ndarray): |
| if dc.is_deferred_compute(): |
| raise TypeError('Indexing with a numpy array is not supported in HybridBlock.') |
| if key.dtype == _np.bool_: |
| key = array(key, dtype='bool', device=self.device) |
| |
| # Handle single boolean index of matching dimensionality and size first for higher speed |
| # If the boolean array is mixed with other idices, it is instead expanded into (multiple) |
| # integer array indices and will be handled by advanced indexing. |
| # Come before the check self.dim == 0 as it also handle the 0-dim case. |
| if isinstance(key, ndarray) and key.dtype == _np.bool_: |
| return self._get_np_boolean_indexing(key, ndim, shape) |
| |
| all = __builtins__['all'] # `def all` below shadows the all builtin |
| if ndim == 0 and key != (): |
| raise IndexError('scalar tensor can only accept `()` as index') |
| # Handle simple cases for higher speed |
| if isinstance(key, tuple) and len(key) == 0: |
| return self |
| if isinstance(key, tuple) and len(key) == ndim\ |
| and py_all(isinstance(idx, integer_types) for idx in key): |
| out = self |
| for idx in key: |
| out = out[idx] |
| return out |
| if isinstance(key, integer_types): |
| # Equivalent to isinstance(key, integer_types) case in numpy/_symbol.py |
| if key > shape[0] - 1: |
| raise IndexError( |
| 'index {} is out of bounds for axis 0 with size {}'.format( |
| key, shape[0])) |
| return self._at(key) |
| elif isinstance(key, py_slice): |
| # Unlike numpy/_symbol.py, calls MXNDArraySlice64 writable memory |
| # sharing if key.step not in [None, 1]. Equivalent otherwise to |
| # isinstance(key, py_slice) case in _symbol.py otherwise. |
| if key.step is None or key.step == 1: |
| if key.start is not None or key.stop is not None: |
| return self._slice(key.start, key.stop) |
| else: |
| return self |
| elif key.step != 0: |
| start = [None] if key.start is None else key.start |
| stop = [None] if key.stop is None else key.stop |
| return _npi.slice(self, start, stop, key.step) |
| else: |
| raise ValueError("slice step cannot be zero") |
| elif isinstance(key, tuple) and \ |
| all((isinstance(arr, NDArray) and _np.issubdtype(arr.dtype, _np.integer) and \ |
| arr.ndim > 0) for arr in key): |
| # Equivalent case in numpy/_symbol.py |
| return _npi.advanced_indexing_multiple(self, _mx_nd_np.stack(key)) |
| elif isinstance(key, tuple) and dc.is_deferred_compute(): |
| # Equivalent to isinstance(key, tuple) case in numpy/_symbol.py |
| # Only enabled in deferred compute mode, as this codepath prevents |
| # memory sharing which may be desired in non-deferred compute |
| # imperative mode. |
| begin = [] |
| end = [] |
| step = [] |
| new_shape = () |
| assert len(key) # len(key) == 0 is handled a above |
| unsupported = False |
| for index in key: |
| if isinstance(index, py_slice): |
| if index.step is not None and index.step == 0: |
| raise ValueError("slice step cannot be zero") |
| begin.append(index.start) |
| end.append(index.stop) |
| step.append(index.step) |
| new_shape += (-2,) |
| elif isinstance(index, integer_types): |
| if index >= 0: |
| begin.append(index) |
| end.append(index+1) |
| step.append(1) |
| else: |
| begin.append(index) |
| end.append(index - 1) |
| step.append(-1) |
| new_shape += (-3,) |
| else: |
| unsupported = True |
| break |
| if not unsupported: |
| new_shape += (-4,) |
| sliced = _npi.slice(self, begin, end, step) |
| return _mx_nd_np.reshape(sliced, new_shape) |
| |
| # Special handling for cases only supported in imperative mode |
| if dc.is_deferred_compute(): |
| raise TypeError('The type of indexing used is not supported in HybridBlock.') |
| # For 0-d boolean indices: A new axis is added, |
| # but at the same time no axis is "used". So if we have True, |
| # we add a new axis (a bit like with np.newaxis). If it is |
| # False, we add a new axis, but this axis has 0 entries. |
| # prepend is defined to handle this case. |
| # prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d boolean scalar |
| # prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim must be expanded |
| # prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must be prepended |
| key, prepend = indexing_key_expand_implicit_axes(key, self.shape) |
| indexing_dispatch_code = get_indexing_dispatch_code(key) |
| if indexing_dispatch_code == _NDARRAY_EMPTY_TUPLE_INDEXING: |
| # won't be affected by zero-dim boolean indices |
| return self._get_np_empty_tuple_indexing(key) |
| elif indexing_dispatch_code == _NDARRAY_BASIC_INDEXING: |
| if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE: |
| return empty((0,) + self._get_np_basic_indexing(key).shape, |
| dtype=self.dtype, device=self.device) |
| if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE: |
| key = (_np.newaxis,) + key |
| return self._get_np_basic_indexing(key) |
| elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING: |
| if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE: |
| return empty((0,) + self._get_np_adanced_indexing(key).shape, |
| dtype=self.dtype, device=self.device) |
| if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE: |
| key = (_np.newaxis,) + key |
| return self._get_np_advanced_indexing(key) |
| else: |
| raise RuntimeError |
| |
| # pylint: disable=inconsistent-return-statements |
| def __setitem__(self, key, value): |
| """Sets ``self[key]`` to ``value``. |
| |
| This functions supports advanced indexing as defined in `the NumPy |
| advanced indexing documentation |
| <https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_, |
| with the restriction that boolean array indexing is not supported. |
| |
| Parameters |
| ---------- |
| key : int, slice, list, np.ndarray, mx.np.ndarray, or tuple of all previous types |
| The indexing key. |
| value : scalar or array-like object that can be broadcast to the shape of self[key] |
| The value to set. |
| |
| Examples |
| -------- |
| >>> x = np.zeros((2, 3)) |
| >>> x[:] = 1 |
| >>> x |
| array([[ 1., 1., 1.], |
| [ 1., 1., 1.]]) |
| >>> x[:, 1:2] = 2 |
| >>> x |
| array([[ 1., 2., 1.], |
| [ 1., 2., 1.]]) |
| >>> x[1:2, 1:] = 3 |
| >>> x |
| array([[ 1., 2., 1.], |
| [ 1., 3., 3.]]) |
| >>> x[1:, 0:2] = np.zeros((1, 2)) |
| >>> x |
| array([[ 1., 2., 1.], |
| [ 0., 0., 3.]]) |
| >>> x[1, 2] = 4 |
| >>> x |
| array([[ 1., 2., 1.], |
| [ 0., 0., 4.]]) |
| >>> x[[0], [1, 2]] = 5 |
| >>> x |
| array([[ 1., 5., 5.], |
| [ 0., 0., 4.]]) |
| >>> x[::-1, 0:2:2] = [6] |
| >>> x |
| array([[ 6., 5., 5.], |
| [ 6., 0., 4.]]) |
| |
| For imformation related to boolean indexing, please refer to |
| https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html. |
| """ |
| if isinstance(value, NDArray) and not isinstance(value, ndarray): |
| raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray') |
| if isinstance(key, bool): # otherwise will be treated as 0 and 1 |
| key = array(key, dtype=_np.bool) |
| |
| # Handle single boolean assign of matching dimensionality and size first for higher speed |
| # If the boolean array is mixed with other idices, it is instead expanded into (multiple) |
| # integer array indices and will be handled by advanced assign. |
| # Come before the check self.dim == 0 as it also handle the 0-dim case. |
| if isinstance(key, ndarray) and key.dtype == _np.bool: |
| return self._set_np_boolean_indexing(key, value) |
| |
| # handle basic and advanced indexing |
| if self.ndim == 0: |
| if not isinstance(key, tuple) or len(key) != 0: |
| raise IndexError('scalar tensor can only accept `()` as index') |
| if isinstance(value, numeric_types): |
| self._full(value) |
| elif isinstance(value, ndarray) and value.size == 1: |
| if value.shape != self.shape: |
| value = value.reshape(self.shape) |
| value.copyto(self) |
| elif isinstance(value, (_np.ndarray, _np.generic)) and value.size == 1: |
| if isinstance(value, _np.generic) or value.shape != self.shape: |
| value = value.reshape(self.shape) |
| self._sync_copyfrom(value) |
| else: |
| raise ValueError('setting an array element with a sequence.') |
| else: |
| # For 0-d boolean indices: A new axis is added, |
| # but at the same time no axis is "used". So if we have True, |
| # we add a new axis (a bit like with np.newaxis). If it is |
| # False, we add a new axis, but this axis has 0 entries. |
| # prepend is defined to handle this case. |
| # prepend == _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d boolean scalar |
| # prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim must be expanded |
| # prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must be expanded |
| # prepend actually has no influence on __setitem__ |
| key, prepend = indexing_key_expand_implicit_axes(key, self.shape) |
| if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE: |
| return # no action is needed |
| slc_key = tuple(idx for idx in key if idx is not None) |
| if len(slc_key) < self.ndim: |
| raise RuntimeError( |
| 'too few indices after normalization: expected `ndim` ({}) ' |
| 'but got {}. This is a bug, please report it!' |
| ''.format(self.ndim, len(slc_key)) |
| ) |
| if len(slc_key) > self.ndim and self.ndim != 0: |
| raise IndexError( |
| 'too many indices ({}) for array with {} dimensions' |
| ''.format(len(slc_key), self.ndim) |
| ) |
| indexing_dispatch_code = get_indexing_dispatch_code(slc_key) |
| if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING: |
| self._set_nd_basic_indexing(key, value) # function is inheritated from NDArray class |
| elif indexing_dispatch_code == _NDARRAY_EMPTY_TUPLE_INDEXING: |
| pass # no action needed |
| elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING: |
| self._set_np_advanced_indexing(key, value) |
| else: |
| raise ValueError( |
| 'Indexing NDArray with index {} of type {} is not supported' |
| ''.format(key, type(key)) |
| ) |
| |
| def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): |
| """Return a broadcast `ndarray` with same device and dtype as ``self``. |
| For setting item, The returned `ndarray` is squeezed according to squeeze_axes since the |
| value_nd is assigned to not yet expanded space in original array. |
| `value`: numeric types or array like. |
| `bcast_shape`: a shape tuple. |
| `squeeze_axes`: a sequence of axes to squeeze in the value array. |
| Note: mxnet.numpy.ndarray not support NDArray as assigned value. |
| """ |
| if isinstance(value, numeric_types): |
| value_nd = full(bcast_shape, value, device=self.device, dtype=self.dtype) |
| elif isinstance(value, self.__class__): |
| value_nd = value.to_device(self.device) |
| if value_nd.dtype != self.dtype: |
| value_nd = value_nd.astype(self.dtype) |
| else: |
| try: |
| value_nd = array(value, device=self.device, dtype=self.dtype) |
| except: |
| raise TypeError('mxnet.np.ndarray does not support assignment with non-array-like ' |
| 'object {} of type {}'.format(value, type(value))) |
| |
| # For advanced indexing setitem, if there is None in indices, we need to squeeze the |
| # assigned value_nd since None is also ignored in slicing the original array. |
| if squeeze_axes and value_nd.ndim > len(bcast_shape): |
| squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)]) |
| value_nd = value_nd.squeeze(axis=tuple(squeeze_axes)) |
| |
| # handle the cases like the following |
| # a = np.zeros((3, 3)), b = np.ones((1, 1, 1, 1, 3)), a[0] = b |
| # b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed |
| if value_nd.ndim > len(bcast_shape): |
| squeeze_axes = [] |
| for i in range(value_nd.ndim - len(bcast_shape)): |
| if value_nd.shape[i] == 1: |
| squeeze_axes.append(i) |
| else: |
| break |
| if squeeze_axes: |
| value_nd = value_nd.squeeze(squeeze_axes) |
| |
| if value_nd.shape != bcast_shape: |
| if value_nd.size == 0: |
| value_nd = value_nd.reshape(bcast_shape) |
| else: |
| value_nd = value_nd.broadcast_to(bcast_shape) |
| return value_nd |
| |
| @wrap_mxnp_np_ufunc |
| def __add__(self, other): |
| """x.__add__(y) <=> x + y""" |
| return add(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __iadd__(self, other): |
| """x.__iadd__(y) <=> x += y""" |
| if not self.writable: |
| raise ValueError('trying to add to a readonly ndarray') |
| return add(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __radd__(self, other): |
| """x.__radd__(y) <=> y + x""" |
| return add(other, self) |
| |
| def __invert__(self): |
| """x.__invert__() <=> ~x""" |
| return invert(self) |
| |
| @wrap_mxnp_np_ufunc |
| def __and__(self, other): |
| """x.__and__(y) <=> x & y""" |
| return bitwise_and(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rand__(self, other): |
| """x.__rand__(y) <=> y & x""" |
| return bitwise_and(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __or__(self, other): |
| """x.__or__(y) <=> x | y""" |
| return bitwise_or(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __ror__(self, other): |
| """x.__ror__(y) <=> y | x""" |
| return bitwise_or(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __xor__(self, other): |
| """x.__xor__(y) <=> x ^ y""" |
| return bitwise_xor(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rxor__(self, other): |
| """x.__rxor__(y) <=> y ^ x""" |
| return bitwise_xor(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __lshift__(self, other): |
| """x.__lshift__(y) <=> x << y""" |
| return bitwise_left_shift(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rshift__(self, other): |
| """x.__rshift__(y) <=> x >> y""" |
| return bitwise_right_shift(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __iand__(self, other): |
| """x.__iand__(y) <=> x &= y""" |
| return bitwise_and(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __ior__(self, other): |
| r"""x.__ior__(y) <=> x \|= y""" |
| return bitwise_or(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __ixor__(self, other): |
| """x.__ixor__(y) <=> x ^= y""" |
| return bitwise_xor(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __ilshift__(self, other): |
| """x.__ilshift__(y) <=> x <<= y""" |
| return bitwise_left_shift(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __irshift__(self, other): |
| """x.__irshift__(y) <=> x >>= y""" |
| return bitwise_right_shift(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __rlshift__(self, other): |
| """x.__rlshift__(y) <=> y << x""" |
| return bitwise_left_shift(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __rrshift__(self, other): |
| """x.__rrshift__(y) <=> y >> x""" |
| return bitwise_right_shift(other, self) |
| |
| def __round__(self, n=0): |
| """x.__round__(n)""" |
| return round(self, decimals=n) |
| |
| def __abs__(self): |
| """x.__abs__()""" |
| return absolute(self) |
| |
| def __ceil__(self): |
| """x.__ceil__()""" |
| return ceil(self) |
| |
| def __floor__(self): |
| """x.__floor__()""" |
| return floor(self) |
| |
| def __trunc__(self): |
| """x.__trunc__()""" |
| return trunc(self) |
| |
| @wrap_mxnp_np_ufunc |
| def __sub__(self, other): |
| """x.__sub__(y) <=> x - y""" |
| return subtract(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __isub__(self, other): |
| """x.__isub__(y) <=> x -= y""" |
| if not self.writable: |
| raise ValueError('trying to subtract from a readonly ndarray') |
| return subtract(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __rsub__(self, other): |
| """x.__rsub__(y) <=> y - x""" |
| return subtract(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __mul__(self, other): |
| """x.__mul__(y) <=> x * y""" |
| return multiply(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __floordiv__(self, other): |
| """x.__floordiv__(y) <=> x // y""" |
| return floor_divide(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __ifloordiv__(self, other): |
| """x.__ifloordiv__(y) <=> x //= y""" |
| if not self.writable: |
| raise ValueError('trying to divide from a readonly ndarray') |
| return floor_divide(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __rfloordiv__(self, other): |
| """x.__rfloordiv__(y) <=> y // x""" |
| return floor_divide(other, self) |
| |
| def __neg__(self): |
| """x.__neg__() <=> -x""" |
| return negative(self) |
| |
| def __pos__(self): |
| """x.__pos__() <=> +x""" |
| return positive(self) |
| |
| @wrap_mxnp_np_ufunc |
| def __imul__(self, other): |
| r"""x.__imul__(y) <=> x \*= y""" |
| if not self.writable: |
| raise ValueError('trying to add to a readonly ndarray') |
| return multiply(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __rmul__(self, other): |
| """x.__rmul__(y) <=> y * x""" |
| return self.__mul__(other) |
| |
| @wrap_mxnp_np_ufunc |
| def __div__(self, other): |
| """x.__div__(y) <=> x / y""" |
| return divide(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rdiv__(self, other): |
| """x.__rdiv__(y) <=> y / x""" |
| return divide(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __idiv__(self, other): |
| """x.__idiv__(y) <=> x /= y""" |
| return divide(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __truediv__(self, other): |
| """x.__truediv__(y) <=> x / y""" |
| return divide(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rtruediv__(self, other): |
| """x.__rtruediv__(y) <=> y / x""" |
| return divide(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __itruediv__(self, other): |
| """x.__itruediv__(y) <=> x /= y""" |
| return divide(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __mod__(self, other): |
| """x.__mod__(y) <=> x % y""" |
| return mod(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rmod__(self, other): |
| """x.__rmod__(y) <=> y % x""" |
| return mod(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __imod__(self, other): |
| """x.__imod__(y) <=> x %= y""" |
| return mod(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __pow__(self, other): |
| """x.__pow__(y) <=> x ** y""" |
| return power(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rpow__(self, other): |
| """x.__rpow__(y) <=> y ** x""" |
| return power(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __ipow__(self, other): |
| """x.__ipow__(y) <=> x **= y""" |
| return power(self, other, out=self) |
| |
| @wrap_mxnp_np_ufunc |
| def __eq__(self, other): |
| """x.__eq__(y) <=> x == y""" |
| return equal(self, other) |
| |
| def __hash__(self): |
| raise NotImplementedError |
| |
| @wrap_mxnp_np_ufunc |
| def __ne__(self, other): |
| """x.__ne__(y) <=> x != y""" |
| return not_equal(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __gt__(self, other): |
| """x.__gt__(y) <=> x > y""" |
| return greater(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __ge__(self, other): |
| """x.__ge__(y) <=> x >= y""" |
| return greater_equal(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __lt__(self, other): |
| """x.__lt__(y) <=> x < y""" |
| return less(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __le__(self, other): |
| """x.__le__(y) <=> x <= y""" |
| return less_equal(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __matmul__(self, other): |
| """x.__matmul__(y) <=> x @ y""" |
| return matmul(self, other) |
| |
| @wrap_mxnp_np_ufunc |
| def __rmatmul__(self, other): |
| """x.__rmatmul__(y) <=> y @ x""" |
| return matmul(other, self) |
| |
| @wrap_mxnp_np_ufunc |
| def __imatmul__(self, other): |
| """x.__imatmul__(y) <=> x @= y""" |
| return matmul(self, other, out=self) |
| |
| def __bool__(self): |
| num_elements = self.size |
| if num_elements == 0: |
| warnings.simplefilter('default') |
| warnings.warn('The truth value of an empty array is ambiguous. Returning False, but in' |
| ' future this will result in an error.', DeprecationWarning) |
| return False |
| elif num_elements == 1: |
| return bool(self.item()) |
| else: |
| raise ValueError("The truth value of an ndarray with multiple elements is ambiguous.") |
| |
| __nonzero__ = __bool__ |
| |
| def __index__(self): |
| if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer): |
| return self.item() |
| raise TypeError('only integer scalar arrays can be converted to a scalar index') |
| |
| def __float__(self): |
| num_elements = self.size |
| if num_elements != 1: |
| raise TypeError('only size-1 arrays can be converted to Python scalars') |
| return float(self.item()) |
| |
| def __int__(self): |
| num_elements = self.size |
| if num_elements != 1: |
| raise TypeError('only size-1 arrays can be converted to Python scalars') |
| return int(self.item()) |
| |
| def __len__(self): |
| """Number of elements along the first axis.""" |
| shape = self.shape # pylint: disable=redefined-outer-name |
| if len(shape) == 0: |
| raise TypeError('len() of unsized object') |
| return self.shape[0] |
| |
| def __reduce__(self): |
| return ndarray, (None,), self.__getstate__() |
| |
| def item(self, *args): |
| """Copy an element of an array to a standard Python scalar and return it. |
| |
| Parameters |
| ---------- |
| *args : Arguments (variable number and type) |
| none: in this case, the method only works for arrays with one element (a.size == 1), |
| which element is copied into a standard Python scalar object and returned. |
| |
| int_type: this argument is interpreted as a flat index into the array, specifying which |
| element to copy and return. |
| |
| tuple of int_types: functions as does a single int_type argument, except that the |
| argument is interpreted as an nd-index into the array. |
| |
| Returns |
| ------- |
| z : Standard Python scalar object |
| A copy of the specified element of the array as a suitable Python scalar. |
| """ |
| # TODO(junwu): no need to call asnumpy() on the whole array. |
| return self.asnumpy().item(*args) |
| |
| def nonzero(self): |
| """Return the indices of the elements that are non-zero. |
| |
| Refer to `numpy.nonzero` for full documentation. |
| |
| See Also |
| -------- |
| numpy.nonzero : equivalent function |
| """ |
| return nonzero(self) |
| |
| @property |
| # pylint: disable= invalid-name, undefined-variable |
| def T(self): |
| """Same as self.transpose(). This always returns a copy of self.""" |
| if self.ndim != 2: |
| warnings.warn('x.T requires x to have 2 dimensions. ' |
| 'Use x.mT to transpose stacks of matrices and ' |
| 'permute_dims() to permute dimensions.') |
| return self.transpose() |
| # pylint: enable= invalid-name, undefined-variable |
| |
| @property |
| # pylint: disable= invalid-name, undefined-variable |
| def mT(self): |
| """Same as self.transpose(). This always returns a copy of self.""" |
| if self.ndim < 2: |
| raise ValueError("x must be at least 2-dimensional for matrix_transpose") |
| return _mx_nd_np.swapaxes(self, -1, -2) |
| # pylint: enable= invalid-name, undefined-variable |
| |
| def all(self, axis=None, out=None, keepdims=False): |
| return _mx_nd_np.all(self, axis=axis, out=out, keepdims=keepdims) |
| |
| def any(self, axis=None, out=None, keepdims=False): |
| return _mx_nd_np.any(self, axis=axis, out=out, keepdims=keepdims) |
| |
| def as_nd_ndarray(self): |
| """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods.""" |
| hdl = NDArrayHandle() |
| check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl))) |
| return NDArray(handle=hdl, writable=self.writable) |
| |
| def as_np_ndarray(self): |
| """A convenience function for creating a numpy ndarray from the current ndarray |
| with zero copy. For this class, it just returns itself since it's already a |
| numpy ndarray.""" |
| return self |
| |
| def __repr__(self): |
| """ |
| Returns a string representation of the array. |
| The dtype of the ndarray will be appended if it's inconsistent with current dtype. |
| The device of the ndarray will be appended for devices other than CPU. |
| |
| Examples |
| -------- |
| >>> from mxnet import np, npx |
| >>> a = np.random.uniform(size=(2, 3)) |
| >>> a |
| array([[0.5488135 , 0.5928446 , 0.71518934], |
| [0.84426576, 0.60276335, 0.8579456 ]]) |
| >>> print(a) |
| [[0.5488135 0.5928446 0.71518934] |
| [0.84426576 0.60276335 0.8579456 ]] |
| >>> a.dtype |
| dtype('float32') |
| >>> npx.set_np_float64() |
| >>> a |
| array([[0.5488135 , 0.5928446 , 0.71518934], |
| [0.84426576, 0.60276335, 0.8579456 ]], dtype=float32) |
| >>> npx.set_np_float64(default_float64=False) |
| >>> a |
| array([[0.5488135 , 0.5928446 , 0.71518934], |
| [0.84426576, 0.60276335, 0.8579456 ]]) |
| >>> b = a.astype(np.float64) |
| >>> b |
| array([[0.54881352, 0.59284461, 0.71518934], |
| [0.84426576, 0.60276335, 0.85794562]], dtype=float64) |
| >>> print(b) |
| [[0.54881352 0.59284461 0.71518934] |
| [0.84426576 0.60276335 0.85794562]] |
| >>> b.dtype |
| dtype('float64') |
| >>> c = a.copyto(npx.gpu(0)) |
| >>> c |
| array([[0.5488135 , 0.5928446 , 0.71518934], |
| [0.84426576, 0.60276335, 0.8579456 ]], device=gpu(0)) |
| >>> print(c) |
| [[0.5488135 0.5928446 0.71518934] |
| [0.84426576 0.60276335 0.8579456 ]] @gpu(0) |
| >>> d = b.copyto(npx.gpu(0)) |
| >>> d |
| array([[0.54881352, 0.59284461, 0.71518934], |
| [0.84426576, 0.60276335, 0.85794562]], dtype=float64, device=gpu(0)) |
| >>> print(d) |
| [[0.54881352 0.59284461 0.71518934] |
| [0.84426576 0.60276335 0.85794562]] @gpu(0) |
| |
| """ |
| if self._alive: |
| array_str = self.asnumpy().__repr__() |
| dtype = self.dtype |
| default_dtype = _np.float64 if is_np_default_dtype() else _np.float32 |
| if 'dtype=' in array_str: |
| if dtype == default_dtype: |
| array_str = array_str[:array_str.rindex(',')] + ')' |
| elif dtype not in (default_dtype, _np.bool_): |
| array_str = array_str[:-1] + ', dtype={})'.format(dtype) |
| |
| device = self.device |
| if device.device_type == 'cpu': |
| return array_str |
| return array_str[:-1] + ', device={})'.format(str(device)) |
| else: |
| return '<FREED {}>'.format(self.__class__.__name__) |
| |
| def __str__(self): |
| """Returns a string representation of the array.""" |
| array_str = self.asnumpy().__str__() |
| device = self.device |
| if device.device_type == 'cpu' or self.ndim == 0: |
| return array_str |
| return '{array} @{device}'.format(array=array_str, device=device) |
| |
| def __format__(self, fmt): |
| """Return value.__format__(format_spec). Overwrite to include 0-d array""" |
| if self.ndim == 0: |
| return self.item().__format__(fmt) |
| elif len(fmt) == 0: |
| return self.__str__().__format__(fmt) |
| else: |
| raise TypeError("Cannot format mxnet.numpy.ndarray with format_spec") |
| |
| def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ |
| """Attach a gradient buffer to this ndarray, so that `backward` |
| can compute gradient with respect to it. |
| |
| Parameters |
| ---------- |
| grad_req : {'write', 'add', 'null'} |
| How gradient will be accumulated. |
| * 'write': gradient will be overwritten on every backward. |
| * 'add': gradient will be added to existing value on every backward. |
| * 'null': do not compute gradient for this NDArray. |
| """ |
| grad = _mx_nd_np.zeros_like(self) # pylint: disable=undefined-variable |
| grad_req = _GRAD_REQ_MAP[grad_req] |
| check_call(_LIB.MXAutogradMarkVariables( |
| 1, ctypes.pointer(self.handle), |
| ctypes.pointer(mx_uint(grad_req)), |
| ctypes.pointer(grad.handle))) |
| |
| def drop_grad(self): |
| """Free the memory of the marked ndarray.""" |
| check_call(_LIB.MXAutogradDropGrads( |
| 1, ctypes.pointer(self.handle))) |
| |
| @property |
| def grad(self): |
| """Returns gradient buffer attached to this ndarray.""" |
| hdl = NDArrayHandle() |
| check_call(_LIB.MXNDArrayGetGrad(self.handle, ctypes.byref(hdl))) |
| if hdl.value is None: |
| return None |
| return _np_ndarray_cls(hdl) |
| |
| def detach(self): |
| """Returns a new ndarray, detached from the current graph.""" |
| hdl = NDArrayHandle() |
| check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) |
| return _np_ndarray_cls(hdl) |
| |
| def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): # pylint: disable=arguments-differ,unused-argument, too-many-arguments |
| """ |
| Copy of the array, cast to a specified type. |
| |
| Parameters |
| ---------- |
| dtype : str or dtype |
| Typecode or data-type to which the array is cast. |
| order : {'C', 'F', 'A', 'K'}, optional |
| Controls the memory layout order of the result. |
| 'C' means C order, 'F' means Fortran order, 'A' |
| means 'F' order if all the arrays are Fortran contiguous, |
| 'C' order otherwise, and 'K' means as close to the |
| order the array elements appear in memory as possible. |
| Default is 'K'. |
| casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional |
| Controls what kind of data casting may occur. Defaults to 'unsafe' |
| for backwards compatibility. |
| |
| * 'no' means the data types should not be cast at all. |
| * 'equiv' means only byte-order changes are allowed. |
| * 'safe' means only casts which can preserve values are allowed. |
| * 'same_kind' means only safe casts or casts within a kind, |
| like float64 to float32, are allowed. |
| * 'unsafe' means any data conversions may be done. |
| subok : bool, optional |
| If True, then sub-classes will be passed-through (default), otherwise |
| the returned array will be forced to be a base-class array. |
| copy : bool, optional |
| Default `True`. By default, astype always returns a newly |
| allocated ndarray on the same device. If this is set to |
| `False`, and the dtype requested is the same as the ndarray's |
| dtype, the ndarray is returned instead of a copy. |
| |
| Returns |
| ------- |
| arr_t : ndarray |
| Unless `copy` is False and the other conditions for returning the input |
| array are satisfied (see description for `copy` input parameter), `arr_t` |
| is a new array of the same shape as the input array with `dtype`. |
| |
| Notes |
| ----- |
| This function differs from the official `ndarray`'s ``astype`` function in the following |
| aspects: |
| * `order` only supports 'C' and 'K'. |
| * `casting` only supports 'unsafe'. |
| * `subok` only supports ``True``. |
| """ |
| if order is not None and order != 'K' and order != 'C': |
| raise ValueError('order must be either \'K\' or \'C\'') |
| if casting != 'unsafe': |
| raise ValueError('casting must be equal to \'unsafe\'') |
| if not subok: |
| raise ValueError('subok must be equal to True') |
| if dtype is None: |
| dtype = _np.float32 |
| if not copy and _np.dtype(dtype) == self.dtype: |
| return self |
| |
| return _npi.cast(self, dtype=dtype) |
| |
| def copyto(self, other): |
| """Copies the value of this array to another array. |
| |
| If ``other`` is a ``ndarray`` object, then ``other.shape`` and |
| ``self.shape`` should be the same. This function copies the value from |
| ``self`` to ``other``. |
| |
| If ``other`` is a device, a new ``np.ndarray`` will be first created on |
| the target device, and the value of ``self`` is copied. |
| |
| Parameters |
| ---------- |
| other : ndarray or Device |
| The destination array or device. |
| |
| Returns |
| ------- |
| out: ndarray |
| The copied array. If ``other`` is an ``ndarray``, then the return value |
| and ``other`` will point to the same ``ndarray``. |
| |
| Examples |
| -------- |
| >>> x = np.ones((2, 3)) |
| >>> y = np.zeros((2, 3), device=npx.gpu(0)) |
| >>> z = x.copyto(y) |
| >>> z is y |
| True |
| >>> y |
| array([[ 1., 1., 1.], |
| [ 1., 1., 1.]]) |
| """ |
| if isinstance(other, ndarray): |
| if other.handle is self.handle: |
| warnings.warn('You are attempting to copy an array to itself', RuntimeWarning) |
| return False |
| return _npi.copyto(self, out=other) |
| elif isinstance(other, Device): |
| hret = ndarray(_new_alloc_handle(self.shape, other, True, self.dtype)) |
| return _npi.copyto(self, out=hret) |
| else: |
| raise TypeError('copyto does not support type ' + str(type(other))) |
| |
| def asscalar(self): |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute asscalar') |
| |
| def argmax(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Return indices of the maximum values along the given axis. |
| Refer to `mxnet.numpy.argmax` for full documentation.""" |
| return argmax(self, axis, out, keepdims) |
| |
| def as_in_context(self, context): |
| """This function has been deprecated. Please refer to ``ndarray.to_device``.""" |
| warnings.warn('ndarray.as_in_context has been renamed to' |
| ' ndarray.to_device', DeprecationWarning) |
| return self.as_nd_ndarray().as_in_context(context).as_np_ndarray() |
| |
| def as_in_ctx(self, ctx): |
| """This function has been deprecated. Please refer to ``ndarray.to_device``.""" |
| warnings.warn('ndarray.to_device has been renamed to' |
| ' ndarray.to_device', DeprecationWarning) |
| return self.to_device(ctx) |
| |
| @property |
| def ctx(self): |
| """This property has been deprecated. Please refer to ``ndarray.device``.""" |
| warnings.warn('ndarray.ctx has been renamed to ndarray.device', DeprecationWarning) |
| return self.device |
| |
| |
| def to_device(self, device): |
| """Returns an array on the target device with the same value as this array. |
| |
| If the target device is the same as ``self.device``, then ``self`` is |
| returned. Otherwise, a copy is made. |
| |
| Parameters |
| ---------- |
| device : Device |
| The target device. |
| |
| Returns |
| ------- |
| ndarray |
| The target array. |
| """ |
| if self.device == device: |
| return self |
| return self.copyto(device) |
| |
| @property |
| def device(self): |
| """Hardware device the array data resides on. |
| |
| Examples |
| -------- |
| >>> x = np.array([1, 2, 3, 4]) |
| >>> x.device |
| cpu(0) |
| >>> type(x.device) |
| <class 'mxnet.device.Device'> |
| >>> y = np.zeros((2, 3), npx.gpu(0)) |
| >>> y.device |
| gpu(0) |
| """ |
| dev_typeid = ctypes.c_int() |
| dev_id = ctypes.c_int() |
| check_call(_LIB.MXNDArrayGetContext( |
| self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id))) |
| return Device(Device.devtype2str[dev_typeid.value], dev_id.value) |
| |
| |
| @property |
| def context(self): |
| """This function has been deprecated. Please refer to ``ndarray.ctx``.""" |
| warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning) |
| return self.as_nd_ndarray().context |
| |
| def copy(self, order='C'): # pylint: disable=arguments-differ |
| """Return a coyp of the array, keeping the same device. |
| |
| Parameters |
| ---------- |
| order : str |
| The memory layout of the copy. Currently, only c-contiguous memory |
| layout is supported. |
| |
| Examples |
| -------- |
| >>> x = np.ones((2, 3)) |
| >>> y = x.copy() |
| >>> y |
| array([[ 1., 1., 1.], |
| [ 1., 1., 1.]]) |
| """ |
| if order != 'C': |
| raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' |
| 'received {}'.format(str(order))) |
| return self.copyto(self.device) |
| |
| def dot(self, b, out=None): |
| """Dot product of two arrays. |
| Refer to ``numpy.dot`` for full documentation.""" |
| return dot(self, b, out=out) |
| |
| def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ |
| """Returns a copy of the array with a new shape. |
| |
| Notes |
| ----- |
| Unlike the free function `numpy.reshape`, this method on `ndarray` allows |
| the elements of the shape parameter to be passed in as separate arguments. |
| For example, ``a.reshape(10, 11)`` is equivalent to |
| ``a.reshape((10, 11))``. |
| """ |
| order = 'C' |
| if len(kwargs) > 1: |
| raise TypeError('function takes at most 1 keyword argument') |
| if len(kwargs) == 1: |
| if 'order' not in kwargs: |
| raise TypeError("'{}' is an invalid keyword argument for this function" |
| .format(list(kwargs.keys())[0])) |
| order = kwargs.pop('order', 'C') |
| if order != 'C': |
| raise NotImplementedError('only supports C-order,' |
| ' while received {}'.format(order)) |
| if len(args) == 0: |
| raise TypeError('reshape() takes exactly 1 argument (0 given)') |
| if len(args) == 1 and isinstance(args[0], tuple): |
| return _mx_nd_np.reshape(self, newshape=args[0], order=order) |
| else: |
| return _mx_nd_np.reshape(self, newshape=args, order=order) |
| |
| def reshape_like(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`reshape_like`. |
| |
| The arguments are the same as for :py:func:`reshape_like`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute reshape_like') |
| |
| def reshape_view(self, *shape, **kwargs): # pylint: disable=redefined-outer-name |
| """Returns a **view** of this array with a new shape without altering any data. |
| Inheritated from NDArray.reshape. |
| """ |
| return super(ndarray, self).reshape(*shape, **kwargs) |
| |
| def zeros_like(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`zeros_like`. |
| |
| The arguments are the same as for :py:func:`zeros_like`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute zeros_like') |
| |
| def ones_like(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`ones_like`. |
| |
| The arguments are the same as for :py:func:`ones_like`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute ones_like') |
| |
| def broadcast_axes(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`broadcast_axes`. |
| |
| The arguments are the same as for :py:func:`broadcast_axes`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like') |
| |
| def repeat(self, repeats, axis=None): # pylint: disable=arguments-differ |
| """Repeat elements of an array.""" |
| return repeat(self, repeats=repeats, axis=axis) |
| |
| def pad(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`pad`. |
| |
| The arguments are the same as for :py:func:`pad`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute pad') |
| |
| def swapaxes(self, axis1, axis2): # pylint: disable=arguments-differ |
| """Return a copy of the array with axis1 and axis2 interchanged. |
| Refer to `mxnet.numpy.swapaxes` for full documentation. |
| """ |
| return swapaxes(self, axis1, axis2) |
| |
| def split(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`split`. |
| |
| The arguments are the same as for :py:func:`split`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute split') |
| |
| def split_v2(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`split_v2`. |
| |
| The arguments are the same as for :py:func:`split_v2`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute split_v2') |
| |
| def slice(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`slice`. |
| |
| The arguments are the same as for :py:func:`slice`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute slice') |
| |
| def slice_axis(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`slice_axis`. |
| |
| The arguments are the same as for :py:func:`slice_axis`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_axis') |
| |
| def slice_like(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`slice_like`. |
| |
| The arguments are the same as for :py:func:`slice_like`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_like') |
| |
| def slice_assign_scalar(self, value, begin, end, step): |
| """ |
| Assign the scalar to a cropped subset of this ndarray. Value will broadcast to the shape of the cropped shape |
| and will be cast to the same dtype of the ndarray. |
| |
| Parameters |
| ---------- |
| value: numeric value |
| Value and this ndarray should be of the same data type. |
| The shape of rhs should be the same as the cropped shape of this ndarray. |
| begin: tuple of begin indices |
| end: tuple of end indices |
| step: tuple of step lenghths |
| |
| Returns |
| ------- |
| This ndarray. |
| |
| Examples |
| -------- |
| >>> x = np.ones((2, 2, 2)) |
| >>> y = x.slice_assign_scalar(0, (0, 0, None), (1, 1, None), (None, None, None)) |
| >>> y |
| array([[[0., 0.], |
| [1., 1.]], |
| |
| [[1., 1.], |
| [1., 1.]]]) |
| >>> x |
| array([[[0., 0.], |
| [1., 1.]], |
| |
| [[1., 1.], |
| [1., 1.]]]) |
| """ |
| return _npi.slice_assign_scalar(self, value, begin=begin, end=end, step=step, out=self) |
| |
| def slice_assign(self, rhs, begin, end, step): |
| """ |
| Assign the rhs to a cropped subset of this ndarray in place. |
| Returns the view of this ndarray. |
| |
| Parameters |
| ---------- |
| rhs: ndarray. |
| rhs and this NDArray should be of the same data type, and on the same device. |
| The shape of rhs should be the same as the cropped shape of this ndarray. |
| begin: tuple of begin indices |
| end: tuple of end indices |
| step: tuple of step lenghths |
| |
| Returns |
| ------- |
| out : ndarray |
| This ndarray. |
| |
| Examples |
| -------- |
| >>> x = np.ones((2, 2, 2)) |
| >>> assigned = np.zeros((1, 1, 2)) |
| >>> y = x.slice_assign(assigned, (0, 0, None), (1, 1, None), (None, None, None)) |
| >>> y |
| array([[[0., 0.], |
| [1., 1.]], |
| |
| [[1., 1.], |
| [1., 1.]]]) |
| >>> x |
| array([[[0., 0.], |
| [1., 1.]], |
| |
| [[1., 1.], |
| [1., 1.]]]) |
| """ |
| return _npi.slice_assign(self, rhs, begin=begin, end=end, step=step, out=self) |
| |
| def take(self, indices, axis=None, mode='raise'): # pylint: disable=arguments-differ, redefined-outer-name |
| """Convenience fluent method for :py:func:`take`. |
| |
| The arguments are the same as for :py:func:`take`, with |
| this array as data. |
| """ |
| return take(self, indices, axis, mode=mode) |
| |
| def one_hot(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`one_hot`. |
| |
| The arguments are the same as for :py:func:`one_hot`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute one_hot') |
| |
| def pick(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`pick`. |
| |
| The arguments are the same as for :py:func:`pick`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute pick') |
| |
| def sort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ |
| """Convenience fluent method for :py:func:`sort`. |
| |
| The arguments are the same as for :py:func:`sort`, with |
| this array as data. |
| """ |
| return sort(self, axis=axis, descending=descending, stable=stable) |
| |
| def topk(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`topk`. |
| |
| The arguments are the same as for :py:func:`topk`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute topk') |
| |
| def argsort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ |
| """Convenience fluent method for :py:func:`argsort`. |
| |
| The arguments are the same as for :py:func:`argsort`, with |
| this array as data. |
| """ |
| return argsort(self, axis=axis, descending=descending, stable=stable) |
| |
| def argmax_channel(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`argmax_channel`. |
| |
| The arguments are the same as for :py:func:`argmax_channel`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute argmax_channel') |
| |
| def argmin(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Return indices of the minium values along the given axis. |
| Refer to `mxnet.numpy.argmin` for full documentation.""" |
| return argmin(self, axis, out, keepdims) |
| |
| def clip(self, min=None, max=None, out=None): # pylint: disable=arguments-differ |
| """Return an array whose values are limited to [min, max]. |
| One of max or min must be given. |
| """ |
| return clip(self, min, max, out=out) |
| |
| def abs(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`abs`. |
| |
| The arguments are the same as for :py:func:`abs`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute abs') |
| |
| def sign(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`sign`. |
| |
| The arguments are the same as for :py:func:`sign`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute sign') |
| |
| def flatten(self, order='C'): # pylint: disable=arguments-differ |
| """Return a copy of the array collapsed into one dimension.""" |
| return self.reshape(-1, order=order) |
| |
| def shape_array(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`shape_array`. |
| |
| The arguments are the same as for :py:func:`shape_array`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute shape_array') |
| |
| def size_array(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`size_array`. |
| |
| The arguments are the same as for :py:func:`size_array`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute size_array') |
| |
| def expand_dims(self, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument |
| """Convenience fluent method for :py:func:`expand_dims`. |
| |
| The arguments are the same as for :py:func:`expand_dims`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute expand_dims') |
| |
| def tile(self, reps): # pylint: disable=arguments-differ |
| """Construct an array by repeating A the number of times given by reps. |
| Refer to `mxnet.numpy.tile` for full documentation.""" |
| return tile(self, reps=reps) |
| |
| def transpose(self, *axes): # pylint: disable=arguments-differ |
| """Permute the dimensions of an array.""" |
| if len(axes) == 0: |
| axes = None |
| elif len(axes) == 1: |
| if isinstance(axes[0], (tuple, list)): |
| axes = axes[0] |
| elif axes[0] is None: |
| axes = None |
| return transpose(self, axes=axes) |
| |
| def flip(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`flip`. |
| |
| The arguments are the same as for :py:func:`flip`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute flip') |
| |
| def depth_to_space(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`depth_to_space`. |
| |
| The arguments are the same as for :py:func:`depth_to_space`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute depth_to_space') |
| |
| def space_to_depth(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`space_to_depth`. |
| |
| The arguments are the same as for :py:func:`space_to_depth`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute space_to_depth') |
| |
| def diag(self, k=0, **kwargs): |
| """Convenience fluent method for :py:func:`diag`. |
| |
| The arguments are the same as for :py:func:`diag`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute diag') |
| |
| def diagonal(self, offset=0, axis1=0, axis2=1): # pylint: disable=arguments-differ |
| """Return the diagonal with the given offset. |
| |
| If array has more than two dimensions, then the axes specified by axis1 and |
| axis2 are used to determine the 2-D sub-array whose diagonal is returned. |
| |
| Refer to `mxnet.numpy.diagonal` for full documents. |
| """ |
| return diagonal(self, offset=offset, axis1=axis1, axis2=axis2) |
| |
| def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Return the sum of the array elements over the given axis.""" |
| return sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) |
| |
| def nansum(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`nansum`. |
| |
| The arguments are the same as for :py:func:`nansum`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute nansum') |
| |
| def prod(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Return the product of the array elements over the given axis.""" |
| return _mx_np_op.prod(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out) |
| |
| def nanprod(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`nanprod`. |
| |
| The arguments are the same as for :py:func:`nanprod`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute nanprod') |
| |
| def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Returns the average of the array elements along given axis.""" |
| return mean(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) |
| |
| # pylint: disable=too-many-arguments, arguments-differ |
| |
| @wrap_data_api_statical_func |
| def std(self, axis=None, dtype=None, out=None, correction=0, keepdims=False): |
| """Returns the standard deviation of the array elements along given axis.""" |
| return std(self, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, out=out) |
| |
| @wrap_data_api_statical_func |
| def var(self, axis=None, dtype=None, out=None, correction=0, keepdims=False): |
| """Returns the variance of the array elements, along given axis.""" |
| return var(self, axis=axis, dtype=dtype, out=out, correction=correction, keepdims=keepdims) |
| # pylint: enable=too-many-arguments, arguments-differ |
| |
| def cumsum(self, axis=None, dtype=None, out=None): |
| """Return the cumulative sum of the elements along the given axis.""" |
| return _mx_nd_np.cumsum(self, axis=axis, dtype=dtype, out=out) |
| |
| def tolist(self): |
| return self.asnumpy().tolist() |
| |
| def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Return the maximum along a given axis.""" |
| return _mx_nd_np.max(self, axis=axis, out=out, keepdims=keepdims) |
| |
| def min(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ |
| """Convenience fluent method for :py:func:`min`. |
| |
| The arguments are the same as for :py:func:`min`, with |
| this array as data. |
| """ |
| return _mx_nd_np.min(self, axis=axis, out=out, keepdims=keepdims) |
| |
| def norm(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`norm`. |
| |
| The arguments are the same as for :py:func:`norm`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute norm') |
| |
| def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ |
| """Convenience fluent method for :py:func:`round`. |
| |
| The arguments are the same as for :py:func:`round`, with |
| this array as data. |
| """ |
| return round(self, decimals=decimals, out=out, **kwargs) |
| |
| def rint(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`rint`. |
| |
| The arguments are the same as for :py:func:`rint`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute rint') |
| |
| def fix(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`fix`. |
| |
| The arguments are the same as for :py:func:`fix`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute fix') |
| |
| def floor(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`floor`. |
| |
| The arguments are the same as for :py:func:`floor`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute floor') |
| |
| def ceil(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`ceil`. |
| |
| The arguments are the same as for :py:func:`ceil`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute ceil') |
| |
| def trunc(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`trunc`. |
| |
| The arguments are the same as for :py:func:`trunc`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute trunc') |
| |
| def sin(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`sin`. |
| |
| The arguments are the same as for :py:func:`sin`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute sin') |
| |
| def cos(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`cos`. |
| |
| The arguments are the same as for :py:func:`cos`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute cos') |
| |
| def tan(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`tan`. |
| |
| The arguments are the same as for :py:func:`tan`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute tan') |
| |
| def arcsin(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`arcsin`. |
| |
| The arguments are the same as for :py:func:`arcsin`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute arcsin') |
| |
| def arccos(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`arccos`. |
| |
| The arguments are the same as for :py:func:`arccos`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute arccos') |
| |
| def arctan(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`arctan`. |
| |
| The arguments are the same as for :py:func:`arctan`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute arctan') |
| |
| def degrees(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`degrees`. |
| |
| The arguments are the same as for :py:func:`degrees`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute degrees') |
| |
| def radians(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`radians`. |
| |
| The arguments are the same as for :py:func:`radians`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute radians') |
| |
| def sinh(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`sinh`. |
| |
| The arguments are the same as for :py:func:`sinh`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute sinh') |
| |
| def cosh(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`cosh`. |
| |
| The arguments are the same as for :py:func:`cosh`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute cosh') |
| |
| def tanh(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`tanh`. |
| |
| The arguments are the same as for :py:func:`tanh`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute tanh') |
| |
| def arcsinh(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`arcsinh`. |
| |
| The arguments are the same as for :py:func:`arcsinh`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute arcsinh') |
| |
| def arccosh(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`arccosh`. |
| |
| The arguments are the same as for :py:func:`arccosh`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute arccosh') |
| |
| def arctanh(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`arctanh`. |
| |
| The arguments are the same as for :py:func:`arctanh`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute arctanh') |
| |
| def exp(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`exp`. |
| |
| The arguments are the same as for :py:func:`exp`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute exp') |
| |
| def expm1(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`expm1`. |
| |
| The arguments are the same as for :py:func:`expm1`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute expm1') |
| |
| def log(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`log`. |
| |
| The arguments are the same as for :py:func:`log`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute log') |
| |
| def log10(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`log10`. |
| |
| The arguments are the same as for :py:func:`log10`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute log10') |
| |
| def log2(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`log2`. |
| |
| The arguments are the same as for :py:func:`log2`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute log2') |
| |
| def log1p(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`log1p`. |
| |
| The arguments are the same as for :py:func:`log1p`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute log1p') |
| |
| def log_sigmoid(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`log_sigmoid`. |
| |
| The arguments are the same as for :py:func:`log_sigmoid`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute log_sigmoid') |
| |
| def sqrt(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`sqrt`. |
| |
| The arguments are the same as for :py:func:`sqrt`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute sqrt') |
| |
| def rsqrt(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`rsqrt`. |
| |
| The arguments are the same as for :py:func:`rsqrt`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute rsqrt') |
| |
| def cbrt(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`cbrt`. |
| |
| The arguments are the same as for :py:func:`cbrt`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute cqrt') |
| |
| def rcbrt(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`rcbrt`. |
| |
| The arguments are the same as for :py:func:`rcbrt`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute rcqrt') |
| |
| def square(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`square`. |
| |
| The arguments are the same as for :py:func:`square`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute square') |
| |
| def reciprocal(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`reciprocal`. |
| |
| The arguments are the same as for :py:func:`reciprocal`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute reciprocal') |
| |
| def relu(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`relu`. |
| |
| The arguments are the same as for :py:func:`relu`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute relu') |
| |
| def sigmoid(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`sigmoid`. |
| |
| The arguments are the same as for :py:func:`sigmoid`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute sigmoid') |
| |
| def softmax(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`softmax`. |
| |
| The arguments are the same as for :py:func:`softmax`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute softmax') |
| |
| def log_softmax(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`log_softmax`. |
| |
| The arguments are the same as for :py:func:`log_softmax`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute log_softmax') |
| |
| def softmin(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`softmin`. |
| |
| The arguments are the same as for :py:func:`softmin`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute softmin') |
| |
| def mish(self, *args, **kwargs): |
| """Convenience fluent method for :py:func:`mish`. |
| |
| The arguments are the same as for :py:func:`mish`, with |
| this array as data. |
| """ |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute mish') |
| |
| def squeeze(self, axis=None): # pylint: disable=arguments-differ |
| """Remove single-dimensional entries from the shape of a.""" |
| return squeeze(self, axis=axis) |
| |
| def broadcast_to(self, shape): # pylint: disable=redefined-outer-name |
| return _mx_nd_np.broadcast_to(self, shape) |
| |
| def broadcast_like(self, other): |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like') |
| |
| def _full(self, value): |
| """ |
| Currently for internal use only. Implemented for __setitem__. |
| Assign to self an array of self's same shape and type, filled with value. |
| """ |
| return _mx_nd_np.full(self.shape, value, device=self.device, dtype=self.dtype, out=self) |
| |
| # pylint: disable=redefined-outer-name |
| def _scatter_set_nd(self, value_nd, indices): |
| """ |
| This is added as an ndarray class method in order to support polymorphism in NDArray and numpy.ndarray indexing |
| """ |
| return _npi.scatter_set_nd( |
| lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self |
| ) |
| # pylint: enable=redefined-outer-name |
| |
| @property |
| def shape(self): |
| """Tuple of array dimensions. |
| |
| Examples |
| -------- |
| >>> x = mx.np.array([1, 2, 3, 4]) |
| >>> x.shape |
| (4L,) |
| >>> y = mx.np.zeros((2, 3, 4)) |
| >>> y.shape |
| (2L, 3L, 4L) |
| >>> z = mx.np.array(3) |
| >>> z.shape |
| () |
| """ |
| num_dim = mx_int() |
| if _int64_enabled(): |
| pdata = ctypes.POINTER(mx_int64)() |
| check_call(_LIB.MXNDArrayGetShape64( |
| self.handle, ctypes.byref(num_dim), ctypes.byref(pdata))) |
| else: |
| pdata = ctypes.POINTER(mx_int)() |
| check_call(_LIB.MXNDArrayGetShape( |
| self.handle, ctypes.byref(num_dim), ctypes.byref(pdata))) |
| if num_dim.value == -1: |
| return None |
| else: |
| return tuple(pdata[:num_dim.value]) # pylint: disable=invalid-slice-index |
| |
| @property |
| def ndim(self): |
| """Number of array dimensions.""" |
| return len(self.shape) |
| |
| @property |
| def size(self): |
| """Number of elements in the array.""" |
| return super(ndarray, self).size |
| |
| @property |
| def dtype(self): |
| """Data-type of the array's elements. |
| |
| Returns |
| ------- |
| numpy.dtype |
| This NDArray's data type. |
| |
| Examples |
| -------- |
| >>> x = np.zeros((2,3)) |
| >>> x.dtype |
| dtype('float32') |
| >>> y = np.zeros((2,3), dtype='int32') |
| >>> y.dtype |
| dtype('int32') |
| """ |
| return _np.dtype(super(ndarray, self).dtype) |
| |
| def tostype(self, stype): |
| raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype') |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def empty(shape, dtype=None, order='C', device=None): # pylint: disable=redefined-outer-name |
| """Return a new array of given shape and type, without initializing entries. |
| |
| Parameters |
| ---------- |
| shape : int or tuple of int Shape of the empty array, e.g., ``(2, 3)`` or ``2``. |
| dtype : data-type, optional |
| Desired output data-type for the array, e.g, `numpy.int8`. |
| Note that this behavior is different from NumPy's `empty` function where `float64` |
| is the default value, here you can set your default dtype as 'float32' or 'float64' |
| because `float32` is considered as the default data type in deep learning. |
| When npx.is_np_default_dtype() returns False, default dtype is float32; |
| When npx.is_np_default_dtype() returns True, default dtype is float64. |
| order : {'C'}, optional, default: 'C' |
| How to store multi-dimensional data in memory, currently only row-major |
| (C-style) is supported. |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| |
| Returns |
| ------- |
| out : ndarray |
| Array of uninitialized (arbitrary) data of the given shape, dtype, and order. |
| |
| Examples |
| -------- |
| >>> np.empty([2, 2]) |
| array([[ 0.000000e+00, -2.524355e-29], |
| [ nan, -8.592023e+09]]) # uninitialized |
| |
| >>> np.empty([2, 2], dtype=int) |
| array([[8751743591039004782, 3196766424264760104], |
| [7583328881310196768, 562950123910254]], dtype=int64) # uninitialized |
| """ |
| if order != 'C': |
| raise NotImplementedError('`empty` only supports order equal to `C`, while received {}' |
| .format(str(order))) |
| if device is None: |
| device = current_device() |
| if dtype is None or dtype is float: |
| dtype = _np.float64 if is_np_default_dtype() else _np.float32 |
| if isinstance(shape, int): |
| shape = (shape,) |
| return ndarray(handle=_new_alloc_handle(shape, device, False, dtype)) |
| |
| |
| # pylint: disable=redefined-outer-name |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def array(object, dtype=None, device=None): |
| """ |
| Create an array. |
| |
| Parameters |
| ---------- |
| object : array_like or `numpy.ndarray` or `mxnet.numpy.ndarray` |
| An array, any object exposing the array interface, an object whose |
| __array__ method returns an array, or any (nested) sequence. |
| dtype : data-type, optional |
| The desired data-type for the array. |
| The default dtype is ``object.dtype`` if `object` is an `ndarray`, `float32` otherwise. |
| Default dtype can be set to be consistent with offical numpy by `npx.set_np(dtype=True)`. |
| |
| * When npx.is_np_default_dtype() returns False, default dtype is float32; |
| * When npx.is_np_default_dtype() returns True, default dtype is float64. |
| |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| |
| Returns |
| ------- |
| out : ndarray |
| An array object satisfying the specified requirements. |
| |
| Examples |
| -------- |
| >>> np.array([1, 2, 3]) |
| array([1., 2., 3.]) |
| |
| >>> np.array([[1, 2], [3, 4]]) |
| array([[1., 2.], |
| [3., 4.]]) |
| |
| >>> np.array([[1, 0], [0, 1]], dtype=bool) |
| array([[ True, False], |
| [False, True]]) |
| |
| >>> np.array([1, 2, 3]).dtype |
| dtype('float32') |
| |
| >>> npx.set_np(dtype=True) |
| >>> np.array([1, 2, 3]).dtype |
| dtype('float64') |
| """ |
| if device is None: |
| device = current_device() |
| if isinstance(object, _np.ndarray): |
| if is_np_default_dtype(): |
| dtype = object.dtype if dtype is None else dtype |
| else: |
| dtype = _np.float32 if dtype is None or object.dtype is _np.float64 else dtype |
| if isinstance(object, ndarray): |
| dtype = object.dtype if dtype is None else dtype |
| elif isinstance(object, NDArray): |
| raise ValueError("If you're trying to create a mxnet.numpy.ndarray " |
| "from mx.nd.NDArray, please use the zero-copy as_np_ndarray function.") |
| else: |
| if dtype is None: |
| default_dtype = _np.float64 if is_np_default_dtype() else _np.float32 |
| dtype = object.dtype if hasattr(object, "dtype") else default_dtype |
| try: |
| object = _np.array(object, dtype=dtype) |
| except Exception as e: |
| # printing out the error raised by official NumPy's array function |
| # for transparency on users' side |
| raise TypeError('{}'.format(str(e))) |
| ret = empty(object.shape, dtype=dtype, device=device) |
| if len(object.shape) == 0: |
| ret[()] = object |
| else: |
| ret[:] = object |
| return ret |
| # pylint: enable=redefined-outer-name |
| |
| |
| @set_module('mxnet.numpy') |
| def shape(a): |
| """ |
| Return the shape of an array. |
| |
| Parameters |
| ---------- |
| a : array_like |
| Input array. |
| |
| Returns |
| ------- |
| shape : tuple of ints |
| The elements of the shape tuple give the lengths of the |
| corresponding array dimensions. |
| |
| See Also |
| -------- |
| ndarray.shape : Equivalent array method. |
| |
| Examples |
| -------- |
| >>> np.shape(np.eye(3)) |
| (3, 3) |
| >>> np.shape([[1, 2]]) |
| (1, 2) |
| >>> np.shape([0]) |
| (1,) |
| >>> np.shape(0) |
| () |
| """ |
| return _mx_nd_np.shape(a) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def zeros(shape, dtype=None, order='C', device=None): # pylint: disable=redefined-outer-name |
| """Return a new array of given shape and type, filled with zeros. |
| This function currently only supports storing multi-dimensional data |
| in row-major (C-style). |
| |
| Parameters |
| ---------- |
| shape : int or tuple of int |
| The shape of the empty array. |
| dtype : str or numpy.dtype, optional |
| An optional value type, |
| When npx.is_np_default_dtype() returns False, default dtype is float32, |
| When npx.is_np_default_dtype() returns True, default dtype is float64. |
| Note that this behavior is different from NumPy's `zeros` function where `float64` |
| is the default value, here we can set 'float32' or 'float64' as your default dtype, |
| because `float32` is considered as the default data type in deep learning. |
| order : {'C'}, optional, default: 'C' |
| How to store multi-dimensional data in memory, currently only row-major |
| (C-style) is supported. |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| |
| Returns |
| ------- |
| out : ndarray |
| Array of zeros with the given shape, dtype, and device. |
| |
| Examples |
| -------- |
| >>> np.zeros(5) |
| array([0., 0., 0., 0., 0.]) |
| |
| >>> np.zeros((5,), dtype=int) |
| array([0, 0, 0, 0, 0], dtype=int64) |
| |
| >>> np.zeros((2, 1)) |
| array([[0.], |
| [0.]]) |
| """ |
| return _mx_nd_np.zeros(shape, dtype, order, device) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def ones(shape, dtype=None, order='C', device=None): # pylint: disable=redefined-outer-name |
| """Return a new array of given shape and type, filled with ones. |
| This function currently only supports storing multi-dimensional data |
| in row-major (C-style). |
| |
| Parameters |
| ---------- |
| shape : int or tuple of int |
| The shape of the empty array. |
| dtype : str or numpy.dtype, optional |
| An optional value type. Default is depend on your current default dtype. |
| When npx.is_np_default_dtype() returns False, default dtype is float32; |
| When npx.is_np_default_dtype() returns True, default dtype is float64. |
| Note that this behavior is different from NumPy's `ones` function where |
| `float64` is the default value. |
| order : {'C'}, optional, default: 'C' |
| How to store multi-dimensional data in memory, currently only row-major |
| (C-style) is supported. |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| |
| Returns |
| ------- |
| out : ndarray |
| Array of ones with the given shape, dtype, and device. |
| |
| Examples |
| -------- |
| >>> np.ones(5) |
| array([1., 1., 1., 1., 1.]) |
| |
| >>> np.ones((5,), dtype=int) |
| array([1, 1, 1, 1, 1], dtype=int64) |
| |
| >>> np.ones((2, 1)) |
| array([[1.], |
| [1.]]) |
| |
| >>> s = (2,2) |
| >>> np.ones(s) |
| array([[1., 1.], |
| [1., 1.]]) |
| """ |
| return _mx_nd_np.ones(shape, dtype, order, device) |
| |
| |
| @set_module('mxnet.numpy') |
| def broadcast_to(array, shape): # pylint: disable=redefined-outer-name |
| """ |
| Broadcast an array to a new shape. |
| |
| Parameters |
| ---------- |
| array : ndarray or scalar |
| The array to broadcast. |
| shape : tuple |
| The shape of the desired array. |
| |
| Returns |
| ------- |
| broadcast : array |
| A readonly view on the original array with the given shape. It is |
| typically not contiguous. Furthermore, more than one element of a |
| broadcasted array may refer to a single memory location. |
| |
| Raises |
| ------ |
| MXNetError |
| If the array is not compatible with the new shape according to NumPy's |
| broadcasting rules. |
| """ |
| return _mx_nd_np.broadcast_to(array, shape) |
| |
| |
| # pylint: disable=too-many-arguments, redefined-outer-name |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def full(shape, fill_value, dtype=None, order='C', device=None, out=None): |
| r"""Return a new array of given shape and type, filled with `fill_value`. |
| |
| Parameters |
| ---------- |
| shape : int or sequence of ints |
| Shape of the new array, e.g., ``(2, 3)`` or ``2``. |
| fill_value : scalar or ndarray |
| Fill value. |
| dtype : data-type, optional |
| If dtype is None, the output array data type must be inferred from fill_value. |
| If it’s an int, the output array dtype must be the default integer dtype; |
| If it’s a float, then the output array dtype must be the default floating-point data type; |
| If it’s a bool then the output array must have boolean dtype. Default: None. |
| order : {'C'}, optional |
| Whether to store multidimensional data in C- or Fortran-contiguous |
| (row- or column-wise) order in memory. Currently only supports C order. |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| out : ndarray or None, optional |
| A location into which the result is stored. |
| If provided, it must have the same shape and dtype as input ndarray. |
| If not provided or `None`, a freshly-allocated array is returned. |
| |
| Returns |
| ------- |
| out : ndarray |
| Array of `fill_value` with the given shape, dtype, and order. |
| If `fill_value` is an ndarray, out will have the same device as `fill_value` |
| regardless of the provided `device`. |
| |
| .. note:: |
| This function differs from the original numpy.full in the following way(s): |
| |
| * Has an additional `device` argument to specify the device |
| * Has an additional `out` argument |
| * Currently does not support `order` selection |
| |
| See Also |
| -------- |
| empty : Return a new uninitialized array. |
| ones : Return a new array setting values to one. |
| zeros : Return a new array setting values to zero. |
| |
| Examples |
| -------- |
| >>> np.full((2, 2), 10) |
| array([[10., 10.], |
| [10., 10.]]) |
| >>> np.full((2, 2), 2, dtype=np.int32, device=mx.cpu(0)) |
| array([[2, 2], |
| [2, 2]], dtype=int32) |
| """ |
| return _mx_nd_np.full(shape, fill_value, order=order, device=device, dtype=dtype, out=out) |
| # pylint: enable=too-many-arguments, redefined-outer-name |
| |
| |
| # pylint: disable=redefined-outer-name, too-many-arguments |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def empty_like(prototype, dtype=None, device=None, order='C', subok=False, shape=None): # pylint: disable=W0621 |
| """ |
| Return a new array with the same shape and type as a given array. |
| |
| Parameters |
| ---------- |
| prototype : ndarray |
| The shape and data-type of `prototype` define these same attributes |
| of the returned array. |
| dtype : data-type, optional |
| Overrides the data type of the result. |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| order : {'C'}, optional |
| Whether to store multidimensional data in C- or Fortran-contiguous |
| (row- or column-wise) order in memory. Currently only supports C order. |
| subok : {False}, optional |
| If True, then the newly created array will use the sub-class |
| type of 'a', otherwise it will be a base-class array. Defaults |
| to False. |
| (Only support False at this moment) |
| shape : int or sequence of ints, optional. |
| Overrides the shape of the result. If order='K' and the number of |
| dimensions is unchanged, will try to keep order, otherwise, |
| order='C' is implied. |
| (Not supported at this moment) |
| |
| Returns |
| ------- |
| out : ndarray |
| Array of uninitialized (arbitrary) data with the same |
| shape and type as `prototype`. |
| |
| See Also |
| -------- |
| ones_like : Return an array of ones with shape and type of input. |
| zeros_like : Return an array of zeros with shape and type of input. |
| full_like : Return a new array with shape of input filled with value. |
| empty : Return a new uninitialized array. |
| |
| Notes |
| ----- |
| This function does *not* initialize the returned array; to do that use |
| `zeros_like` or `ones_like` instead. It may be marginally faster than |
| the functions that do set the array values. |
| |
| Examples |
| -------- |
| >>> a = np.array([[1,2,3], [4,5,6]]) |
| >>> np.empty_like(a) |
| array([[-5764607523034234880, -2305834244544065442, 4563075075], # uninitialized |
| [ 4567052944, -5764607523034234880, 844424930131968]]) |
| >>> a = np.array([[1., 2., 3.],[4.,5.,6.]]) |
| >>> np.empty_like(a) |
| array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized |
| [2.0e-323, 2.5e-323, 3.0e-323]]) |
| """ |
| ret = _mx_nd_np.empty_like(prototype, dtype=dtype, order=order, subok=subok, shape=shape) |
| if device is not None: |
| ret.to_device(device) |
| return ret |
| # pylint: enable=redefined-outer-name |
| |
| |
| # pylint: disable=redefined-outer-name |
| @set_module('mxnet.numpy') |
| def all(a, axis=None, out=None, keepdims=False): |
| """ |
| Test whether all array elements along a given axis evaluate to True. |
| |
| Parameters |
| ---------- |
| a : ndarray |
| Input array or object that can be converted to an array. |
| axis : None or int or tuple of ints, optional |
| Axis or axes along which a logical AND reduction is performed. |
| The default (axis = None) is to perform a logical AND over |
| all the dimensions of the input array. |
| keepdims : bool, optional |
| If this is set to True, the axes which are reduced are left in |
| the result as dimensions with size one. With this option, |
| the result will broadcast correctly against the input array. |
| out : ndarray, optional |
| Alternate output array in which to place the result. It must have |
| the same shape as the expected output and its type is preserved |
| |
| Returns |
| -------- |
| all : ndarray, bool |
| A new boolean or array is returned unless out is specified, |
| in which case a reference to out is returned. |
| |
| Examples: |
| --------- |
| >>> np.all([[True,False],[True,True]]) |
| False |
| |
| >>> np.all([[True,False],[True,True]], axis=0) |
| array([ True, False]) |
| |
| >>> np.all([-1, 4, 5]) |
| True |
| |
| >>> np.all([1.0, np.nan]) |
| True |
| |
| >>> o=np.array(False) |
| >>> z=np.all([-1, 4, 5], out=o) |
| >>> id(z), id(o), z |
| (28293632, 28293632, array(True)) # may vary |
| """ |
| return _mx_nd_np.all(a, axis=axis, out=out, keepdims=keepdims) |
| |
| |
| @set_module('mxnet.numpy') |
| def any(a, axis=None, out=None, keepdims=False): |
| """ |
| Test whether any array element along a given axis evaluates to True. |
| Returns single boolean unless axis is not None |
| |
| Parameters |
| ---------- |
| a : ndarray |
| Input array or object that can be converted to an array. |
| axis : None or int or tuple of ints, optional |
| Axis or axes along which a logical AND reduction is performed. |
| The default (axis = None) is to perform a logical AND over |
| all the dimensions of the input array. |
| keepdims : bool, optional |
| If this is set to True, the axes which are reduced are left in |
| the result as dimensions with size one. With this option, |
| the result will broadcast correctly against the input array. |
| out : ndarray, optional |
| Alternate output array in which to place the result. It must have |
| the same shape as the expected output and its type is preserved |
| |
| Returns |
| -------- |
| any : bool or ndarray |
| A new boolean or ndarray is returned unless out is specified, |
| in which case a reference to out is returned. |
| |
| Examples: |
| --------- |
| >>> np.any([[True, False], [True, True]]) |
| True |
| |
| >>> np.any([[True, False], [False, False]], axis=0) |
| array([ True, False]) |
| |
| >>> np.any([-1, 0, 5]) |
| True |
| |
| >>> np.any(np.nan) |
| True |
| |
| >>> o=np.array(False) |
| >>> z=np.any([-1, 4, 5], out=o) |
| >>> z, o |
| (array(True), array(True)) |
| >>> # Check now that z is a reference to o |
| >>> z is o |
| True |
| >>> id(z), id(o) # identity of z and o # doctest: +SKIP |
| (191614240, 191614240) |
| """ |
| return _mx_nd_np.any(a, axis=axis, out=out, keepdims=keepdims) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_ctx_to_device_func |
| def identity(n, dtype=None, device=None): |
| """ |
| Return the identity array. |
| |
| The identity array is a square array with ones on |
| the main diagonal. |
| |
| Parameters |
| ---------- |
| n : int |
| Number of rows (and columns) in `n` x `n` output. |
| dtype : data-type, optional |
| Data-type of the output. |
| When npx.is_np_default_dtype() returns False, default dtype is float32; |
| When npx.is_np_default_dtype() returns True, default dtype is float64. |
| device : Device, optional |
| Device context on which the memory is allocated. Default is |
| `mxnet.device.current_device()`. |
| |
| Returns |
| ------- |
| out : ndarray |
| `n` x `n` array with its main diagonal set to one, |
| and all other elements 0. |
| |
| Examples |
| -------- |
| >>> np.identity(3) |
| >>> np.identity(3) |
| array([[1., 0., 0.], |
| [0., 1., 0.], |
| [0., 0., 1.]]) |
| """ |
| return _mx_nd_np.identity(n, dtype, device) |
| # pylint: enable=redefined-outer-name |
| |
| |
| # pylint: disable=redefined-outer-name |
| @set_module('mxnet.numpy') |
| def take(a, indices, axis=None, mode='raise', out=None): |
| r""" |
| Take elements from an array along an axis. |
| |
| When axis is not None, this function does the same thing as "fancy" |
| indexing (indexing arrays using arrays); however, it can be easier to use |
| if you need elements along a given axis. A call such as |
| ``np.take(arr, indices, axis=3)`` is equivalent to |
| ``arr[:,:,:,indices,...]``. |
| |
| Explained without fancy indexing, this is equivalent to the following use |
| of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of |
| indices:: |
| |
| Ni, Nk = a.shape[:axis], a.shape[axis+1:] |
| Nj = indices.shape |
| for ii in ndindex(Ni): |
| for jj in ndindex(Nj): |
| for kk in ndindex(Nk): |
| out[ii + jj + kk] = a[ii + (indices[jj],) + kk] |
| |
| Parameters |
| ---------- |
| a : ndarray |
| The source array. |
| indices : ndarray |
| The indices of the values to extract. Also allow scalars for indices. |
| axis : int, optional |
| The axis over which to select values. By default, the flattened |
| input array is used. |
| out : ndarray, optional |
| If provided, the result will be placed in this array. It should |
| be of the appropriate shape and dtype. |
| mode : {'clip', 'wrap'}, optional |
| Specifies how out-of-bounds indices will behave. |
| |
| * 'clip' -- clip to the range (default) |
| * 'wrap' -- wrap around |
| |
| 'clip' mode means that all indices that are too large are replaced |
| by the index that addresses the last element along that axis. Note |
| that this disables indexing with negative numbers. |
| |
| Returns |
| ------- |
| out : ndarray |
| The returned array has the same type as `a`. |
| |
| .. note:: |
| |
| This function differs from the original `numpy.take |
| <https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in |
| the following way(s): |
| |
| * Only ndarray or scalar ndarray is accepted as valid input. |
| |
| Examples |
| -------- |
| >>> a = np.array([4, 3, 5, 7, 6, 8]) |
| >>> indices = np.array([0, 1, 4]) |
| >>> np.take(a, indices) |
| array([4., 3., 6.]) |
| |
| In this example for `a` is an ndarray, "fancy" indexing can be used. |
| |
| >>> a[indices] |
| array([4., 3., 6.]) |
| |
| If `indices` is not one dimensional, the output also has these dimensions. |
| |
| >>> np.take(a, np.array([[0, 1], [2, 3]])) |
| array([[4., 3.], |
| [5., 7.]]) |
| """ |
| return _mx_nd_np.take(a, indices, axis, mode, out) |
| # pylint: enable=redefined-outer-name |
| |
| |
| @set_module('mxnet.numpy') |
| def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): |
| """ |
| Find the unique elements of an array. |
| |
| Returns the sorted unique elements of an array. There are three optional |
| outputs in addition to the unique elements: |
| |
| * the indices of the input array that give the unique values |
| * the indices of the unique array that reconstruct the input array |
| * the number of times each unique value comes up in the input array |
| |
| Parameters |
| ---------- |
| ar : ndarray |
| Input array. Unless `axis` is specified, this will be flattened if it |
| is not already 1-D. |
| return_index : bool, optional |
| If True, also return the indices of `ar` (along the specified axis, |
| if provided, or in the flattened array) that result in the unique array. |
| return_inverse : bool, optional |
| If True, also return the indices of the unique array (for the specified |
| axis, if provided) that can be used to reconstruct `ar`. |
| return_counts : bool, optional |
| If True, also return the number of times each unique item appears |
| in `ar`. |
| axis : int or None, optional |
| The axis to operate on. If None, `ar` will be flattened. If an integer, |
| the subarrays indexed by the given axis will be flattened and treated |
| as the elements of a 1-D array with the dimension of the given axis, |
| see the notes for more details. The default is None. |
| |
| Returns |
| ------- |
| unique : ndarray |
| The sorted unique values. |
| unique_indices : ndarray, optional |
| The indices of the first occurrences of the unique values in the |
| original array. Only provided if `return_index` is True. |
| unique_inverse : ndarray, optional |
| The indices to reconstruct the original array from the |
| unique array. Only provided if `return_inverse` is True. |
| unique_counts : ndarray, optional |
| The number of times each of the unique values comes up in the |
| original array. Only provided if `return_counts` is True. |
| |
| .. note:: |
| |
| When an axis is specified the subarrays indexed by the axis are sorted. |
| This is done by making the specified axis the first dimension of the array |
| and then flattening the subarrays in C order. The flattened subarrays are |
| then viewed as a structured type with each element given a label, with the |
| effect that we end up with a 1-D array of structured types that can be |
| treated in the same way as any other 1-D array. The result is that the |
| flattened subarrays are sorted in lexicographic order starting with the |
| first element. |
| |
| This function differs from the original `numpy.unique |
| <https://docs.scipy.org/doc/numpy/reference/generated/numpy.unique.html>`_ in |
| the following aspects: |
| |
| * Only support ndarray as input. |
| * Object arrays or structured arrays are not supported. |
| |
| Examples |
| -------- |
| >>> np.unique(np.array([1, 1, 2, 2, 3, 3])) |
| array([1., 2., 3.]) |
| >>> a = np.array([[1, 1], [2, 3]]) |
| >>> np.unique(a) |
| array([1., 2., 3.]) |
| |
| Return the unique rows of a 2D array |
| |
| >>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]]) |
| >>> np.unique(a, axis=0) |
| array([[1., 0., 0.], |
| [2., 3., 4.]]) |
| |
| Return the indices of the original array that give the unique values: |
| |
| >>> a = np.array([1, 2, 6, 4, 2, 3, 2]) |
| >>> u, indices = np.unique(a, return_index=True) |
| >>> u |
| array([1., 2., 3., 4., 6.]) |
| >>> indices |
| array([0, 1, 5, 3, 2], dtype=int64) |
| >>> a[indices] |
| array([1., 2., 3., 4., 6.]) |
| |
| Reconstruct the input array from the unique values: |
| |
| >>> a = np.array([1, 2, 6, 4, 2, 3, 2]) |
| >>> u, indices = np.unique(a, return_inverse=True) |
| >>> u |
| array([1., 2., 3., 4., 6.]) |
| >>> indices |
| array([0, 1, 4, 3, 1, 2, 1], dtype=int64) |
| >>> u[indices] |
| array([1., 2., 6., 4., 2., 3., 2.]) |
| """ |
| return _mx_nd_np.unique(ar, return_index, return_inverse, return_counts, axis) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def add(x1, x2, out=None, **kwargs): |
| """ |
| Add arguments element-wise. |
| |
| Parameters |
| ---------- |
| x1, x2 : ndarrays or scalar values |
| The arrays to be added. If x1.shape != x2.shape, they must be broadcastable to |
| a common shape (which may be the shape of one or the other). |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. |
| |
| .. note:: |
| |
| This operator now supports automatic type promotion. The resulting type will be determined |
| according to the following rules: |
| * If both inputs are of floating number types, the output is the more precise type. |
| * If only one of the inputs is floating number type, the result is that type. |
| * If both inputs are of integer types (including boolean), not supported yet. |
| |
| Examples |
| -------- |
| >>> np.add(1.0, 4.0) |
| 5.0 |
| >>> |
| >>> x1 = np.arange(9.0).reshape((3, 3)) |
| >>> x2 = np.arange(3.0) |
| >>> np.add(x1, x2) |
| array([[ 0., 2., 4.], |
| [ 3., 5., 7.], |
| [ 6., 8., 10.]]) |
| """ |
| return _mx_nd_np.add(x1, x2, out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def subtract(x1, x2, out=None, **kwargs): |
| r"""Subtract arguments element-wise. |
| |
| Parameters |
| ---------- |
| x1, x2 : ndarrays or scalar values |
| The arrays to be subtracted from each other. If x1.shape != x2.shape, |
| they must be broadcastable to a common shape (which may be the shape |
| of one or the other). |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| subtract : ndarray or scalar |
| The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. |
| |
| .. note:: |
| This operator now supports automatic type promotion. The resulting type will be determined |
| according to the following rules: |
| * If both inputs are of floating number types, the output is the more precise type. |
| * If only one of the inputs is floating number type, the result is that type. |
| * If both inputs are of integer types (including boolean), not supported yet. |
| |
| Examples |
| -------- |
| >>> np.subtract(1.0, 4.0) |
| -3.0 |
| >>> x1 = np.arange(9.0).reshape((3, 3)) |
| >>> x2 = np.arange(3.0) |
| >>> np.subtract(x1, x2) |
| array([[0., 0., 0.], |
| [3., 3., 3.], |
| [6., 6., 6.]]) |
| """ |
| return _mx_nd_np.subtract(x1, x2, out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def multiply(x1, x2, out=None, **kwargs): |
| """ |
| Multiply arguments element-wise. |
| |
| Parameters |
| ---------- |
| x1, x2 : ndarrays or scalar values |
| The arrays to be multiplied. If x1.shape != x2.shape, they must be broadcastable to |
| a common shape (which may be the shape of one or the other). |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. |
| |
| .. note:: |
| This operator now supports automatic type promotion. The resulting type will be determined |
| according to the following rules: |
| |
| * If both inputs are of floating number types, the output is the more precise type. |
| * If only one of the inputs is floating number type, the result is that type. |
| * If both inputs are of integer types (including boolean), not supported yet. |
| |
| Examples |
| -------- |
| >>> np.multiply(2.0, 4.0) |
| 8.0 |
| >>> x1 = np.arange(9.0).reshape((3, 3)) |
| >>> x2 = np.arange(3.0) |
| >>> np.multiply(x1, x2) |
| array([[ 0., 1., 4.], |
| [ 0., 4., 10.], |
| [ 0., 7., 16.]]) |
| """ |
| return _mx_nd_np.multiply(x1, x2, out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def divide(x1, x2, out=None, **kwargs): |
| """Returns a true division of the inputs, element-wise. |
| |
| .. note:: |
| This operator now supports automatic type promotion. The resulting type will be determined |
| according to the following rules: |
| |
| * If both inputs are of floating number types, the output is the more precise type. |
| * If only one of the inputs is floating number type, the result is that type. |
| * If both inputs are of integer types including boolean, the output is of float32 or |
| float64 type, which depends on your current default dtype: |
| |
| * When ``npx.is_np_default_dtype()`` returns False, default dtype is float32. |
| * When ``npx.is_np_default_dtype()`` returns True, default dtype is float64. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| Dividend array. |
| x2 : ndarray or scalar |
| Divisor array. |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| This is a scalar if both x1 and x2 are scalars. |
| |
| Examples |
| -------- |
| >>> np.true_divide(x, 4) |
| array([0. , 0.25, 0.5 , 0.75, 1. ]) |
| """ |
| return _mx_nd_np.divide(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| def true_divide(x1, x2, out=None): |
| """Returns a true division of the inputs, element-wise. |
| |
| Instead of the Python traditional 'floor division', this returns a true |
| division. True division adjusts the output type to present the best |
| answer, regardless of input types. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| Dividend array. |
| x2 : ndarray or scalar |
| Divisor array. |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| This is a scalar if both x1 and x2 are scalars. |
| |
| .. note:: |
| |
| This operator now supports automatic type promotion. The resulting type will be determined |
| according to the following rules: |
| |
| * If both inputs are of floating number types, the output is the more precise type. |
| * If only one of the inputs is floating number type, the result is that type. |
| * If both inputs are of integer types (including boolean), the output is of float32 or |
| float64 type, which depends on your current default dtype. |
| When npx.is_np_default_dtype() returns False, default dtype is float32; |
| When npx.is_np_default_dtype() returns True, default dtype is float64. |
| |
| Examples |
| -------- |
| >>> x = np.arange(5) |
| >>> np.true_divide(x, 4) |
| array([0. , 0.25, 0.5 , 0.75, 1. ]) |
| """ |
| return _mx_nd_np.true_divide(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def floor_divide(x1, x2, out=None): |
| """Return the largest integer smaller or equal to the division of the inputs. |
| |
| It is equivalent to the Python // operator and pairs with the Python % (remainder), |
| function so that a = a % b + b * (a // b) up to roundoff. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| Dividend array. |
| x2 : ndarray or scalar |
| Divisor array. |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| This is a scalar if both x1 and x2 are scalars. |
| |
| .. note:: |
| |
| This operator now supports automatic type promotion. The resulting type will be determined |
| according to the following rules: |
| |
| * If both inputs are of floating number types, the output is the more precise type. |
| * If only one of the inputs is floating number type, the result is that type. |
| * If both inputs are of integer types (including boolean), the output is the more |
| precise type |
| |
| Examples |
| -------- |
| >>> np.floor_divide(7,3) |
| 2 |
| >>> np.floor_divide([1., 2., 3., 4.], 2.5) |
| array([ 0., 0., 1., 1.]) |
| """ |
| return _mx_nd_np.floor_divide(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def mod(x1, x2, out=None, **kwargs): |
| """ |
| Return element-wise remainder of division. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| Dividend array. |
| |
| x2 : ndarray or scalar |
| Divisor array. |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| This is a scalar if both x1 and x2 are scalars. |
| |
| Examples |
| -------- |
| >>> np.mod(np.arange(7), 5) |
| array([0., 1., 2., 3., 4., 0., 1.]) |
| """ |
| return _mx_nd_np.mod(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def fmod(x1, x2, out=None, **kwargs): |
| """ |
| Return element-wise remainder of division. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| Dividend array. |
| |
| x2 : ndarray or scalar |
| Divisor array. |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| This is a scalar if both x1 and x2 are scalars. |
| |
| Examples |
| -------- |
| >>> np.fmod(np.arange(7), 5) |
| array([0., 1., 2., 3., 4., 0., 1.]) |
| """ |
| return _mx_nd_np.fmod(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def matmul(a, b, out=None, **kwargs): |
| r"""Matrix product of two arrays. |
| |
| Parameters |
| ---------- |
| a, b : ndarray |
| Input arrays, scalars not allowed. |
| out : ndarray, optional |
| A location into which the result is stored. |
| If provided, it must have a shape that matches the signature (n,k),(k,m)->(n,m). |
| If not provided or None, a freshly-allocated array is returned. |
| |
| Returns |
| ------- |
| y : ndarray |
| The matrix product of the inputs. |
| This is a scalar only when both x1, x2 are 1-d vectors. |
| |
| Raises |
| ------ |
| MXNetError |
| If the last dimension of a is not the same size as the second-to-last dimension of b. |
| If a scalar value is passed in. |
| |
| See Also |
| -------- |
| tensordot : Sum products over arbitrary axes. |
| dot : alternative matrix product with different broadcasting rules. |
| einsum : Einstein summation convention. |
| |
| .. note:: |
| |
| The behavior depends on the arguments in the following way. |
| |
| * If both arguments are ``2-D`` they are multiplied like conventional matrices. |
| * If either argument is ``N-D``, ``N > 2``, it is treated as a stack of matrices |
| residing in the last two indexes and broadcast accordingly. |
| * If the first argument is ``1-D``, it is promoted to a matrix by prepending |
| a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. |
| * If the second argument is ``1-D``, it is promoted to a matrix by appending a 1 |
| to its dimensions. After matrix multiplication the appended 1 is removed. |
| |
| matmul differs from dot in two important ways: |
| |
| * Multiplication by scalars is not allowed, use multiply instead. |
| * Stacks of matrices are broadcast together as if the matrices were elements, |
| respecting the signature ``(n,k),(k,m)->(n,m)``: |
| |
| >>> a = np.ones([9, 5, 7, 4]) |
| >>> c = np.ones([9, 5, 4, 3]) |
| >>> np.dot(a, c).shape |
| (9, 5, 7, 9, 5, 3) |
| >>> np.matmul(a, c).shape |
| (9, 5, 7, 3) |
| >>> # n is 7, k is 4, m is 3 |
| |
| Examples |
| -------- |
| For 2-D arrays it is the matrix product: |
| |
| >>> a = np.array([[1, 0], |
| ... [0, 1]]) |
| >>> b = np.array([[4, 1], |
| ... [2, 2]]) |
| >>> np.matmul(a, b) |
| array([[4., 1.], |
| [2., 2.]]) |
| |
| For 2-D mixed with 1-D, the result is the usual. |
| |
| >>> a = np.array([[1, 0], |
| ... [0, 1]]) |
| >>> b = np.array([1, 2]) |
| >>> np.matmul(a, b) |
| array([1., 2.]) |
| >>> np.matmul(b, a) |
| array([1., 2.]) |
| |
| Broadcasting is conventional for stacks of arrays |
| |
| >>> a = np.arange(2 * 2 * 4).reshape((2, 2, 4)) |
| >>> b = np.arange(2 * 2 * 4).reshape((2, 4, 2)) |
| >>> np.matmul(a, b).shape |
| (2, 2, 2) |
| >>> np.matmul(a, b)[0, 1, 1] |
| array(98.) |
| >>> sum(a[0, 1, :] * b[0, :, 1]) |
| array(98.) |
| |
| Scalar multiplication raises an error. |
| |
| >>> np.matmul([1, 2], 3) |
| Traceback (most recent call last): |
| ... |
| mxnet.base.MXNetError: ... : Multiplication by scalars is not allowed. |
| |
| """ |
| return _mx_nd_np.matmul(a, b, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def remainder(x1, x2, out=None, **kwargs): |
| """ |
| Return element-wise remainder of division. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| Dividend array. |
| |
| x2 : ndarray or scalar |
| Divisor array. |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| This is a scalar if both x1 and x2 are scalars. |
| |
| Examples |
| -------- |
| >>> np.remainder(np.arange(7), 5) |
| array([0., 1., 2., 3., 4., 0., 1.]) |
| """ |
| return _mx_nd_np.remainder(x1, x2, out=out) |
| |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def power(x1, x2, out=None, **kwargs): |
| """ |
| First array elements raised to powers from second array, element-wise. |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| The bases. |
| |
| x2 : ndarray or scalar |
| The exponent. |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| The bases in x1 raised to the exponents in x2. |
| This is a scalar if both x1 and x2 are scalars. |
| |
| Examples |
| -------- |
| >>> x1 = np.arange(6) |
| >>> np.power(x1, 3) |
| array([ 0., 1., 8., 27., 64., 125.]) |
| |
| Raise the bases to different exponents. |
| |
| >>> x2 = np.array([1.0, 2.0, 3.0, 3.0, 2.0, 1.0]) |
| >>> np.power(x1, x2) |
| array([ 0., 1., 8., 27., 16., 5.]) |
| |
| The effect of broadcasting. |
| |
| >>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) |
| >>> x2 |
| array([[1., 2., 3., 3., 2., 1.], |
| [1., 2., 3., 3., 2., 1.]]) |
| |
| >>> np.power(x1, x2) |
| array([[ 0., 1., 8., 27., 16., 5.], |
| [ 0., 1., 8., 27., 16., 5.]]) |
| """ |
| return _mx_nd_np.power(x1, x2, out=out) |
| |
| pow = power |
| pow.__doc_ = """ |
| First array elements raised to powers from second array, element-wise. |
| |
| Notes |
| ----- |
| `pow` is an alias for `power`. It is a standard API in |
| https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#pow-x1-x2 |
| instead of an official NumPy operator. |
| |
| >>> np.pow is np.power |
| True |
| |
| Parameters |
| ---------- |
| x1 : ndarray or scalar |
| The bases. |
| |
| x2 : ndarray or scalar |
| The exponent. |
| |
| out : ndarray |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| out : ndarray or scalar |
| The bases in x1 raised to the exponents in x2. |
| This is a scalar if both x1 and x2 are scalars. |
| |
| Examples |
| -------- |
| >>> x1 = np.arange(6) |
| >>> np.pow(x1, 3) |
| array([ 0., 1., 8., 27., 64., 125.]) |
| |
| Raise the bases to different exponents. |
| |
| >>> x2 = np.array([1.0, 2.0, 3.0, 3.0, 2.0, 1.0]) |
| >>> np.pow(x1, x2) |
| array([ 0., 1., 8., 27., 16., 5.]) |
| |
| The effect of broadcasting. |
| |
| >>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) |
| >>> x2 |
| array([[1., 2., 3., 3., 2., 1.], |
| [1., 2., 3., 3., 2., 1.]]) |
| |
| >>> np.pow(x1, x2) |
| array([[ 0., 1., 8., 27., 16., 5.], |
| [ 0., 1., 8., 27., 16., 5.]]) |
| """ |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def gcd(x1, x2, out=None, **kwargs): |
| """ |
| Returns the greatest common divisor of ``|x1|`` and ``|x2|`` |
| |
| Parameters |
| ---------- |
| x1, x2 : ndarrays or scalar values |
| The arrays for computing greatest common divisor. If x1.shape != x2.shape, |
| they must be broadcastable to a common shape (which may be the shape of |
| one or the other). |
| |
| out : ndarray or None, optional |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| y : ndarray or scalar |
| The greatest common divisor of the absolute value of the inputs |
| This is a scalar if both `x1` and `x2` are scalars. |
| |
| See Also |
| -------- |
| gcd : The lowest common multiple |
| |
| Examples |
| -------- |
| >>> np.gcd(12, 20) |
| 4 |
| >>> np.gcd(np.arange(6, dtype=int), 20) |
| array([20, 1, 2, 1, 4, 5], dtype=int64) |
| """ |
| return _mx_nd_np.gcd(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_binary_func |
| def lcm(x1, x2, out=None, **kwargs): |
| """ |
| Returns the lowest common multiple of ``|x1|`` and ``|x2|`` |
| |
| Parameters |
| ---------- |
| x1, x2 : ndarrays or scalar values |
| The arrays for computing lowest common multiple. If x1.shape != x2.shape, |
| they must be broadcastable to a common shape (which may be the shape of |
| one or the other). |
| |
| out : ndarray or None, optional |
| A location into which the result is stored. If provided, it must have a shape |
| that the inputs broadcast to. If not provided or None, a freshly-allocated array |
| is returned. |
| |
| Returns |
| ------- |
| y : ndarray or scalar |
| The lowest common multiple of the absolute value of the inputs |
| This is a scalar if both `x1` and `x2` are scalars. |
| |
| See Also |
| -------- |
| gcd : The greatest common divisor |
| |
| Examples |
| -------- |
| >>> np.lcm(12, 20) |
| 60 |
| >>> np.lcm(np.arange(6, dtype=int), 20) |
| array([ 0, 20, 20, 60, 20, 20], dtype=int64) |
| """ |
| return _mx_nd_np.lcm(x1, x2, out=out) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_unary_func |
| def sin(x, out=None, **kwargs): |
| r""" |
| Trigonometric sine, element-wise. |
| |
| Parameters |
| ---------- |
| x : ndarray or scalar |
| Angle, in radians (:math:`2 \pi` rad equals 360 degrees). |
| out : ndarray or None |
| A location into which the result is stored. If provided, it |
| must have a shape that the inputs broadcast to. If not provided |
| or None, a freshly-allocated array is returned. The dtype of the |
| output is the same as that of the input if the input is an ndarray. |
| |
| Returns |
| ------- |
| y : ndarray or scalar |
| The sine of each element of x. This is a scalar if `x` is a scalar. |
| |
| Notes |
| ---- |
| This function only supports input type of float. |
| |
| Examples |
| -------- |
| >>> np.sin(np.pi/2.) |
| 1.0 |
| >>> np.sin(np.array((0., 30., 45., 60., 90.)) * np.pi / 180.) |
| array([0. , 0.5 , 0.70710677, 0.86602545, 1. ]) |
| """ |
| return _mx_nd_np.sin(x, out=out, **kwargs) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_unary_func |
| def cos(x, out=None, **kwargs): |
| r""" |
| Cosine, element-wise. |
| |
| Parameters |
| ---------- |
| x : ndarray or scalar |
| Angle, in radians (:math:`2 \pi` rad equals 360 degrees). |
| out : ndarray or None |
| A location into which the result is stored. If provided, it |
| must have a shape that the inputs broadcast to. If not provided |
| or None, a freshly-allocated array is returned. The dtype of the |
| output is the same as that of the input if the input is an ndarray. |
| |
| Returns |
| ------- |
| y : ndarray or scalar |
| The corresponding cosine values. This is a scalar if x is a scalar. |
| |
| Notes |
| ---- |
| This function only supports input type of float. |
| |
| Examples |
| -------- |
| >>> np.cos(np.array([0, np.pi/2, np.pi])) |
| array([ 1.000000e+00, -4.371139e-08, -1.000000e+00]) |
| >>> # Example of providing the optional output parameter |
| >>> out1 = np.array([0], dtype='f') |
| >>> out2 = np.cos(np.array([0.1]), out1) |
| >>> out2 is out1 |
| True |
| """ |
| return _mx_nd_np.cos(x, out=out, **kwargs) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_unary_func |
| def sinh(x, out=None, **kwargs): |
| """ |
| Hyperbolic sine, element-wise. |
| Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``. |
| |
| Parameters |
| ---------- |
| x : ndarray or scalar |
| Input array or scalar. |
| out : ndarray or None |
| A location into which the result is stored. If provided, it |
| must have a shape that the inputs broadcast to. If not provided |
| or None, a freshly-allocated array is returned. The dtype of the |
| output is the same as that of the input if the input is an ndarray. |
| |
| Returns |
| ------- |
| y : ndarray or scalar |
| The corresponding hyperbolic sine values. This is a scalar if `x` is a scalar. |
| |
| Notes |
| ---- |
| This function only supports input type of float. |
| |
| Examples |
| -------- |
| >>> np.sinh(0) |
| 0.0 |
| >>> # Example of providing the optional output parameter |
| >>> out1 = np.array([0], dtype='f') |
| >>> out2 = np.sinh(np.array([0.1]), out1) |
| >>> out2 is out1 |
| True |
| """ |
| return _mx_nd_np.sinh(x, out=out, **kwargs) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_unary_func |
| def cosh(x, out=None, **kwargs): |
| """ |
| Hyperbolic cosine, element-wise. |
| Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``. |
| |
| Parameters |
| ---------- |
| x : ndarray or scalar |
| Input array or scalar. |
| out : ndarray or None |
| A location into which the result is stored. If provided, it |
| must have a shape that the inputs broadcast to. If not provided |
| or None, a freshly-allocated array is returned. The dtype of the |
| output is the same as that of the input if the input is an ndarray. |
| |
| Returns |
| ------- |
| y : ndarray or scalar |
| The corresponding hyperbolic cosine values. This is a scalar if `x` is a scalar. |
| |
| Notes |
| ---- |
| This function only supports input type of float. |
| |
| Examples |
| -------- |
| >>> np.cosh(0) |
| 1.0 |
| """ |
| return _mx_nd_np.cosh(x, out=out, **kwargs) |
| |
| |
| @set_module('mxnet.numpy') |
| @wrap_np_unary_func |
| def tanh(x, out=None, **kwargs): |
| """ |
| Compute hyperbolic tangent element-wise. |
| Equivalent to ``np.sinh(x)/np.cosh(x)``. |
| |
| Parameters |
| ---------- |
| x : ndarray or scalar. |
| Input array. |
| out : ndarray or None |
| A location into which the result is stored. If provided, it |
| must have a shape that the inputs fill into. If not provided |
| or None, a freshly-allocated array is returned. The dtype of the |
| output and input must be the same. |
| |
| Returns |
| ---------- |
| y : ndarray or scalar |
| The corresponding hyperbolic tangent values. |
| |
| .. note:: |
| If `out` is provided, the function writes the result into it, |
| and returns a reference to `out`. (See Examples) |
| |
| * input x does not support complex computation (like imaginary number) |
| |
| >>> np.tanh(np.pi*1j) |
| TypeError: type <type 'complex'> not supported |
| |
| Examples |
| -------- |
|