blob: e55fa1af90e84e5d2244c7354a8eaf55ed548e0a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import numpy as np
from distutils.version import LooseVersion
import os
import pickle as pkl
import unittest
from nose.tools import raises
from common import setup_module, with_seed, assertRaises, TemporaryDirectory, teardown
from mxnet.test_utils import almost_equal
from mxnet.test_utils import assert_almost_equal, assert_exception
from mxnet.test_utils import default_context
from mxnet.test_utils import np_reduce
from mxnet.test_utils import same
from mxnet.test_utils import random_sample, rand_shape_nd
from numpy.testing import assert_allclose
import mxnet.autograd
def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list=[np.float32]):
"""check function consistency with uniform random numbers"""
if isinstance(arg_shapes, int):
assert dim
shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
arg_shapes = [shape] * arg_shapes
for dtype in type_list:
ndarray_arg = []
numpy_arg = []
for s in arg_shapes:
npy = np.random.uniform(rmin, 10, s).astype(dtype)
narr = mx.nd.array(npy, dtype=dtype)
ndarray_arg.append(narr)
numpy_arg.append(npy)
out1 = uf(*ndarray_arg)
if npuf is None:
out2 = uf(*numpy_arg).astype(dtype)
else:
out2 = npuf(*numpy_arg).astype(dtype)
assert out1.shape == out2.shape
if isinstance(out1, mx.nd.NDArray):
out1 = out1.asnumpy()
if dtype == np.float16:
assert_almost_equal(out1, out2, rtol=2e-3)
else:
assert_almost_equal(out1, out2)
def random_ndarray(dim):
shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
data = mx.nd.array(np.random.uniform(-10, 10, shape))
return data
@with_seed()
def test_ndarray_setitem():
shape = (3, 4, 2)
# scalar assignment
x = mx.nd.zeros(shape)
x[:] = 1
x_np = np.ones(shape, dtype=x.dtype)
assert same(x.asnumpy(), x_np)
# ndarray assignment
x = mx.nd.zeros(shape)
x[:] = mx.nd.ones(shape)
x_np = np.ones(shape, dtype=x.dtype)
assert same(x.asnumpy(), x_np)
# numpy assignment
x = mx.nd.zeros(shape)
x[:] = np.ones(shape)
x_np = np.ones(shape, dtype=x.dtype)
assert same(x.asnumpy(), x_np)
# indexing sub-arrays
x = mx.nd.zeros(shape)
x[1] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[1] = 1
assert same(x.asnumpy(), x_np)
x[-1] = 1
x_np[-1] = 1
assert same(x.asnumpy(), x_np)
# short all-dim indexing
x = mx.nd.zeros(shape)
val = mx.nd.ones((3, 2))
x[:, 1:3, 1] = val
x_np = np.zeros(shape, dtype=x.dtype)
x_np[:, 1:3, 1] = val.asnumpy()
assert same(x.asnumpy(), x_np)
x[:, 1:3, -1] = val
x_np[:, 1:3, -1] = val.asnumpy()
assert same(x.asnumpy(), x_np)
x = mx.nd.zeros(shape)
x[:, 1:3, 1:2] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[:, 1:3, 1:2] = 1
assert same(x.asnumpy(), x_np)
x[:, -3:-1, -2:-1] = 1
x_np[:, -3:-1, -2:-1] = 1
assert same(x.asnumpy(), x_np)
@with_seed(0)
def test_ndarray_elementwise():
nrepeat = 10
maxdim = 4
all_type = [np.float32, np.float64, np.float16, np.uint8, np.int8, np.int32, np.int64]
real_type = [np.float32, np.float64, np.float16]
for repeat in range(nrepeat):
for dim in range(1, maxdim):
check_with_uniform(lambda x, y: x + y, 2, dim, type_list=all_type)
check_with_uniform(lambda x, y: x - y, 2, dim, type_list=all_type)
check_with_uniform(lambda x, y: x * y, 2, dim, type_list=all_type)
check_with_uniform(lambda x, y: x / y, 2, dim, type_list=real_type)
check_with_uniform(lambda x, y: x / y, 2, dim, rmin=1, type_list=all_type)
check_with_uniform(mx.nd.sqrt, 1, dim, np.sqrt, rmin=0)
check_with_uniform(mx.nd.square, 1, dim, np.square, rmin=0)
check_with_uniform(lambda x: mx.nd.norm(x).asscalar(), 1, dim, np.linalg.norm)
@with_seed()
def test_ndarray_elementwisesum():
ones = mx.nd.ones((10,), dtype=np.int32)
res = mx.nd.ElementWiseSum(ones, ones*2, ones*4, ones*8)
assert same(res.asnumpy(), ones.asnumpy()*15)
@with_seed()
def test_ndarray_negate():
npy = np.random.uniform(-10, 10, (2,3,4))
arr = mx.nd.array(npy)
assert_almost_equal(npy, arr.asnumpy())
assert_almost_equal(-npy, (-arr).asnumpy())
# a final check to make sure the negation (-) is not implemented
# as inplace operation, so the contents of arr does not change after
# we compute (-arr)
assert_almost_equal(npy, arr.asnumpy())
@with_seed()
def test_ndarray_reshape():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
true_res = mx.nd.arange(30) + 1
assert same(tensor.reshape((-1,)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape((2, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape((0, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(6, 5).asnumpy(), true_res.reshape(6, 5).asnumpy())
assert same(tensor.reshape(-1, 2).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(30).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape(-1, 6).asnumpy(), true_res.reshape(5, 6).asnumpy())
assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(6, 5).asnumpy())
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())
@with_seed()
def test_ndarray_choose():
shape = (100, 20)
npy = np.arange(np.prod(shape)).reshape(shape)
arr = mx.nd.array(npy)
nrepeat = 3
for repeat in range(nrepeat):
indices = np.random.randint(shape[1], size=shape[0])
assert same(npy[np.arange(shape[0]), indices],
mx.nd.choose_element_0index(arr, mx.nd.array(indices)).asnumpy())
@with_seed()
def test_ndarray_fill():
shape = (100, 20)
npy = np.arange(np.prod(shape)).reshape(shape)
arr = mx.nd.array(npy)
new_npy = npy.copy()
nrepeat = 3
for repeat in range(nrepeat):
indices = np.random.randint(shape[1], size=shape[0])
val = np.random.randint(shape[1], size=shape[0])
new_npy[:] = npy
new_npy[np.arange(shape[0]), indices] = val
assert same(new_npy,
mx.nd.fill_element_0index(arr, mx.nd.array(val), mx.nd.array(indices)).asnumpy())
@with_seed()
def test_ndarray_onehot():
shape = (100, 20)
npy = np.arange(np.prod(shape)).reshape(shape)
arr = mx.nd.array(npy)
nrepeat = 3
for repeat in range(nrepeat):
indices = np.random.randint(shape[1], size=shape[0])
npy[:] = 0.0
npy[np.arange(shape[0]), indices] = 1.0
mx.nd.onehot_encode(mx.nd.array(indices), out=arr)
assert same(npy, arr.asnumpy())
@with_seed()
def test_ndarray_copy():
c = mx.nd.array(np.random.uniform(-10, 10, (10, 10)))
d = c.copyto(mx.Context('cpu', 0))
assert np.sum(np.abs(c.asnumpy() != d.asnumpy())) == 0.0
@with_seed()
def test_ndarray_scalar():
c = mx.nd.empty((10,10))
d = mx.nd.empty((10,10))
c[:] = 0.5
d[:] = 1.0
d -= c * 2 / 3 * 6.0
c += 0.5
assert(np.sum(c.asnumpy()) - 100 < 1e-5)
assert(np.sum(d.asnumpy()) + 100 < 1e-5)
c[:] = 2
assert(np.sum(c.asnumpy()) - 200 < 1e-5)
d = -c + 2
assert(np.sum(d.asnumpy()) < 1e-5)
@with_seed(0)
def test_ndarray_pickle():
maxdim = 5
nrepeat = 10
for repeat in range(nrepeat):
for dim in range(1, maxdim):
a = random_ndarray(dim)
b = mx.nd.empty(a.shape)
a[:] = np.random.uniform(-10, 10, a.shape)
b[:] = np.random.uniform(-10, 10, a.shape)
a = a + b
data = pkl.dumps(a)
a2 = pkl.loads(data)
assert np.sum(a.asnumpy() != a2.asnumpy()) == 0
@with_seed()
def test_ndarray_saveload():
nrepeat = 10
fname = 'tmp_list.bin'
for repeat in range(nrepeat):
data = []
# test save/load as list
for i in range(10):
data.append(random_ndarray(np.random.randint(1, 5)))
mx.nd.save(fname, data)
data2 = mx.nd.load(fname)
assert len(data) == len(data2)
for x, y in zip(data, data2):
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test save/load as dict
dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)}
mx.nd.save(fname, dmap)
dmap2 = mx.nd.load(fname)
assert len(dmap2) == len(dmap)
for k, x in dmap.items():
y = dmap2[k]
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test save/load as ndarray
# we expect the single ndarray to be converted into a list containing the ndarray
single_ndarray = data[0]
mx.nd.save(fname, single_ndarray)
single_ndarray_loaded = mx.nd.load(fname)
assert len(single_ndarray_loaded) == 1
single_ndarray_loaded = single_ndarray_loaded[0]
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0
os.remove(fname)
@with_seed()
def test_ndarray_legacy_load():
data = []
for i in range(6):
data.append(mx.nd.arange(128))
path = os.path.dirname(os.path.realpath(__file__))
legacy_data = mx.nd.load(os.path.join(path, 'legacy_ndarray.v0'))
assert len(data) == len(legacy_data)
for i in range(len(data)):
assert same(data[i].asnumpy(), legacy_data[i].asnumpy())
@with_seed()
def test_buffer_load():
nrepeat = 10
with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir:
for repeat in range(nrepeat):
# test load_buffer as list
data = []
for i in range(10):
data.append(random_ndarray(np.random.randint(1, 5)))
fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat))
mx.nd.save(fname, data)
with open(fname, 'rb') as dfile:
buf_data = dfile.read()
data2 = mx.nd.load_frombuffer(buf_data)
assert len(data) == len(data2)
for x, y in zip(data, data2):
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_data[:-10])
# test load_buffer as dict
dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)}
fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat))
mx.nd.save(fname, dmap)
with open(fname, 'rb') as dfile:
buf_dmap = dfile.read()
dmap2 = mx.nd.load_frombuffer(buf_dmap)
assert len(dmap2) == len(dmap)
for k, x in dmap.items():
y = dmap2[k]
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_dmap[:-10])
# we expect the single ndarray to be converted into a list containing the ndarray
single_ndarray = data[0]
fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat))
mx.nd.save(fname, single_ndarray)
with open(fname, 'rb') as dfile:
buf_single_ndarray = dfile.read()
single_ndarray_loaded = mx.nd.load_frombuffer(buf_single_ndarray)
assert len(single_ndarray_loaded) == 1
single_ndarray_loaded = single_ndarray_loaded[0]
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_single_ndarray[:-10])
@with_seed()
def test_ndarray_slice():
shape = (10,)
A = mx.nd.array(np.random.uniform(-10, 10, shape))
A2 = A.asnumpy()
assert same(A[3:8].asnumpy(), A2[3:8])
A2[3:8] *= 10
A[3:8] = A2[3:8]
assert same(A[3:8].asnumpy(), A2[3:8])
shape = (3,4,5,6,7)
A = mx.nd.random.uniform(shape=shape)
A2 = A.asnumpy()
assert same(A[1,3:4,:,1:5].asnumpy(), A2[1,3:4,:,1:5])
assert A[1,2,3,4,5].asscalar() == A2[1,2,3,4,5]
assert A[-1,-2,-3,-4,-5].asscalar() == A2[-1,-2,-3,-4,-5]
a = mx.nd.array([[0, 1], [2, 3]])
assert (a[[1, 1, 0], [0, 1, 0]].asnumpy() == [2, 3, 0]).all()
assert (a[mx.nd.array([1, 1, 0]), mx.nd.array([0, 1, 0])].asnumpy() == [2, 3, 0]).all()
shape = (4, 4)
A = mx.nd.random.uniform(shape=shape)
A2 = A.asnumpy()
for i in range(-4, 0):
assert A[i, i].asscalar() == A2[i, i]
assert same(A[:, i].asnumpy(), A2[:, i])
assert same(A[i, :].asnumpy(), A2[i, :])
@with_seed()
def test_ndarray_crop():
# get crop
x = mx.nd.ones((2, 3, 4))
y = mx.nd.crop(x, begin=(0, 0, 0), end=(2, 1, 3))
assert same(y.asnumpy(), np.ones((2, 1, 3), dtype=y.dtype))
# crop assign
z = mx.nd.zeros((2, 1, 3))
mx.nd._internal._crop_assign(x, z, begin=(0, 0, 0),
end=(2, 1, 3), out=x)
np_x = np.ones(x.shape, dtype=x.dtype)
np_x[0:2, 0:1, 0:3] = 0
assert same(x.asnumpy(), np_x)
# crop assign with scalar
x = mx.nd.ones((2, 3, 4))
mx.nd._internal._crop_assign_scalar(x, scalar=5,
begin=(0, 0, 0),
end=(2, 1, 3), out=x)
np_x = np.ones(x.shape, dtype=x.dtype)
np_x[0:2, 0:1, 0:3] = 5
assert same(x.asnumpy(), np_x)
@with_seed()
def test_ndarray_concatenate():
axis = 1
shapes = [(2, 3, 4, 2), (2, 2, 4, 2), (2, 1, 4, 2)]
arrays_np = [np.random.uniform(-10, 10, s).astype(np.float32) for s in shapes]
arrays_nd = [mx.nd.array(x) for x in arrays_np]
array_nd = mx.nd.concatenate(arrays_nd, axis=axis)
array_np = np.concatenate(arrays_np, axis=axis)
assert same(array_np, array_nd.asnumpy())
@with_seed()
def test_clip():
shape = (10,)
A = mx.random.uniform(-10, 10, shape)
B = mx.nd.clip(A, -2, 2)
B1 = B.asnumpy()
for i in range(shape[0]):
assert B1[i] >= -2
assert B1[i] <= 2
@with_seed()
def test_dot():
# Non-zero atol required, as exposed by seed 828791701
atol = 1e-5
# Test normal dot
a = np.random.uniform(-3, 3, (3, 4))
b = np.random.uniform(-3, 3, (4, 5))
c = np.dot(a, b)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B)
assert_almost_equal(c, C.asnumpy(), atol=atol)
# Test dot with transpose kargs
a = np.random.uniform(-3, 3, (3, 4))
b = np.random.uniform(-3, 3, (3, 5))
c = np.dot(a.T, b)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B, transpose_a=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)
# Test dot with transpose kargs
a = np.random.uniform(-3, 3, (3, 4))
b = np.random.uniform(-3, 3, (5, 4))
c = np.dot(a, b.T)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B, transpose_b=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)
# Test dot with transpose kargs
a = np.random.uniform(-3, 3, (4, 3))
b = np.random.uniform(-3, 3, (5, 4))
c = np.dot(a.T, b.T)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B, transpose_a=True, transpose_b=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)
@with_seed()
def test_reduce():
sample_num = 200
def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes):
for i in range(sample_num):
ndim = np.random.randint(1, 6)
shape = np.random.randint(1, 11, size=ndim)
dat = np.random.rand(*shape) - 0.5
keepdims = np.random.randint(0, 2)
if multi_axes:
axis_flags = np.random.randint(0, 2, size=ndim)
axes = []
for (axis, flag) in enumerate(axis_flags):
if flag:
axes.append(axis)
if 0 == len(axes):
axes = tuple(range(ndim))
else:
axes = tuple(axes)
else:
axes = np.random.randint(0, ndim)
numpy_ret = numpy_reduce_func(dat, axis=axes, keepdims=keepdims)
ndarray_ret = nd_reduce_func(mx.nd.array(dat), axis=axes, keepdims=keepdims)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == numpy_ret.shape) or \
(ndarray_ret.shape == (1,) and numpy_ret.shape == ()), "nd:%s, numpy:%s" \
%(ndarray_ret.shape, numpy_ret.shape)
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-4
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.sum),
mx.nd.sum, True)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.max),
mx.nd.max, True)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.min),
mx.nd.min, True)
# argmax and argmin are sensitive to the precision of the calculation (repro seed 1985162693).
# Force numpy to match mxnet's float32.
test_reduce_inner(lambda data, axis,
keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmax),
mx.nd.argmax, False)
test_reduce_inner(lambda data, axis,
keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin),
mx.nd.argmin, False)
@with_seed()
def test_broadcast():
sample_num = 1000
def test_broadcast_to():
for _ in range(sample_num):
ndim = np.random.randint(1, 6)
target_shape = np.random.randint(1, 11, size=ndim)
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray_ret = mx.nd.array(dat).broadcast_to(shape=target_shape)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8
def test_broadcast_like():
for _ in range(sample_num):
ndim = np.random.randint(1, 6)
target_shape = np.random.randint(1, 11, size=ndim)
target = mx.nd.ones(shape=tuple(target_shape))
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray_ret = mx.nd.array(dat).broadcast_like(target)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8
test_broadcast_to()
test_broadcast_like()
@with_seed()
def test_broadcast_binary():
N = 100
def check_broadcast_binary(fn):
for _ in range(N):
ndim = np.random.randint(1, 6)
oshape = np.random.randint(1, 6, size=(ndim,))
bdim = np.random.randint(1, ndim+1)
lshape = list(oshape)
rshape = list(oshape[ndim-bdim:])
for i in range(bdim):
sep = np.random.uniform(0, 1)
if sep < 0.33:
lshape[ndim-i-1] = 1
elif sep < 0.66:
rshape[bdim-i-1] = 1
lhs = np.random.normal(0, 1, size=lshape)
rhs = np.random.normal(0, 1, size=rshape)
assert_allclose(fn(lhs, rhs),
fn(mx.nd.array(lhs), mx.nd.array(rhs)).asnumpy(),
rtol=1e-4, atol=1e-4)
check_broadcast_binary(lambda x, y: x + y)
check_broadcast_binary(lambda x, y: x - y)
check_broadcast_binary(lambda x, y: x * y)
check_broadcast_binary(lambda x, y: x / y)
# The following ops are sensitive to the precision of the calculation.
# Force numpy to match mxnet's float32.
check_broadcast_binary(lambda x, y: x.astype(np.float32) > y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) < y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) >= y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) <= y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) == y.astype(np.float32))
@with_seed()
def test_moveaxis():
X = mx.nd.array([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]])
res = mx.nd.moveaxis(X, 0, 3).asnumpy()
true_res = mx.nd.array([[[ 1., 7.],
[ 2., 8.],
[ 3., 9.]],
[[ 4., 10.],
[ 5., 11.],
[ 6., 12.]]])
assert same(res, true_res.asnumpy())
assert mx.nd.moveaxis(X, 2, 0).shape == (3, 2, 2)
@with_seed()
def test_arange():
for i in range(5):
start = np.random.rand() * 10
stop = start + np.random.rand() * 100
step = np.random.rand() * 4
repeat = int(np.random.rand() * 5) + 1
gt = np.arange(start=start, stop=stop, step=step)
gt = np.broadcast_to(gt.reshape((gt.shape[0], 1)), shape=(gt.shape[0], repeat)).ravel()
pred = mx.nd.arange(start=start, stop=stop, step=step, repeat=repeat).asnumpy()
assert_almost_equal(pred, gt)
gt = np.arange(start=0, stop=10000**2, step=10001, dtype=np.int32)
pred = mx.nd.arange(start=0, stop=10000**2, step=10001,
dtype="int32").asnumpy()
assert_almost_equal(pred, gt)
@with_seed()
def test_order():
ctx = default_context()
dat_size = 5
def gt_topk(dat, axis, ret_typ, k, is_ascend):
if ret_typ == "indices":
if is_ascend:
indices = np.arange(k)
else:
indices = np.arange(-1, -k-1, -1)
ret = np.take(dat.argsort(axis=axis), axis=axis, indices=indices, mode='wrap')
elif ret_typ == "value":
if is_ascend:
indices = np.arange(k)
else:
indices = np.arange(-1, -k-1, -1)
ret = np.take(np.sort(dat, axis=axis), axis=axis, indices=indices, mode='wrap')
else:
assert dat.shape == (dat_size, dat_size, dat_size, dat_size)
assert axis is None or axis ==1
ret = np.zeros(dat.shape)
if is_ascend:
indices = np.arange(k)
else:
indices = np.arange(-1, -k-1, -1)
gt_argsort = np.take(dat.argsort(axis=axis), axis=axis, indices=indices, mode='wrap')
if axis is None:
ret.ravel()[gt_argsort] = 1
else:
for i in range(dat_size):
for j in range(dat_size):
for k in range(dat_size):
ret[i, gt_argsort[i, :, j, k], j, k] = 1
return ret
# Produce input data for the tests, including ensuring unique values if desired.
# Numpy's argsort does not consistently return lowest-index-first for matching
# values, making it hard to generate a numpy 'golden copy' to compare against
# the mxnet operator. The 'mask' function is particularly hard to test given that
# equal values might span the 'k' boundary. Issue exposed with seed 1405838964.
def get_values(ensure_unique):
while True:
data = np.float32(np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)))
if not ensure_unique:
return data
num_unique_values = len(set(data.flatten()))
if data.size == num_unique_values:
return data
a_npy = get_values(ensure_unique=True)
a_nd = mx.nd.array(a_npy, ctx=ctx)
# test for ret_typ=indices
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
# test for ret_typ=value
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
# test for ret_typ=mask
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
# test for ret_typ=both
nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
nd_ret_topk_val = nd_ret_topk_val.asnumpy()
nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk_val, gt_val)
assert_almost_equal(nd_ret_topk_ind, gt_ind)
# test for sort
nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
assert_almost_equal(nd_ret_sort, gt)
nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="value",
k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
assert_almost_equal(nd_ret_sort, gt)
# test for argsort
nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
assert_almost_equal(nd_ret_argsort, gt)
nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="indices",
k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
assert_almost_equal(nd_ret_argsort, gt)
a = mx.nd.arange(0, 1024, step=1, repeat=1)
assert_almost_equal(a.topk(k=1024).asnumpy(), a.asnumpy()[::-1])
# Repeat those tests that don't involve indices. These should pass even with
# duplicated input data values (over many repeated runs with different random seeds,
# this will be tested).
a_npy = get_values(ensure_unique=False)
a_nd = mx.nd.array(a_npy, ctx=ctx)
# test for ret_typ=value
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
# test for sort
nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
assert_almost_equal(nd_ret_sort, gt)
nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="value",
k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
assert_almost_equal(nd_ret_sort, gt)
@with_seed()
def test_ndarray_equal():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = x == y
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = 0 == x
assert (z.asnumpy() == np.ones((2, 3))).all()
@with_seed()
def test_ndarray_not_equal():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = x != y
assert (z.asnumpy() == np.ones((2, 3))).all()
z = 0 != x
assert (z.asnumpy() == np.zeros((2, 3))).all()
@with_seed()
def test_ndarray_greater():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = x > y
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = y > 0
assert (z.asnumpy() == np.ones((2, 3))).all()
z = 0 > y
assert (z.asnumpy() == np.zeros((2, 3))).all()
@with_seed()
def test_ndarray_greater_equal():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = x >= y
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = y >= 0
assert (z.asnumpy() == np.ones((2, 3))).all()
z = 0 >= y
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = y >= 1
assert (z.asnumpy() == np.ones((2, 3))).all()
@with_seed()
def test_ndarray_lesser():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = y < x
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = 0 < y
assert (z.asnumpy() == np.ones((2, 3))).all()
z = y < 0
assert (z.asnumpy() == np.zeros((2, 3))).all()
@with_seed()
def test_ndarray_lesser_equal():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = y <= x
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = 0 <= y
assert (z.asnumpy() == np.ones((2, 3))).all()
z = y <= 0
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = 1 <= y
assert (z.asnumpy() == np.ones((2, 3))).all()
@with_seed()
def test_take():
for data_ndim in range(2, 5):
for idx_ndim in range(1, 4):
data_shape = ()
for _ in range(data_ndim):
data_shape += (np.random.randint(low=3, high=6), )
data_real = np.random.normal(size=data_shape).astype('float32')
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=3, high=5), )
idx_real = np.random.randint(low=0, high=data_shape[0], size=idx_shape)
data_real_mx = mx.nd.array(data_real)
idx_real_mx = mx.nd.array(idx_real)
result = mx.nd.take(data_real_mx, idx_real_mx)
assert_almost_equal(result.asnumpy(), data_real[idx_real])
@with_seed()
def test_iter():
x = mx.nd.array([1, 2, 3])
y = []
for a in x:
y.append(a)
for i in range(x.size):
assert same(y[i].asnumpy(), x[i].asnumpy())
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/8049")
def test_cached():
sym = mx.sym.Convolution(kernel=(3, 3), num_filter=10) + 2
op = mx.nd.CachedOp(sym)
data = mx.nd.ones((3, 4, 10, 10))
weight = mx.nd.ones((10, 4, 3, 3))
bias = mx.nd.ones((10,))
o1 = op(data, weight, bias)
bias[:] = 2
o2 = op(data, weight, bias)
assert_almost_equal(o2.asnumpy(), o1.asnumpy()+1)
o2[:] = 0
op(data, weight, bias, out=o2)
assert_almost_equal(o2.asnumpy(), o1.asnumpy()+1)
weight.attach_grad()
bias.attach_grad()
with mx.autograd.record():
bias = bias + 1
o = op(data, weight, bias)
o = o * 2
o.backward()
with mx.autograd.record():
bias = bias + 1
o = op(data, weight, bias)
o = o * 2
o.backward(retain_graph=True)
o.backward()
# try a different shape
data = mx.nd.ones((5, 2, 10, 10))
weight = mx.nd.ones((10, 2, 3, 3))
bias = mx.nd.ones((10,))
data.attach_grad()
with mx.autograd.record():
bias = bias + 1
o = op(data, weight, bias)
o = o * 2
o.backward()
@with_seed()
def test_output():
shape = (2,2)
ones = mx.nd.ones(shape)
zeros = mx.nd.zeros(shape)
out = mx.nd.zeros(shape)
mx.nd.ones(shape, out=out)
assert_almost_equal(out.asnumpy(), ones.asnumpy())
mx.nd.zeros(shape, out=out)
assert_almost_equal(out.asnumpy(), zeros.asnumpy())
mx.nd.full(shape, 2, out=out)
assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2)
arange_out = mx.nd.arange(0, 20, dtype='int64')
assert_almost_equal(arange_out.asnumpy(), np.arange(0, 20))
N_array = np.random.randint(1, high=8, size=10)
M_array = np.random.randint(1, high=8, size=10)
k_array = np.random.randint(-10, high=10, size=10)
for i in range(10):
N = N_array[i]
M = M_array[i]
k = k_array[i]
assert_almost_equal(np.eye(N, M, k), mx.nd.eye(N, M, k).asnumpy())
assert_almost_equal(np.eye(N, k=k), mx.nd.eye(N, k=k).asnumpy())
@with_seed()
def test_ndarray_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
'broadcast_axes', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 'argmin',
'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
'exp', 'expm1', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt', 'square',
'reshape_like', 'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax',
'reciprocal'])
def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
with mx.name.NameManager():
data = mx.nd.random_uniform(shape=shape, ctx=default_context())
regular = getattr(mx.ndarray, func)(data, **kwargs)
fluent = getattr(data, func)(**kwargs)
if isinstance(regular, list):
for r, f in zip(regular, fluent):
assert almost_equal(r.asnumpy(), f.asnumpy(), equal_nan=equal_nan)
else:
assert almost_equal(regular.asnumpy(), fluent.asnumpy(), equal_nan=equal_nan)
for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
check_fluent_regular(func, {})
for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
'arcsinh', 'arctanh', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt',
'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax']:
check_fluent_regular(func, {}, equal_nan=True)
for func in ['expand_dims', 'flip', 'sort', 'topk', 'argsort', 'argmax', 'argmin']:
check_fluent_regular(func, {'axis': 1})
check_fluent_regular('one_hot', {'depth': 15})
check_fluent_regular('tile', {'reps': (1,2)})
check_fluent_regular('repeat', {'repeats': 3})
check_fluent_regular('transpose', {'axes': (1,0,2)})
check_fluent_regular('split', {'axis': 2, 'num_outputs': 3}, shape=(5, 17, 6))
check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, shape=(5, 17, 6))
check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7})
check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': mx.nd.zeros((3, 3))})
check_fluent_regular('take', {'indices': mx.nd.array([2, 3])})
check_fluent_regular('pick', {'axis': 1, 'index': mx.nd.array([[2], [3], [5], [6], [11]])})
check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)})
check_fluent_regular('pad', {'mode': 'constant', 'pad_width': (0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3))
check_fluent_regular('reshape_like', {'rhs': mx.nd.ones((30, 17))}, shape=(5, 17, 2, 3))
for func in ['sum', 'nansum', 'prod', 'nanprod', 'mean', 'max', 'min', 'norm']:
check_fluent_regular(func, {'axis': (1, 2)})
check_fluent_regular('reshape', {'shape': (17, 1, 5)})
check_fluent_regular('broadcast_to', {'shape': (5, 17, 47)})
check_fluent_regular('squeeze', {'axis': (1, 3)}, shape=(2, 1, 3, 1, 4))
@raises(ValueError)
def test_bool_ambiguous():
bool(mx.nd.ones((2,3,4)))
def test_bool():
assert not bool(mx.nd.array([]))
assert not bool(mx.nd.zeros((1,)))
assert bool(mx.nd.ones((1,)))
@with_seed()
def test_ndarray_indexing():
def test_getitem(np_array, index, is_scalar=False):
"""`is_scalar` indicates whether we should expect a scalar for the result.
If so, the indexed array of NDArray should call asscalar to compare
with numpy's indexed array."""
np_index = index
if isinstance(index, mx.nd.NDArray):
np_index = index.asnumpy()
if isinstance(index, tuple):
np_index = []
for idx in index:
if isinstance(idx, mx.nd.NDArray):
np_index.append(idx.asnumpy())
else:
np_index.append(idx)
np_index = tuple(np_index)
np_indexed_array = np_array[np_index]
mx_array = mx.nd.array(np_array, dtype=np_array.dtype)
mx_indexed_array = mx_array[index]
if is_scalar:
mx_indexed_array = mx_indexed_array.asscalar()
else:
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index)
def test_setitem(np_array, index, is_scalar):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
if np_value is not None:
np_array[np_index] = np_value
elif isinstance(mx_value, mx.nd.NDArray):
np_array[np_index] = mx_value.asnumpy()
else:
np_array[np_index] = mx_value
mx_array[mx_index] = mx_value
assert same(np_array, mx_array.asnumpy())
np_index = index
if isinstance(index, mx.nd.NDArray):
np_index = index.asnumpy()
if isinstance(index, tuple):
np_index = []
for idx in index:
if isinstance(idx, mx.nd.NDArray):
np_index.append(idx.asnumpy())
else:
np_index.append(idx)
np_index = tuple(np_index)
mx_array = mx.nd.array(np_array, dtype=np_array.dtype)
np_array = mx_array.asnumpy()
if is_scalar:
# test value is a numeric type
assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
value_nd = [np.random.randint(low=-10000, high=0)]
assert_same(np_array, np_index, mx_array, index, value_nd, value_nd[0])
else:
indexed_array_shape = np_array[np_index].shape
np_indexed_array = np.random.randint(low=-10000, high=0, size=indexed_array_shape)
# test value is a numpy array without broadcast
assert_same(np_array, np_index, mx_array, index, np_indexed_array)
# test value is an numeric_type
assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
if len(indexed_array_shape) > 1:
# test NDArray with broadcast
assert_same(np_array, np_index, mx_array, index,
mx.nd.random.uniform(low=-10000, high=0, shape=(indexed_array_shape[-1],)))
# test numpy array with broadcast
assert_same(np_array, np_index, mx_array, index,
np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)))
# test list with broadcast
assert_same(np_array, np_index, mx_array, index,
[np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1])
def test_getitem_autograd(np_array, index):
x = mx.nd.array(np_array, dtype=np_array.dtype)
x.attach_grad()
with mx.autograd.record():
y = x[index]
y.backward()
value = mx.nd.ones_like(y)
x_grad = mx.nd.zeros_like(x)
x_grad[index] = value
assert same(x_grad.asnumpy(), x.grad.asnumpy())
def test_setitem_autograd(np_array, index):
x = mx.nd.array(np_array, dtype=np_array.dtype)
out_shape = x[index].shape
y = mx.nd.random.uniform(shape=out_shape)
y.attach_grad()
try:
with mx.autograd.record():
x[index] = y
assert False # should not reach here
except mx.base.MXNetError as err:
assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1
def np_int(index, int_type=np.int32):
def convert(num):
if num is None:
return num
else:
return int_type(num)
if isinstance(index, slice):
return slice(convert(index.start), convert(index.stop), convert(index.step))
elif isinstance(index, tuple): # tuple of slices and integers
ret = []
for elem in index:
if isinstance(elem, slice):
ret.append(slice(convert(elem.start), convert(elem.stop), convert(elem.step)))
else:
ret.append(convert(elem))
return tuple(ret)
else:
assert False
shape = (8, 16, 9, 9)
np_array = np.arange(np.prod(shape), dtype='int32').reshape(shape)
# index_list is a list of tuples. The tuple's first element is the index, the second one is a boolean value
# indicating whether we should expect the result as a scalar compared to numpy.
index_list = [(0, False), (np.int32(0), False), (np.int64(0), False),
(5, False), (np.int32(5), False), (np.int64(5), False),
(-1, False), (np.int32(-1), False), (np.int64(-1), False),
(slice(5), False), (np_int(slice(5), np.int32), False), (np_int(slice(5), np.int64), False),
(slice(1, 5), False), (np_int(slice(1, 5), np.int32), False), (np_int(slice(1, 5), np.int64), False),
(slice(1, 5, 2), False), (np_int(slice(1, 5, 2), np.int32), False),
(np_int(slice(1, 5, 2), np.int64), False),
(slice(7, 0, -1), False), (np_int(slice(7, 0, -1)), False),
(np_int(slice(7, 0, -1), np.int64), False),
(slice(None, 6), False), (np_int(slice(None, 6)), False),
(np_int(slice(None, 6), np.int64), False),
(slice(None, 6, 3), False), (np_int(slice(None, 6, 3)), False),
(np_int(slice(None, 6, 3), np.int64), False),
(slice(1, None), False), (np_int(slice(1, None)), False),
(np_int(slice(1, None), np.int64), False),
(slice(1, None, 3), False), (np_int(slice(1, None, 3)), False),
(np_int(slice(1, None, 3), np.int64), False),
(slice(None, None, 2), False), (np_int(slice(None, None, 2)), False),
(np_int(slice(None, None, 2), np.int64), False),
(slice(None, None, -1), False),
(np_int(slice(None, None, -1)), False), (np_int(slice(None, None, -1), np.int64), False),
(slice(None, None, -2), False),
(np_int(slice(None, None, -2), np.int32), False), (np_int(slice(None, None, -2), np.int64), False),
((slice(None), slice(None), 1, 8), False),
((slice(None), slice(None), -1, 8), False),
((slice(None), slice(None), 1, -8), False),
((slice(None), slice(None), -1, -8), False),
(np_int((slice(None), slice(None), 1, 8)), False),
(np_int((slice(None), slice(None), 1, 8), np.int64), False),
((slice(None), slice(None), 1, 8), False),
(np_int((slice(None), slice(None), -1, -8)), False),
(np_int((slice(None), slice(None), -1, -8), np.int64), False),
((slice(None), 2, slice(1, 5), 1), False),
(np_int((slice(None), 2, slice(1, 5), 1)), False),
(np_int((slice(None), 2, slice(1, 5), 1), np.int64), False),
((1, 2, 3), False),
(np_int((1, 2, 3)), False),
(np_int((1, 2, 3), np.int64), False),
((-1, -2, -3), False),
(np_int((-1, -2, -3)), False),
(np_int((-1, -2, -3), np.int64), False),
((1, 2, 3, 4), True),
(np_int((1, 2, 3, 4)), True),
(np_int((1, 2, 3, 4), np.int64), True),
((-4, -3, -2, -1), True),
(np_int((-4, -3, -2, -1)), True),
(np_int((-4, -3, -2, -1), np.int64), True),
((slice(None, None, -1), 2, slice(1, 5), 1), False),
(np_int((slice(None, None, -1), 2, slice(1, 5), 1)), False),
(np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), False),
((slice(None, None, -1), 2, slice(1, 7, 2), 1), False),
(np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), False),
(np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), False),
((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), False),
(np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), False),
(np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), False),
((slice(1, 8, 2), 1, slice(3, 8), 2), False),
(np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), False),
(np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), False),
([1], False), ([1, 2], False), ([2, 1, 3], False), ([7, 5, 0, 3, 6, 2, 1], False),
(np.array([6, 3], dtype=np.int32), False),
(np.array([[3, 4], [0, 6]], dtype=np.int32), False),
(np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False),
(np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False),
(np.array([[2], [0], [1]], dtype=np.int32), False),
(np.array([[2], [0], [1]], dtype=np.int64), False),
(mx.nd.array([4, 7], dtype=np.int32), False),
(mx.nd.array([4, 7], dtype=np.int64), False),
(mx.nd.array([[3, 6], [2, 1]], dtype=np.int32), False),
(mx.nd.array([[3, 6], [2, 1]], dtype=np.int64), False),
(mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False),
(mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False),
((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int32)), False),
((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int64)), False),
((1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), False),
((1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)), False),
((1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), False),
((1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)), False),
((1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), False),
((1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)), False),
((1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), False),
((1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)), False),
((1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), False),
((1, [2], mx.nd.array([[4]], dtype=np.int32), mx.nd.array([[1, 3], [5, 7]], dtype='int64')),
False),
([0], False), ([0, 1], False), ([1, 2, 3], False), ([2, 0, 5, 6], False),
(([1, 1], [2, 3]), False), (([1], [4], [5]), False), (([1], [4], [5], [6]), False),
(([[1]], [[2]]), False), (([[1]], [[2]], [[3]], [[4]]), False),
((slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), False),
(([[[[1]]]], [[1]], slice(0, 3), [1, 5]), False),
(([[[[1]]]], 3, slice(0, 3), [1, 3]), False),
(([[[[1]]]], 3, slice(0, 3), 0), False),
(([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False),
(([1, 2], slice(3, 5), [2, 3], [3, 4]), False),
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False)]
for index in index_list:
test_getitem(np_array, index[0], index[1])
test_setitem(np_array, index[0], index[1])
test_getitem_autograd(np_array, index[0])
test_setitem_autograd(np_array, index[0])
def test_assign_float_value_to_ndarray():
"""Test case from https://github.com/apache/incubator-mxnet/issues/8668"""
a = np.array([47.844944], dtype=np.float32)
b = mx.nd.zeros(1, dtype=np.float32)
b[0] = a
assert same(a, b.asnumpy())
b[0] = a[0]
assert same(a, b.asnumpy())
@with_seed()
def test_assign_a_row_to_ndarray():
"""Test case from https://github.com/apache/incubator-mxnet/issues/9976"""
H, W = 10, 10
dtype = np.float32
a_np = np.random.random((H, W)).astype(dtype)
a_nd = mx.nd.array(a_np)
# assign directly
a_np[0] = a_np[1]
a_nd[0] = a_nd[1]
assert same(a_np, a_nd.asnumpy())
# assign a list
v = np.random.random(W).astype(dtype).tolist()
a_np[1] = v
a_nd[1] = v
assert same(a_np, a_nd.asnumpy())
# assign a np.ndarray
v = np.random.random(W).astype(dtype)
a_np[2] = v
a_nd[2] = v
assert same(a_np, a_nd.asnumpy())
# assign by slice
a_np[0, :] = a_np[1]
a_nd[0, :] = a_nd[1]
assert same(a_np, a_nd.asnumpy())
@with_seed()
def test_ndarray_astype():
x = mx.nd.zeros((2, 3), dtype='int32')
y = x.astype('float32')
assert (y.dtype==np.float32)
# Test that a new ndarray has been allocated
assert (id(x) != id(y))
x = mx.nd.zeros((2, 3), dtype='int32')
y = x.astype('float32', copy=False)
assert (y.dtype==np.float32)
# Test that a new ndarray has been allocated
assert (id(x) != id(y))
x = mx.nd.zeros((2, 3), dtype='int32')
y = x.astype('int32')
assert (y.dtype==np.int32)
# Test that a new ndarray has been allocated
# even though they have same dtype
assert (id(x) != id(y))
# Test that a new ndarray has not been allocated
x = mx.nd.zeros((2, 3), dtype='int32')
y = x.astype('int32', copy=False)
assert (id(x) == id(y))
# Test the string version 'int32'
# has the same behaviour as the np.int32
x = mx.nd.zeros((2, 3), dtype='int32')
y = x.astype(np.int32, copy=False)
assert (id(x) == id(y))
@with_seed()
def test_norm(ctx=default_context()):
try:
import scipy
assert LooseVersion(scipy.__version__) >= LooseVersion('0.1')
from scipy.linalg import norm as sp_norm
except (AssertionError, ImportError):
print("Could not import scipy.linalg.norm or scipy is too old. "
"Falling back to numpy.linalg.norm which is not numerically stable.")
from numpy.linalg import norm as sp_norm
def l1norm(input_data, axis=0, keepdims=False):
return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
def l2norm(input_data, axis=0, keepdims=False):
return sp_norm(input_data, axis=axis, keepdims=keepdims)
in_data_dim = random_sample([4,5,6], 1)[0]
in_data_shape = rand_shape_nd(in_data_dim)
np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
mx_arr = mx.nd.array(np_arr, ctx=ctx)
for ord in [1,2]:
for keep_dims in [True, False]:
for i in range(4):
npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
if (i < 3):
npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
@with_seed()
def test_ndarray_cpu_shared_ctx():
ctx = mx.Context('cpu_shared', 0)
res = mx.nd.zeros((1, 2, 3), ctx=ctx)
assert(res.context == ctx)
if __name__ == '__main__':
import nose
nose.runmodule()