blob: f0e0e0977a581670c9322f1ce644af45dfb652be [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
from __future__ import print_function
from __future__ import division
import numpy as np
import mxnet as mx
import copy
import math
import random
import itertools
from distutils.version import LooseVersion
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from mxnet.operator import *
from mxnet.base import py_str, MXNetError, _as_list
from common import assert_raises_cudnn_not_satisfied, assert_raises_cuda_not_satisfied, assertRaises
from common import xfail_when_nonstandard_decimal_separator, with_environment
import pytest
import os
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@pytest.mark.serial
def test_rnn_with_new_param():
rnn_modes = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
ngates_ = [1, 1, 3, 4]
num_layers, input_size, seq_len, batch_size, state_size = 3, 128, 5, 64, 8
for bidirectional in [False, True]:
directions = 2 if bidirectional else 1
for mode, ngates in zip(rnn_modes, ngates_):
first_layer_size = (input_size * state_size + state_size * state_size + state_size * 2) * ngates
rest_layer_size = (state_size * directions * state_size + state_size * state_size + state_size * 2) \
* ngates * (num_layers - 1)
param_size = (first_layer_size + rest_layer_size) * directions
sym = mx.sym.RNN(mode=mode, num_layers=num_layers, bidirectional=bidirectional,
state_outputs=False, state_size=state_size, name='rnn')
bind_dict = {
'rnn_data': mx.ndarray.random.uniform(low=-1, high=1, shape=(seq_len, batch_size, input_size)),
'rnn_parameters': mx.ndarray.random.uniform(low=-1, high=1, shape=(param_size)),
'rnn_state': mx.ndarray.zeros(shape=(num_layers * directions, batch_size, state_size))
}
if mode == 'lstm':
bind_dict['rnn_state_cell'] = mx.ndarray.zeros(
shape=(num_layers * directions, batch_size, state_size))
ex = sym._bind(default_device(), bind_dict)
ex.forward(is_train=True)
ex01 = ex.output_dict['rnn_output'].asnumpy()
ex.forward(is_train=False)
ex02 = ex.output_dict['rnn_output'].asnumpy()
assert_allclose(ex01, ex02, rtol=1e-2, atol=1e-4)
bind_dict['rnn_parameters'] = mx.ndarray.random.uniform(low=-1, high=1, shape=(param_size))
ex.copy_params_from(bind_dict)
ex.forward(is_train=True)
ex03 = ex.output_dict['rnn_output'].asnumpy()
ex.forward(is_train=False)
ex04 = ex.output_dict['rnn_output'].asnumpy()
assert_allclose(ex03, ex04, rtol=1e-2, atol=1e-4)
@pytest.mark.serial
def test_lstm_dropout():
X = mx.sym.Variable('x')
Params = mx.sym.Variable('params')
HX = mx.sym.Variable('state')
CX = mx.sym.Variable('state_cell')
T, N, I, H = 300, 20, 800, 800
rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX,
state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM')
exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I))
out = exe.forward(is_train=True)
out[0].wait_to_read()
@pytest.mark.serial
def test_gru_dropout():
X = mx.sym.Variable('x')
Params = mx.sym.Variable('params')
HX = mx.sym.Variable('state')
T, N, I, H = 300, 20, 800, 800
rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
state_size=H, num_layers=5, mode='gru', p=0.5, state_outputs=True, name='GRU')
exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I))
out = exe.forward(is_train=True)
out[0].wait_to_read()
@pytest.mark.serial
def test_rnntanh_dropout():
X = mx.sym.Variable('x')
Params = mx.sym.Variable('params')
HX = mx.sym.Variable('state')
T, N, I, H = 300, 20, 800, 800
rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
state_size=H, num_layers=5, mode='rnn_tanh', p=0.5, state_outputs=True, name='RNN_TANH')
exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I))
out = exe.forward(is_train=True)
out[0].wait_to_read()
@pytest.mark.serial
def test_rnnrelu_dropout():
X = mx.sym.Variable('x')
Params = mx.sym.Variable('params')
HX = mx.sym.Variable('state')
T, N, I, H = 300, 20, 800, 800
rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
state_size=H, num_layers=5, mode='rnn_relu', p=0.5, state_outputs=True, name='RNN_RELU')
exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I))
out = exe.forward(is_train=True)
out[0].wait_to_read()
def test_RNN_float64():
if default_device().device_type == 'gpu':
return
sym = mx.sym.RNN(
mx.sym.Variable('in'),
mx.sym.Variable('par'),
mx.sym.Variable('s'),
state_size = (2),
num_layers = 1,
mode = 'rnn_tanh'
)
dtype = 'float64'
explicit_grad = {
'in': mx.nd.ones([2, 1, 2], dtype=dtype),
'par': mx.nd.ones([12], dtype=dtype),
's': mx.nd.ones([1, 1, 2], dtype=dtype)
}
args_grad = explicit_grad
grad_req = 'write'
ex = sym._bind(default_device(),
{
'in': mx.nd.ones([2, 1, 2], dtype=dtype),
'par': mx.nd.ones([12], dtype=dtype),
's': mx.nd.ones([1, 1, 2], dtype=dtype)
},
args_grad = args_grad,
grad_req = grad_req
)
ex.forward()
ex.outputs[0].wait_to_read()
def np_softmax(x, axis=-1, temperature=1.0):
x = x - np.max(x, axis=axis, keepdims=True)
x = np.exp(x/temperature)
x /= np.sum(x, axis=axis, keepdims=True)
return x
def check_elementwise_sum_with_shape(shape, n):
# forward
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
out = mx.symbol.ElementWiseSum(*inputs, name='esum')
arr = [mx.nd.empty(shape) for i in range(n)]
arr_grad = [mx.nd.empty(shape) for i in range(n)]
for i in range(n):
arr[i][:] = np.random.uniform(-10, 10, shape)
exec1 = out._bind(default_device(),
args=arr,
args_grad=arr_grad)
exec1.forward(is_train=True)
out1 = exec1.outputs[0]
out = sum(a.asnumpy() for a in arr)
assert_almost_equal(out, out1, rtol=1e-5, atol=1e-5)
out_grad = mx.nd.empty(shape)
out_grad[:] = np.random.uniform(-10, 10, shape)
# backward
exec1.backward([out_grad])
for a in arr_grad:
assert_almost_equal(a, out_grad, rtol=1e-5, atol=1e-5)
@pytest.mark.serial
def test_elementwise_sum():
nrepeat = 2
maxdim = 4
for _ in range(nrepeat):
for dim in range(1, maxdim):
shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
check_elementwise_sum_with_shape(shape, np.random.randint(1, 8))
def check_concat_with_shape(shapes, dimension, skip_second):
# if skip_second is True, second argument will not have gradient.
# it is to test #1130
n = len(shapes)
# forward
target_dim = 0
for shape in shapes:
target_dim += shape[dimension]
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
out = mx.symbol.Concat(*inputs, name='conc',dim=dimension)
arr = [mx.nd.empty(shape) for shape in shapes]
for i in range(n):
arr[i][:] = shapes[i][dimension]
arr_np = [np.copy(narray.asnumpy()) for narray in arr]
arr_grad = [mx.nd.empty(shape) for shape in shapes]
dict_grad = {}
arg_names = out.list_arguments()
for name, g in zip(arg_names, arr_grad):
if not skip_second or name != 'arg1':
dict_grad[name] = g
args = out.list_arguments()
arg_shapes, out_shapes, aux_shapes = out.infer_shape(**dict(zip(args, shapes)))
out_grad = mx.nd.empty(out_shapes[0])
exec1 = out._bind(default_device(),
args=arr,
args_grad=dict_grad)
exec1.forward(is_train=True)
out1 = exec1.outputs[0]
ret = np.concatenate([narray.asnumpy() for narray in arr], axis=dimension)
assert_almost_equal(out1, ret)
# backward
out1.copyto(out_grad)
out_grad[:] += 1
exec1.backward([out_grad])
for i, name in enumerate(arg_names):
if not skip_second or name != 'arg1':
grad = dict_grad[name]
np_grad = arr_np[i]
assert_almost_equal(grad, np_grad + 1)
def test_concat():
for dimension in range(4):
n = 2
merge = [2, 3, 4, 5, 6]
a = 2
b = 3
c = 4
# test 2D
if dimension<2:
for dim in range(2, 6):
shapes = []
for i in range(dim):
if dimension == 0:
shapes.append((merge[i], a))
elif dimension == 1:
shapes.append((a, merge[i]))
check_concat_with_shape(shapes,dimension,True)
check_concat_with_shape(shapes,dimension,False)
# Test negative dim
check_concat_with_shape(shapes, dimension - 2, True)
check_concat_with_shape(shapes, dimension - 2, False)
#test 3D
if dimension<3:
for dim in range(2, 6):
shapes = []
for i in range(dim):
if dimension == 0:
shapes.append((merge[i], a,b))
elif dimension ==1:
shapes.append((a,merge[i],b))
elif dimension ==2:
shapes.append((a,b,merge[i]))
check_concat_with_shape(shapes,dimension,True)
check_concat_with_shape(shapes,dimension,False)
# Test negative dim
check_concat_with_shape(shapes, dimension - 3, True)
check_concat_with_shape(shapes, dimension - 3, False)
# test 4D
for dim in range(2, 6):
shapes = []
for i in range(dim):
if dimension == 0:
shapes.append((merge[i],a,b,c))
elif dimension == 1:
shapes.append((a,merge[i],b,c))
elif dimension ==2:
shapes.append((a,b,merge[i],c))
elif dimension ==3:
shapes.append((a,b,c,merge[i]))
check_concat_with_shape(shapes,dimension,True)
check_concat_with_shape(shapes,dimension,False)
# Test negative dim
check_concat_with_shape(shapes, dimension - 4, True)
check_concat_with_shape(shapes, dimension - 4, False)
def test_slice_channel():
def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis):
ins = []
if squeeze_axis:
shape = np.random.randint(2, 5, data_ndim).tolist()
shape[axis] = num_outputs
out_ele_shape = [ele for ele in shape]
del out_ele_shape[axis]
else:
shape = np.random.randint(1, 5, data_ndim).tolist()
shape[axis] *= num_outputs
out_ele_shape = [ele for ele in shape]
out_ele_shape[axis] //= num_outputs
data_npy = np.random.normal(size=shape)
out_grads_npy = [np.random.normal(size=out_ele_shape) for i in range(num_outputs)]
data = mx.sym.Variable('data')
sym = mx.sym.SliceChannel(data=data, num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
exe = sym._simple_bind(ctx=default_device(), data=data_npy.shape)
outputs = exe.forward(is_train=True, data=data_npy)
assert len(exe.outputs) == num_outputs
for i in range(num_outputs):
gt = data_npy.take(np.arange(i * shape[axis]/num_outputs,
(i+1) * shape[axis]/num_outputs).astype(np.int), axis=axis)
if squeeze_axis:
assert_almost_equal(outputs[i], gt.reshape(outputs[i].shape))
else:
assert_almost_equal(outputs[i], gt)
# test backward
ograd = [mx.nd.array(ele, dtype=outputs[i].dtype) for i, ele in enumerate(out_grads_npy)]
exe.backward(out_grads=ograd)
if squeeze_axis:
assert_almost_equal(exe.grad_arrays[0],
np.concatenate([np.expand_dims(ele, axis=axis) for ele in out_grads_npy],
axis=axis))
else:
assert_almost_equal(exe.grad_arrays[0],
np.concatenate(out_grads_npy, axis=axis))
check_slice_channel(data_ndim=2, axis=1, num_outputs=3, squeeze_axis=True)
check_slice_channel(data_ndim=4, axis=2, num_outputs=3, squeeze_axis=False)
check_slice_channel(data_ndim=3, axis=-1, num_outputs=2, squeeze_axis=False)
check_slice_channel(data_ndim=5, axis=-2, num_outputs=3, squeeze_axis=True)
def test_python_op():
X = mx.symbol.Variable('X')
op = mx.operator.NumpyOp()
s = op.get_symbol(X, name='numpy_op')
x = mx.ndarray.ones((10))*10
dx = mx.ndarray.zeros((10))
dy = mx.ndarray.ones((10))
exec1 = s._bind(default_device(), args=[x], args_grad = {'X': dx})
exec1.forward(is_train=True)
assert_almost_equal(x, exec1.outputs[0])
exec1.backward(dy)
assert_almost_equal(dy, dx)
def test_swapaxes():
data = mx.symbol.Variable('data')
shape = (2, 3, 4)
data_tmp = np.ones(shape)
data_tmp[0] = 1
data_tmp[1] = 2
arr_data = mx.nd.array(data_tmp)
swap0 = mx.symbol.SwapAxis(data=data, dim1=0, dim2=2)
swap = mx.symbol.SwapAxis(data=swap0, dim1=1, dim2=2)
exe_c = swap._bind(default_device(), args=[arr_data])
exe_c.forward(is_train=True)
out = exe_c.outputs[0]
swap0_ = np.swapaxes(data_tmp, 0, 2)
swap_ = np.swapaxes(swap0_, 1, 2)
assert_almost_equal(out, swap_)
config = [((1, 1, 2), 0, 1),
((1, 1, 2), -1, -2),
((4, 5, 6, 7), 1, 1),
((4, 5, 6, 7), 2, 3),
((4, 5, 6, 7), -2, 2),
((4, 5, 6, 7), -2, -3)]
for shape, axis1, axis2 in config:
data_np = np.random.uniform(size=shape)
data_mx = mx.nd.array(data_np, dtype=data_np.dtype)
ret_np = np.swapaxes(data_np, axis1=axis1, axis2=axis2)
ret_mx = mx.symbol.SwapAxis(data, dim1=axis1, dim2=axis2)
exe_c = ret_mx._bind(default_device(), args=[data_mx])
exe_c.forward(is_train=True)
out = exe_c.outputs[0]
assert_almost_equal(out, ret_np)
@xfail_when_nonstandard_decimal_separator
def test_scalarop():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)*5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]=3
test = 2 / (4-((1+data+1)*2/5)-0.8-(data!=0))
npout_1 = (4-((1+data_tmp+1)*2/5)-0.8-(data_tmp!=0))
npout = 2/npout_1
check_symbolic_forward(test, [data_tmp], [npout])
npout_grad = 2.*2/5
npout_grad = 2*npout_grad /(npout_1 *npout_1 )
check_symbolic_backward(test, [data_tmp], [np.ones(shape)*2], [npout_grad])
def test_scalar_pow():
data = mx.symbol.Variable('data')
shape = (1, 1)
data_tmp = np.ones(shape)
test = data ** 2
check_numeric_gradient(test, [data_tmp])
check_symbolic_forward(test, [data_tmp], [data_tmp ** 2])
check_symbolic_backward(test, [data_tmp], [np.ones(shape)], [2 * data_tmp])
def test_symbol_pow():
shape = (1, 1)
data = mx.symbol.Variable('data')
data_tmp = np.ones(shape)*2
exp = mx.symbol.Variable('exp')
exp_tmp = np.ones(shape)*3
test = data**exp
check_numeric_gradient(test, [data_tmp, exp_tmp])
check_symbolic_forward(test, [data_tmp, exp_tmp], [data_tmp**exp_tmp])
data_dir = data_tmp**(exp_tmp - 1) * exp_tmp
exp_dir = data_tmp**(exp_tmp) * np.log(data_tmp)
check_symbolic_backward(test, [data_tmp, exp_tmp], [np.ones(shape)], [data_dir, exp_dir])
def test_fully_connected():
# Create data of given shape as a uniform distribution centered on 0.0
def random_data(shape, dtype=np.float32):
return mx.nd.random.uniform(low=-0.5,
high=0.5, shape=shape, dtype=dtype)
data = mx.sym.var("data")
fc_weight = mx.sym.var("weight")
fc_bias = mx.sym.var("bias")
fc = mx.sym.FullyConnected(data=data, weight=fc_weight, bias=fc_bias, num_hidden=10, no_bias=False, name='fc')
data = random_data(shape=(5, 5, 5, 13))
fc_weight = random_data(shape=(10, 325))
fc_bias = random_data(shape=(10))
fc_bias2 = random_data(shape=(10, 1))
data_np = data.asnumpy().reshape(5, 325)
fc_weight_np = np.transpose(fc_weight.asnumpy())
fc_bias_np = fc_bias.asnumpy()
res = np.dot(data_np, fc_weight_np) + fc_bias.asnumpy()
check_symbolic_forward(fc, {'data': data_np, 'weight': fc_weight.asnumpy(), 'bias': fc_bias_np}, {'fc_output': res})
check_numeric_gradient(fc, {'data': data_np, 'weight': fc_weight.asnumpy(), 'bias': fc_bias_np})
# TODO: Fix Bug #15032 when bias has ndim > 1
#check_symbolic_forward(fc, {'data': data_np, 'weight': fc_weight.asnumpy(), 'bias': fc_bias2.asnumpy()}, {'fc_output': res})
def test_pow_fn():
shape = (3, 4)
exp = mx.symbol.Variable("exp")
x = np.ones(shape)*3
for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]:
check_numeric_gradient(y, [x], numeric_eps=1E-3)
check_symbolic_forward(y, [x], [2**x])
check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x])
def test_relu():
def frelu(x):
return np.maximum(x, 0.0)
def frelu_grad(x):
return np.float32(1.0) * (x > np.float32(0.0))
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.relu(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype('float32')
eps = 1e-4
# Avoid finite difference method inaccuracies due to discontinuous gradient at the origin.
# Here we replace small problematic inputs with 1.0. Repro issue with seed 97264195.
xa[abs(xa) < eps] = 1.0
ya = frelu(xa)
ga = frelu_grad(xa)
check_numeric_gradient(y, [xa], numeric_eps=eps)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ga])
# NOTE(haojin2): Skipping the numeric check tests for float16 data type due to precision issues,
# the analytical checks are still performed on each and every data type to verify the correctness.
def test_leaky_relu():
def fleaky_relu(x, act_type, slope=0.25):
neg_indices = x < 0
out = x.copy()
if act_type == 'elu':
out[neg_indices] = slope * np.expm1(out[neg_indices])
elif act_type == 'leaky':
out[neg_indices] = slope * out[neg_indices]
return out
def fleaky_relu_grad(grad, x, y, act_type, slope=0.25):
neg_indices = x < 0
out = np.ones(x.shape)
if act_type == 'elu':
out[neg_indices] = y[neg_indices] + slope
elif act_type == 'leaky':
out[neg_indices] = slope
return out * grad
for ndim in range(1, 4):
shape = rand_shape_nd(ndim)
x = mx.symbol.Variable("x")
slp = 0.25
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype)
eps = 1e-4
rtol = 1e-2
atol = 1e-3
xa[abs(xa) < eps] = 1.0
for act_type in ['elu', 'leaky']:
y = mx.symbol.LeakyReLU(data=x, slope=slp, act_type=act_type)
ya = fleaky_relu(xa, slope=slp, act_type=act_type)
ga = fleaky_relu_grad(np.ones(shape), xa, ya, slope=slp, act_type=act_type)
# Skip numeric check for float16 type to get rid of flaky behavior
if dtype is not np.float16:
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape, dtype=dtype)], [ga], rtol=rtol, atol=atol, dtype=dtype)
# NOTE(haojin2): Skipping the numeric check tests for float16 data type due to precision issues,
# the analytical checks are still performed on each and every data type to verify the correctness.
def test_prelu():
def fprelu(x, gamma):
pos_indices = x > 0
out = x.copy()
if len(x.shape) == 4:
out = out.transpose(2,3,0,1)
out = np.multiply(out, gamma)
out = out.transpose(2,3,0,1)
else:
out = np.multiply(out, gamma)
out[pos_indices] = x[pos_indices]
return out
def fprelu_grad(x, y, gamma):
pos_indices = x > 0
if len(x.shape) == 4:
grad_x = np.multiply(np.ones(x.shape).transpose(2,3,0,1), gamma)
grad_x = grad_x.transpose(2,3,0,1)
else:
grad_x = np.multiply(np.ones(x.shape), gamma)
grad_gam = np.zeros(gamma.shape)
copy_x = x.copy()
copy_x[pos_indices] = 0.0
grad_x[pos_indices] = 1.0
if len(gamma.shape) > 1 and len(x.shape) != 4:
grad_gam = copy_x
elif len(gamma.shape) > 1 and len(x.shape) == 4:
grad_gam = np.sum(copy_x, axis=(2,3))
elif gamma.shape[0] == 1:
grad_gam = np.sum(np.sum(copy_x))
elif gamma.shape[0] > 1 and len(x.shape) != 4:
grad_gam = np.sum(copy_x, axis=0)
elif gamma.shape[0] > 1 and len(x.shape) == 4:
grad_gam = np.sum(copy_x, axis=(0,2,3))
return (grad_x, grad_gam)
x = mx.symbol.Variable("x")
gamma = mx.symbol.Variable("gamma")
for shape in [(3,4), (3,4,4,5)]:
for dtype in [np.float16, np.float32, np.float64]:
for gam in [np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype)]:
gam_full = np.array([gam, gam, gam])
xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype)
rtol = 1e-2
atol = 1e-3
eps = 1e-4
xa[abs(xa) < eps] = 1.0
y = mx.symbol.LeakyReLU(data=x, gamma=gamma, act_type='prelu')
ya = fprelu(xa, gam)
ya_full = fprelu(xa, gam_full)
g_xa, g_gam = fprelu_grad(xa, ya, gamma=gam)
g_xa_full, g_gam_full = fprelu_grad(xa, ya_full, gamma=gam_full)
# Skip numeric check for float16 type to get rid of flaky behavior
if dtype is not np.float16:
check_numeric_gradient(y, [xa, gam], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_numeric_gradient(y, [xa, gam_full], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa, gam], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa, gam], [np.ones(ya.shape, dtype=dtype)],
[g_xa, g_gam], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa, gam_full], [ya_full], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa, gam_full], [np.ones(ya_full.shape, dtype=dtype)],
[g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)
def test_selu():
alpha = 1.6732632423543772848170429916717
lamb = 1.0507009873554804934193349852946
def fselu(x):
neg_indices = x < 0
out = x.copy()
out[neg_indices] = alpha * np.expm1(out[neg_indices])
return out * lamb
def fselu_grad(grad, x, y):
neg_indices = x < 0
out = np.ones(x.shape).astype(x.dtype)
out[neg_indices] = y[neg_indices] + alpha
return out * lamb
shape = (3, 4)
x = mx.sym.Variable("x")
y = mx.sym.LeakyReLU(data=x, act_type="selu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
ya = fselu(xa)
ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya)
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape, dtype=dtype)], [ga], rtol=rtol, atol=atol, dtype=dtype)
def test_gelu():
CUBE_CONSTANT = 0.044715
ROOT_TWO_OVER_PI = 0.7978845608028654
def g(x):
return ROOT_TWO_OVER_PI * (x + CUBE_CONSTANT * np.power(x, 3))
def g_grad(x):
return ROOT_TWO_OVER_PI * (1.0 + 3.0 * CUBE_CONSTANT * np.power(x, 2))
def f(x):
return 1.0 + np.tanh(g(x))
def f_grad(x):
return (1.0 - np.tanh(g(x)) * np.tanh(g(x))) * g_grad(x)
def fgelu(x):
return 0.5 * x * f(x)
def fgelu_grad(grad, x, y):
return grad * (y / x + y * (1 - np.tanh(g(x))) * g_grad(x))
shape = (3, 4)
x = mx.sym.Variable("x")
y = mx.sym.LeakyReLU(data=x, act_type="gelu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 2e-2, 1e-3) if dtype is np.float16 else (1e-4, 1e-3, 1e-5)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
ya = fgelu(xa)
ga = fgelu_grad(np.ones(shape).astype(dtype), xa, ya)
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)
def test_sigmoid():
def fsigmoid(a):
return np.divide(1.0, (1.0 + np.exp(-a)))
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.sigmoid(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
ya = fsigmoid(xa)
check_numeric_gradient(y, [xa], numeric_eps=1E-3)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])
def test_log_sigmoid():
def flog_sigmoid(a):
return np.log(np.divide(1.0, np.add(1.0, np.exp(-a))))
def flog_sigmoid_grad(a):
return np.divide(1.0, np.add(1.0, np.exp(a)))
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.log_sigmoid(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
ya = flog_sigmoid(xa)
ya_grad = flog_sigmoid_grad(xa)
check_numeric_gradient(y, [xa], numeric_eps=1E-3)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])
def test_mish():
def fmish(a):
return a * np.tanh(np.log1p(np.exp(a)))
def fmish_grad(a):
softrelu = np.log1p(np.exp(a))
tanh = np.tanh(softrelu)
sigmoid = np.divide(1.0, (1.0 + np.exp(-a)))
return tanh + a * sigmoid * (1.0 - tanh * tanh)
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.mish(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
ya = fmish(xa)
ya_grad = fmish_grad(xa)
check_numeric_gradient(y, [xa], numeric_eps=1E-3)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])
def test_shape_array():
for i in range(1,6):
shape = rand_shape_nd(i)
x = mx.sym.var('x')
y = mx.sym.shape_array(x)
xa = mx.nd.array(np.random.ranf(shape))
xg = mx.nd.empty(xa.shape)
ya = np.shape(xa)
yg = mx.nd.ones(ya)
exe = y._bind(ctx=default_device(), args={'x': xa},
args_grad={'x': xg})
exe.forward(is_train=True)
exe.backward([yg])
yo = exe.outputs[0].asnumpy()
same(yo, ya)
assert_almost_equal(xg, np.zeros_like(xg.asnumpy()))
def test_size_array():
for i in range(1,6):
shape = rand_shape_nd(i)
x = mx.sym.var('x')
y = mx.sym.size_array(x)
xa = mx.nd.array(np.random.ranf(shape))
xg = mx.nd.empty(xa.shape)
ya = np.size(xa)
yg = mx.nd.ones(ya)
exe = y._bind(ctx=default_device(), args={'x': xa},
args_grad={'x': xg})
exe.forward(is_train=True)
exe.backward([yg])
yo = exe.outputs[0].asnumpy()
same(yo, ya)
assert_almost_equal(xg, np.zeros_like(xg.asnumpy()))
def test_hard_sigmoid():
def fhardsigmoid(a, alpha=0.2, beta=0.5):
return np.maximum(np.zeros(a.shape, dtype=a.dtype),
np.minimum(np.ones(a.shape, dtype=a.dtype), alpha*a+beta))
def fhardsigmoid_grad(a, out_grad, alpha=0.2, beta=0.5):
orig_out = fhardsigmoid(a, alpha, beta)
res = out_grad * alpha
res[orig_out <= 0.0] = 0.0
res[orig_out >= 1.0] = 0.0
return res
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.hard_sigmoid(x)
for dtype in [np.float16, np.float32, np.float64]:
if dtype is np.float16:
rtol = 1e-2
else:
rtol = 1e-3
atol = 1e-3
eps = 1e-3
xa = np.random.uniform(low=-3.0,high=3.0,size=shape).astype(dtype)
# function not differentiable at x=2.5 and -2.5
xa[abs(xa-2.5) < eps] -= 2 * eps
xa[abs(xa+2.5) < eps] += 2 * eps
ya = fhardsigmoid(xa)
grad_xa = fhardsigmoid_grad(xa, np.ones(shape))
if dtype is not np.float16:
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape)], [grad_xa], rtol=rtol, atol=atol, dtype=dtype)
def test_softsign():
def fsoftsign(a):
return np.divide(a, (1.0 + np.abs(a)))
def fsoftsign_grad(a):
return np.divide(1.0, np.square((1.0 + np.abs(a))))
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.softsign(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
ya = fsoftsign(xa)
ya_grad = fsoftsign_grad(xa)
check_numeric_gradient(y, [xa], numeric_eps=1E-3)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])
def test_binary_logic():
def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True):
x = mx.symbol.Variable("x")
y = mx.symbol.Variable("y")
z = logic_sym(x, y)
x_npy = np.random.randint(0, 4, size=x_shape).astype(np.float32)
y_npy = np.random.randint(0, 4, size=y_shape).astype(np.float32)
exe = z._simple_bind(ctx=default_device(), x=x_shape, y=y_shape)
mx_out = exe.forward(is_train=True, x=x_npy, y=y_npy)[0]
assert_almost_equal(mx_out, forward_gt(x_npy, y_npy))
exe.backward()
if test_scalar:
z_lscalar = logic_sym(1, y)
z_rscalar = logic_sym(x, 1)
exe_lscalar = z_lscalar._simple_bind(ctx=default_device(), y=y_shape)
exe_rscalar = z_rscalar._simple_bind(ctx=default_device(), x=x_shape)
mx_lscalar_out = exe_lscalar.forward(is_train=True, y=y_npy)[0]
mx_rscalar_out = exe_rscalar.forward(is_train=True, x=x_npy)[0]
assert_almost_equal(mx_lscalar_out, forward_gt(1, y_npy))
assert_almost_equal(mx_rscalar_out, forward_gt(x_npy, 1))
exe_lscalar.backward()
exe_rscalar.backward()
# Test the no-broadcasting binary logic ops + scalar logic ops
_inner_test(forward_gt=lambda x, y: x == y,
logic_sym=lambda x, y: x == y, x_shape=(10, 10), y_shape=(10, 10))
_inner_test(forward_gt=lambda x, y: x > y,
logic_sym=lambda x, y: x > y, x_shape=(10, 10), y_shape=(10, 10))
_inner_test(forward_gt=lambda x, y: x >= y,
logic_sym=lambda x, y: x >= y, x_shape=(10, 10), y_shape=(10, 10))
_inner_test(forward_gt=lambda x, y: x < y,
logic_sym=lambda x, y: x < y, x_shape=(10, 10), y_shape=(10, 10))
_inner_test(forward_gt=lambda x, y: x <= y,
logic_sym=lambda x, y: x <= y, x_shape=(10, 10), y_shape=(10, 10))
_inner_test(forward_gt=lambda x, y: x != y,
logic_sym=lambda x, y: x != y, x_shape=(10, 10), y_shape=(10, 10))
# Test the broadcasting binary logic ops
_inner_test(forward_gt=lambda x, y: x == y,
logic_sym=lambda x, y: mx.sym.broadcast_equal(x, y),
x_shape=(1, 10), y_shape=(10, 1), test_scalar=False)
_inner_test(forward_gt=lambda x, y: x > y,
logic_sym=lambda x, y: mx.sym.broadcast_greater(x, y),
x_shape=(1, 10), y_shape=(10, 1), test_scalar=False)
_inner_test(forward_gt=lambda x, y: x >= y,
logic_sym=lambda x, y: mx.sym.broadcast_greater_equal(x, y),
x_shape=(1, 10), y_shape=(10, 1), test_scalar=False)
_inner_test(forward_gt=lambda x, y: x < y,
logic_sym=lambda x, y: mx.sym.broadcast_lesser(x, y),
x_shape=(1, 10), y_shape=(10, 1), test_scalar=False)
_inner_test(forward_gt=lambda x, y: x <= y,
logic_sym=lambda x, y: mx.sym.broadcast_lesser_equal(x, y),
x_shape=(1, 10), y_shape=(10, 1), test_scalar=False)
_inner_test(forward_gt=lambda x, y: x != y,
logic_sym=lambda x, y: mx.sym.broadcast_not_equal(x, y),
x_shape=(1, 10), y_shape=(10, 1), test_scalar=False)
def test_unary_logic():
def reference(a, dtype):
return np.logical_not(a).astype(dtype)
shape = (3, 4)
xa = np.random.randint(-2, 2, size=shape).astype(np.float32)
mx_xa = mx.nd.array(xa)
mx_out = mx.nd.logical_not(mx_xa)
assert_almost_equal(mx_out, reference(xa, dtype=xa.dtype))
x = mx.sym.Variable('x')
y = mx.sym.logical_not(data=x)
exe = y._simple_bind(ctx=default_device(), x=shape)
sym_out = exe.forward(is_train=True, x=mx_xa)[0]
assert_almost_equal(sym_out, reference(xa, dtype=xa.dtype))
def test_embedding():
in_dim = 10
out_dim = 4
batch = 24
data = mx.sym.Variable("data")
embed = mx.sym.Embedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed")
exe_test = embed._simple_bind(default_device(), grad_req={'data': 'null', 'embed_weight': 'write'}, data=(batch,))
arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
np_data = np.random.randint(low=0, high=in_dim, size=batch)
np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape)
np_onehot = np.zeros((batch, in_dim))
np_onehot[np.arange(batch), np_data] = 1.0
# forward
arg_map["data"][:] = np_data
arg_map["embed_weight"][:] = np_weight
exe_test.forward(is_train=True)
# Non-zero atol required, as exposed by seed 781663739
rtol = 1e-5
atol = 1e-5
assert_almost_equal(exe_test.outputs[0], np.dot(np_onehot, np_weight), rtol=rtol, atol=atol)
# backward
np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
grad = mx.nd.zeros(np_grad.shape)
grad[:] = np_grad
exe_test.backward([grad])
assert_almost_equal(grad_map["embed_weight"], np.dot(np_onehot.T, np_grad), rtol=rtol, atol=atol)
# check ops handle duplicate input correctly.
def test_binary_op_duplicate_input():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:] = 5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:] = 3
out_grad = mx.nd.empty(shape)
out_grad[:] = 1
square = data * data
exe_square = square._bind(default_device(), args=[arr_data], args_grad=[arr_grad])
exe_square.forward(is_train=True)
assert_almost_equal(exe_square.outputs[0], data_tmp * data_tmp)
exe_square.backward(out_grad)
assert_almost_equal(arr_grad, 2.0 * data_tmp)
def test_sign():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:]=5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]=3
test = mx.sym.sign(data)
exe_test = test._bind(default_device(), args=[arr_data], args_grad=[arr_grad])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
npout = np.sign(data_tmp)
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
out_grad[:] = 2
npout_grad = out_grad.asnumpy()
npout_grad = 0
exe_test.backward(out_grad)
assert_almost_equal(arr_grad, npout_grad)
def test_round_ceil_floor():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:]=5.543
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]= 2
test = mx.sym.round(data) + mx.sym.ceil(data) + mx.sym.floor(data)
exe_test = test._bind(default_device(), args=[arr_data])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
npout = np.round(data_tmp) + np.ceil(data_tmp) + np.floor(data_tmp)
assert_almost_equal(out, npout)
def test_trunc():
data_tmp = np.random.rand(3, 4) * 10 - 5
arr_data = mx.nd.array(data_tmp)
data = mx.symbol.Variable('data')
test = mx.sym.trunc(data)
exe_test = test._bind(default_device(), args=[arr_data])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
# 'trunc' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32.
# Repro issue with seed 1660190454
npout = np.trunc(np.float32(data_tmp))
assert_almost_equal(out, npout)
def test_rsqrt_cos_sin():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:]=5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]=3
test = mx.sym.rsqrt(data) + mx.sym.cos(data) + mx.sym.sin(data)
exe_test = test._bind(default_device(), args=[arr_data], args_grad=[arr_grad])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
npout = 1/ np.sqrt(data_tmp) + np.cos(data_tmp) + np.sin(data_tmp)
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
out_grad[:] = 2
npout_grad = out_grad.asnumpy()
npout_grad = npout_grad * -(1.0 / (2.0 * data_tmp * np.sqrt(data_tmp))) + npout_grad * -1 * np.sin(data_tmp) + npout_grad * np.cos(data_tmp)
exe_test.backward(out_grad)
assert_almost_equal(arr_grad, npout_grad)
def test_maximum_minimum():
data1 = mx.symbol.Variable('data1')
data2 = mx.symbol.Variable('data2')
shape = (3, 4)
data_tmp1 = np.random.rand(3,4)
data_tmp2 = np.random.rand(3,4)
data_tmp1[:] = 2
data_tmp2[:] = 3
arr_data1 = mx.nd.array(data_tmp1)
arr_data2 = mx.nd.array(data_tmp2)
arr_grad1 = mx.nd.empty(shape)
arr_grad2 = mx.nd.empty(shape)
test = mx.sym.maximum(data1,data2) + mx.sym.minimum(data1,data2)
exe_test = test._bind(default_device(), args=[arr_data1,arr_data2], args_grad=[arr_grad1,arr_grad2])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
npout = np.maximum(data_tmp1,data_tmp2) + np.minimum(data_tmp1,data_tmp2)
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
out_grad[:] = 2
exe_test.backward(out_grad)
npout_grad = np.ones(shape)
npout_grad[:] = 2
mask1 = (data_tmp1 > data_tmp2).astype('float')
mask2 = (data_tmp1 < data_tmp2).astype('float')
npout_grad1 = npout_grad * mask1 + npout_grad * mask2
npout_grad2 = (npout_grad - npout_grad * mask1) + (npout_grad - npout_grad * mask2)
assert_almost_equal(arr_grad1, npout_grad1)
assert_almost_equal(arr_grad2, npout_grad2)
def test_maximum_minimum_scalar():
data1 = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp1 = np.random.rand(3,4)
data_tmp1[:] = 2
arr_data1 = mx.nd.array(data_tmp1)
arr_grad1 = mx.nd.empty(shape)
test = mx.sym.maximum(data1,3) + mx.sym.maximum(9,data1) + mx.sym.minimum(5,data1) + mx.sym.minimum(data1,4)
exe_test = test._bind(default_device(), args=[arr_data1], args_grad=[arr_grad1])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
npout = np.maximum(data_tmp1,3) + np.maximum(9,data_tmp1) + np.minimum(5,data_tmp1) + np.minimum(data_tmp1,4)
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
out_grad[:] = 2
exe_test.backward(out_grad)
npout_grad = np.ones(shape)
npout_grad[:] = 2
mask1 = (data_tmp1 > 3).astype('float')
mask2 = (9 > data_tmp1).astype('float')
mask3 = (5 < data_tmp1).astype('float')
mask4 = (data_tmp1 < 4).astype('float')
npout_grad1 = npout_grad * mask1 + (npout_grad - npout_grad * mask2) + (npout_grad - npout_grad * mask3) + npout_grad * mask4
assert_almost_equal(arr_grad1, npout_grad1)
def test_abs():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:]=5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]=3
test = mx.sym.abs(data)
exe_test = test._bind(default_device(), args=[arr_data], args_grad=[arr_grad])
exe_test.forward(is_train=True)
out = exe_test.outputs[0]
npout = abs(data_tmp)
assert_almost_equal(out, npout)
out_grad = mx.nd.empty(shape)
out_grad[:] = 2
npout_grad = out_grad.asnumpy()
npout_grad = npout_grad * np.sign(data_tmp)
exe_test.backward(out_grad)
assert_almost_equal(arr_grad, npout_grad)
def check_deconvolution_forward_backward(input_shape, num_filter, kernel, stride, pad):
"""configure A: input --> conv --> deconv --> output.
the convolution and deconvoluiton has similar parameter which ensure
the input shape is the same as output, and the same weights between conv
and deconv;
If the input value of forward() and backwrad() is the same, then
the output value of them should also the same;
"""
assert input_shape[1] == num_filter
data = mx.sym.Variable(name="data")
conv = mx.sym.Convolution(
data=data, kernel=kernel, stride=stride, pad=pad,
num_filter=num_filter, no_bias = "true", name = "conv")
deconv = mx.sym.Deconvolution(
data=conv, kernel=kernel, stride=stride, pad=pad,
num_filter=num_filter, no_bias = "true", name = "deconv")
arg_names = deconv.list_arguments()
arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape)
input_data = mx.random.uniform(-5, 5, input_shape, ctx=mx.cpu()).copyto(default_device())
out_grad = input_data
args = {}
args["data"] = input_data
args['conv_weight'] = args['deconv_weight'] = mx.random.normal(0, 1,
(num_filter, input_shape[1]) + kernel, ctx=mx.cpu()).copyto(default_device())
args_grad = [mx.nd.empty(s) for s in arg_shapes]
exe = deconv._bind(default_device(), args=args, args_grad=args_grad)
exe.forward(is_train=True)
out = exe.outputs[0]
exe.backward(out_grad)
assert_almost_equal(out, args_grad[0], rtol=1E-3, atol=1e-3)
args_grad_addto_npy = [np.random.normal(size=s) for s in arg_shapes]
args_grad_addto = [mx.nd.array(ele) for ele in args_grad_addto_npy]
exe = deconv._bind(default_device(), args=args, args_grad=args_grad_addto, grad_req="add")
exe.forward(is_train=True)
out = exe.outputs[0].asnumpy()
exe.backward(out_grad)
assert_almost_equal(out + args_grad_addto_npy[0], args_grad_addto[0].asnumpy(), rtol=1e-3, atol=1e-3)
def check_deconvolution_gradient(input_shape, num_filter, pad):
"""configure A: input --> conv --> output.
configure B: input --> deconv --> output
the convolution and deconvoluiton has similar parameter which ensure
the input shape is the same as output;
During backward(), if the input of A equals output of B, and the output
of A equals input of B, then the grad of weight should be the same;
"""
ndim = len(pad)
stride = (1,) * ndim
kernel = tuple(2 * np.array(pad) + 1)
data_conv = mx.sym.Variable(name="data_conv")
conv = mx.sym.Convolution(
data=data_conv, kernel=kernel, stride=stride, pad=pad,
num_filter=num_filter, no_bias = "true", name = "conv")
data_deconv = mx.sym.Variable(name="data_deconv")
deconv = mx.sym.Deconvolution(
data=data_deconv, kernel=kernel, stride=stride, pad=pad,
num_filter=num_filter, no_bias = "true", name = "deconv")
conv_data = mx.random.uniform(-5, 5, input_shape, ctx=mx.cpu()).copyto(default_device())
conv_args = {}
conv_args["data_conv"] = conv_data
conv_args['conv_weight'] = \
mx.random.normal(0, 1,(num_filter, input_shape[1]) + kernel, ctx=mx.cpu()).copyto(default_device())
conv_args_grad = [mx.nd.zeros(conv_data.shape),
mx.nd.zeros((num_filter, input_shape[1]) + kernel)]
exe_conv = conv._bind(default_device(), args=conv_args, args_grad=conv_args_grad)
exe_conv.forward(is_train=True)
conv_out_grad = mx.random.normal(0, 2, exe_conv.outputs[0].shape, ctx=mx.cpu()).copyto(default_device())
exe_conv.backward(conv_out_grad)
deconv_data = conv_out_grad
deconv_args = {}
deconv_args['data_deconv'] = deconv_data
deconv_args['deconv_weight'] = conv_args['conv_weight']
deconv_args_grad = [mx.nd.zeros(deconv_data.shape),
mx.nd.zeros((num_filter, input_shape[1]) + kernel)]
deconv_addto_args_grad_npy = [np.random.normal(size=deconv_data.shape),
np.random.normal(size=(num_filter, input_shape[1]) + kernel)]
deconv_addto_args_grad = [mx.nd.array(deconv_addto_args_grad_npy[0]),
mx.nd.array(deconv_addto_args_grad_npy[1])]
exe_deconv = deconv._bind(default_device(), args=deconv_args, args_grad=deconv_args_grad)
exe_deconv.forward(is_train=True)
deconv_out_grad = conv_data[:]
exe_deconv.backward(deconv_out_grad)
assert_almost_equal(conv_args_grad[1], deconv_args_grad[1], rtol=1e-3, atol=1e-2)
# Test AddTo
exe_deconv_addto = deconv._bind(default_device(), args=deconv_args,
args_grad=deconv_addto_args_grad,
grad_req="add")
exe_deconv_addto.forward(is_train=True)
deconv_out_grad = conv_data[:]
exe_deconv_addto.backward(deconv_out_grad)
assert_almost_equal(conv_args_grad[1].asnumpy() + deconv_addto_args_grad_npy[1],
deconv_addto_args_grad[1].asnumpy(), rtol=1e-3, atol=1e-2)
def check_deconvolution_target_shape(input_shape, kernel, stride, pad, adj, target_shape=None):
data = mx.sym.Variable(name="data")
if target_shape:
deconv = mx.sym.Deconvolution(
data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5,
target_shape = target_shape)
else:
deconv = mx.sym.Deconvolution(
data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5)
arg_names = deconv.list_arguments()
arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape)
default_target_size = 8
if target_shape is None:
target_shape = (default_target_size,) * len(kernel)
assert out_shapes[0] == (input_shape[0], 5) + target_shape
@pytest.mark.serial
def test_deconvolution():
# 2D
check_deconvolution_target_shape(
input_shape = (2,3,4,4),
kernel = (3,3),
stride = (2,2),
target_shape = (8,8),
pad = (99,99), # will be ignored
adj = (101,101), # will be ignored
)
check_deconvolution_target_shape(
input_shape = (2,3,4,4),
kernel = (3,3),
stride = (2,2),
pad = (1,1),
adj = (1,1),
)
check_deconvolution_forward_backward(
input_shape = (1,1,5,5),
num_filter = 1,
kernel = (3,3),
stride = (1,1),
pad = (1,1)
)
check_deconvolution_forward_backward(
input_shape = (32,3,28,28),
num_filter = 3,
kernel = (3,3),
stride = (1,1),
pad = (1,1)
)
check_deconvolution_forward_backward(
input_shape = (10, 3, 403, 403),
num_filter = 3,
kernel = (7,7),
stride = (5,5),
pad = (2,2)
)
check_deconvolution_gradient(
input_shape = (1,3,5,5),
num_filter = 3,
pad = (1,1)
)
check_deconvolution_gradient(
input_shape = (5,3,100,100),
num_filter = 3,
pad = (3,3)
)
# 1D
check_deconvolution_target_shape(
input_shape = (2,3,4),
kernel = (3,),
stride = (2,),
target_shape = (8,),
pad = (99,), # will be ignored
adj = (101,), # will be ignored
)
check_deconvolution_target_shape(
input_shape = (2,3,4),
kernel = (3,),
stride = (2,),
pad = (1,),
adj = (1,),
)
check_deconvolution_forward_backward(
input_shape = (1,1,5),
num_filter = 1,
kernel = (3,),
stride = (1,),
pad = (1,)
)
check_deconvolution_forward_backward(
input_shape = (32,3,28),
num_filter = 3,
kernel = (3,),
stride = (1,),
pad = (1,)
)
check_deconvolution_forward_backward(
input_shape = (10, 3, 403),
num_filter = 3,
kernel = (7,),
stride = (5,),
pad = (2,)
)
check_deconvolution_gradient(
input_shape = (1,3,5),
num_filter = 3,
pad = (1,)
)
check_deconvolution_gradient(
input_shape = (5,3,100),
num_filter = 3,
pad = (3,)
)
@pytest.mark.parametrize('shape,num_filter,num_group,kernel,pad', [
((1, 4, 15), 16, 2, (2,), (0,)),
((8, 4, 16), 16, 1, (3,), (1,)),
((1, 4, 15, 16), 16, 2, (2, 2), (0, 0)),
((8, 4, 16, 16), 16, 1, (3, 3), (1, 1)),
((1, 4, 3, 15, 16), 16, 2, (2, 2, 2), (0, 0, 0)),
((8, 4, 3, 16, 16), 16, 1, (3, 3, 3), (1, 1, 1))])
def test_deconvolution_forward_with_bias(shape, num_filter, num_group, kernel, pad):
"""Check if deconvolution forward can work well with bias=True
"""
if len(kernel) == 3 and mx.current_context().device_type == 'gpu':
pytest.skip('Skipping Conv3DTranspose tests for GPU')
x = mx.sym.Variable('x')
w = mx.sym.Variable('w')
b = mx.sym.Variable('b')
y_nb = mx.sym.Deconvolution(data=x, weight=w, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=True, pad=pad)
y_b = mx.sym.Deconvolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=False, pad=pad)
exe_nb = y_nb._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null')
exe_b = y_b._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null')
data = np.random.uniform(-5, 5, size=exe_b.arg_arrays[0].shape)
weights = np.random.normal(size=exe_b.arg_arrays[1].shape)
bias = np.random.normal(size=exe_b.arg_arrays[2].shape)
def exe_forward(exe):
exe.arg_arrays[0][:] = data
exe.arg_arrays[1][:] = weights
if len(exe.arg_arrays) == 3:
exe.arg_arrays[2][:] = bias
return exe.forward(is_train=False)[0].asnumpy()
out_nb = exe_forward(exe_nb)
out_b = exe_forward(exe_b)
bias = np.broadcast_to(bias, [np.prod(out_nb.shape[2:])] + [num_filter]).T
bias = np.broadcast_to(bias.reshape((num_filter, *out_nb.shape[2:])), out_b.shape)
assert_almost_equal(out_nb + bias, out_b)
def check_nearest_upsampling_with_shape(shapes, scale, root_scale):
arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_device()) for i, shape in zip(range(len(shapes)), shapes)}
arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)}
up = mx.sym.UpSampling(*[mx.sym.Variable('arg_%d'%i) for i in range(len(shapes))], sample_type='nearest', scale=root_scale)
exe = up._bind(default_device(), args=arr, args_grad=arr_grad)
exe.forward(is_train=True)
exe.backward(exe.outputs)
for k in range(len(shapes)):
name = 'arg_%d'%k
assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4)
def check_bilinear_upsampling_with_shape(data_shape, weight_shape, scale, root_scale, num_filter):
def _init_bilinear(arr, f):
weight = np.zeros(np.prod(arr.shape), dtype='float32')
shape = arr.shape
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % shape[3]
y = (i // shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)
return arr
up = mx.sym.UpSampling(mx.sym.Variable("data"),
mx.sym.Variable('weight'), sample_type='bilinear', scale=root_scale,
num_filter=num_filter, num_args=2)
arg_shapes, out_shapes, _ = up.infer_shape(data=data_shape)
arr = {'data': mx.random.uniform(-5, 5, data_shape, ctx=mx.cpu()).copyto(default_device()),
'weight': mx.nd.array(_init_bilinear(mx.ndarray.empty(arg_shapes[1]).asnumpy(), root_scale))}
arr_grad = [mx.nd.empty(s) for s in arg_shapes]
exe = up._bind(default_device(), args=arr, args_grad=arr_grad)
exe.forward(is_train=True)
out = exe.outputs[0].asnumpy()
exe.backward(exe.outputs)
target_shape = (data_shape[2] * root_scale, data_shape[3] * root_scale)
assert out.shape == data_shape[:2] + target_shape
def test_nearest_upsampling():
for root_scale in [1,2,3]:
for scale in [1,2,3]:
for num_shape in [1,2,3]:
for base in [1,2,3]:
shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)]
check_nearest_upsampling_with_shape(shapes, scale, root_scale)
def test_bilinear_upsampling():
rootscale = [2,3]
scales = [1,2,3]
filters = [1,2,3]
bases = [1,2,3]
for params in itertools.product(rootscale, scales, filters, bases):
root_scale, scale, num_filter, base = params
# bilinear upsampling takes only 1 data and 1 weight
# multi input mode is not applicable
dimension = base*root_scale*scale
kernel = 2 * root_scale - root_scale % 2
data_shape = (1, num_filter, dimension, dimension)
weight_shape = (1, num_filter, kernel, kernel)
check_bilinear_upsampling_with_shape(data_shape, weight_shape, scale, root_scale, num_filter)
def test_batchnorm_training():
def check_batchnorm_training(stype):
for shape in [(2, 3), (2, 3, 2, 2), (2, 8, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
beta = np.ones(s)
gamma[1] = 3
beta[0] = 3
rolling_mean = np.random.uniform(size=s)
rolling_std = np.random.uniform(size=s)
data = mx.symbol.Variable('data', stype=stype)
in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]
mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]
test = mx.symbol.BatchNorm(data, fix_gamma=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
test = mx.symbol.BatchNorm(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
# Test varying channel axis
dim = len(shape)
for chaxis in range(-dim, dim):
chaxis_true = chaxis
if chaxis < 0:
chaxis_true = dim + chaxis
shapex = shape
channel_count = shapex[chaxis_true]
data_tmp = np.random.normal(-0.1, 0.1, size=shapex)
gamma = np.ones(channel_count)
beta = np.ones(channel_count)
if channel_count > 1:
gamma[1] = 3
beta[0] = 3
in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]
xrolling_mean = np.random.uniform(size=channel_count)
xrolling_std = np.random.uniform(size=channel_count)
xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
mx.nd.array(xrolling_std).tostype(stype)]
test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
check_batchnorm_training('default')
@xfail_when_nonstandard_decimal_separator
@pytest.mark.parametrize('op_name', ['BatchNorm', 'SyncBatchNorm'])
@pytest.mark.parametrize('shape', [(4, 2), (4, 3, 4),
(4, 6, 4, 5), (4, 5, 6, 4, 5)])
@pytest.mark.parametrize('fix_gamma', [False, True])
@pytest.mark.parametrize('cudnn_off', [False, True])
@pytest.mark.parametrize('output_mean_var', [False, True])
def test_batchnorm(op_name, shape, fix_gamma, cudnn_off, output_mean_var):
if op_name == 'BatchNorm':
op = mx.nd.BatchNorm
elif op_name == 'SyncBatchNorm':
op = mx.nd.contrib.SyncBatchNorm
else:
raise ValueError(f'Not supported {op_name}')
momentum = 0.9
epsilon = 1e-5
def _test_batchnorm_impl(axis,
data_grad_req, gamma_grad_req, beta_grad_req):
kwargs = dict(output_mean_var=output_mean_var)
if op_name == 'SyncBatchNorm':
if axis != 1:
return
key = str(op) + str(shape) + str(axis)
kwargs.update(dict(key=key))
if cudnn_off:
return
else:
kwargs.update(dict(axis=axis, cudnn_off=cudnn_off))
nch = shape[axis]
if not fix_gamma:
bn_gamma = mx.nd.random.uniform(shape=(nch,))
bn_gamma.attach_grad(grad_req=gamma_grad_req)
else:
bn_gamma = mx.nd.ones(shape=(nch,))
bn_beta = mx.nd.random.uniform(shape=(nch,))
bn_beta.attach_grad(grad_req=beta_grad_req)
bn_running_mean = mx.nd.zeros(nch)
bn_running_var = mx.nd.ones(nch)
running_mean = mx.nd.zeros(nch)
running_var = mx.nd.ones(nch)
num_iters = 10
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]
data = mx.nd.random.uniform(shape=shape)
data.attach_grad(grad_req=data_grad_req)
adX, adW, adb = 0, 0, 0
is_train = data_grad_req != 'null' or \
(not fix_gamma and gamma_grad_req != 'null') or \
beta_grad_req != 'null'
for _ in range(num_iters):
if data_grad_req != 'add':
data = mx.nd.random.uniform(shape=shape)
data.attach_grad(grad_req=data_grad_req)
ograd = mx.nd.random.uniform(shape=shape)
with mx.autograd.record():
output = op(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var,
momentum=momentum, eps=epsilon,
fix_gamma=fix_gamma, **kwargs)
if output_mean_var:
output, output_mean, output_std = output
if is_train:
output.backward(ograd)
mx.nd.waitall()
data_mean = data.mean(
axis=axis, exclude=True, keepdims=True)
data_var = (data - data_mean).square().mean(axis=axis,
exclude=True,
keepdims=True)
target_output = (data - data_mean) / \
(data_var + epsilon).sqrt() * \
bn_gamma.reshape(expand_shape) + \
bn_beta.reshape(expand_shape)
# squeeze data_mean and data_var
data_mean_flat = data_mean.squeeze()
data_var_flat = data_var.squeeze()
running_mean = running_mean * momentum + \
data_mean_flat * (1 - momentum)
m = np.prod(shape) / shape[axis]
# cudnn uses m-1 in the denominator of its sample variance calculation, not m
sample_var_adjust = 1.0 if cudnn_off or fix_gamma else m / (m-1)
running_var = running_var * momentum + \
data_var_flat * sample_var_adjust * (1 - momentum)
W = bn_gamma.reshape(expand_shape)
dnx = ograd * W
xsm = data - data_mean
nd = 1.0 / mx.nd.sqrt(data_var + epsilon)
nx = xsm * nd
dvar = (dnx * xsm).sum(axis=axis, keepdims=True,
exclude=True) * (-0.5) * mx.nd.power(nd, 3)
dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \
dvar * xsm.mean(axis=axis, keepdims=True,
exclude=True) * 2.0
dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
dW = (ograd * nx).sum(axis=axis, exclude=True)
db = ograd.sum(axis=axis, exclude=True)
adX = dX if data_grad_req != 'add' else adX + dX
adW = dW if gamma_grad_req != 'add' else adW + dW
adb = db if beta_grad_req != 'add' else adb + db
atol, rtol = 5e-2, 5e-2
if output_mean_var:
assert_almost_equal(output_mean.asnumpy(),
data_mean_flat.asnumpy(),
atol=atol, rtol=rtol)
if op != mx.nd.contrib.SyncBatchNorm:
assert_almost_equal(output_std.asnumpy(),
(1.0 / (data_var_flat +
epsilon).sqrt()).asnumpy(),
atol=atol, rtol=rtol)
else:
assert_almost_equal(output_std.asnumpy(),
data_var_flat.asnumpy(),
atol=atol, rtol=rtol)
assert_almost_equal(output.asnumpy(), target_output.asnumpy(),
atol=atol, rtol=rtol)
if is_train:
assert_almost_equal(bn_running_mean.asnumpy(
), running_mean.asnumpy(), atol=atol, rtol=rtol)
assert_almost_equal(bn_running_var.asnumpy(
), running_var.asnumpy(), atol=atol, rtol=rtol)
if data_grad_req != 'null':
assert_almost_equal(data.grad.asnumpy(),
adX.asnumpy(), atol=atol, rtol=rtol)
if not fix_gamma:
if gamma_grad_req != 'null':
assert_almost_equal(
bn_gamma.grad.asnumpy(), adW.asnumpy(),
atol=atol, rtol=rtol)
else:
assert((bn_gamma.asnumpy() == 1).all())
if beta_grad_req != 'null':
assert_almost_equal(
bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)
grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add']
for data_grad_req in grad_reqs:
for gamma_grad_req in grad_reqs:
if fix_gamma and gamma_grad_req != 'null':
continue
for beta_grad_req in grad_reqs:
for axis in range(len(shape)):
_test_batchnorm_impl(axis,
data_grad_req, gamma_grad_req, beta_grad_req)
def test_groupnorm():
acc_types = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64'}
def x_hat_helper(x, num_groups, eps):
dtype = x.dtype
dshape = x.shape
assert len(dshape) == 4
acc_type = acc_types[str(dtype)]
new_shape = (dshape[0], num_groups, int(dshape[1] / num_groups), dshape[2], dshape[3])
new_moments_shape = (dshape[0], num_groups, 1, 1, 1)
data = x.reshape(new_shape)
mean = np.mean(data, axis=(2, 3, 4), keepdims=False, dtype=acc_type).astype(dtype)
std = np.sqrt(np.var(data, axis=(2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + eps)
x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape)
return x_hat, mean, std
def np_groupnorm(data, gamma, beta, num_groups, eps):
new_param_shape = (1, dshape[1], 1, 1)
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out, mean, std
def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
new_shape = x_hat.shape
dshape = data.shape
dtype = data.dtype
new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
new_param_shape = (1, dshape[1], 1, 1)
acc_type = acc_types[str(dtype)]
ograd = ograd.reshape(new_shape)
data = data.reshape(new_shape)
gamma = gamma.reshape(new_param_shape)
beta = beta.reshape(new_param_shape)
mean = mean.reshape(new_moments_shape)
std = std.reshape(new_moments_shape)
beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1)
ograd_mult = x_hat_grad / std
red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = ograd_mult - red_out
red_out = np.mean(ograd_mult * x_hat, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = data_grad - x_hat * red_out
return data_grad.reshape(dshape), gamma_grad, beta_grad
batch_size = random.randint(1, 8)
num_groups = random.randint(2, 3)
num_channels = random.randint(2, 3) * num_groups
height = random.randint(1, 5)
width = random.randint(1, 5)
dshape = (batch_size, num_channels, height, width)
param_shape = (num_channels,)
temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
np_data = np.random.uniform(0.2, 1.0, dshape)
np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
np_beta = np.random.uniform(-1.0, 1.0, param_shape)
data_sym = mx.sym.Variable("data")
gamma_sym = mx.sym.Variable("gamma")
beta_sym = mx.sym.Variable("beta")
for dtype in [np.float16, np.float32, np.float64]:
eps = 1e-2 if dtype == np.float16 else 1e-5
mx_data = mx.nd.array(np_data, dtype=dtype)
mx_gamma = mx.nd.array(np_gamma, dtype=dtype)
mx_beta = mx.nd.array(np_beta, dtype=dtype)
np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype),
np_gamma.astype(dtype),
np_beta.astype(dtype),
num_groups=num_groups,
eps=eps)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=True)
check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std],
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=5e-3 if dtype == np.float16 else 1e-4, dtype=dtype)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=False)
np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype)
np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd,
np_data.astype(dtype),
np_gamma.astype(dtype),
np_beta.astype(dtype),
np_mean, np_std,
num_groups, eps)
check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd, dtype=np_ograd.dtype)],
[np_data_grad, np_gamma_grad, np_beta_grad],
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=5e-2 if dtype == np.float16 else 1e-4, dtype=dtype)
@pytest.mark.serial
def test_convolution_grouping():
for dim in [1, 2, 3]:
num_filter = 4
for num_group in [1, 2]:
kernel = (3,) * dim
shape = (1, 4) + (9,) * dim
x = mx.sym.Variable('x')
w = mx.sym.Variable('w')
b = mx.sym.Variable('b')
y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel)
xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1)
wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0)
bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0)
y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i],
num_filter=num_filter//num_group, kernel=kernel)
for i in range(num_group)])
exe1 = y1._simple_bind(default_device(), x=shape)
exe2 = y2._simple_bind(default_device(), x=shape, w=(num_filter, shape[1]//num_group) + kernel, b=(num_filter,))
for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays):
arr1[:] = np.random.normal(size=arr1.shape).astype(effective_dtype(mx.nd.array([1.,])))
arr2[:] = arr1
exe1.forward(is_train=True)
exe1.backward(exe1.outputs[0])
exe2.forward(is_train=True)
exe2.backward(exe2.outputs[0])
for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
assert_almost_equal(arr1, arr2)
@pytest.mark.skip(reason="Flaky test https://github.com/apache/incubator-mxnet/issues/14052")
def test_depthwise_convolution():
for dim in [1,2]:
for num_base in [1, 4, 16, 32, 64]:
for kernel_x in [3, 5]:
for stride_x in [1, 2]:
for pad_x in [0, 1]:
for in_size in [7, 32]:
kernel = (kernel_x,) * dim
stride = (stride_x,) * dim
pad = (pad_x,) * dim
num_filter = num_base
num_group = num_base
shape = (2, num_base) + (in_size,) * dim
x = mx.sym.Variable('x')
w = mx.sym.Variable('w')
b = mx.sym.Variable('b')
y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group,
kernel=kernel, stride=stride, pad=pad)
xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1)
wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0)
bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0)
y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i],
num_filter=num_filter//num_group, kernel=kernel,
stride=stride, pad=pad)
for i in range(num_group)])
dev = default_device()
exe1 = y1._simple_bind(dev, x=shape)
exe2 = y2._simple_bind(dev, x=shape, w=(num_filter, shape[1]//num_group)+kernel,
b=(num_filter,))
for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays):
arr1[:] = np.random.normal(size=arr1.shape)
arr2[:] = arr1
exe1.forward(is_train=True)
exe1.backward(exe1.outputs[0])
exe2.forward(is_train=True)
exe2.backward(exe2.outputs[0])
for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
assert_allclose(arr1, arr2, rtol=1e-3, atol=1e-3)
def test_convolution_independent_gradients():
# NOTE(zixuanweeei): Flaky test tracked by https://github.com/apache/incubator-mxnet/issues/15603.
# GPU context will be enabled after figuring out the possible issue tracked at
# https://github.com/apache/incubator-mxnet/issues/15638.
ctx = mx.cpu()
atol = 1.0e-3
rtol = 1.0e-3
reqs = ["null", "write", "add"]
var_names = ["x", "w", "b"]
dims = [1, 2]
num_bases = [1, 8]
kernel_xs = [3, 5]
stride_xs = [1, 2]
pad_xs = [0, 1]
in_sizes = [7, 32]
no_biases = [True, False]
for dim, num_base, kernel_x, stride_x, pad_x , in_size, no_bias in \
itertools.product(dims, num_bases, kernel_xs, stride_xs, pad_xs, in_sizes, no_biases):
# Prepare params shape
kernel = (kernel_x,) * dim
stride = (stride_x,) * dim
pad = (pad_x,) * dim
num_filter = num_base
x_shape = (2, num_base) + (in_size,) * dim
w_shape = (num_filter, num_base) + kernel
# Symbols definition
x = mx.sym.Variable('x')
w = mx.sym.Variable('w')
b = mx.sym.Variable('b') if not no_bias else None
conv = mx.sym.Convolution(x, w, b, num_filter=num_filter,
kernel=kernel, stride=stride, pad=pad, no_bias=no_bias)
for req_kind in reqs:
# Binding args for conv with possible dependent gradients
base_args = {
'x': mx.nd.random.normal(shape=x_shape, ctx=ctx),
'w': mx.nd.random.normal(shape=w_shape, ctx=ctx),
'b': mx.nd.random.normal(shape=(num_filter, ), ctx=ctx) if not no_bias else None}
args1 = copy.deepcopy(base_args)
grad1 = {
'x': mx.nd.zeros(shape=x_shape, ctx=ctx),
'w': mx.nd.zeros(shape=w_shape, ctx=ctx),
'b': mx.nd.zeros(shape=(num_filter, ), ctx=ctx) if not no_bias else None}
grad_req1 = [req_kind] * 3
grad_req1 = dict(zip(var_names, grad_req1))
exe1 = conv._bind(ctx, args1, args_grad=grad1, grad_req=grad_req1)
exe1.forward(is_train=True)
exe1.backward(exe1.outputs[0])
for x_req, w_req, b_req in itertools.product(reqs, repeat=3):
# Binding args for conv with independent gradients
args2 = copy.deepcopy(base_args) # Deepcopy the same params of `exe1`
grad2 = {
'x': mx.nd.zeros(shape=x_shape, ctx=ctx),
'w': mx.nd.zeros(shape=w_shape, ctx=ctx),
'b': mx.nd.zeros(shape=(num_filter, ), ctx=ctx) if not no_bias else None}
grad_req2 = {"x": x_req, "w": w_req, "b": b_req}
exe2 = conv._bind(ctx, args2, args_grad=grad2, grad_req=grad_req2)
exe2.forward(is_train=True)
np.testing.assert_allclose(exe1.outputs[0].asnumpy(),
exe2.outputs[0].asnumpy(), rtol=rtol, atol=atol)
exe2.backward(exe2.outputs[0])
for var_name in var_names:
if var_name == "b" and no_bias:
continue
if grad_req2[var_name] == "null":
exe2_var_grad = grad2[var_name].asnumpy()
np.testing.assert_allclose(exe2_var_grad,
np.zeros_like(exe2_var_grad), rtol=rtol, atol=atol)
if grad_req2[var_name] != grad_req1[var_name]:
continue
np.testing.assert_allclose(args1[var_name].asnumpy(),
args2[var_name].asnumpy(), rtol=rtol, atol=atol)
np.testing.assert_allclose(grad1[var_name].asnumpy(),
grad2[var_name].asnumpy(), rtol=rtol, atol=atol)
def gen_broadcast_data(idx):
# Manually set test cases
binary_op_data_shape = np.array(
[[[2, 5, 1, 30, 7], [1, 5, 448, 30, 1]],
[[10, 49, 1, 77, 17], [10, 1, 2, 1, 17]],
[[13, 2, 65, 2, 1], [13, 1, 65, 1, 225]],
[[9, 434, 4, 2, 37], [9, 1, 4, 1, 37]],
[[2, 52, 1, 4, 1], [1, 52, 60, 1, 37]],
[[1, 23, 7, 122, 50], [2, 1, 7, 1, 50]],
[[1, 17, 1, 5, 1], [22, 1, 2, 1, 28]],
[[29, 1, 2, 1, 8], [29, 22, 1, 130, 1]],
[[2, 36, 1, 427, 3], [1, 36, 11, 427, 1]],
[[1, 2, 1, 100, 7], [1, 2, 448, 100, 1]],
[[1, 2, 495, 77, 7], [1, 2, 1, 1, 7]],
[[1, 43, 65, 2, 1], [1, 43, 65, 1, 225]],
[[1, 92, 434, 2, 2], [1, 92, 1, 2, 2]],
[[1, 92, 1, 4, 1], [1, 92, 134, 1, 17]],
[[1, 53, 2, 122, 143], [1, 1, 2, 1, 143]],
[[1, 179, 1, 87, 17], [1, 179, 1, 1, 17]],
[[1, 1, 17, 5, 1], [1, 22, 1, 1, 28]],
[[1, 2, 1, 1, 8], [1, 2, 52, 430, 1]],
[[1, 163, 1, 22, 3], [1, 163, 116, 22, 1]],
[[1, 1, 44, 30, 7], [1, 1, 44, 30, 1]],
[[1, 1, 1, 1, 28], [1, 127, 1, 5, 28]],
[[1, 2, 394, 38, 1], [1, 2, 394, 38, 16]],
[[1, 10, 49, 77, 17], [1, 1, 1, 1, 17]],
[[1, 431, 6, 2, 225], [1, 1, 6, 2, 225]],
[[1, 15, 1, 28, 1], [1, 15, 1, 28, 463]],
[[1, 129, 2, 48, 96], [1, 129, 2, 1, 1]],
[[1, 1, 403, 17, 2], [1, 44, 403, 17, 2]],
[[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]],
[[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]],
[[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]],
[[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]],
[[8, 1, 6, 1], [7, 1, 5]], [[5, 4], [1]],
[[256, 256, 3], [3]], [[5, 4], [4]],
[[15, 3, 5], [3, 5]], [[15, 3, 5], [1, 5]],
[[15, 3, 5], [3, 1]], [[1,1,1,1], [1,1]],
[[15,3], [4, 1, 3]], [[7, 1, 5], [8, 1, 6, 1]]])
if idx < binary_op_data_shape.shape[0]:
l_shape = binary_op_data_shape[idx][0]
r_shape = binary_op_data_shape[idx][1]
else:
# Generate random data that has ndim between 1-7 and all the shape dims between 1-5
ndim = np.random.randint(1, 6)
shape = np.random.randint(1, 6, size=(ndim,))
l_same_dim = np.random.randint(0, 5)
r_same_dim = np.random.randint(0, 5)
l_axis_flags = np.random.randint(0, 2, size=ndim)
r_axis_flags = np.random.randint(0, 2, size=ndim)
if l_same_dim == 4:
l_axis_flags = np.ones(ndim)
if r_same_dim == 4:
r_axis_flags = np.ones(ndim)
l_shape = shape.copy()
r_shape = shape.copy()
l_shape[np.where(l_axis_flags == 0)] = 1
r_shape[np.where(r_axis_flags == 0)] = 1
return [np.random.random(l_shape), np.random.random(r_shape)]
def gen_broadcast_data_int(idx):
d = gen_broadcast_data(idx)
return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)]
def gen_binary_data(dummy):
ndim = np.random.randint(1, 6)
shape = np.random.randint(1, 6, size=(ndim,))
#print("gen shape {}".format(shape))
return [np.random.random(shape), np.random.random(shape)]
def gen_binary_data_int(dummy):
d = gen_binary_data(dummy)
return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)]
def check_binary_op_forward(symbol, baseline, gen_data, rtol=1e-3, atol=1e-5, mx_nd_func=None):
sample_num = 200
for i in range(sample_num):
d = gen_data(i)
y = symbol._bind(default_device(), args={'a': mx.nd.array(d[0]), 'b': mx.nd.array(d[1])})
y.forward(is_train=True)
y = y.outputs[0].asnumpy()
x = baseline(d[0], d[1]).astype(y.dtype)
#np.set_printoptions(precision=20)
a = d[0]
b = d[1]
#print("a: {} {}".format(a.dtype, a))
#print("a: {} {}".format(b.dtype, b))
#print("x: {} {}".format(x.dtype, x))
#print("y: {} {}".format(y.dtype, y))
if mx_nd_func is not None:
d0 = mx.nd.array(d[0], dtype=d[0].dtype)
d1 = mx.nd.array(d[1], dtype=d[1].dtype)
assert_almost_equal(y, mx_nd_func(d0, d1).asnumpy(), rtol=rtol, atol=atol)
idx = np.abs(x-y) > atol+rtol*np.abs(x)
if idx.any():
import binascii
np.set_printoptions(precision=20)
logging.error('found precision problem:')
d[0] = np.broadcast_to(d[0], x.shape)
d[1] = np.broadcast_to(d[1], x.shape)
logging.error('input a: {}'.format(d[0][idx]))
logging.error('input b: {}'.format(d[1][idx]))
logging.error("output x: {} {}".format(x.dtype, x))
logging.error("output y: {} {}".format(y.dtype, y))
def ftohex(xs):
import struct
return list(map(lambda x: binascii.hexlify(struct.pack('d', x)), xs.flatten()))
logging.error('output x in baseline(a, b): {}'.format(x[idx]))
logging.error('output y in symbol(a, b): {}'.format(y[idx]))
logging.error('output x in baseline(a,b) hex: {}'.format(ftohex(x[idx])))
logging.error('output y in symbol(a,b) hex: {}'.format(ftohex(y[idx])))
logging.error('input a hex: {}'.format(ftohex(d[0][idx])))
logging.error('input a hex: {}'.format(ftohex(d[1][idx])))
logging.error('diff: {}'.format(np.abs(x-y)[idx] - atol-rtol*np.abs(x)[idx]))
assert_allclose(y, x, rtol=rtol, atol=atol)
def check_binary_op_backward(symbol, baseline, gen_data, rtol=1e-3, atol=1e-5):
sample_num = 200
for i in range(sample_num):
d = gen_data(i)
out = np.random.random((d[0] + d[1]).shape)
def reduce_op(shape, x):
if shape == x.shape:
return x
keepdims_shape = list(x.shape)
# calculate difference between output and input ndims
# to include cases where inputs' ndims are not equal
ndim_diff = len(x.shape) - len(shape)
for i in range(ndim_diff):
keepdims_shape[i] = 1
x = np.sum(x, axis=i).reshape(keepdims_shape)
for i in range(len(shape)):
if x.shape[ndim_diff + i] != shape[i]:
keepdims_shape[ndim_diff + i] = 1
x = np.sum(x, axis=ndim_diff + i).reshape(keepdims_shape)
return x
baseline_grad1, baseline_grad2 = baseline(out, d[0], d[1])
x_1 = reduce_op(d[0].shape, baseline_grad1)
x_2 = reduce_op(d[1].shape, baseline_grad2)
y_1 = mx.nd.empty(d[0].shape)
y_2 = mx.nd.empty(d[1].shape)
y = symbol._bind(default_device(), args={'a': mx.nd.array(d[0]), 'b': mx.nd.array(d[1])},
args_grad=[y_1, y_2])
o = y.forward(is_train=True)
y.backward([mx.nd.array(out, dtype=o[0].dtype)])
assert_allclose(y_1.asnumpy(), x_1, rtol=rtol, atol=atol)
assert_allclose(y_2.asnumpy(), x_2, rtol=rtol, atol=atol)
def test_binary_op():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
def test_bplus(a, b):
c = a + b
check_binary_op_forward(c, lambda a, b: a + b, gen_binary_data)
check_binary_op_backward(c, lambda g_out, a, b: (g_out, g_out), gen_binary_data)
def test_bminus(a, b):
c = a - b
check_binary_op_forward(c, lambda a, b: a - b, gen_binary_data)
check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out), gen_binary_data)
def test_bmul(a, b):
c = a * b
check_binary_op_forward(c, lambda a, b: a * b, gen_binary_data)
check_binary_op_backward(c, lambda g_out, a, b: (g_out * b, g_out * a), gen_binary_data)
def test_bdiv(a, b):
c = a / b
check_binary_op_forward(c, lambda a, b: a / b, gen_binary_data)
check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_binary_data)
def test_bmod(a, b):
# Python and numpy operate only in double so to avoid numerical errors we have to use
# doubles as well. This was a flaky test before when using float32. seed 1688524483, 1768433044
#c = a % b
c = mx.sym.cast(a, dtype='float64') % mx.sym.cast(b, dtype='float64')
# '%' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32.
check_binary_op_forward(c, lambda a, b: np.float32(a) % np.float32(b), gen_binary_data, rtol=0, atol=0)
check_binary_op_backward(c,
lambda g_out, a, b: (g_out, - g_out * (np.float32(a) // np.float32(b))), gen_binary_data)
def test_bmod_int(a, b):
c = mx.sym.cast(a, dtype='int32') % mx.sym.cast(b, dtype='int32')
check_binary_op_forward(c, lambda a, b: a % b, gen_binary_data_int)
check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_binary_data_int)
def test_bpow(a, b):
c = a ** b
check_binary_op_forward(c, lambda a, b: a ** b, gen_binary_data)
check_binary_op_backward(c, lambda g_out, a, b: (g_out * a **(b - 1) * b,
g_out * a ** b * np.log(a)), gen_binary_data)
def test_bneq(a, b):
c = a != b
# '!=' is sensitive to the precision of the comparison. Force numpy to match mxnet's float32.
# Issue exposed with seed 1644387363
check_binary_op_forward(c, lambda a, b: (np.float32(a) != np.float32(b)).astype(a.dtype), gen_binary_data)
check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_binary_data)
test_bplus(a, b)
test_bminus(a, b)
test_bmul(a, b)
test_bdiv(a, b)
test_bmod(a, b)
test_bmod_int(a, b)
test_bpow(a, b)
test_bneq(a, b)
def test_broadcast_binary_op():
def check_bmaxmin_gradient(test_sym, x, y, delta, rtol, atol):
"""This function ensures that checking the numerical gradient of
broadcast_max/min is not crossing the boundary y=x where there
is no gradient definition at those sigularities."""
x_max = np.max(x)
y = x_max + 2 * delta + np.random.random(y.shape)
check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol)
x_min = np.min(x)
y = x_min - 2 * delta - np.random.random(y.shape)
check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol)
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
def test_bplus(a, b):
c = mx.sym.broadcast_plus(a, b)
check_binary_op_forward(c, lambda a, b: a + b, gen_broadcast_data, mx_nd_func=mx.nd.add)
check_binary_op_backward(c, lambda g_out, a, b: (g_out, g_out), gen_broadcast_data)
def test_bminus(a, b):
c = mx.sym.broadcast_minus(a, b)
check_binary_op_forward(c, lambda a, b: a - b, gen_broadcast_data, mx_nd_func=mx.nd.subtract)
check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out), gen_broadcast_data)
def test_bmul(a, b):
c = mx.sym.broadcast_mul(a, b)
check_binary_op_forward(c, lambda a, b: a * b, gen_broadcast_data, mx_nd_func=mx.nd.multiply)
check_binary_op_backward(c, lambda g_out, a, b: (g_out * b, g_out * a), gen_broadcast_data)
def test_bdiv(a, b):
c = mx.sym.broadcast_div(a, b)
check_binary_op_forward(c, lambda a, b: a / b, gen_broadcast_data, mx_nd_func=mx.nd.divide)
check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_broadcast_data)
def test_bmod(a_, b_):
# Python and numpy operate only in double so to avoid numerical errors we have to use
# doubles as well. This was a flaky test before when using float32. seed 1688524483, 1768433044
a = mx.sym.cast(a_, dtype='float64')
b = mx.sym.cast(b_, dtype='float64')
# '%' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32.
c = mx.sym.broadcast_mod(a, b)
check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data, atol=1, mx_nd_func=mx.nd.modulo)
check_binary_op_backward(c,
lambda g_out, a, b: (g_out, - g_out * (np.float32(a) // np.float32(b))), gen_binary_data)
def test_bmod_int(a, b):
c = mx.sym.broadcast_mod(mx.sym.cast(a, dtype='int32'), mx.sym.cast(b, dtype='int32'))
check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data_int, mx_nd_func=mx.nd.modulo)
check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_broadcast_data_int)
def tes