blob: ed96795fface9f2de705877d1cfc448fa85a6caa [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines, unused-argument
"""numpy ndarray and util functions."""
try:
from __builtin__ import all as py_all
from __builtin__ import slice as py_slice
except ImportError:
from builtins import all as py_all
from builtins import slice as py_slice
from array import array as native_array
import functools
import ctypes
import sys
import datetime
import warnings
import numpy as _np
from .. import _deferred_compute as dc
from ..autograd import is_recording
from ..ndarray import NDArray, dtype_np_to_mx, _GRAD_REQ_MAP
from ..ndarray import indexing_key_expand_implicit_axes, get_indexing_dispatch_code,\
get_oshape_of_gather_nd_op
from ..ndarray._internal import _set_np_ndarray_class
from . import _op as _mx_np_op
from ..base import check_call, _LIB, NDArrayHandle, c_array, mx_int, mx_int64
from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types
from ..runtime import Features
from ..device import Device
from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\
is_np_default_dtype, wrap_ctx_to_device_func,\
dtype_from_number, wrap_data_api_statical_func,\
wrap_sort_functions
from ..device import current_device
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type
from ..dlpack import ndarray_from_numpy, ndarray_to_dlpack_for_write, DLDeviceType,\
ndarray_from_dlpack
from .utils import _get_np_op
from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback
__all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape', 'median',
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'all', 'any', 'broadcast_to',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'pow', 'power', 'bitwise_not',
'delete', 'trace', 'transpose', 'copy', 'moveaxis', 'reshape', 'dot',
'arctan2', 'atan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'bitwise_invert', 'invert',
'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'asin', 'arccos', 'acos', 'arctan',
'atan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square',
'negative', 'histogram', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'asinh',
'arccosh', 'acosh', 'arctanh', 'atanh', 'append', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit',
'dsplit', 'flatnonzero', 'tril_indices', 'concatenate', 'concat', 'stack', 'vstack', 'row_stack',
'column_stack', 'hstack', 'dstack', 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin',
'amax', 'amin', 'max', 'min', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert',
'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman',
'logical_and', 'logical_or', 'logical_xor',
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'triu_indices_from', 'triu_indices', 'tri',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'gcd', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'cross', 'kron', 'equal', 'not_equal', 'interp',
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal',
'positive', 'logaddexp', 'floor_divide', 'permute_dims', 'bitwise_left_shift', 'bitwise_right_shift',
'asarray', 'from_dlpack']
__all__ += fallback.__all__
# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1
_NDARRAY_EMPTY_TUPLE_INDEXING = 2
# Return code for 0-d boolean array handler
_NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1
_NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0
_NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1
_SIGNED_INT32_UPPER_LIMIT = (2**31 - 1)
# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None
def _int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
return _INT64_TENSOR_SIZE_ENABLED
# This function is copied from ndarray.py since pylint
# keeps giving false alarm error of undefined-all-variable
def _new_alloc_handle(shape, device, delay_alloc, dtype=mx_real_t): # pylint: disable=redefined-outer-name
"""Return a new handle with specified shape and device.
Empty handle is only used to hold results.
Returns
-------
handle
A new empty `ndarray` handle.
"""
hdl = NDArrayHandle()
if _int64_enabled():
check_call(_LIB.MXNDArrayCreate64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
ctypes.c_int(device.device_typeid),
ctypes.c_int(device.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(dtype_np_to_mx(dtype))),
ctypes.byref(hdl)))
else:
# When shape is larger than uint32 then there is an overflow error at python end itself.
# It needs to be caught here since the call doesn't even reach backend.
array_size = 1
for idx in shape:
array_size = array_size * idx
if array_size > _SIGNED_INT32_UPPER_LIMIT:
raise Exception("[_new_alloc_handle] Size of tensor you are trying to allocate is " +
"larger than 2^31 elements. Please build with flag " +
"USE_INT64_TENSOR_SIZE=1")
check_call(_LIB.MXNDArrayCreate(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(device.device_typeid),
ctypes.c_int(device.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(dtype_np_to_mx(dtype))),
ctypes.byref(hdl)))
return hdl
def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name
"""Returns a **view** of this array with a new shape without altering any data.
Parameters
----------
shape : tuple of int, or n ints
The new shape should not change the array size, namely
``np.prod(new_shape)`` should be equal to ``np.prod(a.shape)``.
Some dimensions of the shape can take special value -1, which
infers the dimension of the output shape by using the remainder of the
input dimensions keeping the size of the new array same as that of the input array.
At most one dimension of shape can be -1.
Returns
-------
ndarray
An array with desired shape that shares data with this array.
"""
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
shape = shape[0]
handle = NDArrayHandle()
check_call(_LIB.MXNDArrayReshape64(a.handle,
len(shape),
c_array(ctypes.c_int64, shape),
False,
ctypes.byref(handle)))
return ndarray(handle=handle, writable=a.writable)
def _as_mx_np_array(object, device=None, zero_copy=False):
"""Convert arrays or any array member of container to mxnet.numpy.ndarray on device."""
if object is None or isinstance(object, ndarray):
return object
elif isinstance(object, _np.ndarray):
from_numpy = ndarray_from_numpy(ndarray, array)
return from_numpy(object, zero_copy and object.flags['C_CONTIGUOUS'])
elif isinstance(object, (integer_types, numeric_types)):
return object
elif isinstance(object, (_np.bool_, _np.bool)):
return array(object, dtype=_np.bool_, device=device)
elif isinstance(object, (list, tuple)):
tmp = [_as_mx_np_array(arr, device=device, zero_copy=zero_copy) for arr in object]
return object.__class__(tmp)
else:
raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))
def _as_onp_array(object, cur_device=None):
"""Convert object to numpy.ndarray."""
def _update_device(cur_device, tmp_device):
if cur_device is None:
cur_device = tmp_device
elif tmp_device is not None and cur_device != tmp_device:
raise ValueError('Ambiguous to set the device for the output ndarray since' # pylint: disable=too-few-format-args
' input ndarrays are allocated on different devices: {} and {}'
.format(str(cur_device, tmp_device)))
return cur_device
if isinstance(object, ndarray):
return object.asnumpy(), object.device
elif isinstance(object, (list, tuple)):
tmp = []
for arr in object:
arr, tmp_device = _as_onp_array(arr, cur_device)
tmp.append(arr)
cur_device = _update_device(cur_device, tmp_device)
return object.__class__(tmp), cur_device
elif isinstance(object, dict):
tmp = dict()
for key, value in object.items():
value, tmp_device = _as_onp_array(value, cur_device)
tmp[key] = value
cur_device = _update_device(cur_device, tmp_device)
return object.__class__(tmp), cur_device
else:
return object, cur_device
# Have to use 0 as default value for stype since pylint does not allow
# importing _STORAGE_TYPE_DEFAULT from ndarray.py.
def _np_ndarray_cls(handle, writable=True, stype=0):
if stype == -1:
stype = _storage_type(handle)
if stype != 0:
raise ValueError('_np_ndarray_cls currently only supports default storage '
'type, while received stype = {}'.format(stype))
return ndarray(handle, writable=writable)
_set_np_ndarray_class(_np_ndarray_cls)
_NUMPY_ARRAY_FUNCTION_DICT = {}
_NUMPY_ARRAY_UFUNC_DICT = {}
_FALLBACK_ARRAY_FUNCTION_WARNED_RECORD = {}
_FALLBACK_ARRAY_UFUNC_WARNED_RECORD = {}
def wrap_mxnp_np_ufunc(func):
"""
A convenience decorator for wrapping for python overload-able ops to provide type
casting for mixed use of mx_np and onp inputs.
Parameters
----------
func : a python overload-able binary function to be wrapped for type casting.
Returns
-------
Function
A function wrapped with type casted.
"""
@functools.wraps(func)
def _wrap_mxnp_np_ufunc(x1, x2):
if isinstance(x2, _np.ndarray):
x2 = _as_mx_np_array(x2, device=x1.device)
return func(x1, x2)
return _wrap_mxnp_np_ufunc
@set_module('mxnet.numpy')
class ndarray(NDArray): # pylint: disable=invalid-name
"""
ndarray(handle, writable=True):
An array object represents a multidimensional, homogeneous array of fixed-size items.
An associated data-type object describes the format of each element in the array
(its byte-order, how many bytes it occupies in memory, whether it is an integer, a
floating point number, or something else, etc.). Arrays should be constructed using
`array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported.
Arrays should be constructed using `array`, `zeros` or `empty` (refer
to the See Also section below). The parameters given here refer to
a low-level method (`ndarray(...)`) for instantiating an array.
For more information, refer to the `mxnet.numpy` module and examine the
methods and attributes of an array.
Parameters
----------
handle: int
The ndarray handle in backend (C++).
writable: bool
Indicates whether inplace-assignment is allowed for the array.
Attributes
----------
T : ndarray
Transpose of the array.
dtype : dtype object
Describes the format of the elements in the array.
size : int
Number of elements in the array.
ndim : int
The array's number of dimensions.
shape : tuple of ints
Shape of the array.
See Also
--------
array : Construct an array.
zeros : Create an array, each element of which is zero.
empty : Create an array, but leave its allocated memory unchanged (i.e.,
it contains "garbage").
"""
@staticmethod
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pylint: disable=bad-staticmethod-argument
"""
Dispatch official NumPy unary/binary operator calls on mxnet.numpy.ndarray
to this function. The operators must comply with the ufunc definition in NumPy.
The following code is adapted from CuPy.
Casting rules for operator with mx_np and onp (inplace op will keep its type)
| Expression | a type | b type | out type|
| --- | --- | --- | --- |
| `a += b` | onp | mx_np | onp |
| `a += b` | mx_np | onp | mx_np |
| `c = a + b` | onp | mx_np | mx_np |
| `c = a + b` | mx_np | onp | mx_np |
"""
ufunc_list = ["add", "subtract", "multiply", "divide", "true_divide", "floor_divide", "power",
"remainder", "bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift",
"greater", "greater_equal", "less", "less_equal", "not_equal", "equal", "matmul"]
if 'out' in kwargs:
# need to unfold tuple argument in kwargs
out = kwargs['out']
if len(out) != 1:
raise ValueError('The `out` parameter must have exactly one ndarray')
kwargs['out'] = out[0]
if method == '__call__':
name = ufunc.__name__
mx_ufunc = _NUMPY_ARRAY_UFUNC_DICT.get(name, None)
onp_op = _get_np_op(name)
if mx_ufunc is None:
# try to fallback to official NumPy op
if is_recording():
raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
"Please consider moving the operator to the outside of the autograd scope.")\
.format(name)
new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs]
if onp_op not in _FALLBACK_ARRAY_UFUNC_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
"which is actually using official numpy's implementation", name)
_FALLBACK_ARRAY_UFUNC_WARNED_RECORD[onp_op] = True
out = onp_op(*new_inputs, **kwargs)
return _as_mx_np_array(out, device=inputs[0].device)
# ops with np mx_np
elif name in ufunc_list and isinstance(inputs[0], _np.ndarray):
# inplace
if 'out' in kwargs:
new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs]
return onp_op(*new_inputs, **kwargs)
else:
new_inputs = [_as_mx_np_array(arg, device=inputs[1].device)
if isinstance(arg, _np.ndarray) else arg for arg in inputs]
return mx_ufunc(*new_inputs, **kwargs)
else:
return mx_ufunc(*inputs, **kwargs)
else:
return NotImplemented
@staticmethod
def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-staticmethod-argument
"""
Dispatch official NumPy operators that comply with the array function protocol to
this function.
"""
mx_np_func = _NUMPY_ARRAY_FUNCTION_DICT.get(func, None)
func_name = func.__name__
if mx_np_func is None:
# try to fallback to official NumPy op
if is_recording():
raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
"Please consider moving the operator to the outside of the autograd scope.")\
.format(func)
cur_device = None
new_args, cur_device = _as_onp_array(args, cur_device)
new_kwargs, cur_device = _as_onp_array(kwargs, cur_device)
if cur_device is None:
raise ValueError('Unknown device for the input ndarrays. It is probably a bug. Please'
' create an issue on GitHub.')
if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
"which is actually using official numpy's implementation.", func_name)
_FALLBACK_ARRAY_FUNCTION_WARNED_RECORD[func] = True
out = func(*new_args, **new_kwargs)
return _as_mx_np_array(out, device=cur_device)
else:
if py_all(issubclass(t, ndarray) for t in types):
return mx_np_func(*args, **kwargs)
else:
try:
cur_device = next(a.device for a in args if hasattr(a, 'device'))
except StopIteration:
cur_device = next(a.device for a in kwargs.values() if hasattr(a, 'device'))
new_args = _as_mx_np_array(args, device=cur_device,
zero_copy=func_name in {'may_share_memory', 'shares_memory'})
new_kwargs = {k: _as_mx_np_array(v, cur_device) for k, v in kwargs.items()}
return mx_np_func(*new_args, **new_kwargs)
def __array_namespace__(self, api_version=None):
"""
Returns an object that has all the array API functions on it.
Notes
-----
This is a standard API in
https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none.
Parameters
----------
self : ndarray
The indexing key.
api_version : Optional, string
string representing the version of the array API specification to be returned, in `YYYY.MM` form.
If it is None, it should return the namespace corresponding to latest version of the array API
specification.
"""
if api_version is not None:
try:
date = datetime.datetime.strptime(api_version, '%Y.%m')
if date.year != 2021:
raise ValueError
except ValueError:
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return sys.modules[self.__module__]
def __dlpack__(self, stream=None):
"""Exports the array for consumption by from_dlpack() as a DLPack capsule.
Parameters
----------
stream : int, optional
A Python integer representing a pointer to a stream (CUDA or ROCm).
Stream is provided by the consumer to the producer to instruct the producer
to ensure that operations can safely be performed on the array. The pointer must
be positive integer or -1. If stream is -1, the value must be used by the consumer
to signal "producer must not perform any synchronization".
Returns
-------
capsule : PyCapsule
A DLPack capsule for the array, containing a DLPackManagedTensor.
"""
if stream is not None:
if type(stream) is not int:
raise TypeError('The input stream must be int or None')
if self.device.device_type != "gpu":
raise ValueError('Stream {} is not supported in current device {}'\
.format(stream, self.device.device_type))
if stream != -1:
check_call(_LIB.MXPushStreamDep(self.handle, ctypes.c_int64(stream)))
to_dlpack_write = ndarray_to_dlpack_for_write()
return to_dlpack_write(self)
def __dlpack_device__(self):
"""Returns device type and device ID in DLPack format"""
devtype_map = {'cpu': DLDeviceType.DLCPU,
'gpu': DLDeviceType.DLGPU,
'cpu_pinned': DLDeviceType.DLCPUPINNED}
if self.device.device_type not in devtype_map:
raise ValueError('Unkown device type {} for DLPack'.format(self.device.device_type))
return (devtype_map[self.device.device_type], self.device.device_id)
def _get_np_basic_indexing(self, key):
"""
This function indexes ``self`` with a tuple of `slice` objects only.
"""
key_nd = tuple(idx for idx in key if idx is not None)
if len(key_nd) < self.ndim:
raise RuntimeError(
'too few indices after normalization: expected `ndim` ({}) '
'but got {}. This is a bug, please report it!'
''.format(self.ndim, len(key_nd))
)
if len(key_nd) > self.ndim:
raise IndexError(
'too many indices ({}) for array with {} dimensions'
''.format(len(key_nd), self.ndim)
)
none_axes = [ax for ax in range(len(key)) if key[ax] is None] # pylint: disable=invalid-name
slc_key, int_axes = self._basic_indexing_key_int_to_slice(key_nd)
new_axes = self._new_axes_after_basic_indexing(none_axes, key)
# Check bounds for integer axes
for ax in int_axes: # pylint: disable=invalid-name
if not -self.shape[ax] <= key_nd[ax] < self.shape[ax]:
raise IndexError(
'index {} is out of bounds for axis {} with size {}'
''.format(key_nd[ax], ax, self.shape[ax]))
if self._basic_indexing_slice_is_contiguous(slc_key, self.shape):
# Create a shared-memory view by using low-level flat slicing
flat_begin, flat_end = self._basic_indexing_contiguous_flat_begin_end(
slc_key, self.shape
)
handle = NDArrayHandle()
flat_self = self.reshape_view(-1)
if _int64_enabled():
check_call(
_LIB.MXNDArraySlice64(
flat_self.handle,
ctypes.c_int64(flat_begin),
ctypes.c_int64(flat_end),
ctypes.byref(handle),
)
)
else:
check_call(
_LIB.MXNDArraySlice(
flat_self.handle,
ctypes.c_uint32(flat_begin),
ctypes.c_uint32(flat_end),
ctypes.byref(handle),
)
)
sliced_shape = self._basic_indexing_sliced_shape(slc_key, self.shape)
sliced = self.__class__(handle=handle, writable=self.writable)
if 0 in sliced_shape:
sliced = sliced.reshape(sliced_shape)
else:
sliced = sliced.reshape_view(sliced_shape)
else:
begin, end, step = self._basic_indexing_key_to_begin_end_step(
slc_key, self.shape, keep_none=True
)
sliced = _npi.slice(self, begin, end, step)
# Reshape to final shape due to integer and `None` entries in `key`.
final_shape = [sliced.shape[i] for i in range(sliced.ndim) if i not in int_axes]
for ax in new_axes: # pylint: disable=invalid-name
final_shape.insert(ax, 1)
if sliced.size == 0:
return sliced.reshape(tuple(final_shape))
else:
return sliced.reshape_view(tuple(final_shape))
def _get_np_empty_tuple_indexing(self, key):
new_shape = []
num_none = 0
for i, idx in enumerate(key):
if idx is None:
new_shape.append(1) # expand dimension
num_none += 1
elif idx == ():
new_shape.append(0) # 0 shape
elif idx == slice(None, None, None):
new_shape.append(self.shape[i - num_none])
return empty(new_shape, dtype=self.dtype)
def _get_np_advanced_indexing(self, key):
idcs, new_axes = self._get_index_nd(key)
if type(idcs) == NDArray: # pylint: disable=unidiomatic-typecheck
idcs = idcs.as_np_ndarray()
else:
idcs = _mx_nd_np.stack([i if isinstance(i, self.__class__) else i.as_np_ndarray() for i in idcs])
sliced = _npi.gather_nd(self, idcs)
# Reshape due to `None` entries in `key`.
if new_axes:
final_shape = [sliced.shape[i] for i in range(sliced.ndim)]
for ax in new_axes: # pylint: disable=invalid-name
final_shape.insert(ax, 1)
return sliced.reshape(tuple(final_shape))
else:
return sliced
def _set_np_advanced_indexing(self, key, value):
"""This function is called by __setitem__ when key is an advanced index."""
idcs, new_axes = self._get_index_nd(key)
if type(idcs) == NDArray: # pylint: disable=unidiomatic-typecheck
idcs = idcs.as_np_ndarray()
else:
idcs = _mx_nd_np.stack([i if isinstance(i, self.__class__) else i.as_np_ndarray() for i in idcs])
vshape = get_oshape_of_gather_nd_op(self.shape, idcs.shape)
value_nd = self._prepare_value_nd(value, bcast_shape=vshape, squeeze_axes=new_axes)
self._scatter_set_nd(value_nd, idcs)
# pylint: disable=redefined-outer-name
def _get_np_boolean_indexing(self, key, ndim, shape):
"""
There are two types of boolean indices (which are equivalent,
for the most part though). This function will handle single
boolean indexing for higher speed.
If this is not the case, it is instead expanded into (multiple)
integer array indices and will be handled by advanced indexing.
"""
key_shape = key.shape
key_ndim = len(key_shape)
if ndim < key_ndim:
raise IndexError('too many indices, whose ndim = {}, for array with ndim = {}'
.format(key_ndim, ndim))
for i in range(key_ndim):
if key_shape[i] != shape[i]:
raise IndexError('boolean index did not match indexed array along dimension {};'
' dimension is {} but corresponding boolean dimension is {}'
.format(i, shape[i], key_shape[i]))
remaining_dims = shape[key_ndim:]
data = _reshape_view(self, -1, *remaining_dims)
key = _reshape_view(key, -1)
if data.size == 0 and key.size == 0:
return data
return _reshape_view(_npi.boolean_mask(data, key), -1, *remaining_dims)
def _set_np_boolean_indexing(self, key, value):
"""
There are two types of boolean indices (which are equivalent,
for the most part though). This function will handle single boolean assign for higher speed.
If this is not the case, it is instead expanded into (multiple)
integer array indices and will be handled by advanced assign.
"""
if isinstance(value, numeric_types):
_npi.boolean_mask_assign_scalar(data=self, mask=key,
value=int(value) if isinstance(value, bool) else value,
start_axis=0, out=self)
elif isinstance(value, ndarray):
_npi.boolean_mask_assign_tensor(data=self, mask=key, value=value, start_axis=0, out=self)
else:
raise NotImplementedError(f'type {type(value)} is not supported.')
# pylint: disable=too-many-return-statements
def __getitem__(self, key):
"""Return self[key].
Returns a sliced view of this array if the elements fetched are contiguous in memory;
otherwise, returns a newly created NDArray.
This functions supports advanced indexing defined in the following reference with
some restrictions. Boolean indexing is supported only for a single boolean ndarray
as a key. Mixing boolean ndarray with other index types is not supported in ``advanced``
indexing.
For basic indexing, i.e., if ``key`` consists only of integers,
``slice``, ``Ellipsis`` (``...``) and ``None``, a mutable view is
returned that shares memory with this array if the accessed portion is
contiguous in memory.
Otherwise, a newly created ``ndarray`` is returned.
This functions supports advanced indexing as defined in `the NumPy
advanced indexing documentation
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_.
Parameters
----------
key : int, slice, list, np.ndarray, mx.np.ndarray, or tuple of all previous types
Indexing key.
Examples
--------
The default is to give explicit indices for all axes:
>>> x = np.arange(6).reshape(2, 3)
>>> x
array([[0., 1., 2.],
[3., 4., 5.]])
>>> x[0, :2]
array([0., 1.])
>>> x[:, :-1]
array([[0., 1.],
[3., 4.]])
If fewer indices are given, they are automatically supplemented by an
appropriate number of ``slice(None)`` ("``:``") to the right. For
instance, a single integer indexes along the first axis:
>>> x[0]
array([0., 1., 2.])
>>> x[1:]
array([[3., 4., 5.]])
To omit a range of axes that should be kept as-is, an `Ellipsis`
("``...``") can be used:
>>> x = np.arange(16).reshape(2, 2, 2, 2)
>>> x[0, ..., 1]
array([[1., 3.],
[5., 7.]])
>>> x[0, :, :, 1] # equivalent
array([[1., 3.],
[5., 7.]])
New axes of length 1 can be created by inserting ``None``
(`numpy.newaxis`) in the index:
>>> x = np.arange(6).reshape(2, 3)
>>> x[None, :, :]
array([[[0., 1., 2.],
[3., 4., 5.]]])
>>> x[None, :, :].shape
(1, 2, 3)
If the indexed portion of the array is contiguous in memory, no data
is copied. Instead, a shared-memory view of the original array is
returned, and changes to that view affect the original array:
>>> x = np.arange(8).reshape(2, 2, 2)
>>> y = x[0] # contiguous
>>> y
array([[0., 1.],
[2., 3.]])
>>> y[:] = -1
>>> x
array([[[-1., -1.],
[-1., -1.]],
[[ 4., 5.],
[ 6., 7.]]])
>>> x = np.arange(8).reshape(2, 2, 2)
>>> y = x[1, :1, :] # contiguous
>>> y
array([[4., 5.]])
>>> y[:] = -1
>>> x
array([[[ 0., 1.],
[ 2., 3.]],
[[-1., -1.],
[ 6., 7.]]])
>>> x = np.arange(0, 8).reshape(2, 2, 2)
>>> y = x[:, :, 1] # not contiguous
>>> y
array([[1., 3.],
[5., 7.]])
>>> y[:] = -1
>>> x
array([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
If the indexing key contains `list`, `numpy.ndarray` or `NDArray`
objects, advanced indexing is triggered, which always returns a
copy:
>>> x = np.arange(8).reshape(2, 2, 2)
>>> x[[0, 1]]
array([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
>>> x[[0, 1], :] # equivalent
array([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
>>> y = np.array([0, 1], dtype='int32')
>>> x[1:, y]
array([[[4., 5.],
[6., 7.]]])
>>> y = np.array([0, 1], dtype='int32')
>>> x[1:, y]
array([[[4., 5.],
[6., 7.]]])
Get negative elements in an ndarray through boolean array indexing
>>> x = np.array([1., -1., -2., 3])
>>> x[x < 0]
array([-1., -2.])
For more imformation related to boolean indexing, please refer to
https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html.
"""
ndim = self.ndim # pylint: disable=redefined-outer-name
shape = self.shape # pylint: disable=redefined-outer-name
if isinstance(key, bool): # otherwise will be treated as 0 and 1
key = array(key, dtype=_np.bool, device=self.device)
if isinstance(key, list):
try:
new_key = _np.array(key)
if new_key.dtype == _np.bool_:
key = new_key
except Exception as err:
raise TypeError('{}'.format(str(err)))
if isinstance(key, _np.ndarray):
if dc.is_deferred_compute():
raise TypeError('Indexing with a numpy array is not supported in HybridBlock.')
if key.dtype == _np.bool_:
key = array(key, dtype='bool', device=self.device)
# Handle single boolean index of matching dimensionality and size first for higher speed
# If the boolean array is mixed with other idices, it is instead expanded into (multiple)
# integer array indices and will be handled by advanced indexing.
# Come before the check self.dim == 0 as it also handle the 0-dim case.
if isinstance(key, ndarray) and key.dtype == _np.bool_:
return self._get_np_boolean_indexing(key, ndim, shape)
all = __builtins__['all'] # `def all` below shadows the all builtin
if ndim == 0 and key != ():
raise IndexError('scalar tensor can only accept `()` as index')
# Handle simple cases for higher speed
if isinstance(key, tuple) and len(key) == 0:
return self
if isinstance(key, tuple) and len(key) == ndim\
and py_all(isinstance(idx, integer_types) for idx in key):
out = self
for idx in key:
out = out[idx]
return out
if isinstance(key, integer_types):
# Equivalent to isinstance(key, integer_types) case in numpy/_symbol.py
if key > shape[0] - 1:
raise IndexError(
'index {} is out of bounds for axis 0 with size {}'.format(
key, shape[0]))
return self._at(key)
elif isinstance(key, py_slice):
# Unlike numpy/_symbol.py, calls MXNDArraySlice64 writable memory
# sharing if key.step not in [None, 1]. Equivalent otherwise to
# isinstance(key, py_slice) case in _symbol.py otherwise.
if key.step is None or key.step == 1:
if key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
return self
elif key.step != 0:
start = [None] if key.start is None else key.start
stop = [None] if key.stop is None else key.stop
return _npi.slice(self, start, stop, key.step)
else:
raise ValueError("slice step cannot be zero")
elif isinstance(key, tuple) and \
all((isinstance(arr, NDArray) and _np.issubdtype(arr.dtype, _np.integer) and \
arr.ndim > 0) for arr in key):
# Equivalent case in numpy/_symbol.py
return _npi.advanced_indexing_multiple(self, _mx_nd_np.stack(key))
elif isinstance(key, tuple) and dc.is_deferred_compute():
# Equivalent to isinstance(key, tuple) case in numpy/_symbol.py
# Only enabled in deferred compute mode, as this codepath prevents
# memory sharing which may be desired in non-deferred compute
# imperative mode.
begin = []
end = []
step = []
new_shape = ()
assert len(key) # len(key) == 0 is handled a above
unsupported = False
for index in key:
if isinstance(index, py_slice):
if index.step is not None and index.step == 0:
raise ValueError("slice step cannot be zero")
begin.append(index.start)
end.append(index.stop)
step.append(index.step)
new_shape += (-2,)
elif isinstance(index, integer_types):
if index >= 0:
begin.append(index)
end.append(index+1)
step.append(1)
else:
begin.append(index)
end.append(index - 1)
step.append(-1)
new_shape += (-3,)
else:
unsupported = True
break
if not unsupported:
new_shape += (-4,)
sliced = _npi.slice(self, begin, end, step)
return _mx_nd_np.reshape(sliced, new_shape)
# Special handling for cases only supported in imperative mode
if dc.is_deferred_compute():
raise TypeError('The type of indexing used is not supported in HybridBlock.')
# For 0-d boolean indices: A new axis is added,
# but at the same time no axis is "used". So if we have True,
# we add a new axis (a bit like with np.newaxis). If it is
# False, we add a new axis, but this axis has 0 entries.
# prepend is defined to handle this case.
# prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d boolean scalar
# prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim must be expanded
# prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must be prepended
key, prepend = indexing_key_expand_implicit_axes(key, self.shape)
indexing_dispatch_code = get_indexing_dispatch_code(key)
if indexing_dispatch_code == _NDARRAY_EMPTY_TUPLE_INDEXING:
# won't be affected by zero-dim boolean indices
return self._get_np_empty_tuple_indexing(key)
elif indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
return empty((0,) + self._get_np_basic_indexing(key).shape,
dtype=self.dtype, device=self.device)
if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE:
key = (_np.newaxis,) + key
return self._get_np_basic_indexing(key)
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
return empty((0,) + self._get_np_adanced_indexing(key).shape,
dtype=self.dtype, device=self.device)
if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE:
key = (_np.newaxis,) + key
return self._get_np_advanced_indexing(key)
else:
raise RuntimeError
# pylint: disable=inconsistent-return-statements
def __setitem__(self, key, value):
"""Sets ``self[key]`` to ``value``.
This functions supports advanced indexing as defined in `the NumPy
advanced indexing documentation
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_,
with the restriction that boolean array indexing is not supported.
Parameters
----------
key : int, slice, list, np.ndarray, mx.np.ndarray, or tuple of all previous types
The indexing key.
value : scalar or array-like object that can be broadcast to the shape of self[key]
The value to set.
Examples
--------
>>> x = np.zeros((2, 3))
>>> x[:] = 1
>>> x
array([[ 1., 1., 1.],
[ 1., 1., 1.]])
>>> x[:, 1:2] = 2
>>> x
array([[ 1., 2., 1.],
[ 1., 2., 1.]])
>>> x[1:2, 1:] = 3
>>> x
array([[ 1., 2., 1.],
[ 1., 3., 3.]])
>>> x[1:, 0:2] = np.zeros((1, 2))
>>> x
array([[ 1., 2., 1.],
[ 0., 0., 3.]])
>>> x[1, 2] = 4
>>> x
array([[ 1., 2., 1.],
[ 0., 0., 4.]])
>>> x[[0], [1, 2]] = 5
>>> x
array([[ 1., 5., 5.],
[ 0., 0., 4.]])
>>> x[::-1, 0:2:2] = [6]
>>> x
array([[ 6., 5., 5.],
[ 6., 0., 4.]])
For imformation related to boolean indexing, please refer to
https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html.
"""
if isinstance(value, NDArray) and not isinstance(value, ndarray):
raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray')
if isinstance(key, bool): # otherwise will be treated as 0 and 1
key = array(key, dtype=_np.bool)
# Handle single boolean assign of matching dimensionality and size first for higher speed
# If the boolean array is mixed with other idices, it is instead expanded into (multiple)
# integer array indices and will be handled by advanced assign.
# Come before the check self.dim == 0 as it also handle the 0-dim case.
if isinstance(key, ndarray) and key.dtype == _np.bool:
return self._set_np_boolean_indexing(key, value)
# handle basic and advanced indexing
if self.ndim == 0:
if not isinstance(key, tuple) or len(key) != 0:
raise IndexError('scalar tensor can only accept `()` as index')
if isinstance(value, numeric_types):
self._full(value)
elif isinstance(value, ndarray) and value.size == 1:
if value.shape != self.shape:
value = value.reshape(self.shape)
value.copyto(self)
elif isinstance(value, (_np.ndarray, _np.generic)) and value.size == 1:
if isinstance(value, _np.generic) or value.shape != self.shape:
value = value.reshape(self.shape)
self._sync_copyfrom(value)
else:
raise ValueError('setting an array element with a sequence.')
else:
# For 0-d boolean indices: A new axis is added,
# but at the same time no axis is "used". So if we have True,
# we add a new axis (a bit like with np.newaxis). If it is
# False, we add a new axis, but this axis has 0 entries.
# prepend is defined to handle this case.
# prepend == _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d boolean scalar
# prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim must be expanded
# prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must be expanded
# prepend actually has no influence on __setitem__
key, prepend = indexing_key_expand_implicit_axes(key, self.shape)
if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
return # no action is needed
slc_key = tuple(idx for idx in key if idx is not None)
if len(slc_key) < self.ndim:
raise RuntimeError(
'too few indices after normalization: expected `ndim` ({}) '
'but got {}. This is a bug, please report it!'
''.format(self.ndim, len(slc_key))
)
if len(slc_key) > self.ndim and self.ndim != 0:
raise IndexError(
'too many indices ({}) for array with {} dimensions'
''.format(len(slc_key), self.ndim)
)
indexing_dispatch_code = get_indexing_dispatch_code(slc_key)
if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
self._set_nd_basic_indexing(key, value) # function is inheritated from NDArray class
elif indexing_dispatch_code == _NDARRAY_EMPTY_TUPLE_INDEXING:
pass # no action needed
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
self._set_np_advanced_indexing(key, value)
else:
raise ValueError(
'Indexing NDArray with index {} of type {} is not supported'
''.format(key, type(key))
)
def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
"""Return a broadcast `ndarray` with same device and dtype as ``self``.
For setting item, The returned `ndarray` is squeezed according to squeeze_axes since the
value_nd is assigned to not yet expanded space in original array.
`value`: numeric types or array like.
`bcast_shape`: a shape tuple.
`squeeze_axes`: a sequence of axes to squeeze in the value array.
Note: mxnet.numpy.ndarray not support NDArray as assigned value.
"""
if isinstance(value, numeric_types):
value_nd = full(bcast_shape, value, device=self.device, dtype=self.dtype)
elif isinstance(value, self.__class__):
value_nd = value.to_device(self.device)
if value_nd.dtype != self.dtype:
value_nd = value_nd.astype(self.dtype)
else:
try:
value_nd = array(value, device=self.device, dtype=self.dtype)
except:
raise TypeError('mxnet.np.ndarray does not support assignment with non-array-like '
'object {} of type {}'.format(value, type(value)))
# For advanced indexing setitem, if there is None in indices, we need to squeeze the
# assigned value_nd since None is also ignored in slicing the original array.
if squeeze_axes and value_nd.ndim > len(bcast_shape):
squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)])
value_nd = value_nd.squeeze(axis=tuple(squeeze_axes))
# handle the cases like the following
# a = np.zeros((3, 3)), b = np.ones((1, 1, 1, 1, 3)), a[0] = b
# b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed
if value_nd.ndim > len(bcast_shape):
squeeze_axes = []
for i in range(value_nd.ndim - len(bcast_shape)):
if value_nd.shape[i] == 1:
squeeze_axes.append(i)
else:
break
if squeeze_axes:
value_nd = value_nd.squeeze(squeeze_axes)
if value_nd.shape != bcast_shape:
if value_nd.size == 0:
value_nd = value_nd.reshape(bcast_shape)
else:
value_nd = value_nd.broadcast_to(bcast_shape)
return value_nd
@wrap_mxnp_np_ufunc
def __add__(self, other):
"""x.__add__(y) <=> x + y"""
return add(self, other)
@wrap_mxnp_np_ufunc
def __iadd__(self, other):
"""x.__iadd__(y) <=> x += y"""
if not self.writable:
raise ValueError('trying to add to a readonly ndarray')
return add(self, other, out=self)
@wrap_mxnp_np_ufunc
def __radd__(self, other):
"""x.__radd__(y) <=> y + x"""
return add(other, self)
def __invert__(self):
"""x.__invert__() <=> ~x"""
return invert(self)
@wrap_mxnp_np_ufunc
def __and__(self, other):
"""x.__and__(y) <=> x & y"""
return bitwise_and(self, other)
@wrap_mxnp_np_ufunc
def __rand__(self, other):
"""x.__rand__(y) <=> y & x"""
return bitwise_and(other, self)
@wrap_mxnp_np_ufunc
def __or__(self, other):
"""x.__or__(y) <=> x | y"""
return bitwise_or(self, other)
@wrap_mxnp_np_ufunc
def __ror__(self, other):
"""x.__ror__(y) <=> y | x"""
return bitwise_or(other, self)
@wrap_mxnp_np_ufunc
def __xor__(self, other):
"""x.__xor__(y) <=> x ^ y"""
return bitwise_xor(self, other)
@wrap_mxnp_np_ufunc
def __rxor__(self, other):
"""x.__rxor__(y) <=> y ^ x"""
return bitwise_xor(other, self)
@wrap_mxnp_np_ufunc
def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
return bitwise_left_shift(self, other)
@wrap_mxnp_np_ufunc
def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
return bitwise_right_shift(self, other)
@wrap_mxnp_np_ufunc
def __iand__(self, other):
"""x.__iand__(y) <=> x &= y"""
return bitwise_and(self, other, out=self)
@wrap_mxnp_np_ufunc
def __ior__(self, other):
r"""x.__ior__(y) <=> x \|= y"""
return bitwise_or(self, other, out=self)
@wrap_mxnp_np_ufunc
def __ixor__(self, other):
"""x.__ixor__(y) <=> x ^= y"""
return bitwise_xor(self, other, out=self)
@wrap_mxnp_np_ufunc
def __ilshift__(self, other):
"""x.__ilshift__(y) <=> x <<= y"""
return bitwise_left_shift(self, other, out=self)
@wrap_mxnp_np_ufunc
def __irshift__(self, other):
"""x.__irshift__(y) <=> x >>= y"""
return bitwise_right_shift(self, other, out=self)
@wrap_mxnp_np_ufunc
def __rlshift__(self, other):
"""x.__rlshift__(y) <=> y << x"""
return bitwise_left_shift(other, self)
@wrap_mxnp_np_ufunc
def __rrshift__(self, other):
"""x.__rrshift__(y) <=> y >> x"""
return bitwise_right_shift(other, self)
def __round__(self, n=0):
"""x.__round__(n)"""
return round(self, decimals=n)
def __abs__(self):
"""x.__abs__()"""
return absolute(self)
def __ceil__(self):
"""x.__ceil__()"""
return ceil(self)
def __floor__(self):
"""x.__floor__()"""
return floor(self)
def __trunc__(self):
"""x.__trunc__()"""
return trunc(self)
@wrap_mxnp_np_ufunc
def __sub__(self, other):
"""x.__sub__(y) <=> x - y"""
return subtract(self, other)
@wrap_mxnp_np_ufunc
def __isub__(self, other):
"""x.__isub__(y) <=> x -= y"""
if not self.writable:
raise ValueError('trying to subtract from a readonly ndarray')
return subtract(self, other, out=self)
@wrap_mxnp_np_ufunc
def __rsub__(self, other):
"""x.__rsub__(y) <=> y - x"""
return subtract(other, self)
@wrap_mxnp_np_ufunc
def __mul__(self, other):
"""x.__mul__(y) <=> x * y"""
return multiply(self, other)
@wrap_mxnp_np_ufunc
def __floordiv__(self, other):
"""x.__floordiv__(y) <=> x // y"""
return floor_divide(self, other)
@wrap_mxnp_np_ufunc
def __ifloordiv__(self, other):
"""x.__ifloordiv__(y) <=> x //= y"""
if not self.writable:
raise ValueError('trying to divide from a readonly ndarray')
return floor_divide(self, other, out=self)
@wrap_mxnp_np_ufunc
def __rfloordiv__(self, other):
"""x.__rfloordiv__(y) <=> y // x"""
return floor_divide(other, self)
def __neg__(self):
"""x.__neg__() <=> -x"""
return negative(self)
def __pos__(self):
"""x.__pos__() <=> +x"""
return positive(self)
@wrap_mxnp_np_ufunc
def __imul__(self, other):
r"""x.__imul__(y) <=> x \*= y"""
if not self.writable:
raise ValueError('trying to add to a readonly ndarray')
return multiply(self, other, out=self)
@wrap_mxnp_np_ufunc
def __rmul__(self, other):
"""x.__rmul__(y) <=> y * x"""
return self.__mul__(other)
@wrap_mxnp_np_ufunc
def __div__(self, other):
"""x.__div__(y) <=> x / y"""
return divide(self, other)
@wrap_mxnp_np_ufunc
def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y / x"""
return divide(other, self)
@wrap_mxnp_np_ufunc
def __idiv__(self, other):
"""x.__idiv__(y) <=> x /= y"""
return divide(self, other, out=self)
@wrap_mxnp_np_ufunc
def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
return divide(self, other)
@wrap_mxnp_np_ufunc
def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
return divide(other, self)
@wrap_mxnp_np_ufunc
def __itruediv__(self, other):
"""x.__itruediv__(y) <=> x /= y"""
return divide(self, other, out=self)
@wrap_mxnp_np_ufunc
def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
return mod(self, other)
@wrap_mxnp_np_ufunc
def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
return mod(other, self)
@wrap_mxnp_np_ufunc
def __imod__(self, other):
"""x.__imod__(y) <=> x %= y"""
return mod(self, other, out=self)
@wrap_mxnp_np_ufunc
def __pow__(self, other):
"""x.__pow__(y) <=> x ** y"""
return power(self, other)
@wrap_mxnp_np_ufunc
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
return power(other, self)
@wrap_mxnp_np_ufunc
def __ipow__(self, other):
"""x.__ipow__(y) <=> x **= y"""
return power(self, other, out=self)
@wrap_mxnp_np_ufunc
def __eq__(self, other):
"""x.__eq__(y) <=> x == y"""
return equal(self, other)
def __hash__(self):
raise NotImplementedError
@wrap_mxnp_np_ufunc
def __ne__(self, other):
"""x.__ne__(y) <=> x != y"""
return not_equal(self, other)
@wrap_mxnp_np_ufunc
def __gt__(self, other):
"""x.__gt__(y) <=> x > y"""
return greater(self, other)
@wrap_mxnp_np_ufunc
def __ge__(self, other):
"""x.__ge__(y) <=> x >= y"""
return greater_equal(self, other)
@wrap_mxnp_np_ufunc
def __lt__(self, other):
"""x.__lt__(y) <=> x < y"""
return less(self, other)
@wrap_mxnp_np_ufunc
def __le__(self, other):
"""x.__le__(y) <=> x <= y"""
return less_equal(self, other)
@wrap_mxnp_np_ufunc
def __matmul__(self, other):
"""x.__matmul__(y) <=> x @ y"""
return matmul(self, other)
@wrap_mxnp_np_ufunc
def __rmatmul__(self, other):
"""x.__rmatmul__(y) <=> y @ x"""
return matmul(other, self)
@wrap_mxnp_np_ufunc
def __imatmul__(self, other):
"""x.__imatmul__(y) <=> x @= y"""
return matmul(self, other, out=self)
def __bool__(self):
num_elements = self.size
if num_elements == 0:
warnings.simplefilter('default')
warnings.warn('The truth value of an empty array is ambiguous. Returning False, but in'
' future this will result in an error.', DeprecationWarning)
return False
elif num_elements == 1:
return bool(self.item())
else:
raise ValueError("The truth value of an ndarray with multiple elements is ambiguous.")
__nonzero__ = __bool__
def __index__(self):
if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer):
return self.item()
raise TypeError('only integer scalar arrays can be converted to a scalar index')
def __float__(self):
num_elements = self.size
if num_elements != 1:
raise TypeError('only size-1 arrays can be converted to Python scalars')
return float(self.item())
def __int__(self):
num_elements = self.size
if num_elements != 1:
raise TypeError('only size-1 arrays can be converted to Python scalars')
return int(self.item())
def __len__(self):
"""Number of elements along the first axis."""
shape = self.shape # pylint: disable=redefined-outer-name
if len(shape) == 0:
raise TypeError('len() of unsized object')
return self.shape[0]
def __reduce__(self):
return ndarray, (None,), self.__getstate__()
def item(self, *args):
"""Copy an element of an array to a standard Python scalar and return it.
Parameters
----------
*args : Arguments (variable number and type)
none: in this case, the method only works for arrays with one element (a.size == 1),
which element is copied into a standard Python scalar object and returned.
int_type: this argument is interpreted as a flat index into the array, specifying which
element to copy and return.
tuple of int_types: functions as does a single int_type argument, except that the
argument is interpreted as an nd-index into the array.
Returns
-------
z : Standard Python scalar object
A copy of the specified element of the array as a suitable Python scalar.
"""
# TODO(junwu): no need to call asnumpy() on the whole array.
return self.asnumpy().item(*args)
def nonzero(self):
"""Return the indices of the elements that are non-zero.
Refer to `numpy.nonzero` for full documentation.
See Also
--------
numpy.nonzero : equivalent function
"""
return nonzero(self)
@property
# pylint: disable= invalid-name, undefined-variable
def T(self):
"""Same as self.transpose(). This always returns a copy of self."""
if self.ndim != 2:
warnings.warn('x.T requires x to have 2 dimensions. '
'Use x.mT to transpose stacks of matrices and '
'permute_dims() to permute dimensions.')
return self.transpose()
# pylint: enable= invalid-name, undefined-variable
@property
# pylint: disable= invalid-name, undefined-variable
def mT(self):
"""Same as self.transpose(). This always returns a copy of self."""
if self.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return _mx_nd_np.swapaxes(self, -1, -2)
# pylint: enable= invalid-name, undefined-variable
def all(self, axis=None, out=None, keepdims=False):
return _mx_nd_np.all(self, axis=axis, out=out, keepdims=keepdims)
def any(self, axis=None, out=None, keepdims=False):
return _mx_nd_np.any(self, axis=axis, out=out, keepdims=keepdims)
def as_nd_ndarray(self):
"""Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods."""
hdl = NDArrayHandle()
check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
return NDArray(handle=hdl, writable=self.writable)
def as_np_ndarray(self):
"""A convenience function for creating a numpy ndarray from the current ndarray
with zero copy. For this class, it just returns itself since it's already a
numpy ndarray."""
return self
def __repr__(self):
"""
Returns a string representation of the array.
The dtype of the ndarray will be appended if it's inconsistent with current dtype.
The device of the ndarray will be appended for devices other than CPU.
Examples
--------
>>> from mxnet import np, npx
>>> a = np.random.uniform(size=(2, 3))
>>> a
array([[0.5488135 , 0.5928446 , 0.71518934],
[0.84426576, 0.60276335, 0.8579456 ]])
>>> print(a)
[[0.5488135 0.5928446 0.71518934]
[0.84426576 0.60276335 0.8579456 ]]
>>> a.dtype
dtype('float32')
>>> npx.set_np_float64()
>>> a
array([[0.5488135 , 0.5928446 , 0.71518934],
[0.84426576, 0.60276335, 0.8579456 ]], dtype=float32)
>>> npx.set_np_float64(default_float64=False)
>>> a
array([[0.5488135 , 0.5928446 , 0.71518934],
[0.84426576, 0.60276335, 0.8579456 ]])
>>> b = a.astype(np.float64)
>>> b
array([[0.54881352, 0.59284461, 0.71518934],
[0.84426576, 0.60276335, 0.85794562]], dtype=float64)
>>> print(b)
[[0.54881352 0.59284461 0.71518934]
[0.84426576 0.60276335 0.85794562]]
>>> b.dtype
dtype('float64')
>>> c = a.copyto(npx.gpu(0))
>>> c
array([[0.5488135 , 0.5928446 , 0.71518934],
[0.84426576, 0.60276335, 0.8579456 ]], device=gpu(0))
>>> print(c)
[[0.5488135 0.5928446 0.71518934]
[0.84426576 0.60276335 0.8579456 ]] @gpu(0)
>>> d = b.copyto(npx.gpu(0))
>>> d
array([[0.54881352, 0.59284461, 0.71518934],
[0.84426576, 0.60276335, 0.85794562]], dtype=float64, device=gpu(0))
>>> print(d)
[[0.54881352 0.59284461 0.71518934]
[0.84426576 0.60276335 0.85794562]] @gpu(0)
"""
if self._alive:
array_str = self.asnumpy().__repr__()
dtype = self.dtype
default_dtype = _np.float64 if is_np_default_dtype() else _np.float32
if 'dtype=' in array_str:
if dtype == default_dtype:
array_str = array_str[:array_str.rindex(',')] + ')'
elif dtype not in (default_dtype, _np.bool_):
array_str = array_str[:-1] + ', dtype={})'.format(dtype)
device = self.device
if device.device_type == 'cpu':
return array_str
return array_str[:-1] + ', device={})'.format(str(device))
else:
return '<FREED {}>'.format(self.__class__.__name__)
def __str__(self):
"""Returns a string representation of the array."""
array_str = self.asnumpy().__str__()
device = self.device
if device.device_type == 'cpu' or self.ndim == 0:
return array_str
return '{array} @{device}'.format(array=array_str, device=device)
def __format__(self, fmt):
"""Return value.__format__(format_spec). Overwrite to include 0-d array"""
if self.ndim == 0:
return self.item().__format__(fmt)
elif len(fmt) == 0:
return self.__str__().__format__(fmt)
else:
raise TypeError("Cannot format mxnet.numpy.ndarray with format_spec")
def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ
"""Attach a gradient buffer to this ndarray, so that `backward`
can compute gradient with respect to it.
Parameters
----------
grad_req : {'write', 'add', 'null'}
How gradient will be accumulated.
* 'write': gradient will be overwritten on every backward.
* 'add': gradient will be added to existing value on every backward.
* 'null': do not compute gradient for this NDArray.
"""
grad = _mx_nd_np.zeros_like(self) # pylint: disable=undefined-variable
grad_req = _GRAD_REQ_MAP[grad_req]
check_call(_LIB.MXAutogradMarkVariables(
1, ctypes.pointer(self.handle),
ctypes.pointer(mx_uint(grad_req)),
ctypes.pointer(grad.handle)))
def drop_grad(self):
"""Free the memory of the marked ndarray."""
check_call(_LIB.MXAutogradDropGrads(
1, ctypes.pointer(self.handle)))
@property
def grad(self):
"""Returns gradient buffer attached to this ndarray."""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayGetGrad(self.handle, ctypes.byref(hdl)))
if hdl.value is None:
return None
return _np_ndarray_cls(hdl)
def detach(self):
"""Returns a new ndarray, detached from the current graph."""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl)))
return _np_ndarray_cls(hdl)
def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): # pylint: disable=arguments-differ,unused-argument, too-many-arguments
"""
Copy of the array, cast to a specified type.
Parameters
----------
dtype : str or dtype
Typecode or data-type to which the array is cast.
order : {'C', 'F', 'A', 'K'}, optional
Controls the memory layout order of the result.
'C' means C order, 'F' means Fortran order, 'A'
means 'F' order if all the arrays are Fortran contiguous,
'C' order otherwise, and 'K' means as close to the
order the array elements appear in memory as possible.
Default is 'K'.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur. Defaults to 'unsafe'
for backwards compatibility.
* 'no' means the data types should not be cast at all.
* 'equiv' means only byte-order changes are allowed.
* 'safe' means only casts which can preserve values are allowed.
* 'same_kind' means only safe casts or casts within a kind,
like float64 to float32, are allowed.
* 'unsafe' means any data conversions may be done.
subok : bool, optional
If True, then sub-classes will be passed-through (default), otherwise
the returned array will be forced to be a base-class array.
copy : bool, optional
Default `True`. By default, astype always returns a newly
allocated ndarray on the same device. If this is set to
`False`, and the dtype requested is the same as the ndarray's
dtype, the ndarray is returned instead of a copy.
Returns
-------
arr_t : ndarray
Unless `copy` is False and the other conditions for returning the input
array are satisfied (see description for `copy` input parameter), `arr_t`
is a new array of the same shape as the input array with `dtype`.
Notes
-----
This function differs from the official `ndarray`'s ``astype`` function in the following
aspects:
* `order` only supports 'C' and 'K'.
* `casting` only supports 'unsafe'.
* `subok` only supports ``True``.
"""
if order is not None and order != 'K' and order != 'C':
raise ValueError('order must be either \'K\' or \'C\'')
if casting != 'unsafe':
raise ValueError('casting must be equal to \'unsafe\'')
if not subok:
raise ValueError('subok must be equal to True')
if dtype is None:
dtype = _np.float32
if not copy and _np.dtype(dtype) == self.dtype:
return self
return _npi.cast(self, dtype=dtype)
def copyto(self, other):
"""Copies the value of this array to another array.
If ``other`` is a ``ndarray`` object, then ``other.shape`` and
``self.shape`` should be the same. This function copies the value from
``self`` to ``other``.
If ``other`` is a device, a new ``np.ndarray`` will be first created on
the target device, and the value of ``self`` is copied.
Parameters
----------
other : ndarray or Device
The destination array or device.
Returns
-------
out: ndarray
The copied array. If ``other`` is an ``ndarray``, then the return value
and ``other`` will point to the same ``ndarray``.
Examples
--------
>>> x = np.ones((2, 3))
>>> y = np.zeros((2, 3), device=npx.gpu(0))
>>> z = x.copyto(y)
>>> z is y
True
>>> y
array([[ 1., 1., 1.],
[ 1., 1., 1.]])
"""
if isinstance(other, ndarray):
if other.handle is self.handle:
warnings.warn('You are attempting to copy an array to itself', RuntimeWarning)
return False
return _npi.copyto(self, out=other)
elif isinstance(other, Device):
hret = ndarray(_new_alloc_handle(self.shape, other, True, self.dtype))
return _npi.copyto(self, out=hret)
else:
raise TypeError('copyto does not support type ' + str(type(other)))
def asscalar(self):
raise AttributeError('mxnet.numpy.ndarray object has no attribute asscalar')
def argmax(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Return indices of the maximum values along the given axis.
Refer to `mxnet.numpy.argmax` for full documentation."""
return argmax(self, axis, out, keepdims)
def as_in_context(self, context):
"""This function has been deprecated. Please refer to ``ndarray.to_device``."""
warnings.warn('ndarray.as_in_context has been renamed to'
' ndarray.to_device', DeprecationWarning)
return self.as_nd_ndarray().as_in_context(context).as_np_ndarray()
def as_in_ctx(self, ctx):
"""This function has been deprecated. Please refer to ``ndarray.to_device``."""
warnings.warn('ndarray.to_device has been renamed to'
' ndarray.to_device', DeprecationWarning)
return self.to_device(ctx)
@property
def ctx(self):
"""This property has been deprecated. Please refer to ``ndarray.device``."""
warnings.warn('ndarray.ctx has been renamed to ndarray.device', DeprecationWarning)
return self.device
def to_device(self, device):
"""Returns an array on the target device with the same value as this array.
If the target device is the same as ``self.device``, then ``self`` is
returned. Otherwise, a copy is made.
Parameters
----------
device : Device
The target device.
Returns
-------
ndarray
The target array.
"""
if self.device == device:
return self
return self.copyto(device)
@property
def device(self):
"""Hardware device the array data resides on.
Examples
--------
>>> x = np.array([1, 2, 3, 4])
>>> x.device
cpu(0)
>>> type(x.device)
<class 'mxnet.device.Device'>
>>> y = np.zeros((2, 3), npx.gpu(0))
>>> y.device
gpu(0)
"""
dev_typeid = ctypes.c_int()
dev_id = ctypes.c_int()
check_call(_LIB.MXNDArrayGetContext(
self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id)))
return Device(Device.devtype2str[dev_typeid.value], dev_id.value)
@property
def context(self):
"""This function has been deprecated. Please refer to ``ndarray.ctx``."""
warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning)
return self.as_nd_ndarray().context
def copy(self, order='C'): # pylint: disable=arguments-differ
"""Return a coyp of the array, keeping the same device.
Parameters
----------
order : str
The memory layout of the copy. Currently, only c-contiguous memory
layout is supported.
Examples
--------
>>> x = np.ones((2, 3))
>>> y = x.copy()
>>> y
array([[ 1., 1., 1.],
[ 1., 1., 1.]])
"""
if order != 'C':
raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
'received {}'.format(str(order)))
return self.copyto(self.device)
def dot(self, b, out=None):
"""Dot product of two arrays.
Refer to ``numpy.dot`` for full documentation."""
return dot(self, b, out=out)
def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Returns a copy of the array with a new shape.
Notes
-----
Unlike the free function `numpy.reshape`, this method on `ndarray` allows
the elements of the shape parameter to be passed in as separate arguments.
For example, ``a.reshape(10, 11)`` is equivalent to
``a.reshape((10, 11))``.
"""
order = 'C'
if len(kwargs) > 1:
raise TypeError('function takes at most 1 keyword argument')
if len(kwargs) == 1:
if 'order' not in kwargs:
raise TypeError("'{}' is an invalid keyword argument for this function"
.format(list(kwargs.keys())[0]))
order = kwargs.pop('order', 'C')
if order != 'C':
raise NotImplementedError('only supports C-order,'
' while received {}'.format(order))
if len(args) == 0:
raise TypeError('reshape() takes exactly 1 argument (0 given)')
if len(args) == 1 and isinstance(args[0], tuple):
return _mx_nd_np.reshape(self, newshape=args[0], order=order)
else:
return _mx_nd_np.reshape(self, newshape=args, order=order)
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
The arguments are the same as for :py:func:`reshape_like`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute reshape_like')
def reshape_view(self, *shape, **kwargs): # pylint: disable=redefined-outer-name
"""Returns a **view** of this array with a new shape without altering any data.
Inheritated from NDArray.reshape.
"""
return super(ndarray, self).reshape(*shape, **kwargs)
def zeros_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`zeros_like`.
The arguments are the same as for :py:func:`zeros_like`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute zeros_like')
def ones_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`ones_like`.
The arguments are the same as for :py:func:`ones_like`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute ones_like')
def broadcast_axes(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`broadcast_axes`.
The arguments are the same as for :py:func:`broadcast_axes`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like')
def repeat(self, repeats, axis=None): # pylint: disable=arguments-differ
"""Repeat elements of an array."""
return repeat(self, repeats=repeats, axis=axis)
def pad(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`pad`.
The arguments are the same as for :py:func:`pad`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute pad')
def swapaxes(self, axis1, axis2): # pylint: disable=arguments-differ
"""Return a copy of the array with axis1 and axis2 interchanged.
Refer to `mxnet.numpy.swapaxes` for full documentation.
"""
return swapaxes(self, axis1, axis2)
def split(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split`.
The arguments are the same as for :py:func:`split`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute split')
def split_v2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split_v2`.
The arguments are the same as for :py:func:`split_v2`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute split_v2')
def slice(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice`.
The arguments are the same as for :py:func:`slice`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute slice')
def slice_axis(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice_axis`.
The arguments are the same as for :py:func:`slice_axis`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_axis')
def slice_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice_like`.
The arguments are the same as for :py:func:`slice_like`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_like')
def slice_assign_scalar(self, value, begin, end, step):
"""
Assign the scalar to a cropped subset of this ndarray. Value will broadcast to the shape of the cropped shape
and will be cast to the same dtype of the ndarray.
Parameters
----------
value: numeric value
Value and this ndarray should be of the same data type.
The shape of rhs should be the same as the cropped shape of this ndarray.
begin: tuple of begin indices
end: tuple of end indices
step: tuple of step lenghths
Returns
-------
This ndarray.
Examples
--------
>>> x = np.ones((2, 2, 2))
>>> y = x.slice_assign_scalar(0, (0, 0, None), (1, 1, None), (None, None, None))
>>> y
array([[[0., 0.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
>>> x
array([[[0., 0.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
"""
return _npi.slice_assign_scalar(self, value, begin=begin, end=end, step=step, out=self)
def slice_assign(self, rhs, begin, end, step):
"""
Assign the rhs to a cropped subset of this ndarray in place.
Returns the view of this ndarray.
Parameters
----------
rhs: ndarray.
rhs and this NDArray should be of the same data type, and on the same device.
The shape of rhs should be the same as the cropped shape of this ndarray.
begin: tuple of begin indices
end: tuple of end indices
step: tuple of step lenghths
Returns
-------
out : ndarray
This ndarray.
Examples
--------
>>> x = np.ones((2, 2, 2))
>>> assigned = np.zeros((1, 1, 2))
>>> y = x.slice_assign(assigned, (0, 0, None), (1, 1, None), (None, None, None))
>>> y
array([[[0., 0.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
>>> x
array([[[0., 0.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
"""
return _npi.slice_assign(self, rhs, begin=begin, end=end, step=step, out=self)
def take(self, indices, axis=None, mode='raise'): # pylint: disable=arguments-differ, redefined-outer-name
"""Convenience fluent method for :py:func:`take`.
The arguments are the same as for :py:func:`take`, with
this array as data.
"""
return take(self, indices, axis, mode=mode)
def one_hot(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`one_hot`.
The arguments are the same as for :py:func:`one_hot`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute one_hot')
def pick(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`pick`.
The arguments are the same as for :py:func:`pick`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute pick')
def sort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sort`.
The arguments are the same as for :py:func:`sort`, with
this array as data.
"""
return sort(self, axis=axis, descending=descending, stable=stable)
def topk(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`topk`.
The arguments are the same as for :py:func:`topk`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute topk')
def argsort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`argsort`.
The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
return argsort(self, axis=axis, descending=descending, stable=stable)
def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
The arguments are the same as for :py:func:`argmax_channel`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute argmax_channel')
def argmin(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Return indices of the minium values along the given axis.
Refer to `mxnet.numpy.argmin` for full documentation."""
return argmin(self, axis, out, keepdims)
def clip(self, min=None, max=None, out=None): # pylint: disable=arguments-differ
"""Return an array whose values are limited to [min, max].
One of max or min must be given.
"""
return clip(self, min, max, out=out)
def abs(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`abs`.
The arguments are the same as for :py:func:`abs`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute abs')
def sign(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sign`.
The arguments are the same as for :py:func:`sign`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute sign')
def flatten(self, order='C'): # pylint: disable=arguments-differ
"""Return a copy of the array collapsed into one dimension."""
return self.reshape(-1, order=order)
def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
The arguments are the same as for :py:func:`shape_array`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute shape_array')
def size_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`size_array`.
The arguments are the same as for :py:func:`size_array`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute size_array')
def expand_dims(self, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument
"""Convenience fluent method for :py:func:`expand_dims`.
The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute expand_dims')
def tile(self, reps): # pylint: disable=arguments-differ
"""Construct an array by repeating A the number of times given by reps.
Refer to `mxnet.numpy.tile` for full documentation."""
return tile(self, reps=reps)
def transpose(self, *axes): # pylint: disable=arguments-differ
"""Permute the dimensions of an array."""
if len(axes) == 0:
axes = None
elif len(axes) == 1:
if isinstance(axes[0], (tuple, list)):
axes = axes[0]
elif axes[0] is None:
axes = None
return transpose(self, axes=axes)
def flip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flip`.
The arguments are the same as for :py:func:`flip`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute flip')
def depth_to_space(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`depth_to_space`.
The arguments are the same as for :py:func:`depth_to_space`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute depth_to_space')
def space_to_depth(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`space_to_depth`.
The arguments are the same as for :py:func:`space_to_depth`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute space_to_depth')
def diag(self, k=0, **kwargs):
"""Convenience fluent method for :py:func:`diag`.
The arguments are the same as for :py:func:`diag`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute diag')
def diagonal(self, offset=0, axis1=0, axis2=1): # pylint: disable=arguments-differ
"""Return the diagonal with the given offset.
If array has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-array whose diagonal is returned.
Refer to `mxnet.numpy.diagonal` for full documents.
"""
return diagonal(self, offset=offset, axis1=axis1, axis2=axis2)
def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Return the sum of the array elements over the given axis."""
return sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
def nansum(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`nansum`.
The arguments are the same as for :py:func:`nansum`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute nansum')
def prod(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Return the product of the array elements over the given axis."""
return _mx_np_op.prod(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
def nanprod(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`nanprod`.
The arguments are the same as for :py:func:`nanprod`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute nanprod')
def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Returns the average of the array elements along given axis."""
return mean(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
# pylint: disable=too-many-arguments, arguments-differ
@wrap_data_api_statical_func
def std(self, axis=None, dtype=None, out=None, correction=0, keepdims=False):
"""Returns the standard deviation of the array elements along given axis."""
return std(self, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, out=out)
@wrap_data_api_statical_func
def var(self, axis=None, dtype=None, out=None, correction=0, keepdims=False):
"""Returns the variance of the array elements, along given axis."""
return var(self, axis=axis, dtype=dtype, out=out, correction=correction, keepdims=keepdims)
# pylint: enable=too-many-arguments, arguments-differ
def cumsum(self, axis=None, dtype=None, out=None):
"""Return the cumulative sum of the elements along the given axis."""
return _mx_nd_np.cumsum(self, axis=axis, dtype=dtype, out=out)
def tolist(self):
return self.asnumpy().tolist()
def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Return the maximum along a given axis."""
return _mx_nd_np.max(self, axis=axis, out=out, keepdims=keepdims)
def min(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`min`.
The arguments are the same as for :py:func:`min`, with
this array as data.
"""
return _mx_nd_np.min(self, axis=axis, out=out, keepdims=keepdims)
def norm(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`norm`.
The arguments are the same as for :py:func:`norm`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute norm')
def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`round`.
The arguments are the same as for :py:func:`round`, with
this array as data.
"""
return round(self, decimals=decimals, out=out, **kwargs)
def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
The arguments are the same as for :py:func:`rint`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute rint')
def fix(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`fix`.
The arguments are the same as for :py:func:`fix`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute fix')
def floor(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`floor`.
The arguments are the same as for :py:func:`floor`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute floor')
def ceil(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`ceil`.
The arguments are the same as for :py:func:`ceil`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute ceil')
def trunc(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`trunc`.
The arguments are the same as for :py:func:`trunc`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute trunc')
def sin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sin`.
The arguments are the same as for :py:func:`sin`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute sin')
def cos(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`cos`.
The arguments are the same as for :py:func:`cos`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute cos')
def tan(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tan`.
The arguments are the same as for :py:func:`tan`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute tan')
def arcsin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arcsin`.
The arguments are the same as for :py:func:`arcsin`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute arcsin')
def arccos(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arccos`.
The arguments are the same as for :py:func:`arccos`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute arccos')
def arctan(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arctan`.
The arguments are the same as for :py:func:`arctan`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute arctan')
def degrees(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`degrees`.
The arguments are the same as for :py:func:`degrees`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute degrees')
def radians(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`radians`.
The arguments are the same as for :py:func:`radians`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute radians')
def sinh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sinh`.
The arguments are the same as for :py:func:`sinh`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute sinh')
def cosh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`cosh`.
The arguments are the same as for :py:func:`cosh`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute cosh')
def tanh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tanh`.
The arguments are the same as for :py:func:`tanh`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute tanh')
def arcsinh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arcsinh`.
The arguments are the same as for :py:func:`arcsinh`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute arcsinh')
def arccosh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arccosh`.
The arguments are the same as for :py:func:`arccosh`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute arccosh')
def arctanh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arctanh`.
The arguments are the same as for :py:func:`arctanh`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute arctanh')
def exp(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`exp`.
The arguments are the same as for :py:func:`exp`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute exp')
def expm1(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expm1`.
The arguments are the same as for :py:func:`expm1`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute expm1')
def log(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log`.
The arguments are the same as for :py:func:`log`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log')
def log10(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log10`.
The arguments are the same as for :py:func:`log10`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log10')
def log2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log2`.
The arguments are the same as for :py:func:`log2`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log2')
def log1p(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log1p`.
The arguments are the same as for :py:func:`log1p`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log1p')
def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.
The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log_sigmoid')
def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.
The arguments are the same as for :py:func:`sqrt`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute sqrt')
def rsqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rsqrt`.
The arguments are the same as for :py:func:`rsqrt`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute rsqrt')
def cbrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`cbrt`.
The arguments are the same as for :py:func:`cbrt`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute cqrt')
def rcbrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rcbrt`.
The arguments are the same as for :py:func:`rcbrt`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute rcqrt')
def square(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`square`.
The arguments are the same as for :py:func:`square`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute square')
def reciprocal(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reciprocal`.
The arguments are the same as for :py:func:`reciprocal`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute reciprocal')
def relu(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`relu`.
The arguments are the same as for :py:func:`relu`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute relu')
def sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sigmoid`.
The arguments are the same as for :py:func:`sigmoid`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute sigmoid')
def softmax(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`softmax`.
The arguments are the same as for :py:func:`softmax`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute softmax')
def log_softmax(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_softmax`.
The arguments are the same as for :py:func:`log_softmax`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log_softmax')
def softmin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`softmin`.
The arguments are the same as for :py:func:`softmin`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute softmin')
def mish(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`mish`.
The arguments are the same as for :py:func:`mish`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute mish')
def squeeze(self, axis=None): # pylint: disable=arguments-differ
"""Remove single-dimensional entries from the shape of a."""
return squeeze(self, axis=axis)
def broadcast_to(self, shape): # pylint: disable=redefined-outer-name
return _mx_nd_np.broadcast_to(self, shape)
def broadcast_like(self, other):
raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like')
def _full(self, value):
"""
Currently for internal use only. Implemented for __setitem__.
Assign to self an array of self's same shape and type, filled with value.
"""
return _mx_nd_np.full(self.shape, value, device=self.device, dtype=self.dtype, out=self)
# pylint: disable=redefined-outer-name
def _scatter_set_nd(self, value_nd, indices):
"""
This is added as an ndarray class method in order to support polymorphism in NDArray and numpy.ndarray indexing
"""
return _npi.scatter_set_nd(
lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self
)
# pylint: enable=redefined-outer-name
@property
def shape(self):
"""Tuple of array dimensions.
Examples
--------
>>> x = mx.np.array([1, 2, 3, 4])
>>> x.shape
(4L,)
>>> y = mx.np.zeros((2, 3, 4))
>>> y.shape
(2L, 3L, 4L)
>>> z = mx.np.array(3)
>>> z.shape
()
"""
num_dim = mx_int()
if _int64_enabled():
pdata = ctypes.POINTER(mx_int64)()
check_call(_LIB.MXNDArrayGetShape64(
self.handle, ctypes.byref(num_dim), ctypes.byref(pdata)))
else:
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShape(
self.handle, ctypes.byref(num_dim), ctypes.byref(pdata)))
if num_dim.value == -1:
return None
else:
return tuple(pdata[:num_dim.value]) # pylint: disable=invalid-slice-index
@property
def ndim(self):
"""Number of array dimensions."""
return len(self.shape)
@property
def size(self):
"""Number of elements in the array."""
return super(ndarray, self).size
@property
def dtype(self):
"""Data-type of the array's elements.
Returns
-------
numpy.dtype
This NDArray's data type.
Examples
--------
>>> x = np.zeros((2,3))
>>> x.dtype
dtype('float32')
>>> y = np.zeros((2,3), dtype='int32')
>>> y.dtype
dtype('int32')
"""
return _np.dtype(super(ndarray, self).dtype)
def tostype(self, stype):
raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype')
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def empty(shape, dtype=None, order='C', device=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, without initializing entries.
Parameters
----------
shape : int or tuple of int Shape of the empty array, e.g., ``(2, 3)`` or ``2``.
dtype : data-type, optional
Desired output data-type for the array, e.g, `numpy.int8`.
Note that this behavior is different from NumPy's `empty` function where `float64`
is the default value, here you can set your default dtype as 'float32' or 'float64'
because `float32` is considered as the default data type in deep learning.
When npx.is_np_default_dtype() returns False, default dtype is float32;
When npx.is_np_default_dtype() returns True, default dtype is float64.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
Returns
-------
out : ndarray
Array of uninitialized (arbitrary) data of the given shape, dtype, and order.
Examples
--------
>>> np.empty([2, 2])
array([[ 0.000000e+00, -2.524355e-29],
[ nan, -8.592023e+09]]) # uninitialized
>>> np.empty([2, 2], dtype=int)
array([[8751743591039004782, 3196766424264760104],
[7583328881310196768, 562950123910254]], dtype=int64) # uninitialized
"""
if order != 'C':
raise NotImplementedError('`empty` only supports order equal to `C`, while received {}'
.format(str(order)))
if device is None:
device = current_device()
if dtype is None or dtype is float:
dtype = _np.float64 if is_np_default_dtype() else _np.float32
if isinstance(shape, int):
shape = (shape,)
return ndarray(handle=_new_alloc_handle(shape, device, False, dtype))
# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def array(object, dtype=None, device=None):
"""
Create an array.
Parameters
----------
object : array_like or `numpy.ndarray` or `mxnet.numpy.ndarray`
An array, any object exposing the array interface, an object whose
__array__ method returns an array, or any (nested) sequence.
dtype : data-type, optional
The desired data-type for the array.
The default dtype is ``object.dtype`` if `object` is an `ndarray`, `float32` otherwise.
Default dtype can be set to be consistent with offical numpy by `npx.set_np(dtype=True)`.
* When npx.is_np_default_dtype() returns False, default dtype is float32;
* When npx.is_np_default_dtype() returns True, default dtype is float64.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
Returns
-------
out : ndarray
An array object satisfying the specified requirements.
Examples
--------
>>> np.array([1, 2, 3])
array([1., 2., 3.])
>>> np.array([[1, 2], [3, 4]])
array([[1., 2.],
[3., 4.]])
>>> np.array([[1, 0], [0, 1]], dtype=bool)
array([[ True, False],
[False, True]])
>>> np.array([1, 2, 3]).dtype
dtype('float32')
>>> npx.set_np(dtype=True)
>>> np.array([1, 2, 3]).dtype
dtype('float64')
"""
if device is None:
device = current_device()
if isinstance(object, _np.ndarray):
if is_np_default_dtype():
dtype = object.dtype if dtype is None else dtype
else:
dtype = _np.float32 if dtype is None or object.dtype is _np.float64 else dtype
if isinstance(object, ndarray):
dtype = object.dtype if dtype is None else dtype
elif isinstance(object, NDArray):
raise ValueError("If you're trying to create a mxnet.numpy.ndarray "
"from mx.nd.NDArray, please use the zero-copy as_np_ndarray function.")
else:
if dtype is None:
default_dtype = _np.float64 if is_np_default_dtype() else _np.float32
dtype = object.dtype if hasattr(object, "dtype") else default_dtype
try:
object = _np.array(object, dtype=dtype)
except Exception as e:
# printing out the error raised by official NumPy's array function
# for transparency on users' side
raise TypeError('{}'.format(str(e)))
ret = empty(object.shape, dtype=dtype, device=device)
if len(object.shape) == 0:
ret[()] = object
else:
ret[:] = object
return ret
# pylint: enable=redefined-outer-name
@set_module('mxnet.numpy')
def shape(a):
"""
Return the shape of an array.
Parameters
----------
a : array_like
Input array.
Returns
-------
shape : tuple of ints
The elements of the shape tuple give the lengths of the
corresponding array dimensions.
See Also
--------
ndarray.shape : Equivalent array method.
Examples
--------
>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 2]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
"""
return _mx_nd_np.shape(a)
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def zeros(shape, dtype=None, order='C', device=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Parameters
----------
shape : int or tuple of int
The shape of the empty array.
dtype : str or numpy.dtype, optional
An optional value type,
When npx.is_np_default_dtype() returns False, default dtype is float32,
When npx.is_np_default_dtype() returns True, default dtype is float64.
Note that this behavior is different from NumPy's `zeros` function where `float64`
is the default value, here we can set 'float32' or 'float64' as your default dtype,
because `float32` is considered as the default data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
Returns
-------
out : ndarray
Array of zeros with the given shape, dtype, and device.
Examples
--------
>>> np.zeros(5)
array([0., 0., 0., 0., 0.])
>>> np.zeros((5,), dtype=int)
array([0, 0, 0, 0, 0], dtype=int64)
>>> np.zeros((2, 1))
array([[0.],
[0.]])
"""
return _mx_nd_np.zeros(shape, dtype, order, device)
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def ones(shape, dtype=None, order='C', device=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Parameters
----------
shape : int or tuple of int
The shape of the empty array.
dtype : str or numpy.dtype, optional
An optional value type. Default is depend on your current default dtype.
When npx.is_np_default_dtype() returns False, default dtype is float32;
When npx.is_np_default_dtype() returns True, default dtype is float64.
Note that this behavior is different from NumPy's `ones` function where
`float64` is the default value.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
Returns
-------
out : ndarray
Array of ones with the given shape, dtype, and device.
Examples
--------
>>> np.ones(5)
array([1., 1., 1., 1., 1.])
>>> np.ones((5,), dtype=int)
array([1, 1, 1, 1, 1], dtype=int64)
>>> np.ones((2, 1))
array([[1.],
[1.]])
>>> s = (2,2)
>>> np.ones(s)
array([[1., 1.],
[1., 1.]])
"""
return _mx_nd_np.ones(shape, dtype, order, device)
@set_module('mxnet.numpy')
def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
"""
Broadcast an array to a new shape.
Parameters
----------
array : ndarray or scalar
The array to broadcast.
shape : tuple
The shape of the desired array.
Returns
-------
broadcast : array
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
Raises
------
MXNetError
If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
"""
return _mx_nd_np.broadcast_to(array, shape)
# pylint: disable=too-many-arguments, redefined-outer-name
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def full(shape, fill_value, dtype=None, order='C', device=None, out=None):
r"""Return a new array of given shape and type, filled with `fill_value`.
Parameters
----------
shape : int or sequence of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
fill_value : scalar or ndarray
Fill value.
dtype : data-type, optional
If dtype is None, the output array data type must be inferred from fill_value.
If it’s an int, the output array dtype must be the default integer dtype;
If it’s a float, then the output array dtype must be the default floating-point data type;
If it’s a bool then the output array must have boolean dtype. Default: None.
order : {'C'}, optional
Whether to store multidimensional data in C- or Fortran-contiguous
(row- or column-wise) order in memory. Currently only supports C order.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Array of `fill_value` with the given shape, dtype, and order.
If `fill_value` is an ndarray, out will have the same device as `fill_value`
regardless of the provided `device`.
.. note::
This function differs from the original numpy.full in the following way(s):
* Has an additional `device` argument to specify the device
* Has an additional `out` argument
* Currently does not support `order` selection
See Also
--------
empty : Return a new uninitialized array.
ones : Return a new array setting values to one.
zeros : Return a new array setting values to zero.
Examples
--------
>>> np.full((2, 2), 10)
array([[10., 10.],
[10., 10.]])
>>> np.full((2, 2), 2, dtype=np.int32, device=mx.cpu(0))
array([[2, 2],
[2, 2]], dtype=int32)
"""
return _mx_nd_np.full(shape, fill_value, order=order, device=device, dtype=dtype, out=out)
# pylint: enable=too-many-arguments, redefined-outer-name
# pylint: disable=redefined-outer-name, too-many-arguments
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def empty_like(prototype, dtype=None, device=None, order='C', subok=False, shape=None): # pylint: disable=W0621
"""
Return a new array with the same shape and type as a given array.
Parameters
----------
prototype : ndarray
The shape and data-type of `prototype` define these same attributes
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
order : {'C'}, optional
Whether to store multidimensional data in C- or Fortran-contiguous
(row- or column-wise) order in memory. Currently only supports C order.
subok : {False}, optional
If True, then the newly created array will use the sub-class
type of 'a', otherwise it will be a base-class array. Defaults
to False.
(Only support False at this moment)
shape : int or sequence of ints, optional.
Overrides the shape of the result. If order='K' and the number of
dimensions is unchanged, will try to keep order, otherwise,
order='C' is implied.
(Not supported at this moment)
Returns
-------
out : ndarray
Array of uninitialized (arbitrary) data with the same
shape and type as `prototype`.
See Also
--------
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full_like : Return a new array with shape of input filled with value.
empty : Return a new uninitialized array.
Notes
-----
This function does *not* initialize the returned array; to do that use
`zeros_like` or `ones_like` instead. It may be marginally faster than
the functions that do set the array values.
Examples
--------
>>> a = np.array([[1,2,3], [4,5,6]])
>>> np.empty_like(a)
array([[-5764607523034234880, -2305834244544065442, 4563075075], # uninitialized
[ 4567052944, -5764607523034234880, 844424930131968]])
>>> a = np.array([[1., 2., 3.],[4.,5.,6.]])
>>> np.empty_like(a)
array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized
[2.0e-323, 2.5e-323, 3.0e-323]])
"""
ret = _mx_nd_np.empty_like(prototype, dtype=dtype, order=order, subok=subok, shape=shape)
if device is not None:
ret.to_device(device)
return ret
# pylint: enable=redefined-outer-name
# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
def all(a, axis=None, out=None, keepdims=False):
"""
Test whether all array elements along a given axis evaluate to True.
Parameters
----------
a : ndarray
Input array or object that can be converted to an array.
axis : None or int or tuple of ints, optional
Axis or axes along which a logical AND reduction is performed.
The default (axis = None) is to perform a logical AND over
all the dimensions of the input array.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
out : ndarray, optional
Alternate output array in which to place the result. It must have
the same shape as the expected output and its type is preserved
Returns
--------
all : ndarray, bool
A new boolean or array is returned unless out is specified,
in which case a reference to out is returned.
Examples:
---------
>>> np.all([[True,False],[True,True]])
False
>>> np.all([[True,False],[True,True]], axis=0)
array([ True, False])
>>> np.all([-1, 4, 5])
True
>>> np.all([1.0, np.nan])
True
>>> o=np.array(False)
>>> z=np.all([-1, 4, 5], out=o)
>>> id(z), id(o), z
(28293632, 28293632, array(True)) # may vary
"""
return _mx_nd_np.all(a, axis=axis, out=out, keepdims=keepdims)
@set_module('mxnet.numpy')
def any(a, axis=None, out=None, keepdims=False):
"""
Test whether any array element along a given axis evaluates to True.
Returns single boolean unless axis is not None
Parameters
----------
a : ndarray
Input array or object that can be converted to an array.
axis : None or int or tuple of ints, optional
Axis or axes along which a logical AND reduction is performed.
The default (axis = None) is to perform a logical AND over
all the dimensions of the input array.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
out : ndarray, optional
Alternate output array in which to place the result. It must have
the same shape as the expected output and its type is preserved
Returns
--------
any : bool or ndarray
A new boolean or ndarray is returned unless out is specified,
in which case a reference to out is returned.
Examples:
---------
>>> np.any([[True, False], [True, True]])
True
>>> np.any([[True, False], [False, False]], axis=0)
array([ True, False])
>>> np.any([-1, 0, 5])
True
>>> np.any(np.nan)
True
>>> o=np.array(False)
>>> z=np.any([-1, 4, 5], out=o)
>>> z, o
(array(True), array(True))
>>> # Check now that z is a reference to o
>>> z is o
True
>>> id(z), id(o) # identity of z and o # doctest: +SKIP
(191614240, 191614240)
"""
return _mx_nd_np.any(a, axis=axis, out=out, keepdims=keepdims)
@set_module('mxnet.numpy')
@wrap_ctx_to_device_func
def identity(n, dtype=None, device=None):
"""
Return the identity array.
The identity array is a square array with ones on
the main diagonal.
Parameters
----------
n : int
Number of rows (and columns) in `n` x `n` output.
dtype : data-type, optional
Data-type of the output.
When npx.is_np_default_dtype() returns False, default dtype is float32;
When npx.is_np_default_dtype() returns True, default dtype is float64.
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
Returns
-------
out : ndarray
`n` x `n` array with its main diagonal set to one,
and all other elements 0.
Examples
--------
>>> np.identity(3)
>>> np.identity(3)
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
"""
return _mx_nd_np.identity(n, dtype, device)
# pylint: enable=redefined-outer-name
# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
def take(a, indices, axis=None, mode='raise', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : ndarray
The returned array has the same type as `a`.
.. note::
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
* Only ndarray or scalar ndarray is accepted as valid input.
Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])
In this example for `a` is an ndarray, "fancy" indexing can be used.
>>> a[indices]
array([4., 3., 6.])
If `indices` is not one dimensional, the output also has these dimensions.
>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
return _mx_nd_np.take(a, indices, axis, mode, out)
# pylint: enable=redefined-outer-name
@set_module('mxnet.numpy')
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None):
"""
Find the unique elements of an array.
Returns the sorted unique elements of an array. There are three optional
outputs in addition to the unique elements:
* the indices of the input array that give the unique values
* the indices of the unique array that reconstruct the input array
* the number of times each unique value comes up in the input array
Parameters
----------
ar : ndarray
Input array. Unless `axis` is specified, this will be flattened if it
is not already 1-D.
return_index : bool, optional
If True, also return the indices of `ar` (along the specified axis,
if provided, or in the flattened array) that result in the unique array.
return_inverse : bool, optional
If True, also return the indices of the unique array (for the specified
axis, if provided) that can be used to reconstruct `ar`.
return_counts : bool, optional
If True, also return the number of times each unique item appears
in `ar`.
axis : int or None, optional
The axis to operate on. If None, `ar` will be flattened. If an integer,
the subarrays indexed by the given axis will be flattened and treated
as the elements of a 1-D array with the dimension of the given axis,
see the notes for more details. The default is None.
Returns
-------
unique : ndarray
The sorted unique values.
unique_indices : ndarray, optional
The indices of the first occurrences of the unique values in the
original array. Only provided if `return_index` is True.
unique_inverse : ndarray, optional
The indices to reconstruct the original array from the
unique array. Only provided if `return_inverse` is True.
unique_counts : ndarray, optional
The number of times each of the unique values comes up in the
original array. Only provided if `return_counts` is True.
.. note::
When an axis is specified the subarrays indexed by the axis are sorted.
This is done by making the specified axis the first dimension of the array
and then flattening the subarrays in C order. The flattened subarrays are
then viewed as a structured type with each element given a label, with the
effect that we end up with a 1-D array of structured types that can be
treated in the same way as any other 1-D array. The result is that the
flattened subarrays are sorted in lexicographic order starting with the
first element.
This function differs from the original `numpy.unique
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.unique.html>`_ in
the following aspects:
* Only support ndarray as input.
* Object arrays or structured arrays are not supported.
Examples
--------
>>> np.unique(np.array([1, 1, 2, 2, 3, 3]))
array([1., 2., 3.])
>>> a = np.array([[1, 1], [2, 3]])
>>> np.unique(a)
array([1., 2., 3.])
Return the unique rows of a 2D array
>>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
>>> np.unique(a, axis=0)
array([[1., 0., 0.],
[2., 3., 4.]])
Return the indices of the original array that give the unique values:
>>> a = np.array([1, 2, 6, 4, 2, 3, 2])
>>> u, indices = np.unique(a, return_index=True)
>>> u
array([1., 2., 3., 4., 6.])
>>> indices
array([0, 1, 5, 3, 2], dtype=int64)
>>> a[indices]
array([1., 2., 3., 4., 6.])
Reconstruct the input array from the unique values:
>>> a = np.array([1, 2, 6, 4, 2, 3, 2])
>>> u, indices = np.unique(a, return_inverse=True)
>>> u
array([1., 2., 3., 4., 6.])
>>> indices
array([0, 1, 4, 3, 1, 2, 1], dtype=int64)
>>> u[indices]
array([1., 2., 6., 4., 2., 3., 2.])
"""
return _mx_nd_np.unique(ar, return_index, return_inverse, return_counts, axis)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def add(x1, x2, out=None, **kwargs):
"""
Add arguments element-wise.
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays to be added. If x1.shape != x2.shape, they must be broadcastable to
a common shape (which may be the shape of one or the other).
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
.. note::
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.add(1.0, 4.0)
5.0
>>>
>>> x1 = np.arange(9.0).reshape((3, 3))
>>> x2 = np.arange(3.0)
>>> np.add(x1, x2)
array([[ 0., 2., 4.],
[ 3., 5., 7.],
[ 6., 8., 10.]])
"""
return _mx_nd_np.add(x1, x2, out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def subtract(x1, x2, out=None, **kwargs):
r"""Subtract arguments element-wise.
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays to be subtracted from each other. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which may be the shape
of one or the other).
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
subtract : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
.. note::
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.subtract(1.0, 4.0)
-3.0
>>> x1 = np.arange(9.0).reshape((3, 3))
>>> x2 = np.arange(3.0)
>>> np.subtract(x1, x2)
array([[0., 0., 0.],
[3., 3., 3.],
[6., 6., 6.]])
"""
return _mx_nd_np.subtract(x1, x2, out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def multiply(x1, x2, out=None, **kwargs):
"""
Multiply arguments element-wise.
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays to be multiplied. If x1.shape != x2.shape, they must be broadcastable to
a common shape (which may be the shape of one or the other).
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.
.. note::
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), not supported yet.
Examples
--------
>>> np.multiply(2.0, 4.0)
8.0
>>> x1 = np.arange(9.0).reshape((3, 3))
>>> x2 = np.arange(3.0)
>>> np.multiply(x1, x2)
array([[ 0., 1., 4.],
[ 0., 4., 10.],
[ 0., 7., 16.]])
"""
return _mx_nd_np.multiply(x1, x2, out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def divide(x1, x2, out=None, **kwargs):
"""Returns a true division of the inputs, element-wise.
.. note::
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types including boolean, the output is of float32 or
float64 type, which depends on your current default dtype:
* When ``npx.is_np_default_dtype()`` returns False, default dtype is float32.
* When ``npx.is_np_default_dtype()`` returns True, default dtype is float64.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.true_divide(x, 4)
array([0. , 0.25, 0.5 , 0.75, 1. ])
"""
return _mx_nd_np.divide(x1, x2, out=out)
@set_module('mxnet.numpy')
def true_divide(x1, x2, out=None):
"""Returns a true division of the inputs, element-wise.
Instead of the Python traditional 'floor division', this returns a true
division. True division adjusts the output type to present the best
answer, regardless of input types.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
.. note::
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is of float32 or
float64 type, which depends on your current default dtype.
When npx.is_np_default_dtype() returns False, default dtype is float32;
When npx.is_np_default_dtype() returns True, default dtype is float64.
Examples
--------
>>> x = np.arange(5)
>>> np.true_divide(x, 4)
array([0. , 0.25, 0.5 , 0.75, 1. ])
"""
return _mx_nd_np.true_divide(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def floor_divide(x1, x2, out=None):
"""Return the largest integer smaller or equal to the division of the inputs.
It is equivalent to the Python // operator and pairs with the Python % (remainder),
function so that a = a % b + b * (a // b) up to roundoff.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
.. note::
This operator now supports automatic type promotion. The resulting type will be determined
according to the following rules:
* If both inputs are of floating number types, the output is the more precise type.
* If only one of the inputs is floating number type, the result is that type.
* If both inputs are of integer types (including boolean), the output is the more
precise type
Examples
--------
>>> np.floor_divide(7,3)
2
>>> np.floor_divide([1., 2., 3., 4.], 2.5)
array([ 0., 0., 1., 1.])
"""
return _mx_nd_np.floor_divide(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def mod(x1, x2, out=None, **kwargs):
"""
Return element-wise remainder of division.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.mod(np.arange(7), 5)
array([0., 1., 2., 3., 4., 0., 1.])
"""
return _mx_nd_np.mod(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def fmod(x1, x2, out=None, **kwargs):
"""
Return element-wise remainder of division.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.fmod(np.arange(7), 5)
array([0., 1., 2., 3., 4., 0., 1.])
"""
return _mx_nd_np.fmod(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def matmul(a, b, out=None, **kwargs):
r"""Matrix product of two arrays.
Parameters
----------
a, b : ndarray
Input arrays, scalars not allowed.
out : ndarray, optional
A location into which the result is stored.
If provided, it must have a shape that matches the signature (n,k),(k,m)->(n,m).
If not provided or None, a freshly-allocated array is returned.
Returns
-------
y : ndarray
The matrix product of the inputs.
This is a scalar only when both x1, x2 are 1-d vectors.
Raises
------
MXNetError
If the last dimension of a is not the same size as the second-to-last dimension of b.
If a scalar value is passed in.
See Also
--------
tensordot : Sum products over arbitrary axes.
dot : alternative matrix product with different broadcasting rules.
einsum : Einstein summation convention.
.. note::
The behavior depends on the arguments in the following way.
* If both arguments are ``2-D`` they are multiplied like conventional matrices.
* If either argument is ``N-D``, ``N > 2``, it is treated as a stack of matrices
residing in the last two indexes and broadcast accordingly.
* If the first argument is ``1-D``, it is promoted to a matrix by prepending
a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
* If the second argument is ``1-D``, it is promoted to a matrix by appending a 1
to its dimensions. After matrix multiplication the appended 1 is removed.
matmul differs from dot in two important ways:
* Multiplication by scalars is not allowed, use multiply instead.
* Stacks of matrices are broadcast together as if the matrices were elements,
respecting the signature ``(n,k),(k,m)->(n,m)``:
>>> a = np.ones([9, 5, 7, 4])
>>> c = np.ones([9, 5, 4, 3])
>>> np.dot(a, c).shape
(9, 5, 7, 9, 5, 3)
>>> np.matmul(a, c).shape
(9, 5, 7, 3)
>>> # n is 7, k is 4, m is 3
Examples
--------
For 2-D arrays it is the matrix product:
>>> a = np.array([[1, 0],
... [0, 1]])
>>> b = np.array([[4, 1],
... [2, 2]])
>>> np.matmul(a, b)
array([[4., 1.],
[2., 2.]])
For 2-D mixed with 1-D, the result is the usual.
>>> a = np.array([[1, 0],
... [0, 1]])
>>> b = np.array([1, 2])
>>> np.matmul(a, b)
array([1., 2.])
>>> np.matmul(b, a)
array([1., 2.])
Broadcasting is conventional for stacks of arrays
>>> a = np.arange(2 * 2 * 4).reshape((2, 2, 4))
>>> b = np.arange(2 * 2 * 4).reshape((2, 4, 2))
>>> np.matmul(a, b).shape
(2, 2, 2)
>>> np.matmul(a, b)[0, 1, 1]
array(98.)
>>> sum(a[0, 1, :] * b[0, :, 1])
array(98.)
Scalar multiplication raises an error.
>>> np.matmul([1, 2], 3)
Traceback (most recent call last):
...
mxnet.base.MXNetError: ... : Multiplication by scalars is not allowed.
"""
return _mx_nd_np.matmul(a, b, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def remainder(x1, x2, out=None, **kwargs):
"""
Return element-wise remainder of division.
Parameters
----------
x1 : ndarray or scalar
Dividend array.
x2 : ndarray or scalar
Divisor array.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> np.remainder(np.arange(7), 5)
array([0., 1., 2., 3., 4., 0., 1.])
"""
return _mx_nd_np.remainder(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def power(x1, x2, out=None, **kwargs):
"""
First array elements raised to powers from second array, element-wise.
Parameters
----------
x1 : ndarray or scalar
The bases.
x2 : ndarray or scalar
The exponent.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
The bases in x1 raised to the exponents in x2.
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> x1 = np.arange(6)
>>> np.power(x1, 3)
array([ 0., 1., 8., 27., 64., 125.])
Raise the bases to different exponents.
>>> x2 = np.array([1.0, 2.0, 3.0, 3.0, 2.0, 1.0])
>>> np.power(x1, x2)
array([ 0., 1., 8., 27., 16., 5.])
The effect of broadcasting.
>>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]])
>>> x2
array([[1., 2., 3., 3., 2., 1.],
[1., 2., 3., 3., 2., 1.]])
>>> np.power(x1, x2)
array([[ 0., 1., 8., 27., 16., 5.],
[ 0., 1., 8., 27., 16., 5.]])
"""
return _mx_nd_np.power(x1, x2, out=out)
pow = power
pow.__doc_ = """
First array elements raised to powers from second array, element-wise.
Notes
-----
`pow` is an alias for `power`. It is a standard API in
https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#pow-x1-x2
instead of an official NumPy operator.
>>> np.pow is np.power
True
Parameters
----------
x1 : ndarray or scalar
The bases.
x2 : ndarray or scalar
The exponent.
out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
out : ndarray or scalar
The bases in x1 raised to the exponents in x2.
This is a scalar if both x1 and x2 are scalars.
Examples
--------
>>> x1 = np.arange(6)
>>> np.pow(x1, 3)
array([ 0., 1., 8., 27., 64., 125.])
Raise the bases to different exponents.
>>> x2 = np.array([1.0, 2.0, 3.0, 3.0, 2.0, 1.0])
>>> np.pow(x1, x2)
array([ 0., 1., 8., 27., 16., 5.])
The effect of broadcasting.
>>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]])
>>> x2
array([[1., 2., 3., 3., 2., 1.],
[1., 2., 3., 3., 2., 1.]])
>>> np.pow(x1, x2)
array([[ 0., 1., 8., 27., 16., 5.],
[ 0., 1., 8., 27., 16., 5.]])
"""
@set_module('mxnet.numpy')
@wrap_np_binary_func
def gcd(x1, x2, out=None, **kwargs):
"""
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which may be the shape of
one or the other).
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
y : ndarray or scalar
The greatest common divisor of the absolute value of the inputs
This is a scalar if both `x1` and `x2` are scalars.
See Also
--------
gcd : The lowest common multiple
Examples
--------
>>> np.gcd(12, 20)
4
>>> np.gcd(np.arange(6, dtype=int), 20)
array([20, 1, 2, 1, 4, 5], dtype=int64)
"""
return _mx_nd_np.gcd(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_binary_func
def lcm(x1, x2, out=None, **kwargs):
"""
Returns the lowest common multiple of ``|x1|`` and ``|x2|``
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays for computing lowest common multiple. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which may be the shape of
one or the other).
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
y : ndarray or scalar
The lowest common multiple of the absolute value of the inputs
This is a scalar if both `x1` and `x2` are scalars.
See Also
--------
gcd : The greatest common divisor
Examples
--------
>>> np.lcm(12, 20)
60
>>> np.lcm(np.arange(6, dtype=int), 20)
array([ 0, 20, 20, 60, 20, 20], dtype=int64)
"""
return _mx_nd_np.lcm(x1, x2, out=out)
@set_module('mxnet.numpy')
@wrap_np_unary_func
def sin(x, out=None, **kwargs):
r"""
Trigonometric sine, element-wise.
Parameters
----------
x : ndarray or scalar
Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
out : ndarray or None
A location into which the result is stored. If provided, it
must have a shape that the inputs broadcast to. If not provided
or None, a freshly-allocated array is returned. The dtype of the
output is the same as that of the input if the input is an ndarray.
Returns
-------
y : ndarray or scalar
The sine of each element of x. This is a scalar if `x` is a scalar.
Notes
----
This function only supports input type of float.
Examples
--------
>>> np.sin(np.pi/2.)
1.0
>>> np.sin(np.array((0., 30., 45., 60., 90.)) * np.pi / 180.)
array([0. , 0.5 , 0.70710677, 0.86602545, 1. ])
"""
return _mx_nd_np.sin(x, out=out, **kwargs)
@set_module('mxnet.numpy')
@wrap_np_unary_func
def cos(x, out=None, **kwargs):
r"""
Cosine, element-wise.
Parameters
----------
x : ndarray or scalar
Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
out : ndarray or None
A location into which the result is stored. If provided, it
must have a shape that the inputs broadcast to. If not provided
or None, a freshly-allocated array is returned. The dtype of the
output is the same as that of the input if the input is an ndarray.
Returns
-------
y : ndarray or scalar
The corresponding cosine values. This is a scalar if x is a scalar.
Notes
----
This function only supports input type of float.
Examples
--------
>>> np.cos(np.array([0, np.pi/2, np.pi]))
array([ 1.000000e+00, -4.371139e-08, -1.000000e+00])
>>> # Example of providing the optional output parameter
>>> out1 = np.array([0], dtype='f')
>>> out2 = np.cos(np.array([0.1]), out1)
>>> out2 is out1
True
"""
return _mx_nd_np.cos(x, out=out, **kwargs)
@set_module('mxnet.numpy')
@wrap_np_unary_func
def sinh(x, out=None, **kwargs):
"""
Hyperbolic sine, element-wise.
Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``.
Parameters
----------
x : ndarray or scalar
Input array or scalar.
out : ndarray or None
A location into which the result is stored. If provided, it
must have a shape that the inputs broadcast to. If not provided
or None, a freshly-allocated array is returned. The dtype of the
output is the same as that of the input if the input is an ndarray.
Returns
-------
y : ndarray or scalar
The corresponding hyperbolic sine values. This is a scalar if `x` is a scalar.
Notes
----
This function only supports input type of float.
Examples
--------
>>> np.sinh(0)
0.0
>>> # Example of providing the optional output parameter
>>> out1 = np.array([0], dtype='f')
>>> out2 = np.sinh(np.array([0.1]), out1)
>>> out2 is out1
True
"""
return _mx_nd_np.sinh(x, out=out, **kwargs)
@set_module('mxnet.numpy')
@wrap_np_unary_func
def cosh(x, out=None, **kwargs):
"""
Hyperbolic cosine, element-wise.
Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``.
Parameters
----------
x : ndarray or scalar
Input array or scalar.
out : ndarray or None
A location into which the result is stored. If provided, it
must have a shape that the inputs broadcast to. If not provided
or None, a freshly-allocated array is returned. The dtype of the
output is the same as that of the input if the input is an ndarray.
Returns
-------
y : ndarray or scalar
The corresponding hyperbolic cosine values. This is a scalar if `x` is a scalar.
Notes
----
This function only supports input type of float.
Examples
--------
>>> np.cosh(0)
1.0
"""
return _mx_nd_np.cosh(x, out=out, **kwargs)
@set_module('mxnet.numpy')
@wrap_np_unary_func
def tanh(x, out=None, **kwargs):
"""
Compute hyperbolic tangent element-wise.
Equivalent to ``np.sinh(x)/np.cosh(x)``.
Parameters
----------
x : ndarray or scalar.
Input array.
out : ndarray or None
A location into which the result is stored. If provided, it
must have a shape that the inputs fill into. If not provided
or None, a freshly-allocated array is returned. The dtype of the
output and input must be the same.
Returns
----------
y : ndarray or scalar
The corresponding hyperbolic tangent values.
.. note::
If `out` is provided, the function writes the result into it,
and returns a reference to `out`. (See Examples)
* input x does not support complex computation (like imaginary number)
>>> np.tanh(np.pi*1j)
TypeError: type <type 'complex'> not supported
Examples
--------