blob: e3a2fd8036c44456284f8b928fccffdd26a93ff8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
from __future__ import absolute_import
from distutils.version import StrictVersion
import sys
import copy
import itertools
from mxnet.gluon.parameter import Parameter
import numpy as onp
import platform
import mxnet as mx
import scipy.stats as ss
import scipy.special as scipy_special
import pytest
import mxnet.ndarray.numpy._internal as _npi
from functools import reduce
from mxnet import np, npx
from mxnet.gluon import HybridBlock
from mxnet.base import MXNetError
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
from mxnet.test_utils import check_numeric_gradient, use_np, collapse_sum_like, effective_dtype
from mxnet.test_utils import new_matrix_with_real_eigvals_nd
from mxnet.test_utils import new_sym_matrix_with_real_eigvals_nd
from common import assertRaises, retry, xfail_when_nonstandard_decimal_separator
import random
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
from mxnet.numpy_op_signature import _get_builtin_op
from mxnet.test_utils import is_op_runnable, has_tvm_ops, rand_shape_2d
from mxnet.operator import get_all_registered_operators
from common import assert_raises_cuda_not_satisfied
from numpy.testing import assert_allclose
@use_np
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
@pytest.mark.parametrize('a_shape,b_shape,axes', [
((3, 5), (5, 4), 1),
((3,), (3,), 1),
((3, 4, 5, 3, 2), (5, 3, 2, 1, 2), 3),
((3, 5, 4, 3, 2), (2, 3, 5, 1, 2), [[1, 3, 4], [2, 1, 0]]),
((3, 5, 4), (5, 4, 3), [[1, 0, 2], [0, 2, 1]]),
((3, 5, 4), (5, 3, 4), [[2, 0], [-1, -2]]),
((2, 2), (2, 2), 2),
((3, 5, 4), (5, ), [[-2], [0]]),
((3, 5, 4), (5, ), [[1], [0]]),
((2,), (2, 3), 1),
((3,), (3,), 0),
((2,), (2, 3), 0),
((3, 5, 4), (5, ), 0),
((2, 3, 4), (4, 3, 2), [[], []]),
((3, 0), (0, 5), 1),
((3, 0), (0, 4), [[1], [0]]),
((0, 3), (3, 5), 1),
((0, 3), (5, 0), [[0], [1]])
])
def test_np_tensordot(a_shape, b_shape, axes, hybridize, dtype):
class TestTensordot(HybridBlock):
def __init__(self, axes):
super(TestTensordot, self).__init__()
self._axes = axes
def forward(self, a, b):
return np.tensordot(a, b, self._axes)
def tensordot_backward(out_grad, a, b, axes=2):
if (a.ndim < 1) or (b.ndim < 1):
raise ValueError('An input is zero-dim')
if onp.isscalar(axes):
a_axes_summed = [i + a.ndim - axes for i in range(axes)]
b_axes_summed = [i for i in range(axes)]
else:
if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if onp.isscalar(a_axes_summed):
a_axes_summed = a_axes_summed,
if onp.isscalar(b_axes_summed):
b_axes_summed = b_axes_summed,
for i in range(len(a_axes_summed)):
a_axes_summed[i] = (a_axes_summed[i] + a.ndim) % a.ndim
for i in range(len(b_axes_summed)):
b_axes_summed[i] = (b_axes_summed[i] + b.ndim) % b.ndim
if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')
a_axes_remained = []
for i in range(a.ndim):
if not (i in a_axes_summed):
a_axes_remained.append(i)
a_axes = a_axes_remained[:] + a_axes_summed[:]
b_axes_remained = []
for i in range(b.ndim):
if not (i in b_axes_summed):
b_axes_remained.append(i)
b_axes = b_axes_summed[:] + b_axes_remained[:]
ad1 = onp.prod([a.shape[i] for i in a_axes_remained]) if len(a_axes_remained) > 0 else 1
ad2 = onp.prod([a.shape[i] for i in a_axes_summed]) if len(a_axes_summed) > 0 else 1
bd1 = onp.prod([b.shape[i] for i in b_axes_summed]) if len(b_axes_summed) > 0 else 1
bd2 = onp.prod([b.shape[i] for i in b_axes_remained]) if len(b_axes_remained) > 0 else 1
out_grad = out_grad.reshape((ad1, bd2))
new_a = onp.transpose(a, a_axes)
new_a_shape = new_a.shape[:]
new_a = new_a.reshape((ad1, ad2))
new_b = onp.transpose(b, b_axes)
new_b_shape = new_b.shape[:]
new_b = new_b.reshape((bd1, bd2))
reverse_a_axes = [0 for i in a_axes]
for i in range(len(a_axes)):
reverse_a_axes[a_axes[i]] = i
reverse_b_axes = [0 for i in b_axes]
for i in range(len(b_axes)):
reverse_b_axes[b_axes[i]] = i
grad_b = onp.dot(new_a.T, out_grad).reshape(new_b_shape)
grad_b = onp.transpose(grad_b, reverse_b_axes)
grad_a = onp.dot(out_grad, new_b.T).reshape(new_a_shape)
grad_a = onp.transpose(grad_a, reverse_a_axes)
return [grad_a, grad_b]
test_tensordot = TestTensordot(axes)
if hybridize:
test_tensordot.hybridize()
a = rand_ndarray(shape = a_shape, dtype = dtype).as_np_ndarray()
b = rand_ndarray(shape = b_shape, dtype = dtype).as_np_ndarray()
a.attach_grad()
b.attach_grad()
np_out = onp.tensordot(a.asnumpy(), b.asnumpy(), axes)
with mx.autograd.record():
mx_out = test_tensordot(a, b)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
mx_out.backward()
np_backward = tensordot_backward(onp.ones(np_out.shape), a.asnumpy(), b.asnumpy(), axes)
assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-3, atol=1e-5)
# Test imperative once again
mx_out = np.tensordot(a, b, axes)
np_out = onp.tensordot(a.asnumpy(), b.asnumpy(), axes)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test numeric gradient
if (onp.prod(a_shape) > 0 and onp.prod(b_shape) > 0):
a_sym = mx.sym.Variable("a").as_np_ndarray()
b_sym = mx.sym.Variable("b").as_np_ndarray()
mx_sym = mx.sym.np.tensordot(a_sym, b_sym, axes).as_nd_ndarray()
check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
rtol=1e-1, atol=1e-1, dtype = dtype)
# General Gradient Test
for a_grad_status in ['add', 'write']:
for b_grad_status in ['add', 'write']:
a = mx.np.random.normal(0, 1, a_shape)
b = mx.np.random.normal(0, 1, b_shape)
a.attach_grad(a_grad_status)
b.attach_grad(b_grad_status)
if a_grad_status == 'add':
ori_a_grad = mx.np.random.normal(0, 1, a_shape)
if a.ndim == 0:
a.grad[()] = ori_a_grad
else:
a.grad[:] = ori_a_grad
if b_grad_status == 'add':
ori_b_grad = mx.np.random.normal(0, 1, b_shape)
if b.ndim == 0:
b.grad[()] = ori_b_grad
else:
b.grad[:] = ori_b_grad
with mx.autograd.record():
mx_out = mx.np.tensordot(a, b, axes)
out_grad = mx.np.random.normal(0, 1, mx_out.shape)
loss = (mx_out * out_grad).sum()
loss.backward()
gt_in_grad = tensordot_backward(out_grad.asnumpy(), a.asnumpy(), b.asnumpy(), axes)
if(a_grad_status == 'add'):
gt_in_grad[0] += ori_a_grad
if(b_grad_status == 'add'):
gt_in_grad[1] += ori_b_grad
assert_almost_equal(a.grad.asnumpy(), gt_in_grad[0], rtol=1e-2, atol=1e-2)
assert_almost_equal(b.grad.asnumpy(), gt_in_grad[1], rtol=1e-2, atol=1e-2)
@use_np
@pytest.mark.parametrize('shape_a,shape_b', [
((3, 0), (0, 4)),
((3,), (3,)),
((3, 4), (4, 5)),
((), ()),
((3, 4, 5), ()),
((), (3, 4, 5)),
((3, 4, 5), (5, )),
((3, 4, 5), (5, 2)),
((5,), (5, 2)),
((3, 5, 4), (5, 4, 3)),
((3, 4), (5, 4, 3)),
((4,), (5, 4, 3))
])
def test_np_dot(shape_a, shape_b):
eps = 1e-3
np_a = onp.random.uniform(-1.0, 1.0, shape_a)
np_a[abs(np_a) < eps] = 2 * eps
np_b = onp.random.uniform(-1.0, 1.0, shape_b)
np_b[abs(np_b) < eps] = 2 * eps
a = mx.nd.array(np_a)
b = mx.nd.array(np_b)
np_res = onp.dot(np_a, np_b)
mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray())
assert mx_res.shape == np_res.shape
assert_almost_equal(np_res, mx_res.asnumpy(), rtol=1e-5, atol=1e-5)
mx_a = mx.sym.Variable("a")
mx_b = mx.sym.Variable("b")
mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_nd_ndarray()
if (len(shape_a) > 0 and len(shape_b) > 0 and onp.prod(shape_a) > 0 and onp.prod(shape_b) > 0):
check_numeric_gradient(mx_sym, {"a": a, "b": b}, numeric_eps=eps, rtol=1e-2, atol=1e-3)
@use_np
@pytest.mark.parametrize('shape_a,shape_b', [
((4, 5), (2, 3)),
((3, 4, 5), (6, ))
])
def test_np_dot_error(shape_a, shape_b):
a = mx.nd.array(random.random()) if len(shape_a) == 0 else rand_ndarray(shape_a)
b = mx.nd.array(random.random()) if len(shape_b) == 0 else rand_ndarray(shape_b)
with pytest.raises(mx.base.MXNetError):
mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray())
@use_np
@pytest.mark.parametrize('shape', [(), (5,), (3, 3)])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_np_vdot(shape, dtype, hybridize):
class TestVdot(HybridBlock):
def __init__(self):
super(TestVdot, self).__init__()
def forward(self, a, b):
return np.vdot(a, b)
def vdot_backward(a, b):
return [b, a]
test_vdot = TestVdot()
if hybridize:
test_vdot.hybridize()
a = rand_ndarray(shape=shape, dtype=dtype).as_np_ndarray()
b = rand_ndarray(shape=shape, dtype=dtype).as_np_ndarray()
a.attach_grad()
b.attach_grad()
np_out = onp.vdot(a.asnumpy(), b.asnumpy())
with mx.autograd.record():
mx_out = test_vdot(a, b)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
mx_out.backward()
np_backward = vdot_backward(a.asnumpy(), b.asnumpy())
assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-2, atol=1e-2)
assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-2, atol=1e-2)
# Test imperative once again
mx_out = np.vdot(a, b)
np_out = onp.vdot(a.asnumpy(), b.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test numeric gradient
if len(shape) > 0 and onp.prod(shape) > 0:
a_sym = mx.sym.Variable("a").as_np_ndarray()
b_sym = mx.sym.Variable("b").as_np_ndarray()
mx_sym = mx.sym.np.vdot(a_sym, b_sym).as_nd_ndarray()
check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
rtol=1e-1, atol=1e-1, dtype=dtype)
@use_np
@pytest.mark.parametrize('a_shape,b_shape', [
((3,), (3,)),
((2, 3), (3,)),
((3,), (2, 3))
])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_np_inner(a_shape, b_shape, dtype, hybridize):
class TestInner(HybridBlock):
def __init__(self):
super(TestInner, self).__init__()
def forward(self, a, b):
return np.inner(a, b)
def inner_backward(a, b):
a_axes_summed = [a.ndim - 1]
b_axes_summed = [b.ndim - 1]
a_axes_remained = []
for i in range(a.ndim):
if not (i in a_axes_summed):
a_axes_remained.append(i)
a_axes = a_axes_remained[:] + a_axes_summed[:]
b_axes_remained = []
for i in range(b.ndim):
if not (i in b_axes_summed):
b_axes_remained.append(i)
b_axes = b_axes_summed[:] + b_axes_remained[:]
ad1 = onp.prod([a.shape[i] for i in a_axes_remained]) if len(a_axes_remained) > 0 else 1
ad2 = onp.prod([a.shape[i] for i in a_axes_summed]) if len(a_axes_summed) > 0 else 1
bd1 = onp.prod([b.shape[i] for i in b_axes_summed]) if len(b_axes_summed) > 0 else 1
bd2 = onp.prod([b.shape[i] for i in b_axes_remained]) if len(b_axes_remained) > 0 else 1
out_grad = onp.ones((ad1, bd2))
new_a = onp.transpose(a, a_axes)
new_a_shape = new_a.shape[:]
new_a = new_a.reshape((ad1, ad2))
new_b = onp.transpose(b, b_axes)
new_b_shape = new_b.shape[:]
new_b = new_b.reshape((bd1, bd2))
reverse_a_axes = [0 for i in a_axes]
for i in range(len(a_axes)):
reverse_a_axes[a_axes[i]] = i
reverse_b_axes = [0 for i in b_axes]
for i in range(len(b_axes)):
reverse_b_axes[b_axes[i]] = i
grad_b = onp.dot(new_a.T, out_grad).reshape(new_b_shape)
grad_b = onp.transpose(grad_b, reverse_b_axes)
grad_a = onp.dot(out_grad, new_b.T).reshape(new_a_shape)
grad_a = onp.transpose(grad_a, reverse_a_axes)
return [grad_a, grad_b]
test_inner = TestInner()
if hybridize:
test_inner.hybridize()
a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
a.attach_grad()
b.attach_grad()
np_out = onp.inner(a.asnumpy(), b.asnumpy())
with mx.autograd.record():
mx_out = test_inner(a, b)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
mx_out.backward()
np_backward = inner_backward(a.asnumpy(), b.asnumpy())
assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-2, atol=1e-2)
assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-2, atol=1e-2)
# Test imperative once again
mx_out = np.inner(a, b)
np_out = onp.inner(a.asnumpy(), b.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test numeric gradient
a_sym = mx.sym.Variable("a").as_np_ndarray()
b_sym = mx.sym.Variable("b").as_np_ndarray()
mx_sym = mx.sym.np.inner(a_sym, b_sym).as_nd_ndarray()
check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
rtol=1e-1, atol=1e-1, dtype=dtype)
@use_np
@pytest.mark.parametrize('a_shape,b_shape', [
((3,), (3,)),
((2, 3), (6,)),
((6,), (2, 3))
])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_np_outer(a_shape, b_shape, dtype, hybridize):
class TestOuter(HybridBlock):
def __init__(self):
super(TestOuter, self).__init__()
def forward(self, a, b):
return np.outer(a, b)
test_outer = TestOuter()
if hybridize:
test_outer.hybridize()
a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
a.attach_grad()
b.attach_grad()
np_out = onp.outer(a.asnumpy(), b.asnumpy())
with mx.autograd.record():
mx_out = test_outer(a, b)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
# Test imperative once again
mx_out = np.outer(a, b)
np_out = onp.outer(a.asnumpy(), b.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test numeric gradient
a_sym = mx.sym.Variable("a").as_np_ndarray()
b_sym = mx.sym.Variable("b").as_np_ndarray()
mx_sym = mx.sym.np.outer(a_sym, b_sym).as_nd_ndarray()
check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
rtol=1e-1, atol=1e-1, dtype=dtype)
@use_np
@pytest.mark.parametrize('shape_a,shape_b', [
((3,), (3,)),
((3, 4), (4, 5)),
((3, 0), (0, 4)),
((4, 5), (5,)),
((3, 4, 5), (5,)),
((5,), (5, 2)),
((2,), (4, 2, 3)),
((2, 1, 3, 4, 5), (5, 2)),
((1, 3, 5, 4), (1, 4, 3)),
((3, 5, 4), (2, 1, 4, 3)),
((3, 4), (1, 5, 4, 3))
])
@pytest.mark.parametrize('grad_req_a', ['write', 'add', 'null'])
@pytest.mark.parametrize('grad_req_b', ['write', 'add', 'null'])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_np_matmul(shape_a, shape_b, grad_req_a, grad_req_b,
dtype, hybridize):
class TestMatmul(HybridBlock):
def __init__(self):
super(TestMatmul, self).__init__()
def forward(self, a, b):
return np.matmul(a, b)
def matmul_backward(a, b):
def ShapeInfer(mat_a, mat_b):
if mat_a.ndim == 1:
mat_a = mat_a.reshape((1, mat_a.size))
if mat_b.ndim == 1:
mat_b = mat_b.reshape((mat_b.size, 1))
ndim = max(mat_a.ndim, mat_b.ndim)
newshape_a = list(onp.array(mat_a, ndmin=ndim).shape)
newshape_b = list(onp.array(mat_b, ndmin=ndim).shape)
if ndim >= 3:
pre_shape = onp.fmax(newshape_a[ndim - 3::-1], newshape_b[ndim - 3::-1])
newshape_a[ndim - 3::-1] = pre_shape
newshape_b[ndim - 3::-1] = pre_shape
else:
pre_shape = onp.array([])
out_shape = onp.append(pre_shape[::-1].astype(onp.int64), [newshape_a[ndim - 2], newshape_b[ndim - 1]])
return [ndim, newshape_a, newshape_b, out_shape]
def ShapeReduce(mat, shape, is_b=False):
ndim = mat.ndim
if is_b and len(shape) == 1:
rng = onp.arange(ndim - 2)
else:
pre_len = ndim - len(shape)
in_pre = onp.array(mat.shape[pre_len : ndim - 2])
out_pre = onp.array(shape[:len(shape) - 2])
diff = onp.nonzero(in_pre != out_pre)[0] + pre_len
rng = onp.append(onp.arange(ndim - len(shape)), diff)
mat = onp.sum(mat, axis=tuple(rng))
return mat.reshape(shape)
a_shape = a.shape
b_shape = b.shape
[ndim, newshape_a, newshape_b, out_shape] = ShapeInfer(a, b)
new_a = onp.broadcast_to(a, newshape_a)
if len(b_shape) == 1:
new_b = onp.broadcast_to(b.reshape((b.size, 1)), newshape_b)
else:
new_b = onp.broadcast_to(b, newshape_b)
ad1 = new_a.shape[ndim - 2]
ad2 = new_a.shape[ndim - 1]
bd1 = new_b.shape[ndim - 2]
bd2 = new_b.shape[ndim - 1]
a_T = onp.moveaxis(new_a, [ndim - 2, ndim - 1], [ndim - 1, ndim - 2])
b_T = onp.moveaxis(new_b, [ndim - 2, ndim - 1], [ndim - 1, ndim - 2])
out_grad = onp.ones(out_shape)
grad_b = onp.matmul(a_T, out_grad)
grad_b = ShapeReduce(grad_b, b_shape, is_b=True)
grad_a = onp.matmul(out_grad, b_T)
grad_a = ShapeReduce(grad_a, a_shape)
return [grad_a, grad_b]
eps = 1E-4
test_matmul = TestMatmul()
if hybridize:
test_matmul.hybridize()
np_a = onp.random.uniform(-1.0, 1.0, shape_a).astype(dtype)
np_a[abs(np_a) < eps] = 2 * eps
np_b = onp.random.uniform(-1.0, 1.0, shape_b).astype(dtype)
np_b[abs(np_b) < eps] = 2 * eps
a = mx.np.array(np_a, dtype=dtype)
a.attach_grad(grad_req=grad_req_a)
b = mx.np.array(np_b, dtype=dtype)
b.attach_grad(grad_req=grad_req_b)
np_out = onp.matmul(np_a, np_b)
with mx.autograd.record():
mx_out = test_matmul(a, b)
assert mx_out.shape == np_out.shape
assert_almost_equal(np_out, mx_out.asnumpy(), rtol=eps, atol=eps)
if grad_req_a != 'null' or grad_req_b != 'null':
mx_out.backward()
np_backward = matmul_backward(np_a, np_b)
if grad_req_a == 'null':
assert a.grad is None
else:
assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = eps, atol=eps)
if grad_req_b == 'null':
assert b.grad is None
else:
assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = eps, atol=eps)
mx_out = np.matmul(a, b)
np_out = onp.matmul(np_a, np_b)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=eps, atol=eps)
@pytest.mark.parametrize('shape_a,shape_b', [
((1,), (2,)), # mismatched vector vector
((2, 1,), (2,)), # mismatched matrix vector
((2,), (1, 2)), # mismatched vector matrix
((1, 2), (3, 1)), # mismatched matrix matrix
((1,), ()), # vector scalar
((), (1,)), # scalar vector
((1, 1), ()), # matrix scalar
((), (1, 1)), # scalar matrix
((2, 2, 1), (3, 1, 2)), # cannot broadcast
])
def test_np_matmul_error(shape_a, shape_b):
a = np.random.uniform(size=shape_a)
b = np.random.uniform(size=shape_b)
with pytest.raises(MXNetError):
np.matmul(a, b)
@use_np
@pytest.mark.parametrize('a_shape,b_shape', [
((3,), (3,)),
((2, 3), (3,)),
((2, 3, 4), (2,)),
((3, 2), ())
])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
@pytest.mark.parametrize('hybridize', [True, False])
def test_np_kron(a_shape, b_shape, dtype, hybridize):
def np_kron_backward(ograd, a, b):
ndim = ograd.ndim
# Make ndim equal
if ndim > a.ndim:
a = a.reshape((1,)*(ndim - a.ndim) + a.shape)
else:
b = b.reshape((1,)*(ndim - b.ndim) + b.shape)
assert(a.ndim == b.ndim)
# Compute agrad
agrad = onp.zeros(a.shape)
for i in range(a.size):
ia = onp.asarray(onp.unravel_index(i, a.shape))
for j in range(b.size):
jb = onp.asarray(onp.unravel_index(j, b.shape))
k = ia * onp.asarray(b.shape) + jb
agrad[tuple(ia)] += ograd[tuple(k)] * b[tuple(jb)]
# Compute bgrad
bgrad = onp.zeros(b.shape)
for j in range(b.size):
jb = onp.asarray(onp.unravel_index(j, b.shape))
for i in range(a.size):
ia = onp.asarray(onp.unravel_index(i, a.shape))
k = ia * onp.asarray(b.shape) + jb
bgrad[tuple(jb)] += ograd[tuple(k)] * a[tuple(ia)]
return [agrad, bgrad]
class TestKron(HybridBlock):
def __init__(self):
super(TestKron, self).__init__()
def forward(self, a, b):
return np.kron(a, b)
test_kron = TestKron()
if hybridize:
test_kron.hybridize()
a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
a.attach_grad()
b.attach_grad()
np_out = onp.kron(a.asnumpy(), b.asnumpy())
with mx.autograd.record():
mx_out = test_kron(a, b)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
# Test imperative once again
mx_out = np.kron(a, b)
np_out = onp.kron(a.asnumpy(), b.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
# test numeric gradient
a_sym = mx.sym.Variable("a").as_np_ndarray()
b_sym = mx.sym.Variable("b").as_np_ndarray()
mx_sym = mx.sym.np.kron(a_sym, b_sym).as_nd_ndarray()
check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
rtol=1e-2, atol=1e-2, dtype=dtype)
# test gradient via backward implemented by numpy
np_backward = np_kron_backward(onp.ones(np_out.shape, dtype = dtype), a.asnumpy(), b.asnumpy())
assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol=1e-2, atol=1e-2)
assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol=1e-2, atol=1e-2)
@use_np
@pytest.mark.parametrize('shape', [rand_shape_nd(4, dim=4), (4, 0, 4, 0)])
@pytest.mark.parametrize('axis', [0, 1, 2, 3, (), None])
@pytest.mark.parametrize('keepdims', [True, False])
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'])
@pytest.mark.parametrize('itype,acc_type', [
('float16', 'float32'),
('float32', 'float64'),
('float64', 'float64'),
('int8', 'int32'),
('int32', 'int64'),
('int64', 'int64'),
('bool', 'int64')
])
@pytest.mark.parametrize('hybridize', [True, False])
def test_np_sum(shape, axis, keepdims, itype, acc_type, dtype, hybridize):
class TestSum(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):
super(TestSum, self).__init__()
self._axis = axis
self._dtype = dtype
self._keepdims = keepdims
def forward(self, a, *args, **kwargs):
return np.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
class TestSumConv(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):
super(TestSumConv, self).__init__()
self._axis = axis
self._dtype = dtype
self._keepdims = keepdims
def forward(self, a, *args, **kwargs):
return a.sum(axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
def is_int(dtype):
return 'int' in dtype
is_windows = sys.platform.startswith('win')
if (is_int(dtype) and not is_int(itype)) or (is_windows and is_int(itype))\
or (itype == 'bool' and\
(dtype not in ('float32', 'float64', 'int32', 'int64') or is_windows)):
return
# test gluon
test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims)
test_sum_conv = TestSumConv(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_sum.hybridize()
test_sum_conv.hybridize()
if is_int(itype):
x = onp.random.randint(-128, 128, shape, dtype=itype)
x = np.array(x)
elif itype == 'bool':
x = onp.random.randint(0, 2, shape) < 1
x = np.array(x, dtype='bool')
else:
x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype)
expected_ret = onp.sum(x.asnumpy(), axis=axis, dtype=acc_type, keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
if itype == 'bool':
if is_op_runnable() and (not is_windows): # special handling of boolean ndarray
y = test_sum(x)
y_conv = test_sum_conv(x)
assert y.dtype == expected_ret.dtype
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-4, atol=1e-5,
use_broadcast=False)
assert y_conv.dtype == expected_ret.dtype
assert_almost_equal(y_conv.asnumpy(), expected_ret, rtol=1e-4, atol=1e-5,
use_broadcast=False)
return
x.attach_grad()
with mx.autograd.record():
y = test_sum(x)
y_conv = test_sum_conv(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5, use_broadcast=False)
assert y_conv.shape == expected_ret.shape
assert_almost_equal(y_conv.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5, use_broadcast=False)
y.backward()
assert same(x.grad.asnumpy(), onp.ones(shape=x.shape, dtype=x.dtype))
# test numeric
if itype == 'float32' and dtype == 'float32' and shape != (4, 0, 4, 0):
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-2, atol=1e-3, dtype=onp.float32)
# test imperative
mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = onp.sum(x.asnumpy(), axis=axis, dtype=acc_type, keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@use_np
@pytest.mark.parametrize('bool_agg', ['all', 'any'])
@pytest.mark.parametrize('shape', [
(), (5, ), (10, ), (2, 5), (5, 5), (10, 10),
(4, 4, 4), (4, 6, 9), (6, 6, 6), (6, 0, 5),
(7, 8, 9, 10), (7, 9, 11, 13), (0, 7, 7, 5)
])
@pytest.mark.parametrize('axis', [True, False])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('keepdim', [True, False])
@pytest.mark.parametrize('dtype', [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64, np.bool])
def test_np_bool_agg(bool_agg, shape, axis, keepdim, dtype, hybridize):
class TestOp(HybridBlock):
def __init__(self, axis=None, keepdims=False) :
super(TestOp, self).__init__()
self._axis = axis
self._keepdims = keepdims
def forward(self, a):
return getattr(np, bool_agg)(a, axis=self._axis, keepdims=self._keepdims)
ndim = len(shape)
samples = random.randint(0, ndim)
axis = None if not axis else tuple(random.sample([i for i in range(0, ndim)], samples))
x = np.random.normal(0, 5.0, size=shape).astype(dtype)
test_op = TestOp(axis=axis, keepdims=keepdim)
if hybridize:
test_op.hybridize()
y = test_op(x)
expected_ret = getattr(onp, bool_agg)(x.asnumpy(), axis=axis, keepdims=keepdim)
assert_almost_equal(y.asnumpy(), expected_ret)
# test imperative
mx_outs = getattr(np, bool_agg)(x, axis=axis, keepdims=keepdim)
np_outs = getattr(onp, bool_agg)(x.asnumpy(), axis=axis, keepdims=keepdim)
assert_almost_equal(mx_outs.asnumpy(), np_outs)
@use_np
@pytest.mark.parametrize('func', ['max', 'min'])
@pytest.mark.parametrize('in_data_dim', [2, 3, 4])
@pytest.mark.parametrize('itype', ['float16', 'float32', 'float64', 'int'])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('keepdims', [True, False])
def test_np_max_min(func, in_data_dim, itype, keepdims, hybridize):
class TestOp(HybridBlock):
def __init__(self, axis=None, keepdims=False):
super(TestOp, self).__init__()
self._axis = axis
self._keepdims = keepdims
def forward(self, a, *args, **kwargs):
return getattr(a, func)(axis=self._axis, keepdims=self._keepdims)
def is_int(dtype):
return 'int' == dtype
def get_grad(axis, func_name):
index = -1 if func_name == 'max' else 0
if axis == ():
return onp.ones((2,3,4,5))
else:
temp = onp.zeros((2,3,4,5))
if axis == 0:
temp[index,:,:,:] = 1
return temp
elif axis == 1:
temp[:,index,:,:] = 1
return temp
elif axis == 2:
temp[:,:,index,:] = 1
return temp
elif (axis == 3 or axis == -1):
temp[:,:,:,index] = 1
return temp
elif not axis:
temp[index,index,index,index] = 1
return temp
raise ValueError('axis should be int or None or ()')
shape = rand_shape_nd(in_data_dim, dim=3)
for axis in ([i for i in range(in_data_dim)] + [(), None] + [-1]):
test_gluon = TestOp(axis=axis, keepdims=keepdims)
if hybridize:
test_gluon.hybridize()
if is_int(itype):
x = np.arange(120).reshape((2, 3, 4, 5))
else:
x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype)
x.attach_grad()
ref_op = getattr(onp, 'a'+func)
expected_ret = ref_op(x.asnumpy(), axis=axis, keepdims=keepdims)
with mx.autograd.record():
y = test_gluon(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3,
atol=1e-5 if itype == 'float16' else 1e-5)
y.backward()
# only check the gradient with hardcoded input
if is_int(itype):
assert same(x.grad.asnumpy(), get_grad(axis, func)), \
'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), x.grad.asnumpy(), get_grad(axis))
# test imperative
mx_out = getattr(np, func)(x, axis=axis, keepdims=keepdims)
np_out = ref_op(x.asnumpy(), axis=axis, keepdims=keepdims)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
@use_np
@pytest.mark.parametrize('func', ['max', 'min'])
@pytest.mark.parametrize('shape,exception', [
((), False),
((0), True),
((2, 0), True),
((0, 2, 1), True)
])
def test_np_max_min_error(func, shape, exception):
# test zero and zero dim
def _test_np_exception(func, shape, dim):
x = np.random.uniform(-1.0, 1.0, shape)
out = getattr(x, func)()
assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim)
dim = 0
if exception:
assertRaises(MXNetError, _test_np_exception, func, shape, dim)
else:
_test_np_exception(func, shape, dim)
@use_np
@pytest.mark.parametrize('a_shape,w_shape,axes', [
((3, 5), (3, 5), None),
((4, 5, 6), (4, 5, 6), (0, 2)),
((3,), (3,), 0),
((2, 3), (3,), 1),
((2, 3, 4), (2,), 0),
((2, 3, 4), (3,), 1),
((2, 3, 4), (4,), -1),
((2, 3, 4, 5), (5,), 3)
])
@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('is_weighted', [True, False])
@pytest.mark.parametrize('returned', [True, False])
@pytest.mark.parametrize('req_a', ['null', 'add', 'write'])
@pytest.mark.flaky
def test_np_average(a_shape, w_shape, axes, is_weighted, req_a,
hybridize, returned, dtype):
class TestAverage(HybridBlock):
def __init__(self, axis=None, returned=False):
super(TestAverage, self).__init__()
# necessary initializations
self._axis = axis
self._returned = returned
def forward(self, a, weights):
return np.average(a, weights=weights, axis=self._axis, returned=self._returned)
def avg_backward(a, w, avg, axes, init_a_grad=None, init_w_grad=None):
# avg = sum(a * w) / sum(w)
if axes is not None and not isinstance(axes, tuple) and axes < 0:
axes += a.ndim
if w is None:
a_grad = onp.ones(shape=a.shape, dtype=a.dtype)/(a.size/avg.size)
if init_a_grad is not None:
a_grad += init_a_grad.asnumpy()
return [a_grad, None]
onedim = a.ndim != w.ndim
if onedim:
new_shape = [a.shape[i] if i == axes else 1 for i in range(a.ndim)]
w = w.reshape(new_shape)
w = onp.broadcast_to(w, a.shape)
# partial a = w / sum(w)
# partial w = (a*sum(w) - sum(a*w)) / (sum(w) * sum(w))
scl = onp.sum(w, axis=axes, keepdims=True)
a_grad = onp.divide(w, scl)
w_grad = onp.divide(a*scl-onp.sum(a*w, axis=axes, keepdims=True), scl*scl)
if onedim:
axis = list(range(a.ndim))
axis.remove(axes)
w_grad = onp.sum(w_grad, axis=tuple(axis))
if init_a_grad is not None:
a_grad += init_a_grad.asnumpy()
if init_w_grad is not None:
w_grad += init_w_grad.asnumpy()
return [a_grad, w_grad]
if req_a == 'null' and not is_weighted:
return
rtol, atol = 1e-3, 1e-4
test_average = TestAverage(axes, returned)
if hybridize:
test_average.hybridize()
a = np.random.uniform(-1.0, 1.0, size=a_shape, dtype=dtype)
a.attach_grad(req_a)
init_a_grad = np.random.uniform(-1.0, 1.0, size=a_shape, dtype=dtype) if req_a == 'add' else None
init_w_grad = None
req_w = req_a
w, np_w = None, None
if is_weighted:
w = np.random.uniform(-1.0, 1.0, size=w_shape, dtype=dtype)
if req_a == 'null':
req_w = random.choice(['add', 'write'])
w.attach_grad(req_w)
if req_w == 'add':
init_w_grad = np.random.uniform(-1.0, 1.0, size=w_shape, dtype=dtype)
np_w = w.asnumpy()
np_out = onp.average(a.asnumpy(), axis=axes, weights=np_w, returned=returned)
with mx.autograd.record():
mx_out = test_average(a, w)
if returned:
np_out, np_sum_of_weights = np_out
mx_out, mx_sum_of_weights = mx_out
assert_almost_equal(mx_sum_of_weights.asnumpy(), np_sum_of_weights, rtol=rtol, atol=atol)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
if req_a == 'add':
a.grad[:] = init_a_grad
if is_weighted and req_w == 'add':
w.grad[:] = init_w_grad
mx_out.backward()
# Code to get reference backward value
a_grad, w_grad = avg_backward(a.asnumpy(), np_w, np_out, axes, init_a_grad, init_w_grad)
if is_weighted:
assert_almost_equal(w.grad.asnumpy(), w_grad, rtol=rtol*10, atol=atol*10)
if req_a == 'null':
assert a.grad is None
else:
assert_almost_equal(a.grad.asnumpy(), a_grad, rtol=rtol, atol=atol)
# Test imperative once again
np_out = onp.average(a.asnumpy(), weights=np_w, axis=axes, returned=returned)
mx_out = np.average(a, weights=w, axis=axes, returned=returned)
if returned:
np_out, np_sum_of_weights = np_out
mx_out, mx_sum_of_weights = mx_out
assert_almost_equal(mx_sum_of_weights.asnumpy(), np_sum_of_weights, rtol=rtol, atol=atol)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
@use_np
def test_np_mean():
class TestMean(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):
super(TestMean, self).__init__()
self._axis = axis
self._dtype = dtype
self._keepdims = keepdims
def forward(self, a, *args, **kwargs):
return a.mean(axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
def is_int(dtype):
return 'int' in dtype
is_windows = sys.platform.startswith('win')
in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
ft_types = ['float16', 'float32', 'float64']
it_types = ['bool', 'int8', 'int32', 'int64']
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types):
if dtype == 'bool':
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype)
x = x.as_np_ndarray()
x.attach_grad()
expected_ret = onp.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
with mx.autograd.record():
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)
y.backward()
N = x.size / y.size
assert same(x.grad.asnumpy(), onp.ones(shape=x.shape, dtype=x.dtype) / N)
# test numeric
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=onp.float32)
# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = onp.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types):
if dtype == 'bool':
continue
# test gluon
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
if hybridize:
test_mean.hybridize()
if itype == 'bool':
x = np.array(onp.random.uniform(size=shape) > 0.5)
else:
x = np.random.uniform(-128, 127, size=shape).astype(itype)
expected_ret = onp.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims)
if itype == 'bool':
if is_op_runnable() and (not is_windows) and dtype not in ['float16', 'int8']: # special handling of boolean ndarray
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)
continue
y = test_mean(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
atol=1e-5 if dtype == 'float16' else 1e-5)
# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
np_out = onp.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
@use_np
def test_np_moment():
class TestMoment(HybridBlock):
def __init__(self, name, axis=None, dtype=None, keepdims=False, ddof=0):
super(TestMoment, self).__init__()
self._moment_name = name
self._axis = axis
self._dtype = dtype
self._keepdims = keepdims
self._ddof = ddof
def forward(self, a, *args, **kwargs):
return getattr(a, self._moment_name)(axis=self._axis, dtype=self._dtype,
keepdims=self._keepdims, ddof=self._ddof)
def is_int(dtype):
return 'int' in dtype
def legalize_shape(shape):
shape_ = list(shape)
for i in range(len(shape_)):
shape_[i] += 1
return tuple(shape_)
in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
shape = legalize_shape(shape)
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
'int8': 'float64', 'int32': 'float64', 'int64': 'float64'}
for name in ['var', 'std']:
for hybridize in [False, True]:
for ddof in [0, 1]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']:
for dtype in ['float16', 'float32', 'float64']:
if is_int(dtype) and not is_int(itype) or is_int(itype) and is_int(dtype):
continue
atol = 3e-4 if itype == 'float16' or dtype == 'float16' else 1e-5
rtol = 1e-2 if itype == 'float16' or dtype == 'float16' else 1e-3
# test gluon
test_moment = TestMoment(name, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)
if hybridize:
test_moment.hybridize()
if is_int(itype):
x = onp.random.randint(-16, 16, shape, dtype=itype)
x = mx.nd.array(x)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x.attach_grad()
expected_ret = getattr(onp, name)(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims, ddof=ddof)
expected_ret = expected_ret.astype(dtype)
y = test_moment(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=rtol, atol=atol, use_broadcast=False, equal_nan=True)
# test imperative
mx_out = getattr(np, name)(x, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)
np_out = getattr(onp, name)(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims, ddof=ddof).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False, equal_nan=True)
@use_np
def test_np_shape():
shapes = [
(),
(0, 1),
(2, 3),
(2, 3, 4),
]
for shape in shapes:
mx_a = np.random.uniform(size=shape)
np_a = onp.random.uniform(size=shape)
mx_shape = np.shape(mx_a)
np_shape = onp.shape(np_a)
assert mx_shape == np_shape
@use_np
@pytest.mark.parametrize('config', [
(0.0, 1.0, 10),
(-2, 4, 30),
(5.234324, 8.98324, 324),
(2, 10, 100)
])
@pytest.mark.parametrize('dtype', ['int32', 'float16', 'float32', 'float64', None])
@pytest.mark.parametrize('endpoint', [True, False])
@pytest.mark.parametrize('retstep', [True, False])
def test_np_linspace(config, dtype, endpoint, retstep):
if isinstance(config, tuple):
mx_ret = np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = onp.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
else:
mx_ret = np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = onp.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
if retstep:
assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5)
assert same(mx_ret[1], np_ret[1])
else:
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5)
@use_np
@pytest.mark.parametrize('config', [
(0.0, 1.0, 10),
(-2, 4, 30),
(5.234324, 8.98324, 324),
(2, 10, 100)
])
@pytest.mark.parametrize('dtype', ['int32', 'float16', 'float32', 'float64', None])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('endpoint', [True, False])
def test_np_linspace_gluon(config, dtype, endpoint, hybridize):
class TestLinspace(HybridBlock):
def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0):
super(TestLinspace, self).__init__()
self._start = start
self._stop = stop
self._num = num
self._endpoint = endpoint
self._retstep = retstep
self._dtype = dtype
def forward(self, x):
if self._retstep:
raise ValueError("linspace didn't support retstep = True inside HybridBlock")
else:
return x + np.linspace(self._start, self._stop, num=self._num, \
endpoint=self._endpoint, retstep=self._retstep, dtype=self._dtype)
x = np.zeros(shape=(), dtype=dtype)
if isinstance(config, tuple):
net = TestLinspace(*config, endpoint=endpoint, dtype=dtype)
np_out = onp.linspace(*config, endpoint=endpoint, dtype=dtype)
else:
net = TestLinspace(config, endpoint=endpoint, dtype=dtype)
np_out = onp.linspace(config, endpoint=endpoint, dtype=dtype)
if hybridize:
net.hybridize()
mx_out = net(x)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5)
@use_np
@pytest.mark.parametrize('config', [
(0, 10, -1),
(0, 1, 2.5)
])
def test_np_linspace_error(config):
with pytest.raises(MXNetError):
np.linspace(*config)
@use_np
def test_np_linspace_arange():
# check linspace equivalent to arange
for test_index in range(1000):
assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), onp.arange(test_index + 1))
@use_np
@pytest.mark.parametrize('config', [
(0.0, 1.0, 20),
(2, 8, 0),
(22, 11, 1),
(2.22, 9.99, 11),
(4.99999, 12.11111111, 111)
])
@pytest.mark.parametrize('dtype', ['float32', 'float64', None])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('endpoint', [True, False])
@pytest.mark.parametrize('base', [0, 1, 5, 8, 10, 33])
def test_np_logspace(config, dtype, endpoint, hybridize, base):
class TestLogspace(HybridBlock):
def __init__(self, start, stop, num=50, endpoint=None, base=50.0, dtype=None, axis=0):
super(TestLogspace, self).__init__()
self._start = start
self._stop = stop
self._num = num
self._endpoint = endpoint
self._base = base
self._dtype = dtype
self.axis = axis
def forward(self, x):
return x + np.logspace(self._start, self._stop, self._num, self._endpoint, self._base, self._dtype, self.axis)
x = np.zeros(shape=(), dtype=dtype)
net = TestLogspace(*config, endpoint=endpoint, base=base, dtype=dtype)
np_out = onp.logspace(*config, endpoint=endpoint, base=base, dtype=dtype)
if hybridize:
net.hybridize()
mx_out = net(x)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5)
if dtype is not None:
assert mx_out.dtype == np_out.dtype
# Test imperative once again
mx_ret = np.logspace(*config, endpoint=endpoint, base=base, dtype=dtype)
np_ret = onp.logspace(*config, endpoint=endpoint, base=base, dtype=dtype)
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5)
if dtype is not None:
assert mx_out.dtype == np_out.dtype
@use_np
@pytest.mark.parametrize('start,end,step', [
([], [], None),
([], [], []),
([1], [4], None),
([1], [10], [3]),
([10], [0], [-2]),
([None], [None], [None]),
([None], [None], [-1]),
([10], [None], [-1]),
([1, 0, 3], [-2, 10, -4], [None, 2, 3]),
([-2, -3, -5, -6], [1, 3, 4, 5], None),
([-2, -3, -5, -6], [1, 3, 4, 5], [-1, -2, -3, -4]),
([2, -3, -5, -6], [2, 3, 4, 5], None),
([2, -3, -5, 5], [3, 3, 4, 5], None),
])
@pytest.mark.parametrize('hybridize', [True, False])
def test_npx_slice(start, end, step, hybridize):
class TestSlice(HybridBlock):
def __init__(self, begin, end, step):
super(TestSlice, self).__init__()
self._begin = begin
self._end = end
self._step = step
def forward(self, a):
return npx.slice(a, begin=self._begin, end=self._end, step=self._step)
shape = (8, 16, 9, 9)
np_array = onp.arange(onp.prod(shape), dtype='int32').reshape(shape)
test_slice = TestSlice(begin=start, end=end, step=step)
if hybridize:
test_slice.hybridize()
a = np.array(np_array, dtype=np_array.dtype)
a.attach_grad()
basic_index = tuple([
slice(start[i], end[i], step[i]) if step is not None else slice(start[i], end[i])
for i in range(len(start))
])
expected_ret = np_array[basic_index]
with mx.autograd.record():
y = test_slice(a)
assert same(y.asnumpy(), expected_ret)
# test backward
mx.autograd.backward(y)
expected_grad = onp.zeros(shape)
expected_grad[basic_index] = 1
assert same(a.grad.asnumpy(), expected_grad)
@use_np
def test_npx_index_add():
class TestIndexAdd(HybridBlock):
def __init__(self):
super(TestIndexAdd, self).__init__()
def forward(self, a, ind, val):
return npx.index_add(a, ind, val)
def index_add_forward(a, ind, val, ind_ndim, ind_num):
if val.dtype != a.dtype:
val = val.astype(a.dtype)
ind_arr = ind.transpose()
if ind_arr.ndim == 0:
ind_arr = onp.array([ind_arr])
for i in range(ind_arr.shape[0]):
t_ind = ind_arr[i]
t_ind = tuple(t_ind.tolist()) if type(t_ind) is onp.ndarray else t_ind.tolist()
if val.ndim + ind_ndim > a.ndim:
t_val = val[tuple([0 if val.shape[0]==1 else i])]
if type(t_val) is onp.ndarray and t_val.shape[0] == 1:
a[t_ind] += onp.squeeze(t_val, axis=0)
else:
a[t_ind] += t_val
else:
a[t_ind] += val
return a
def index_add_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, grad_req_a, grad_req_val):
if grad_req_a == 'add':
init_a_grad = onp.array(a_grad)
if grad_req_val == 'add':
init_val_grad = onp.array(val_grad)
a_grad = onp.zeros(a_grad.shape) + out_grad
a_grad = a_grad.astype(a_grad.dtype)
val_grad = onp.zeros(val_grad.shape).astype(val_grad.dtype)
ind_arr = ind.transpose()
if ind_arr.ndim == 0:
ind_arr = onp.array([ind_arr])
for i in range(ind_arr.shape[0]):
t_ind = ind_arr[i]
t_ind = tuple(ind_arr[i].tolist()) if type(ind_arr[i]) is onp.ndarray else ind_arr[i].tolist()
if val_grad.ndim + ind_ndim > a_grad.ndim:
idx = 0 if val_grad.shape[0]==1 else i
t_grad = out_grad[t_ind]
t_grad_shape = onp.array(t_grad.shape)
val_grad_shape = onp.array(val_grad[idx].shape)
if type(val_grad[idx]) is not onp.ndarray:
t_grad = onp.sum(t_grad)
else:
is_not_equal = t_grad_shape - val_grad_shape
if onp.any(is_not_equal):
broadcast_dim = onp.nonzero(onp.where(is_not_equal, 1, 0))
t_grad = onp.sum(t_grad, axis=tuple(broadcast_dim[0].reshape(1, -1)[0]), keepdims=True)
val_grad[idx] += t_grad
else:
t_grad = out_grad[t_ind]
if type(val_grad) is not onp.ndarray or val_grad.shape == ():
t_grad = onp.sum(t_grad)
else:
if type(t_grad) is onp.ndarray:
ext_dim = t_grad.ndim() - val_grad.ndim()
if ext_dim:
t_grad = onp.sum(t_grad, axis=tuple(onp.arange(ext_dim)))
t_grad_shape = onp.array(t_grad.shape)
val_grad_shape = onp.array(val_grad.shape)
is_not_equal = t_grad_shape - val_grad_shape
if onp.any(is_not_equal):
broadcast_dim = onp.nonzero(onp.where(is_not_equal, 1, 0))
t_grad = onp.sum(t_grad, axis=tuple(broadcast_dim.reshape(1, -1)[0]), keepdims=True)
val_grad += t_grad
if grad_req_a == 'add':
a_grad += init_a_grad
if grad_req_val == 'add':
val_grad += init_val_grad
return a_grad, val_grad
# a.shape, ind.shape, val.shape, ind_ndim, ind_num
configs = [((2, ), np.array(1, dtype=onp.int32), (1, ), 1, 1)]
shape = tuple(onp.random.randint(1, 6, size=(4))) # a.shape
for ind_ndim in range(1, 5): # ind.shape: (ind_ndim, ind_num)
ind_num = onp.random.randint(1, 7)
ind = []
for ind_dim in range(ind_ndim):
ind.append(onp.random.randint(0, shape[ind_dim], size=(ind_num)))
ind = onp.array(ind).astype(onp.int32)
# case: val is scalar
configs.append(tuple([shape, ind, (), ind_ndim, ind_num]))
for _ in range(1, 5 - ind_ndim):
val_shape = [1 if onp.random.randint(0, 5)==0 else ind_num]
for val_dim in range(ind_ndim, 4):
val_shape.append(1 if onp.random.randint(0, 5)==0 else shape[val_dim])
# case: val is tensor
configs.append(tuple([shape, ind, tuple(val_shape), ind_ndim, ind_num]))
dtypes = ['float32', 'float64', 'int32', 'int64']
grad_req = ['write', 'null', 'add']
for hybridize, grad_req_a, grad_req_val, dtype, indtype in \
itertools.product([True, False], grad_req, grad_req, dtypes, ['int32', 'int64']):
for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
eps = 1e-3
atype = dtype
valtype = dtype
test_index_add = TestIndexAdd()
if hybridize:
test_index_add.hybridize()
a = mx.nd.random.uniform(-10.0, 10.0, shape=a_shape).as_np_ndarray().astype(atype)
a.attach_grad(grad_req=grad_req_a)
val = mx.nd.random.uniform(-10.0, 10.0, shape=val_shape).as_np_ndarray().astype(valtype)
val.attach_grad(grad_req=grad_req_val)
expected_ret = index_add_forward(a.asnumpy(), ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num)
with mx.autograd.record():
mx_ret = test_index_add(a, np.array(ind).astype(indtype), val)
assert mx_ret.shape == a.shape
assert expected_ret.shape == a.shape
assert mx_ret.dtype == a.dtype
assert expected_ret.dtype == a.dtype
assert_almost_equal(mx_ret.asnumpy(), expected_ret, rtol=eps, atol=eps)
if atype not in ['float16', 'float32', 'float64'] or valtype not in ['float16', 'float32', 'float64']:
continue
if grad_req_a != 'null' or grad_req_val != 'null':
init_a_grad = mx.nd.random.uniform(-10.0, 10.0, shape=a_shape).as_np_ndarray().astype(atype)
init_val_grad = mx.nd.random.uniform(-10.0, 10.0, shape=val_shape).as_np_ndarray().astype(valtype)
out_grad = mx.nd.random.uniform(-10.0, 10.0, shape=a_shape).as_np_ndarray().astype(atype)
if grad_req_a == 'add':
if init_a_grad.ndim == 0:
a.grad[()] = init_a_grad.item()
else:
a.grad[:] = init_a_grad
if grad_req_val == 'add':
if init_val_grad.ndim == 0:
val.grad[()] = init_val_grad.item()
else:
val.grad[:] = init_val_grad
mx_ret.backward(out_grad)
expected_bwd_a, expected_bwd_val = index_add_bwd(out_grad.asnumpy(), init_a_grad.asnumpy(), ind,
init_val_grad.asnumpy(), ind_ndim, ind_num,
grad_req_a, grad_req_val)
if grad_req_a == 'null':
assert a.grad is None
else:
assert_almost_equal(a.grad.asnumpy(), expected_bwd_a, rtol = eps, atol=eps)
if grad_req_val == 'null':
assert val.grad is None
else:
assert_almost_equal(val.grad.asnumpy(), expected_bwd_val, rtol = eps, atol=eps)
mx_out = npx.index_add(a, np.array(ind).astype(indtype), val)
assert_almost_equal(mx_out.asnumpy(), expected_ret, rtol=eps, atol=eps)
@use_np
def test_npx_index_update():
class TestIndexUpdate(HybridBlock):
def __init__(self):
super(TestIndexUpdate, self).__init__()
def forward(self, a, ind, val):
return npx.index_update(a, ind, val)
def check_index_update_forward(mx_ret, a, ind, val, ind_ndim, ind_num, eps):
if val.dtype != a.dtype:
val = val.astype(a.dtype)
ind_arr = ind.transpose()
if ind_arr.ndim == 0:
ind_arr = onp.array([ind_arr])
for i in range(ind_arr.shape[0]):
t_ind = ind_arr[i]
t_ind = tuple(t_ind.tolist()) if type(t_ind) is onp.ndarray else t_ind.tolist()
if val.ndim + ind_ndim > a.ndim:
t_val = val[tuple([0 if val.shape[0]==1 else i])]
if type(t_val) is onp.ndarray and t_val.shape[0] == 1:
expect_tmp = onp.squeeze(t_val, axis=0)
else:
expect_tmp = t_val
else:
expect_tmp = val
mx_tmp = mx_ret[t_ind]
close_pos = onp.where(onp.isclose(expect_tmp, mx_tmp, rtol=eps, atol=eps))
if a[t_ind].ndim == 0:
if close_pos[0].size == 1:
mx_ret[t_ind] = 0
a[t_ind] = 0
else:
mx_ret[t_ind][close_pos] = 0
a[t_ind][close_pos] = 0
assert_almost_equal(mx_ret, a, rtol=eps, atol=eps)
def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, grad_req_a, grad_req_val):
if grad_req_a == 'add':
init_a_grad = onp.array(a_grad)
if grad_req_val == 'add':
init_val_grad = onp.array(val_grad)
a_grad = onp.zeros(a_grad.shape) + out_grad
a_grad = a_grad.astype(a_grad.dtype)
val_grad = onp.zeros(val_grad.shape).astype(val_grad.dtype)
ind_arr = ind.transpose()
if ind_arr.ndim == 0:
ind_arr = onp.array([ind_arr])
for i in range(ind_arr.shape[0]):
t_ind = ind_arr[i]
t_ind = tuple(ind_arr[i].tolist()) if type(ind_arr[i]) is onp.ndarray else ind_arr[i].tolist()
a_grad[t_ind] = 0
if val_grad.ndim + ind_ndim > a_grad.ndim:
idx = 0 if val_grad.shape[0]==1 else i
t_grad = out_grad[t_ind]
t_grad_shape = onp.array(t_grad.shape)
val_grad_shape = onp.array(val_grad[idx].shape)
if type(val_grad[idx]) is not onp.ndarray:
t_grad = onp.sum(t_grad)
else:
is_not_equal = t_grad_shape - val_grad_shape
if onp.any(is_not_equal):
broadcast_dim = onp.nonzero(onp.where(is_not_equal, 1, 0))
t_grad = onp.sum(t_grad, axis=tuple(broadcast_dim[0].reshape(1, -1)[0]), keepdims=True)
val_grad[idx] += t_grad
else:
t_grad = out_grad[t_ind]
if type(val_grad) is not onp.ndarray or val_grad.shape == ():
t_grad = onp.sum(t_grad)
else:
if type(t_grad) is onp.ndarray:
ext_dim = t_grad.ndim() - val_grad.ndim()
if ext_dim:
t_grad = onp.sum(t_grad, axis=tuple(onp.arange(ext_dim)))
t_grad_shape = onp.array(t_grad.shape)
val_grad_shape = onp.array(val_grad.shape)
is_not_equal = t_grad_shape - val_grad_shape
if onp.any(is_not_equal):
broadcast_dim = onp.nonzero(onp.where(is_not_equal, 1, 0))
t_grad = onp.sum(t_grad, axis=tuple(broadcast_dim.reshape(1, -1)[0]), keepdims=True)
val_grad += t_grad
if grad_req_a == 'add':
a_grad += init_a_grad
if grad_req_val == 'add':
val_grad += init_val_grad
return a_grad, val_grad
# a.shape, ind.shape, val.shape, ind_ndim, ind_num
configs = [((2, ), np.array(1, dtype=onp.int32), (1, ), 1, 1)]
shape = tuple(onp.random.randint(1, 6, size=(4))) # a.shape
for ind_ndim in range(1, 5): # ind.shape: (ind_ndim, ind_num)
ind_num = onp.random.randint(1, 7)
ind = []
for ind_dim in range(ind_ndim):
ind.append(onp.random.randint(0, shape[ind_dim], size=(ind_num)))
ind = onp.array(ind).astype(onp.int32)
# case: val is scalar
configs.append(tuple([shape, ind, (), ind_ndim, ind_num]))
for _ in range(1, 5 - ind_ndim):
val_shape = [1 if onp.random.randint(0, 5)==0 else ind_num]
for val_dim in range(ind_ndim, 4):
val_shape.append(1 if onp.random.randint(0, 5)==0 else shape[val_dim])
# case: val is tensor
configs.append(tuple([shape, ind, tuple(val_shape), ind_ndim, ind_num]))
dtypes = ['float32', 'float64', 'int32', 'int64']
grad_req = ['write', 'null', 'add']
for hybridize, grad_req_a, grad_req_val, dtype, indtype in \
itertools.product([True, False], grad_req, grad_req, dtypes, ['int32', 'int64']):
for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
eps = 1e-3
atype = dtype
valtype = dtype
test_index_update = TestIndexUpdate()
if hybridize:
test_index_update.hybridize()
a = mx.nd.random.uniform(-10.0, 10.0, shape=a_shape).as_np_ndarray().astype(atype)
a.attach_grad(grad_req=grad_req_a)
val = mx.nd.random.uniform(-10.0, 10.0, shape=val_shape).as_np_ndarray().astype(valtype)
val.attach_grad(grad_req=grad_req_val)
with mx.autograd.record():
mx_ret = test_index_update(a, np.array(ind).astype(indtype), val)
assert mx_ret.shape == a.shape
assert mx_ret.dtype == a.dtype
check_index_update_forward(mx_ret.asnumpy(), a.asnumpy(), ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num, eps)
if atype not in ['float16', 'float32', 'float64'] or valtype not in ['float16', 'float32', 'float64']:
continue
if grad_req_a != 'null' or grad_req_val != 'null':
init_a_grad = mx.nd.random.uniform(-10.0, 10.0, shape=a_shape).as_np_ndarray().astype(atype)
init_val_grad = mx.nd.random.uniform(-10.0, 10.0, shape=val_shape).as_np_ndarray().astype(valtype)
out_grad = mx.nd.random.uniform(-10.0, 10.0, shape=a_shape).as_np_ndarray().astype(atype)
if grad_req_a == 'add':
if init_a_grad.ndim == 0:
a.grad[()] = init_a_grad.item()
else:
a.grad[:] = init_a_grad
if grad_req_val == 'add':
if init_val_grad.ndim == 0:
val.grad[()] = init_val_grad.item()
else:
val.grad[:] = init_val_grad
mx_ret.backward(out_grad)
expected_bwd_a, expected_bwd_val = index_update_bwd(out_grad.asnumpy(), init_a_grad.asnumpy(), ind,
init_val_grad.asnumpy(), ind_ndim, ind_num,
grad_req_a, grad_req_val)
if grad_req_a == 'null':
assert a.grad is None
else:
assert_almost_equal(a.grad.asnumpy(), expected_bwd_a, rtol = eps, atol=eps)
if grad_req_val == 'null':
assert val.grad is None
else:
assert_almost_equal(val.grad.asnumpy(), expected_bwd_val, rtol = eps, atol=eps)
mx_out = npx.index_update(a, np.array(ind).astype(indtype), val)
check_index_update_forward(mx_out.asnumpy(), a.asnumpy(), ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num, eps)
@use_np
def test_npx_batch_dot():
device = mx.device.current_device()
dtypes = ['float32', 'float64']
if device.device_type == 'gpu':
dtypes += ['float16']
eps_dict = {'float32': 1E-4, 'float64': 1E-4, 'float16': 1E-3}
class TestBatchDot(HybridBlock):
def __init__(self, transpose_a, transpose_b):
super(TestBatchDot, self).__init__()
self._transpose_a = transpose_a
self._transpose_b = transpose_b
def forward(self, lhs, rhs):
return npx.batch_dot(lhs, rhs,
transpose_a=self._transpose_a,
transpose_b=self._transpose_b)
def batch_dot_numpy(lhs, rhs, transpose_a, transpose_b):
assert lhs.ndim == rhs.ndim >= 3
if transpose_a:
lhs = lhs.swapaxes(-1, -2)
if transpose_b:
rhs = rhs.swapaxes(-1, -2)
return onp.matmul(lhs, rhs)
def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, rhs_req,
init_lhs_grad, init_rhs_grad):
if transpose_a and transpose_b:
# Gradient of z = dot(x.T, y.T)
# dx = dot(dz, y).T = dot(y.T, dz.T)
# dy = dot(x, dz).T = dot(dz.T, x.T)
lhs_grad = batch_dot_numpy(rhs, ograd, transpose_a=True, transpose_b=True)
rhs_grad = batch_dot_numpy(ograd, lhs, transpose_a=True, transpose_b=True)
elif not transpose_a and transpose_b:
# Gradient of z = dot(x, y.T)
# dx = dot(dz, y)
# dy = dot(x.T, dz).T = dot(dz.T, x)
lhs_grad = batch_dot_numpy(ograd, rhs, transpose_a=False, transpose_b=False)
rhs_grad = batch_dot_numpy(ograd, lhs, transpose_a=True, transpose_b=False)
elif transpose_a and not transpose_b:
# Gradient of z = dot(x.T, y)
# dx = dot(dz, y.T).T = dot(y, dz.T)
# dy = dot(x, dz)
lhs_grad = batch_dot_numpy(rhs, ograd, transpose_a=False, transpose_b=True)
rhs_grad = batch_dot_numpy(lhs, ograd, transpose_a=False, transpose_b=False)
else:
# Gradient of z = dot(x, y)
# dx = dot(dz, y.T)
# dy = dot(x.T, dz)
lhs_grad = batch_dot_numpy(ograd, rhs, transpose_a=False, transpose_b=True)
rhs_grad = batch_dot_numpy(lhs, ograd, transpose_a=True, transpose_b=False)
if lhs_req == 'add':
lhs_grad += init_lhs_grad
if rhs_req == 'add':
rhs_grad += init_rhs_grad
return lhs_grad, rhs_grad
configs = [
((2, 3, 0), (2, 4, 0), False, True),
((2, 4, 3), (2, 4, 3), True, False),
((0, 3, 0), (0, 0, 2), False, False),
((3, 2, 3, 2), (3, 2, 2, 3), True, True),
((3, 1, 5, 2), (3, 1, 2, 1), False, False)
]
bad_configs = [
((5, 3, 2), (5, 1, 3), False, False),
((2, 5, 3, 1), (2, 4, 3, 1), True, False)
]
for hybridize in [True, False]:
for lhs_shape, rhs_shape, transpose_a, transpose_b in configs:
for dtype in dtypes:
eps = eps_dict[dtype]
for lhs_grad_req in ['write', 'add']:
for rhs_grad_req in ['write', 'add']:
f_batch_dot = TestBatchDot(transpose_a=transpose_a,
transpose_b=transpose_b)
if hybridize:
f_batch_dot.hybridize()
lhs_val = mx.np.array(onp.random.uniform(-1.0, 1.0, lhs_shape), dtype=dtype)
rhs_val = mx.np.array(onp.random.uniform(-1.0, 1.0, rhs_shape), dtype=dtype)
lhs_val.attach_grad(grad_req=lhs_grad_req)
rhs_val.attach_grad(grad_req=rhs_grad_req)
gt_out = batch_dot_numpy(lhs_val.asnumpy(), rhs_val.asnumpy(),
transpose_a, transpose_b)
init_lhs_grad = mx.np.random.uniform(-1.0, 1.0, lhs_shape, dtype=dtype)
init_rhs_grad = mx.np.random.uniform(-1.0, 1.0, rhs_shape, dtype=dtype)
o_grad = mx.np.random.uniform(-1.0, 1.0, gt_out.shape, dtype=dtype)
if lhs_grad_req == 'add':
lhs_val.grad[:] = init_lhs_grad
if rhs_grad_req == 'add':
rhs_val.grad[:] = init_rhs_grad
with mx.autograd.record():
out = f_batch_dot(lhs_val, rhs_val)
out.backward(o_grad)
assert_almost_equal(out.asnumpy(), gt_out, rtol=eps, atol=eps)
gt_lhs_grad, gt_rhs_grad = gt_grad_batch_dot_numpy(lhs_val.asnumpy(),
rhs_val.asnumpy(),
o_grad.asnumpy(),
transpose_a=transpose_a,
transpose_b=transpose_b,
lhs_req=lhs_grad_req,
rhs_req=rhs_grad_req,
init_lhs_grad=init_lhs_grad.asnumpy(),
init_rhs_grad=init_rhs_grad.asnumpy())
assert_almost_equal(lhs_val.grad.asnumpy(), gt_lhs_grad, rtol=eps, atol=eps)
assert_almost_equal(rhs_val.grad.asnumpy(), gt_rhs_grad, rtol=eps, atol=eps)
for lhs_shape, rhs_shape, transpose_a, transpose_b in bad_configs:
for dtype in dtypes:
lhs_val = mx.np.array(onp.random.uniform(-1.0, 1.0, lhs_shape), dtype=dtype)
rhs_val = mx.np.array(onp.random.uniform(-1.0, 1.0, rhs_shape), dtype=dtype)
pytest.raises(MXNetError, lambda: mx.npx.batch_dot(lhs_val, rhs_val,
transpose_a=transpose_a,
transpose_b=transpose_b))
@use_np
@pytest.mark.parametrize('shape', [(4, 2), (4, 3, 4),
(4, 6, 4, 5), (4, 5, 6, 4, 5)])
@pytest.mark.parametrize('fix_gamma', [False, True])
@pytest.mark.parametrize('cudnn_off', [False, True])
@pytest.mark.parametrize('output_mean_var', [False, True])
@pytest.mark.flaky
def test_npx_batch_norm(shape, fix_gamma, cudnn_off, output_mean_var):
momentum = 0.9
epsilon = 1e-5
class TestBatchNorm(HybridBlock):
def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs):
super().__init__()
self.eps = eps
self.fix_gamma = fix_gamma
self.momentum = momentum
self.kwargs = kwargs
def forward(self, data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var):
op = npx.batch_norm
output = op(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var,
momentum=self.momentum, eps=self.eps,
fix_gamma=self.fix_gamma, **self.kwargs)
return output
def _test_batchnorm_impl(axis,
data_grad_req, gamma_grad_req, beta_grad_req):
kwargs = dict(output_mean_var=output_mean_var)
kwargs.update(dict(axis=axis, cudnn_off=cudnn_off))
op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs)
nch = shape[axis]
if not fix_gamma:
bn_gamma = np.random.uniform(size=(nch,))
bn_gamma.attach_grad(grad_req=gamma_grad_req)
else:
bn_gamma = np.ones((nch,))
bn_beta = np.random.uniform(size=(nch,))
bn_beta.attach_grad(grad_req=beta_grad_req)
bn_running_mean = np.zeros(nch)
bn_running_var = np.ones(nch)
running_mean = np.zeros(nch)
running_var = np.ones(nch)
num_iters = 10
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]
expand_shape = tuple(expand_shape)
data = np.random.uniform(size=shape)
data.attach_grad(grad_req=data_grad_req)
adX, adW, adb = 0, 0, 0
is_train = data_grad_req != 'null' or \
(not fix_gamma and gamma_grad_req != 'null') or \
beta_grad_req != 'null'
for _ in range(num_iters):
if data_grad_req != 'add':
data = np.random.uniform(size=shape)
data.attach_grad(grad_req=data_grad_req)
ograd = np.random.uniform(size=shape)
with mx.autograd.record():
output = op(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var)
if output_mean_var:
output, output_mean, output_std = output
if is_train:
output.backward(ograd)
mx.nd.waitall()
assert 0 <= axis < data.ndim
reduce_axis = tuple(i for i in range(data.ndim) if i != axis)
assert len(reduce_axis) == data.ndim - 1
data_mean = data.mean(
axis=reduce_axis, keepdims=True)
data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis,
keepdims=True)
target_output = (data - data_mean) / \
np.sqrt(data_var + epsilon) * \
bn_gamma.reshape(expand_shape) + \
bn_beta.reshape(expand_shape)
# squeeze data_mean and data_var
data_mean_flat = data_mean.squeeze()
data_var_flat = data_var.squeeze()
running_mean = running_mean * momentum + \
data_mean_flat * (1 - momentum)
m = onp.prod(shape) / shape[axis]
# cudnn uses m-1 in the denominator of its sample variance calculation, not m
sample_var_adjust = 1.0 if cudnn_off or fix_gamma else m / (m-1)
running_var = running_var * momentum + \
data_var_flat * sample_var_adjust * (1 - momentum)
W = bn_gamma.reshape(expand_shape)
dnx = ograd * W
xsm = data - data_mean
nd = 1.0 / np.sqrt(data_var + epsilon)
nx = xsm * nd
dvar = (dnx * xsm).sum(axis=reduce_axis, keepdims=True,
) * (-0.5) * np.power(nd, 3)
dmean = -nd * dnx.sum(axis=reduce_axis, keepdims=True) - \
dvar * xsm.mean(axis=reduce_axis, keepdims=True,
) * 2.0
dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
dW = (ograd * nx).sum(axis=reduce_axis)
db = ograd.sum(axis=reduce_axis)
adX = dX if data_grad_req != 'add' else adX + dX
adW = dW if gamma_grad_req != 'add' else adW + dW
adb = db if beta_grad_req != 'add' else adb + db
atol, rtol = 5e-2, 5e-2
if output_mean_var:
assert_almost_equal(output_mean.asnumpy(),
data_mean_flat.asnumpy(),
atol=atol, rtol=rtol)
assert_almost_equal(output_std.asnumpy(),
(1.0 / np.sqrt(data_var_flat +
epsilon)).asnumpy(),
atol=atol, rtol=rtol)
assert_almost_equal(output.asnumpy(), target_output.asnumpy(),
atol=atol, rtol=rtol)
if is_train:
assert_almost_equal(bn_running_mean.asnumpy(
), running_mean.asnumpy(), atol=atol, rtol=rtol)
assert_almost_equal(bn_running_var.asnumpy(
), running_var.asnumpy(), atol=atol, rtol=rtol)
if data_grad_req != 'null':
assert_almost_equal(data.grad.asnumpy(),
adX.asnumpy(), atol=atol, rtol=rtol)
if not fix_gamma:
if gamma_grad_req != 'null':
assert_almost_equal(
bn_gamma.grad.asnumpy(), adW.asnumpy(),
atol=atol, rtol=rtol)
else:
assert((bn_gamma.asnumpy() == 1).all())
if beta_grad_req != 'null':
assert_almost_equal(
bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)
grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add']
for data_grad_req in grad_reqs:
for gamma_grad_req in grad_reqs:
if fix_gamma and gamma_grad_req != 'null':
continue
for beta_grad_req in grad_reqs:
for axis in range(len(shape)):
_test_batchnorm_impl(axis,
data_grad_req, gamma_grad_req, beta_grad_req)
def np_softmax(x, axis=-1):
if (x.shape[axis] == 0):
return onp.sum(x, axis=axis, keepdims=True)
x = x - onp.max(x, axis=axis, keepdims=True)
x = onp.exp(x)
x /= onp.sum(x, axis=axis, keepdims=True)
return x
def np_log_softmax(x, axis=-1):
return onp.log(np_softmax(x, axis))
@use_np
def test_npx_softmax():
class TestSoftmax(HybridBlock):
def __init__(self, axis):
super(TestSoftmax, self).__init__()
self._axis = axis
def forward(self, a):
return npx.softmax(a, axis=axis)
class TestLogSoftmax(HybridBlock):
def __init__(self, axis):
super(TestLogSoftmax, self).__init__()
self._axis = axis
def forward(self, a):
return npx.log_softmax(a, axis=axis)
#(operator, function) tuples
tested_ops = [(TestSoftmax, np_softmax),
(TestLogSoftmax, np_log_softmax)]
# only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
for SoftmaxOp, softmax_function in tested_ops:
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax_op = SoftmaxOp(axis)
if hybridize:
test_softmax_op.hybridize()
with mx.autograd.record():
mx_out = test_softmax_op(mx_a)
mx_out.wait_to_read()
np_out = softmax_function(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
mx_out.backward()
mx_a.grad.wait_to_read()
assert_almost_equal(mx_a.grad.asnumpy(), onp.zeros(shape), rtol=1e-3, atol=1e-5)
def np_masked_softmax(data, mask, axis=-1, temperature=1.0):
neg = -1e18
if data.dtype == onp.float16:
neg = -1e4
temp = onp.where(mask, data, neg)
result = (np_softmax(temp, axis=axis) / temperature) * mask
return result
def np_masked_log_softmax(data, mask, axis=-1, temperature=1.0):
neg = -1e18
if data.dtype == onp.float16:
neg = -1e4
data = onp.where(mask, data, neg)
return onp.where(mask, np_log_softmax(data, axis=axis) / temperature, -onp.inf)
@use_np
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('shape', [(3, 0, 4), (0, 0)])
def test_npx_masked_softmax(hybridize, shape):
class TestMaskedSoftmax(HybridBlock):
def __init__(self, axis):
super(TestMaskedSoftmax, self).__init__()
self._axis = axis
def forward(self, a, mask):
return npx.masked_softmax(a, mask, axis=self._axis)
class TestMaskedLogSoftmax(HybridBlock):
def __init__(self, axis):
super(TestMaskedLogSoftmax, self).__init__()
self._axis = axis
def forward(self, a, mask):
return npx.masked_log_softmax(a, mask, axis=self._axis)
#(operator, function) tuples
tested_ops = [(TestMaskedSoftmax, np_masked_softmax),
(TestMaskedLogSoftmax, np_masked_log_softmax)]
# only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
for SoftmaxOp, softmax_function in tested_ops:
mx_a = np.random.uniform(size=shape)
mask = np.random.randint(0, 2, shape)
mx_a.attach_grad()
mask.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax_op = SoftmaxOp(axis)
if hybridize:
test_softmax_op.hybridize()
with mx.autograd.record():
mx_out = test_softmax_op(mx_a, mask)
mx_out.wait_to_read()
np_out = softmax_function(mx_a.asnumpy(), mask.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
@use_np
def test_npi_boolean_assign():
class TestBooleanAssignScalar(HybridBlock):
def __init__(self, val, start_axis):
super(TestBooleanAssignScalar, self).__init__()
self._val = val
self._start_axis = start_axis
def forward(self, a, mask):
return _npi.boolean_mask_assign_scalar(a, mask, self._val, start_axis=self._start_axis, out=a)
class TestBooleanAssignTensor(HybridBlock):
def __init__(self, start_axis):
super(TestBooleanAssignTensor, self).__init__()
self._start_axis = start_axis
def forward(self, a, mask, value):
return _npi.boolean_mask_assign_tensor(a, mask, value, start_axis=self._start_axis, out=a)
configs = [
((3, 4), (3, 4), 0),
((3, 0), (3, 0), 0),
((), (), 0),
((2, 3, 4, 5), (2, 3), 0),
((2, 3, 4, 5), (3, 4), 1),
((2, 3, 4, 5), (4, 5), 2),
]
for hybridize in [False]:
for config in configs:
dshape, mshape, start_axis = config
test_data = np.random.uniform(size=dshape)
valid_num = 0
while valid_num == 0:
mx_mask = np.random.choice(np.array([False, True], dtype=np.bool), size=mshape)
if test_data.size == 0:
break
valid_num = int(mx_mask.asnumpy().sum())
np_mask = mx_mask.asnumpy().astype(onp.bool)
vshape = []
vshape_broadcast = []
for i in range(len(dshape)):
if i < start_axis:
vshape.append(dshape[i])
vshape_broadcast.append(dshape[i])
elif i == start_axis:
vshape.append(valid_num)
vshape_broadcast.append(1)
elif i >= start_axis + len(mshape):
vshape.append(dshape[i])
vshape_broadcast.append(dshape[i])
vshape_broadcast = tuple(vshape_broadcast)
for val in [42.0, onp.array(42.), onp.array([42.]), onp.random.uniform(size=vshape), onp.random.uniform(size=vshape_broadcast)]:
mx_val = val if isinstance(val, float) else np.array(val, dtype=np.float32)
test_block = TestBooleanAssignScalar(val, start_axis) if isinstance(val, float) else TestBooleanAssignTensor(start_axis)
if hybridize:
test_block.hybridize()
np_data = test_data.asnumpy()
mx_data1 = test_data.copy()
mx_data2 = test_data.copy()
trailing_axis = len(np_data.shape) - len(np_mask.shape) - start_axis
if start_axis == 0:
if trailing_axis == 0:
np_data[np_mask] = val
mx_data1[mx_mask] = mx_val
elif trailing_axis == 1:
np_data[np_mask, :] = val
mx_data1[mx_mask, :] = mx_val
elif trailing_axis == 2:
np_data[np_mask, :, :] = val
mx_data1[mx_mask, :, :] = mx_val
elif start_axis == 1:
if trailing_axis == 0:
np_data[:, np_mask] = val
mx_data1[:, mx_mask] = mx_val
elif trailing_axis == 1:
np_data[:, np_mask, :] = val
mx_data1[:, mx_mask, :] = mx_val
elif start_axis == 2:
if trailing_axis == 0:
np_data[:, :, np_mask] = val
mx_data1[:, :, mx_mask] = mx_val
mx_data1 = test_block(mx_data2, mx_mask) if isinstance(val, float) else test_block(mx_data2, mx_mask, mx_val)
assert_almost_equal(mx_data1.asnumpy(), np_data, rtol=1e-3, atol=1e-5, use_broadcast=False)
assert_almost_equal(mx_data2.asnumpy(), np_data, rtol=1e-3, atol=1e-5, use_broadcast=False)
@use_np
def test_np_reshape():
class TestReshape(HybridBlock):
def __init__(self, newshape):
super(TestReshape, self).__init__()
self._newshape = newshape
def forward(self, a):
return np.reshape(a, self._newshape)
shape_pairs = [((2, 6), (6, 2)), ((2, 6), (3, 4)), ((1, 0), (0,)), ((0, 0), (0,)), ((), (1, 1, 1))]
for hybridize in [True, False]:
for shape_pair in shape_pairs:
shape1, shape2 = shape_pair
test_reshape = TestReshape(shape2)
if hybridize:
test_reshape.hybridize()
x = rand_ndarray(shape1).as_np_ndarray()
x.attach_grad()
np_out = onp.reshape(x.asnumpy(), shape2)
with mx.autograd.record():
mx_out = test_reshape(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = onp.ones(shape1)
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out = np.reshape(x, shape2)
np_out = onp.reshape(x.asnumpy(), shape2)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
@use_np
@pytest.mark.parametrize('descending', [True, False])
@pytest.mark.parametrize('shape', [
(),
(2, 3),
(1, 0, 2),
])
@pytest.mark.parametrize('hybrid', [False, True])
def test_np_argsort(descending, shape, hybrid):
class TestArgsort(HybridBlock):
def __init__(self, axis, descending):
super(TestArgsort, self).__init__()
self._axis = axis
self._descending = descending
def forward(self, x):
return np.argsort(x, axis=self._axis, descending=self._descending)
data = np.random.uniform(size=shape)
np_data = data.asnumpy()
for axis in [None] + [i for i in range(-len(shape), len(shape))]:
if descending:
np_out = onp.argsort(-1 * np_data, axis)
else:
np_out = onp.argsort(np_data, axis)
test_argsort = TestArgsort(axis, descending)
if hybrid:
test_argsort.hybridize()
mx_out = test_argsort(data)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False)
mx_out = np.argsort(data, axis, descending)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False)
@use_np
@pytest.mark.parametrize('descending', [True, False])
@pytest.mark.parametrize('shape', [
(),
(1,),
(5,),
(4, 3),
(3, 5),
(4, 4),
(4, 5),
(5, 5),
(5, 6),
(6, 6),
(0, 1),
(6, 5, 6),
(2, 3, 3, 4),
(4, 2, 1, 2),
(0, 5, 3, 3),
(5, 0, 3, 3),
(3, 3, 0, 0),
])
@pytest.mark.parametrize('dtype', [np.int8, np.uint8, np.int32, np.int64, np.float32, np.float64])
@pytest.mark.parametrize('hybridize', [True, False])
def test_np_sort(shape, dtype, hybridize, descending):
class TestSort(HybridBlock):
def __init__(self,