| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: skip-file |
| from __future__ import print_function |
| from __future__ import division |
| import numpy as np |
| import mxnet as mx |
| import copy |
| import math |
| import random |
| import itertools |
| from distutils.version import LooseVersion |
| from numpy.testing import assert_allclose, assert_array_equal |
| from mxnet.test_utils import * |
| from mxnet.operator import * |
| from mxnet.base import py_str, MXNetError, _as_list |
| from common import assert_raises_cudnn_not_satisfied, assert_raises_cuda_not_satisfied, assertRaises |
| from common import xfail_when_nonstandard_decimal_separator, with_environment |
| import pytest |
| import os |
| |
| @assert_raises_cudnn_not_satisfied(min_version='5.1.10') |
| @pytest.mark.serial |
| def test_rnn_with_new_param(): |
| rnn_modes = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm'] |
| ngates_ = [1, 1, 3, 4] |
| num_layers, input_size, seq_len, batch_size, state_size = 3, 128, 5, 64, 8 |
| for bidirectional in [False, True]: |
| directions = 2 if bidirectional else 1 |
| for mode, ngates in zip(rnn_modes, ngates_): |
| first_layer_size = (input_size * state_size + state_size * state_size + state_size * 2) * ngates |
| rest_layer_size = (state_size * directions * state_size + state_size * state_size + state_size * 2) \ |
| * ngates * (num_layers - 1) |
| param_size = (first_layer_size + rest_layer_size) * directions |
| sym = mx.sym.RNN(mode=mode, num_layers=num_layers, bidirectional=bidirectional, |
| state_outputs=False, state_size=state_size, name='rnn') |
| |
| bind_dict = { |
| 'rnn_data': mx.ndarray.random.uniform(low=-1, high=1, shape=(seq_len, batch_size, input_size)), |
| 'rnn_parameters': mx.ndarray.random.uniform(low=-1, high=1, shape=(param_size)), |
| 'rnn_state': mx.ndarray.zeros(shape=(num_layers * directions, batch_size, state_size)) |
| } |
| if mode == 'lstm': |
| bind_dict['rnn_state_cell'] = mx.ndarray.zeros( |
| shape=(num_layers * directions, batch_size, state_size)) |
| |
| ex = sym._bind(default_device(), bind_dict) |
| ex.forward(is_train=True) |
| ex01 = ex.output_dict['rnn_output'].asnumpy() |
| ex.forward(is_train=False) |
| ex02 = ex.output_dict['rnn_output'].asnumpy() |
| assert_allclose(ex01, ex02, rtol=1e-2, atol=1e-4) |
| bind_dict['rnn_parameters'] = mx.ndarray.random.uniform(low=-1, high=1, shape=(param_size)) |
| ex.copy_params_from(bind_dict) |
| ex.forward(is_train=True) |
| ex03 = ex.output_dict['rnn_output'].asnumpy() |
| ex.forward(is_train=False) |
| ex04 = ex.output_dict['rnn_output'].asnumpy() |
| assert_allclose(ex03, ex04, rtol=1e-2, atol=1e-4) |
| |
| |
| @pytest.mark.serial |
| def test_lstm_dropout(): |
| X = mx.sym.Variable('x') |
| Params = mx.sym.Variable('params') |
| HX = mx.sym.Variable('state') |
| CX = mx.sym.Variable('state_cell') |
| T, N, I, H = 300, 20, 800, 800 |
| rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX, |
| state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM') |
| exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I)) |
| out = exe.forward(is_train=True) |
| out[0].wait_to_read() |
| |
| @pytest.mark.serial |
| def test_gru_dropout(): |
| X = mx.sym.Variable('x') |
| Params = mx.sym.Variable('params') |
| HX = mx.sym.Variable('state') |
| T, N, I, H = 300, 20, 800, 800 |
| rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, |
| state_size=H, num_layers=5, mode='gru', p=0.5, state_outputs=True, name='GRU') |
| exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I)) |
| out = exe.forward(is_train=True) |
| out[0].wait_to_read() |
| |
| @pytest.mark.serial |
| def test_rnntanh_dropout(): |
| X = mx.sym.Variable('x') |
| Params = mx.sym.Variable('params') |
| HX = mx.sym.Variable('state') |
| T, N, I, H = 300, 20, 800, 800 |
| rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, |
| state_size=H, num_layers=5, mode='rnn_tanh', p=0.5, state_outputs=True, name='RNN_TANH') |
| exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I)) |
| out = exe.forward(is_train=True) |
| out[0].wait_to_read() |
| |
| @pytest.mark.serial |
| def test_rnnrelu_dropout(): |
| X = mx.sym.Variable('x') |
| Params = mx.sym.Variable('params') |
| HX = mx.sym.Variable('state') |
| T, N, I, H = 300, 20, 800, 800 |
| rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, |
| state_size=H, num_layers=5, mode='rnn_relu', p=0.5, state_outputs=True, name='RNN_RELU') |
| exe = rnn._simple_bind(ctx=mx.cpu(), x=(T, N, I)) |
| out = exe.forward(is_train=True) |
| out[0].wait_to_read() |
| |
| def test_RNN_float64(): |
| if default_device().device_type == 'gpu': |
| return |
| sym = mx.sym.RNN( |
| mx.sym.Variable('in'), |
| mx.sym.Variable('par'), |
| mx.sym.Variable('s'), |
| state_size = (2), |
| num_layers = 1, |
| mode = 'rnn_tanh' |
| ) |
| |
| dtype = 'float64' |
| explicit_grad = { |
| 'in': mx.nd.ones([2, 1, 2], dtype=dtype), |
| 'par': mx.nd.ones([12], dtype=dtype), |
| 's': mx.nd.ones([1, 1, 2], dtype=dtype) |
| } |
| |
| args_grad = explicit_grad |
| grad_req = 'write' |
| |
| ex = sym._bind(default_device(), |
| { |
| 'in': mx.nd.ones([2, 1, 2], dtype=dtype), |
| 'par': mx.nd.ones([12], dtype=dtype), |
| 's': mx.nd.ones([1, 1, 2], dtype=dtype) |
| }, |
| args_grad = args_grad, |
| grad_req = grad_req |
| ) |
| ex.forward() |
| ex.outputs[0].wait_to_read() |
| |
| def np_softmax(x, axis=-1, temperature=1.0): |
| x = x - np.max(x, axis=axis, keepdims=True) |
| x = np.exp(x/temperature) |
| x /= np.sum(x, axis=axis, keepdims=True) |
| return x |
| |
| |
| def check_elementwise_sum_with_shape(shape, n): |
| # forward |
| inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] |
| out = mx.symbol.ElementWiseSum(*inputs, name='esum') |
| arr = [mx.nd.empty(shape) for i in range(n)] |
| arr_grad = [mx.nd.empty(shape) for i in range(n)] |
| for i in range(n): |
| arr[i][:] = np.random.uniform(-10, 10, shape) |
| exec1 = out._bind(default_device(), |
| args=arr, |
| args_grad=arr_grad) |
| |
| exec1.forward(is_train=True) |
| out1 = exec1.outputs[0] |
| out = sum(a.asnumpy() for a in arr) |
| assert_almost_equal(out, out1, rtol=1e-5, atol=1e-5) |
| |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = np.random.uniform(-10, 10, shape) |
| # backward |
| exec1.backward([out_grad]) |
| for a in arr_grad: |
| assert_almost_equal(a, out_grad, rtol=1e-5, atol=1e-5) |
| |
| |
| @pytest.mark.serial |
| def test_elementwise_sum(): |
| nrepeat = 2 |
| maxdim = 4 |
| for _ in range(nrepeat): |
| for dim in range(1, maxdim): |
| shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) |
| check_elementwise_sum_with_shape(shape, np.random.randint(1, 8)) |
| |
| |
| def check_concat_with_shape(shapes, dimension, skip_second): |
| # if skip_second is True, second argument will not have gradient. |
| # it is to test #1130 |
| n = len(shapes) |
| # forward |
| target_dim = 0 |
| for shape in shapes: |
| target_dim += shape[dimension] |
| |
| inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] |
| out = mx.symbol.Concat(*inputs, name='conc',dim=dimension) |
| arr = [mx.nd.empty(shape) for shape in shapes] |
| for i in range(n): |
| arr[i][:] = shapes[i][dimension] |
| arr_np = [np.copy(narray.asnumpy()) for narray in arr] |
| arr_grad = [mx.nd.empty(shape) for shape in shapes] |
| dict_grad = {} |
| arg_names = out.list_arguments() |
| |
| for name, g in zip(arg_names, arr_grad): |
| if not skip_second or name != 'arg1': |
| dict_grad[name] = g |
| |
| args = out.list_arguments() |
| arg_shapes, out_shapes, aux_shapes = out.infer_shape(**dict(zip(args, shapes))) |
| out_grad = mx.nd.empty(out_shapes[0]) |
| exec1 = out._bind(default_device(), |
| args=arr, |
| args_grad=dict_grad) |
| exec1.forward(is_train=True) |
| out1 = exec1.outputs[0] |
| ret = np.concatenate([narray.asnumpy() for narray in arr], axis=dimension) |
| assert_almost_equal(out1, ret) |
| # backward |
| out1.copyto(out_grad) |
| out_grad[:] += 1 |
| exec1.backward([out_grad]) |
| |
| for i, name in enumerate(arg_names): |
| if not skip_second or name != 'arg1': |
| grad = dict_grad[name] |
| np_grad = arr_np[i] |
| assert_almost_equal(grad, np_grad + 1) |
| |
| |
| def test_concat(): |
| for dimension in range(4): |
| n = 2 |
| merge = [2, 3, 4, 5, 6] |
| a = 2 |
| b = 3 |
| c = 4 |
| # test 2D |
| if dimension<2: |
| for dim in range(2, 6): |
| shapes = [] |
| for i in range(dim): |
| if dimension == 0: |
| shapes.append((merge[i], a)) |
| elif dimension == 1: |
| shapes.append((a, merge[i])) |
| check_concat_with_shape(shapes,dimension,True) |
| check_concat_with_shape(shapes,dimension,False) |
| # Test negative dim |
| check_concat_with_shape(shapes, dimension - 2, True) |
| check_concat_with_shape(shapes, dimension - 2, False) |
| |
| #test 3D |
| if dimension<3: |
| for dim in range(2, 6): |
| shapes = [] |
| for i in range(dim): |
| if dimension == 0: |
| shapes.append((merge[i], a,b)) |
| elif dimension ==1: |
| shapes.append((a,merge[i],b)) |
| elif dimension ==2: |
| shapes.append((a,b,merge[i])) |
| check_concat_with_shape(shapes,dimension,True) |
| check_concat_with_shape(shapes,dimension,False) |
| # Test negative dim |
| check_concat_with_shape(shapes, dimension - 3, True) |
| check_concat_with_shape(shapes, dimension - 3, False) |
| # test 4D |
| for dim in range(2, 6): |
| shapes = [] |
| for i in range(dim): |
| if dimension == 0: |
| shapes.append((merge[i],a,b,c)) |
| elif dimension == 1: |
| shapes.append((a,merge[i],b,c)) |
| elif dimension ==2: |
| shapes.append((a,b,merge[i],c)) |
| elif dimension ==3: |
| shapes.append((a,b,c,merge[i])) |
| check_concat_with_shape(shapes,dimension,True) |
| check_concat_with_shape(shapes,dimension,False) |
| # Test negative dim |
| check_concat_with_shape(shapes, dimension - 4, True) |
| check_concat_with_shape(shapes, dimension - 4, False) |
| |
| def test_slice_channel(): |
| def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis): |
| ins = [] |
| if squeeze_axis: |
| shape = np.random.randint(2, 5, data_ndim).tolist() |
| shape[axis] = num_outputs |
| out_ele_shape = [ele for ele in shape] |
| del out_ele_shape[axis] |
| else: |
| shape = np.random.randint(1, 5, data_ndim).tolist() |
| shape[axis] *= num_outputs |
| out_ele_shape = [ele for ele in shape] |
| out_ele_shape[axis] //= num_outputs |
| data_npy = np.random.normal(size=shape) |
| out_grads_npy = [np.random.normal(size=out_ele_shape) for i in range(num_outputs)] |
| data = mx.sym.Variable('data') |
| sym = mx.sym.SliceChannel(data=data, num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis) |
| exe = sym._simple_bind(ctx=default_device(), data=data_npy.shape) |
| outputs = exe.forward(is_train=True, data=data_npy) |
| assert len(exe.outputs) == num_outputs |
| for i in range(num_outputs): |
| gt = data_npy.take(np.arange(i * shape[axis]/num_outputs, |
| (i+1) * shape[axis]/num_outputs).astype(np.int), axis=axis) |
| if squeeze_axis: |
| assert_almost_equal(outputs[i], gt.reshape(outputs[i].shape)) |
| else: |
| assert_almost_equal(outputs[i], gt) |
| # test backward |
| ograd = [mx.nd.array(ele, dtype=outputs[i].dtype) for i, ele in enumerate(out_grads_npy)] |
| exe.backward(out_grads=ograd) |
| if squeeze_axis: |
| assert_almost_equal(exe.grad_arrays[0], |
| np.concatenate([np.expand_dims(ele, axis=axis) for ele in out_grads_npy], |
| axis=axis)) |
| else: |
| assert_almost_equal(exe.grad_arrays[0], |
| np.concatenate(out_grads_npy, axis=axis)) |
| check_slice_channel(data_ndim=2, axis=1, num_outputs=3, squeeze_axis=True) |
| check_slice_channel(data_ndim=4, axis=2, num_outputs=3, squeeze_axis=False) |
| check_slice_channel(data_ndim=3, axis=-1, num_outputs=2, squeeze_axis=False) |
| check_slice_channel(data_ndim=5, axis=-2, num_outputs=3, squeeze_axis=True) |
| |
| |
| def test_python_op(): |
| X = mx.symbol.Variable('X') |
| op = mx.operator.NumpyOp() |
| s = op.get_symbol(X, name='numpy_op') |
| |
| x = mx.ndarray.ones((10))*10 |
| dx = mx.ndarray.zeros((10)) |
| dy = mx.ndarray.ones((10)) |
| exec1 = s._bind(default_device(), args=[x], args_grad = {'X': dx}) |
| exec1.forward(is_train=True) |
| assert_almost_equal(x, exec1.outputs[0]) |
| exec1.backward(dy) |
| assert_almost_equal(dy, dx) |
| |
| |
| def test_swapaxes(): |
| data = mx.symbol.Variable('data') |
| shape = (2, 3, 4) |
| data_tmp = np.ones(shape) |
| data_tmp[0] = 1 |
| data_tmp[1] = 2 |
| arr_data = mx.nd.array(data_tmp) |
| swap0 = mx.symbol.SwapAxis(data=data, dim1=0, dim2=2) |
| swap = mx.symbol.SwapAxis(data=swap0, dim1=1, dim2=2) |
| exe_c = swap._bind(default_device(), args=[arr_data]) |
| exe_c.forward(is_train=True) |
| out = exe_c.outputs[0] |
| |
| swap0_ = np.swapaxes(data_tmp, 0, 2) |
| swap_ = np.swapaxes(swap0_, 1, 2) |
| |
| assert_almost_equal(out, swap_) |
| |
| config = [((1, 1, 2), 0, 1), |
| ((1, 1, 2), -1, -2), |
| ((4, 5, 6, 7), 1, 1), |
| ((4, 5, 6, 7), 2, 3), |
| ((4, 5, 6, 7), -2, 2), |
| ((4, 5, 6, 7), -2, -3)] |
| |
| for shape, axis1, axis2 in config: |
| data_np = np.random.uniform(size=shape) |
| data_mx = mx.nd.array(data_np, dtype=data_np.dtype) |
| ret_np = np.swapaxes(data_np, axis1=axis1, axis2=axis2) |
| ret_mx = mx.symbol.SwapAxis(data, dim1=axis1, dim2=axis2) |
| exe_c = ret_mx._bind(default_device(), args=[data_mx]) |
| exe_c.forward(is_train=True) |
| out = exe_c.outputs[0] |
| assert_almost_equal(out, ret_np) |
| |
| |
| @xfail_when_nonstandard_decimal_separator |
| def test_scalarop(): |
| data = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp = np.ones(shape)*5 |
| arr_data = mx.nd.array(data_tmp) |
| arr_grad = mx.nd.empty(shape) |
| arr_grad[:]=3 |
| |
| test = 2 / (4-((1+data+1)*2/5)-0.8-(data!=0)) |
| |
| npout_1 = (4-((1+data_tmp+1)*2/5)-0.8-(data_tmp!=0)) |
| npout = 2/npout_1 |
| |
| check_symbolic_forward(test, [data_tmp], [npout]) |
| |
| npout_grad = 2.*2/5 |
| npout_grad = 2*npout_grad /(npout_1 *npout_1 ) |
| |
| check_symbolic_backward(test, [data_tmp], [np.ones(shape)*2], [npout_grad]) |
| |
| |
| def test_scalar_pow(): |
| data = mx.symbol.Variable('data') |
| shape = (1, 1) |
| data_tmp = np.ones(shape) |
| test = data ** 2 |
| check_numeric_gradient(test, [data_tmp]) |
| check_symbolic_forward(test, [data_tmp], [data_tmp ** 2]) |
| check_symbolic_backward(test, [data_tmp], [np.ones(shape)], [2 * data_tmp]) |
| |
| |
| def test_symbol_pow(): |
| shape = (1, 1) |
| |
| data = mx.symbol.Variable('data') |
| data_tmp = np.ones(shape)*2 |
| |
| exp = mx.symbol.Variable('exp') |
| exp_tmp = np.ones(shape)*3 |
| |
| test = data**exp |
| |
| check_numeric_gradient(test, [data_tmp, exp_tmp]) |
| check_symbolic_forward(test, [data_tmp, exp_tmp], [data_tmp**exp_tmp]) |
| |
| data_dir = data_tmp**(exp_tmp - 1) * exp_tmp |
| exp_dir = data_tmp**(exp_tmp) * np.log(data_tmp) |
| check_symbolic_backward(test, [data_tmp, exp_tmp], [np.ones(shape)], [data_dir, exp_dir]) |
| |
| |
| def test_fully_connected(): |
| # Create data of given shape as a uniform distribution centered on 0.0 |
| def random_data(shape, dtype=np.float32): |
| return mx.nd.random.uniform(low=-0.5, |
| high=0.5, shape=shape, dtype=dtype) |
| data = mx.sym.var("data") |
| fc_weight = mx.sym.var("weight") |
| fc_bias = mx.sym.var("bias") |
| fc = mx.sym.FullyConnected(data=data, weight=fc_weight, bias=fc_bias, num_hidden=10, no_bias=False, name='fc') |
| |
| data = random_data(shape=(5, 5, 5, 13)) |
| fc_weight = random_data(shape=(10, 325)) |
| fc_bias = random_data(shape=(10)) |
| fc_bias2 = random_data(shape=(10, 1)) |
| |
| data_np = data.asnumpy().reshape(5, 325) |
| fc_weight_np = np.transpose(fc_weight.asnumpy()) |
| fc_bias_np = fc_bias.asnumpy() |
| res = np.dot(data_np, fc_weight_np) + fc_bias.asnumpy() |
| check_symbolic_forward(fc, {'data': data_np, 'weight': fc_weight.asnumpy(), 'bias': fc_bias_np}, {'fc_output': res}) |
| check_numeric_gradient(fc, {'data': data_np, 'weight': fc_weight.asnumpy(), 'bias': fc_bias_np}) |
| # TODO: Fix Bug #15032 when bias has ndim > 1 |
| #check_symbolic_forward(fc, {'data': data_np, 'weight': fc_weight.asnumpy(), 'bias': fc_bias2.asnumpy()}, {'fc_output': res}) |
| |
| |
| def test_pow_fn(): |
| shape = (3, 4) |
| exp = mx.symbol.Variable("exp") |
| x = np.ones(shape)*3 |
| for y in [mx.sym.pow(2, exp), mx.sym.power(2, exp)]: |
| check_numeric_gradient(y, [x], numeric_eps=1E-3) |
| check_symbolic_forward(y, [x], [2**x]) |
| check_symbolic_backward(y, [x], [np.ones(shape)], [np.log(2) * 2**x]) |
| |
| |
| def test_relu(): |
| def frelu(x): |
| return np.maximum(x, 0.0) |
| def frelu_grad(x): |
| return np.float32(1.0) * (x > np.float32(0.0)) |
| shape = (3, 4) |
| x = mx.symbol.Variable("x") |
| y = mx.sym.relu(x) |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype('float32') |
| eps = 1e-4 |
| # Avoid finite difference method inaccuracies due to discontinuous gradient at the origin. |
| # Here we replace small problematic inputs with 1.0. Repro issue with seed 97264195. |
| xa[abs(xa) < eps] = 1.0 |
| ya = frelu(xa) |
| ga = frelu_grad(xa) |
| check_numeric_gradient(y, [xa], numeric_eps=eps) |
| check_symbolic_forward(y, [xa], [ya]) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [ga]) |
| |
| |
| # NOTE(haojin2): Skipping the numeric check tests for float16 data type due to precision issues, |
| # the analytical checks are still performed on each and every data type to verify the correctness. |
| def test_leaky_relu(): |
| def fleaky_relu(x, act_type, slope=0.25): |
| neg_indices = x < 0 |
| out = x.copy() |
| if act_type == 'elu': |
| out[neg_indices] = slope * np.expm1(out[neg_indices]) |
| elif act_type == 'leaky': |
| out[neg_indices] = slope * out[neg_indices] |
| return out |
| def fleaky_relu_grad(grad, x, y, act_type, slope=0.25): |
| neg_indices = x < 0 |
| out = np.ones(x.shape) |
| if act_type == 'elu': |
| out[neg_indices] = y[neg_indices] + slope |
| elif act_type == 'leaky': |
| out[neg_indices] = slope |
| return out * grad |
| for ndim in range(1, 4): |
| shape = rand_shape_nd(ndim) |
| x = mx.symbol.Variable("x") |
| slp = 0.25 |
| for dtype in [np.float16, np.float32, np.float64]: |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype) |
| eps = 1e-4 |
| rtol = 1e-2 |
| atol = 1e-3 |
| xa[abs(xa) < eps] = 1.0 |
| for act_type in ['elu', 'leaky']: |
| y = mx.symbol.LeakyReLU(data=x, slope=slp, act_type=act_type) |
| ya = fleaky_relu(xa, slope=slp, act_type=act_type) |
| ga = fleaky_relu_grad(np.ones(shape), xa, ya, slope=slp, act_type=act_type) |
| # Skip numeric check for float16 type to get rid of flaky behavior |
| if dtype is not np.float16: |
| check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_backward(y, [xa], [np.ones(shape, dtype=dtype)], [ga], rtol=rtol, atol=atol, dtype=dtype) |
| |
| |
| # NOTE(haojin2): Skipping the numeric check tests for float16 data type due to precision issues, |
| # the analytical checks are still performed on each and every data type to verify the correctness. |
| def test_prelu(): |
| def fprelu(x, gamma): |
| pos_indices = x > 0 |
| out = x.copy() |
| if len(x.shape) == 4: |
| out = out.transpose(2,3,0,1) |
| out = np.multiply(out, gamma) |
| out = out.transpose(2,3,0,1) |
| else: |
| out = np.multiply(out, gamma) |
| out[pos_indices] = x[pos_indices] |
| return out |
| def fprelu_grad(x, y, gamma): |
| pos_indices = x > 0 |
| if len(x.shape) == 4: |
| grad_x = np.multiply(np.ones(x.shape).transpose(2,3,0,1), gamma) |
| grad_x = grad_x.transpose(2,3,0,1) |
| else: |
| grad_x = np.multiply(np.ones(x.shape), gamma) |
| grad_gam = np.zeros(gamma.shape) |
| copy_x = x.copy() |
| copy_x[pos_indices] = 0.0 |
| grad_x[pos_indices] = 1.0 |
| if len(gamma.shape) > 1 and len(x.shape) != 4: |
| grad_gam = copy_x |
| elif len(gamma.shape) > 1 and len(x.shape) == 4: |
| grad_gam = np.sum(copy_x, axis=(2,3)) |
| elif gamma.shape[0] == 1: |
| grad_gam = np.sum(np.sum(copy_x)) |
| elif gamma.shape[0] > 1 and len(x.shape) != 4: |
| grad_gam = np.sum(copy_x, axis=0) |
| elif gamma.shape[0] > 1 and len(x.shape) == 4: |
| grad_gam = np.sum(copy_x, axis=(0,2,3)) |
| return (grad_x, grad_gam) |
| x = mx.symbol.Variable("x") |
| gamma = mx.symbol.Variable("gamma") |
| for shape in [(3,4), (3,4,4,5)]: |
| for dtype in [np.float16, np.float32, np.float64]: |
| for gam in [np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype)]: |
| gam_full = np.array([gam, gam, gam]) |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape).astype(dtype) |
| rtol = 1e-2 |
| atol = 1e-3 |
| eps = 1e-4 |
| xa[abs(xa) < eps] = 1.0 |
| y = mx.symbol.LeakyReLU(data=x, gamma=gamma, act_type='prelu') |
| ya = fprelu(xa, gam) |
| ya_full = fprelu(xa, gam_full) |
| g_xa, g_gam = fprelu_grad(xa, ya, gamma=gam) |
| g_xa_full, g_gam_full = fprelu_grad(xa, ya_full, gamma=gam_full) |
| # Skip numeric check for float16 type to get rid of flaky behavior |
| if dtype is not np.float16: |
| check_numeric_gradient(y, [xa, gam], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) |
| check_numeric_gradient(y, [xa, gam_full], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_forward(y, [xa, gam], [ya], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_backward(y, [xa, gam], [np.ones(ya.shape, dtype=dtype)], |
| [g_xa, g_gam], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_forward(y, [xa, gam_full], [ya_full], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_backward(y, [xa, gam_full], [np.ones(ya_full.shape, dtype=dtype)], |
| [g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype) |
| |
| def test_selu(): |
| alpha = 1.6732632423543772848170429916717 |
| lamb = 1.0507009873554804934193349852946 |
| def fselu(x): |
| neg_indices = x < 0 |
| out = x.copy() |
| out[neg_indices] = alpha * np.expm1(out[neg_indices]) |
| return out * lamb |
| def fselu_grad(grad, x, y): |
| neg_indices = x < 0 |
| out = np.ones(x.shape).astype(x.dtype) |
| out[neg_indices] = y[neg_indices] + alpha |
| return out * lamb |
| |
| shape = (3, 4) |
| x = mx.sym.Variable("x") |
| y = mx.sym.LeakyReLU(data=x, act_type="selu") |
| for dtype in [np.float16, np.float32, np.float64]: |
| xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype) |
| eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4) |
| if dtype is np.float16: |
| xa /= 10.0 |
| xa[abs(xa) < eps] = 0.01 |
| ya = fselu(xa) |
| ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya) |
| check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_backward(y, [xa], [np.ones(shape, dtype=dtype)], [ga], rtol=rtol, atol=atol, dtype=dtype) |
| |
| |
| def test_gelu(): |
| CUBE_CONSTANT = 0.044715 |
| ROOT_TWO_OVER_PI = 0.7978845608028654 |
| def g(x): |
| return ROOT_TWO_OVER_PI * (x + CUBE_CONSTANT * np.power(x, 3)) |
| def g_grad(x): |
| return ROOT_TWO_OVER_PI * (1.0 + 3.0 * CUBE_CONSTANT * np.power(x, 2)) |
| def f(x): |
| return 1.0 + np.tanh(g(x)) |
| def f_grad(x): |
| return (1.0 - np.tanh(g(x)) * np.tanh(g(x))) * g_grad(x) |
| def fgelu(x): |
| return 0.5 * x * f(x) |
| def fgelu_grad(grad, x, y): |
| return grad * (y / x + y * (1 - np.tanh(g(x))) * g_grad(x)) |
| |
| shape = (3, 4) |
| x = mx.sym.Variable("x") |
| y = mx.sym.LeakyReLU(data=x, act_type="gelu") |
| for dtype in [np.float16, np.float32, np.float64]: |
| xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype) |
| eps, rtol, atol = (7.5e-4, 2e-2, 1e-3) if dtype is np.float16 else (1e-4, 1e-3, 1e-5) |
| if dtype is np.float16: |
| xa /= 10.0 |
| xa[abs(xa) < eps] = 0.01 |
| ya = fgelu(xa) |
| ga = fgelu_grad(np.ones(shape).astype(dtype), xa, ya) |
| check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype) |
| |
| |
| def test_sigmoid(): |
| def fsigmoid(a): |
| return np.divide(1.0, (1.0 + np.exp(-a))) |
| shape = (3, 4) |
| x = mx.symbol.Variable("x") |
| y = mx.sym.sigmoid(x) |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape) |
| ya = fsigmoid(xa) |
| check_numeric_gradient(y, [xa], numeric_eps=1E-3) |
| check_symbolic_forward(y, [xa], [ya]) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)]) |
| |
| def test_log_sigmoid(): |
| def flog_sigmoid(a): |
| return np.log(np.divide(1.0, np.add(1.0, np.exp(-a)))) |
| def flog_sigmoid_grad(a): |
| return np.divide(1.0, np.add(1.0, np.exp(a))) |
| shape = (3, 4) |
| x = mx.symbol.Variable("x") |
| y = mx.sym.log_sigmoid(x) |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape) |
| ya = flog_sigmoid(xa) |
| ya_grad = flog_sigmoid_grad(xa) |
| check_numeric_gradient(y, [xa], numeric_eps=1E-3) |
| check_symbolic_forward(y, [xa], [ya]) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad]) |
| |
| def test_mish(): |
| def fmish(a): |
| return a * np.tanh(np.log1p(np.exp(a))) |
| def fmish_grad(a): |
| softrelu = np.log1p(np.exp(a)) |
| tanh = np.tanh(softrelu) |
| sigmoid = np.divide(1.0, (1.0 + np.exp(-a))) |
| return tanh + a * sigmoid * (1.0 - tanh * tanh) |
| shape = (3, 4) |
| x = mx.symbol.Variable("x") |
| y = mx.sym.mish(x) |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape) |
| ya = fmish(xa) |
| ya_grad = fmish_grad(xa) |
| check_numeric_gradient(y, [xa], numeric_eps=1E-3) |
| check_symbolic_forward(y, [xa], [ya]) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad]) |
| |
| def test_shape_array(): |
| for i in range(1,6): |
| shape = rand_shape_nd(i) |
| x = mx.sym.var('x') |
| y = mx.sym.shape_array(x) |
| xa = mx.nd.array(np.random.ranf(shape)) |
| xg = mx.nd.empty(xa.shape) |
| ya = np.shape(xa) |
| yg = mx.nd.ones(ya) |
| exe = y._bind(ctx=default_device(), args={'x': xa}, |
| args_grad={'x': xg}) |
| exe.forward(is_train=True) |
| exe.backward([yg]) |
| yo = exe.outputs[0].asnumpy() |
| same(yo, ya) |
| assert_almost_equal(xg, np.zeros_like(xg.asnumpy())) |
| |
| def test_size_array(): |
| for i in range(1,6): |
| shape = rand_shape_nd(i) |
| x = mx.sym.var('x') |
| y = mx.sym.size_array(x) |
| xa = mx.nd.array(np.random.ranf(shape)) |
| xg = mx.nd.empty(xa.shape) |
| ya = np.size(xa) |
| yg = mx.nd.ones(ya) |
| exe = y._bind(ctx=default_device(), args={'x': xa}, |
| args_grad={'x': xg}) |
| exe.forward(is_train=True) |
| exe.backward([yg]) |
| yo = exe.outputs[0].asnumpy() |
| same(yo, ya) |
| assert_almost_equal(xg, np.zeros_like(xg.asnumpy())) |
| |
| def test_hard_sigmoid(): |
| def fhardsigmoid(a, alpha=0.2, beta=0.5): |
| return np.maximum(np.zeros(a.shape, dtype=a.dtype), |
| np.minimum(np.ones(a.shape, dtype=a.dtype), alpha*a+beta)) |
| def fhardsigmoid_grad(a, out_grad, alpha=0.2, beta=0.5): |
| orig_out = fhardsigmoid(a, alpha, beta) |
| res = out_grad * alpha |
| res[orig_out <= 0.0] = 0.0 |
| res[orig_out >= 1.0] = 0.0 |
| return res |
| shape = (3, 4) |
| x = mx.symbol.Variable("x") |
| y = mx.sym.hard_sigmoid(x) |
| for dtype in [np.float16, np.float32, np.float64]: |
| if dtype is np.float16: |
| rtol = 1e-2 |
| else: |
| rtol = 1e-3 |
| atol = 1e-3 |
| eps = 1e-3 |
| xa = np.random.uniform(low=-3.0,high=3.0,size=shape).astype(dtype) |
| # function not differentiable at x=2.5 and -2.5 |
| xa[abs(xa-2.5) < eps] -= 2 * eps |
| xa[abs(xa+2.5) < eps] += 2 * eps |
| ya = fhardsigmoid(xa) |
| grad_xa = fhardsigmoid_grad(xa, np.ones(shape)) |
| if dtype is not np.float16: |
| check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [grad_xa], rtol=rtol, atol=atol, dtype=dtype) |
| |
| def test_softsign(): |
| def fsoftsign(a): |
| return np.divide(a, (1.0 + np.abs(a))) |
| def fsoftsign_grad(a): |
| return np.divide(1.0, np.square((1.0 + np.abs(a)))) |
| shape = (3, 4) |
| x = mx.symbol.Variable("x") |
| y = mx.sym.softsign(x) |
| xa = np.random.uniform(low=-1.0,high=1.0,size=shape) |
| ya = fsoftsign(xa) |
| ya_grad = fsoftsign_grad(xa) |
| check_numeric_gradient(y, [xa], numeric_eps=1E-3) |
| check_symbolic_forward(y, [xa], [ya]) |
| check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad]) |
| |
| def test_binary_logic(): |
| def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True): |
| x = mx.symbol.Variable("x") |
| y = mx.symbol.Variable("y") |
| z = logic_sym(x, y) |
| x_npy = np.random.randint(0, 4, size=x_shape).astype(np.float32) |
| y_npy = np.random.randint(0, 4, size=y_shape).astype(np.float32) |
| exe = z._simple_bind(ctx=default_device(), x=x_shape, y=y_shape) |
| mx_out = exe.forward(is_train=True, x=x_npy, y=y_npy)[0] |
| assert_almost_equal(mx_out, forward_gt(x_npy, y_npy)) |
| exe.backward() |
| if test_scalar: |
| z_lscalar = logic_sym(1, y) |
| z_rscalar = logic_sym(x, 1) |
| exe_lscalar = z_lscalar._simple_bind(ctx=default_device(), y=y_shape) |
| exe_rscalar = z_rscalar._simple_bind(ctx=default_device(), x=x_shape) |
| mx_lscalar_out = exe_lscalar.forward(is_train=True, y=y_npy)[0] |
| mx_rscalar_out = exe_rscalar.forward(is_train=True, x=x_npy)[0] |
| assert_almost_equal(mx_lscalar_out, forward_gt(1, y_npy)) |
| assert_almost_equal(mx_rscalar_out, forward_gt(x_npy, 1)) |
| exe_lscalar.backward() |
| exe_rscalar.backward() |
| # Test the no-broadcasting binary logic ops + scalar logic ops |
| _inner_test(forward_gt=lambda x, y: x == y, |
| logic_sym=lambda x, y: x == y, x_shape=(10, 10), y_shape=(10, 10)) |
| _inner_test(forward_gt=lambda x, y: x > y, |
| logic_sym=lambda x, y: x > y, x_shape=(10, 10), y_shape=(10, 10)) |
| _inner_test(forward_gt=lambda x, y: x >= y, |
| logic_sym=lambda x, y: x >= y, x_shape=(10, 10), y_shape=(10, 10)) |
| _inner_test(forward_gt=lambda x, y: x < y, |
| logic_sym=lambda x, y: x < y, x_shape=(10, 10), y_shape=(10, 10)) |
| _inner_test(forward_gt=lambda x, y: x <= y, |
| logic_sym=lambda x, y: x <= y, x_shape=(10, 10), y_shape=(10, 10)) |
| _inner_test(forward_gt=lambda x, y: x != y, |
| logic_sym=lambda x, y: x != y, x_shape=(10, 10), y_shape=(10, 10)) |
| # Test the broadcasting binary logic ops |
| _inner_test(forward_gt=lambda x, y: x == y, |
| logic_sym=lambda x, y: mx.sym.broadcast_equal(x, y), |
| x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) |
| _inner_test(forward_gt=lambda x, y: x > y, |
| logic_sym=lambda x, y: mx.sym.broadcast_greater(x, y), |
| x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) |
| _inner_test(forward_gt=lambda x, y: x >= y, |
| logic_sym=lambda x, y: mx.sym.broadcast_greater_equal(x, y), |
| x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) |
| _inner_test(forward_gt=lambda x, y: x < y, |
| logic_sym=lambda x, y: mx.sym.broadcast_lesser(x, y), |
| x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) |
| _inner_test(forward_gt=lambda x, y: x <= y, |
| logic_sym=lambda x, y: mx.sym.broadcast_lesser_equal(x, y), |
| x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) |
| _inner_test(forward_gt=lambda x, y: x != y, |
| logic_sym=lambda x, y: mx.sym.broadcast_not_equal(x, y), |
| x_shape=(1, 10), y_shape=(10, 1), test_scalar=False) |
| |
| |
| def test_unary_logic(): |
| def reference(a, dtype): |
| return np.logical_not(a).astype(dtype) |
| shape = (3, 4) |
| xa = np.random.randint(-2, 2, size=shape).astype(np.float32) |
| mx_xa = mx.nd.array(xa) |
| mx_out = mx.nd.logical_not(mx_xa) |
| assert_almost_equal(mx_out, reference(xa, dtype=xa.dtype)) |
| x = mx.sym.Variable('x') |
| y = mx.sym.logical_not(data=x) |
| exe = y._simple_bind(ctx=default_device(), x=shape) |
| sym_out = exe.forward(is_train=True, x=mx_xa)[0] |
| assert_almost_equal(sym_out, reference(xa, dtype=xa.dtype)) |
| |
| |
| def test_embedding(): |
| in_dim = 10 |
| out_dim = 4 |
| batch = 24 |
| |
| data = mx.sym.Variable("data") |
| embed = mx.sym.Embedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") |
| exe_test = embed._simple_bind(default_device(), grad_req={'data': 'null', 'embed_weight': 'write'}, data=(batch,)) |
| arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) |
| grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) |
| np_data = np.random.randint(low=0, high=in_dim, size=batch) |
| np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) |
| np_onehot = np.zeros((batch, in_dim)) |
| np_onehot[np.arange(batch), np_data] = 1.0 |
| # forward |
| arg_map["data"][:] = np_data |
| arg_map["embed_weight"][:] = np_weight |
| exe_test.forward(is_train=True) |
| # Non-zero atol required, as exposed by seed 781663739 |
| rtol = 1e-5 |
| atol = 1e-5 |
| assert_almost_equal(exe_test.outputs[0], np.dot(np_onehot, np_weight), rtol=rtol, atol=atol) |
| # backward |
| np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) |
| grad = mx.nd.zeros(np_grad.shape) |
| grad[:] = np_grad |
| exe_test.backward([grad]) |
| assert_almost_equal(grad_map["embed_weight"], np.dot(np_onehot.T, np_grad), rtol=rtol, atol=atol) |
| |
| |
| # check ops handle duplicate input correctly. |
| def test_binary_op_duplicate_input(): |
| data = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp = np.ones(shape) |
| data_tmp[:] = 5 |
| arr_data = mx.nd.array(data_tmp) |
| arr_grad = mx.nd.empty(shape) |
| arr_grad[:] = 3 |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = 1 |
| square = data * data |
| exe_square = square._bind(default_device(), args=[arr_data], args_grad=[arr_grad]) |
| exe_square.forward(is_train=True) |
| assert_almost_equal(exe_square.outputs[0], data_tmp * data_tmp) |
| exe_square.backward(out_grad) |
| assert_almost_equal(arr_grad, 2.0 * data_tmp) |
| |
| |
| def test_sign(): |
| data = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp = np.ones(shape) |
| data_tmp[:]=5 |
| arr_data = mx.nd.array(data_tmp) |
| arr_grad = mx.nd.empty(shape) |
| arr_grad[:]=3 |
| |
| test = mx.sym.sign(data) |
| exe_test = test._bind(default_device(), args=[arr_data], args_grad=[arr_grad]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| npout = np.sign(data_tmp) |
| assert_almost_equal(out, npout) |
| |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = 2 |
| npout_grad = out_grad.asnumpy() |
| npout_grad = 0 |
| exe_test.backward(out_grad) |
| assert_almost_equal(arr_grad, npout_grad) |
| |
| |
| def test_round_ceil_floor(): |
| data = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp = np.ones(shape) |
| data_tmp[:]=5.543 |
| arr_data = mx.nd.array(data_tmp) |
| arr_grad = mx.nd.empty(shape) |
| arr_grad[:]= 2 |
| |
| test = mx.sym.round(data) + mx.sym.ceil(data) + mx.sym.floor(data) |
| exe_test = test._bind(default_device(), args=[arr_data]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| npout = np.round(data_tmp) + np.ceil(data_tmp) + np.floor(data_tmp) |
| assert_almost_equal(out, npout) |
| |
| |
| def test_trunc(): |
| data_tmp = np.random.rand(3, 4) * 10 - 5 |
| arr_data = mx.nd.array(data_tmp) |
| data = mx.symbol.Variable('data') |
| test = mx.sym.trunc(data) |
| |
| exe_test = test._bind(default_device(), args=[arr_data]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| # 'trunc' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32. |
| # Repro issue with seed 1660190454 |
| npout = np.trunc(np.float32(data_tmp)) |
| |
| assert_almost_equal(out, npout) |
| |
| |
| def test_rsqrt_cos_sin(): |
| data = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp = np.ones(shape) |
| data_tmp[:]=5 |
| arr_data = mx.nd.array(data_tmp) |
| arr_grad = mx.nd.empty(shape) |
| arr_grad[:]=3 |
| |
| test = mx.sym.rsqrt(data) + mx.sym.cos(data) + mx.sym.sin(data) |
| exe_test = test._bind(default_device(), args=[arr_data], args_grad=[arr_grad]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| npout = 1/ np.sqrt(data_tmp) + np.cos(data_tmp) + np.sin(data_tmp) |
| assert_almost_equal(out, npout) |
| |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = 2 |
| npout_grad = out_grad.asnumpy() |
| npout_grad = npout_grad * -(1.0 / (2.0 * data_tmp * np.sqrt(data_tmp))) + npout_grad * -1 * np.sin(data_tmp) + npout_grad * np.cos(data_tmp) |
| exe_test.backward(out_grad) |
| assert_almost_equal(arr_grad, npout_grad) |
| |
| |
| def test_maximum_minimum(): |
| data1 = mx.symbol.Variable('data1') |
| data2 = mx.symbol.Variable('data2') |
| shape = (3, 4) |
| data_tmp1 = np.random.rand(3,4) |
| data_tmp2 = np.random.rand(3,4) |
| data_tmp1[:] = 2 |
| data_tmp2[:] = 3 |
| |
| arr_data1 = mx.nd.array(data_tmp1) |
| arr_data2 = mx.nd.array(data_tmp2) |
| |
| arr_grad1 = mx.nd.empty(shape) |
| arr_grad2 = mx.nd.empty(shape) |
| |
| test = mx.sym.maximum(data1,data2) + mx.sym.minimum(data1,data2) |
| exe_test = test._bind(default_device(), args=[arr_data1,arr_data2], args_grad=[arr_grad1,arr_grad2]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| npout = np.maximum(data_tmp1,data_tmp2) + np.minimum(data_tmp1,data_tmp2) |
| assert_almost_equal(out, npout) |
| |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = 2 |
| exe_test.backward(out_grad) |
| |
| npout_grad = np.ones(shape) |
| npout_grad[:] = 2 |
| mask1 = (data_tmp1 > data_tmp2).astype('float') |
| mask2 = (data_tmp1 < data_tmp2).astype('float') |
| npout_grad1 = npout_grad * mask1 + npout_grad * mask2 |
| npout_grad2 = (npout_grad - npout_grad * mask1) + (npout_grad - npout_grad * mask2) |
| |
| assert_almost_equal(arr_grad1, npout_grad1) |
| assert_almost_equal(arr_grad2, npout_grad2) |
| |
| |
| def test_maximum_minimum_scalar(): |
| data1 = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp1 = np.random.rand(3,4) |
| data_tmp1[:] = 2 |
| |
| arr_data1 = mx.nd.array(data_tmp1) |
| arr_grad1 = mx.nd.empty(shape) |
| |
| test = mx.sym.maximum(data1,3) + mx.sym.maximum(9,data1) + mx.sym.minimum(5,data1) + mx.sym.minimum(data1,4) |
| exe_test = test._bind(default_device(), args=[arr_data1], args_grad=[arr_grad1]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| npout = np.maximum(data_tmp1,3) + np.maximum(9,data_tmp1) + np.minimum(5,data_tmp1) + np.minimum(data_tmp1,4) |
| assert_almost_equal(out, npout) |
| |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = 2 |
| exe_test.backward(out_grad) |
| |
| npout_grad = np.ones(shape) |
| npout_grad[:] = 2 |
| mask1 = (data_tmp1 > 3).astype('float') |
| mask2 = (9 > data_tmp1).astype('float') |
| mask3 = (5 < data_tmp1).astype('float') |
| mask4 = (data_tmp1 < 4).astype('float') |
| npout_grad1 = npout_grad * mask1 + (npout_grad - npout_grad * mask2) + (npout_grad - npout_grad * mask3) + npout_grad * mask4 |
| |
| assert_almost_equal(arr_grad1, npout_grad1) |
| |
| |
| def test_abs(): |
| data = mx.symbol.Variable('data') |
| shape = (3, 4) |
| data_tmp = np.ones(shape) |
| data_tmp[:]=5 |
| arr_data = mx.nd.array(data_tmp) |
| arr_grad = mx.nd.empty(shape) |
| arr_grad[:]=3 |
| |
| test = mx.sym.abs(data) |
| exe_test = test._bind(default_device(), args=[arr_data], args_grad=[arr_grad]) |
| exe_test.forward(is_train=True) |
| out = exe_test.outputs[0] |
| npout = abs(data_tmp) |
| assert_almost_equal(out, npout) |
| |
| out_grad = mx.nd.empty(shape) |
| out_grad[:] = 2 |
| npout_grad = out_grad.asnumpy() |
| npout_grad = npout_grad * np.sign(data_tmp) |
| exe_test.backward(out_grad) |
| assert_almost_equal(arr_grad, npout_grad) |
| |
| |
| def check_deconvolution_forward_backward(input_shape, num_filter, kernel, stride, pad): |
| """configure A: input --> conv --> deconv --> output. |
| the convolution and deconvoluiton has similar parameter which ensure |
| the input shape is the same as output, and the same weights between conv |
| and deconv; |
| If the input value of forward() and backwrad() is the same, then |
| the output value of them should also the same; |
| """ |
| assert input_shape[1] == num_filter |
| data = mx.sym.Variable(name="data") |
| conv = mx.sym.Convolution( |
| data=data, kernel=kernel, stride=stride, pad=pad, |
| num_filter=num_filter, no_bias = "true", name = "conv") |
| deconv = mx.sym.Deconvolution( |
| data=conv, kernel=kernel, stride=stride, pad=pad, |
| num_filter=num_filter, no_bias = "true", name = "deconv") |
| |
| arg_names = deconv.list_arguments() |
| arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape) |
| input_data = mx.random.uniform(-5, 5, input_shape, ctx=mx.cpu()).copyto(default_device()) |
| out_grad = input_data |
| args = {} |
| args["data"] = input_data |
| args['conv_weight'] = args['deconv_weight'] = mx.random.normal(0, 1, |
| (num_filter, input_shape[1]) + kernel, ctx=mx.cpu()).copyto(default_device()) |
| args_grad = [mx.nd.empty(s) for s in arg_shapes] |
| |
| exe = deconv._bind(default_device(), args=args, args_grad=args_grad) |
| exe.forward(is_train=True) |
| out = exe.outputs[0] |
| exe.backward(out_grad) |
| assert_almost_equal(out, args_grad[0], rtol=1E-3, atol=1e-3) |
| |
| args_grad_addto_npy = [np.random.normal(size=s) for s in arg_shapes] |
| args_grad_addto = [mx.nd.array(ele) for ele in args_grad_addto_npy] |
| exe = deconv._bind(default_device(), args=args, args_grad=args_grad_addto, grad_req="add") |
| exe.forward(is_train=True) |
| out = exe.outputs[0].asnumpy() |
| exe.backward(out_grad) |
| assert_almost_equal(out + args_grad_addto_npy[0], args_grad_addto[0].asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| |
| def check_deconvolution_gradient(input_shape, num_filter, pad): |
| """configure A: input --> conv --> output. |
| configure B: input --> deconv --> output |
| the convolution and deconvoluiton has similar parameter which ensure |
| the input shape is the same as output; |
| During backward(), if the input of A equals output of B, and the output |
| of A equals input of B, then the grad of weight should be the same; |
| """ |
| ndim = len(pad) |
| stride = (1,) * ndim |
| kernel = tuple(2 * np.array(pad) + 1) |
| data_conv = mx.sym.Variable(name="data_conv") |
| conv = mx.sym.Convolution( |
| data=data_conv, kernel=kernel, stride=stride, pad=pad, |
| num_filter=num_filter, no_bias = "true", name = "conv") |
| data_deconv = mx.sym.Variable(name="data_deconv") |
| deconv = mx.sym.Deconvolution( |
| data=data_deconv, kernel=kernel, stride=stride, pad=pad, |
| num_filter=num_filter, no_bias = "true", name = "deconv") |
| |
| conv_data = mx.random.uniform(-5, 5, input_shape, ctx=mx.cpu()).copyto(default_device()) |
| conv_args = {} |
| conv_args["data_conv"] = conv_data |
| conv_args['conv_weight'] = \ |
| mx.random.normal(0, 1,(num_filter, input_shape[1]) + kernel, ctx=mx.cpu()).copyto(default_device()) |
| conv_args_grad = [mx.nd.zeros(conv_data.shape), |
| mx.nd.zeros((num_filter, input_shape[1]) + kernel)] |
| exe_conv = conv._bind(default_device(), args=conv_args, args_grad=conv_args_grad) |
| exe_conv.forward(is_train=True) |
| conv_out_grad = mx.random.normal(0, 2, exe_conv.outputs[0].shape, ctx=mx.cpu()).copyto(default_device()) |
| exe_conv.backward(conv_out_grad) |
| |
| deconv_data = conv_out_grad |
| deconv_args = {} |
| deconv_args['data_deconv'] = deconv_data |
| deconv_args['deconv_weight'] = conv_args['conv_weight'] |
| deconv_args_grad = [mx.nd.zeros(deconv_data.shape), |
| mx.nd.zeros((num_filter, input_shape[1]) + kernel)] |
| deconv_addto_args_grad_npy = [np.random.normal(size=deconv_data.shape), |
| np.random.normal(size=(num_filter, input_shape[1]) + kernel)] |
| deconv_addto_args_grad = [mx.nd.array(deconv_addto_args_grad_npy[0]), |
| mx.nd.array(deconv_addto_args_grad_npy[1])] |
| exe_deconv = deconv._bind(default_device(), args=deconv_args, args_grad=deconv_args_grad) |
| exe_deconv.forward(is_train=True) |
| deconv_out_grad = conv_data[:] |
| exe_deconv.backward(deconv_out_grad) |
| assert_almost_equal(conv_args_grad[1], deconv_args_grad[1], rtol=1e-3, atol=1e-2) |
| # Test AddTo |
| exe_deconv_addto = deconv._bind(default_device(), args=deconv_args, |
| args_grad=deconv_addto_args_grad, |
| grad_req="add") |
| exe_deconv_addto.forward(is_train=True) |
| deconv_out_grad = conv_data[:] |
| exe_deconv_addto.backward(deconv_out_grad) |
| assert_almost_equal(conv_args_grad[1].asnumpy() + deconv_addto_args_grad_npy[1], |
| deconv_addto_args_grad[1].asnumpy(), rtol=1e-3, atol=1e-2) |
| |
| |
| def check_deconvolution_target_shape(input_shape, kernel, stride, pad, adj, target_shape=None): |
| data = mx.sym.Variable(name="data") |
| if target_shape: |
| deconv = mx.sym.Deconvolution( |
| data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5, |
| target_shape = target_shape) |
| else: |
| deconv = mx.sym.Deconvolution( |
| data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5) |
| arg_names = deconv.list_arguments() |
| arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape) |
| default_target_size = 8 |
| if target_shape is None: |
| target_shape = (default_target_size,) * len(kernel) |
| assert out_shapes[0] == (input_shape[0], 5) + target_shape |
| |
| |
| @pytest.mark.serial |
| def test_deconvolution(): |
| # 2D |
| check_deconvolution_target_shape( |
| input_shape = (2,3,4,4), |
| kernel = (3,3), |
| stride = (2,2), |
| target_shape = (8,8), |
| pad = (99,99), # will be ignored |
| adj = (101,101), # will be ignored |
| ) |
| check_deconvolution_target_shape( |
| input_shape = (2,3,4,4), |
| kernel = (3,3), |
| stride = (2,2), |
| pad = (1,1), |
| adj = (1,1), |
| ) |
| check_deconvolution_forward_backward( |
| input_shape = (1,1,5,5), |
| num_filter = 1, |
| kernel = (3,3), |
| stride = (1,1), |
| pad = (1,1) |
| ) |
| check_deconvolution_forward_backward( |
| input_shape = (32,3,28,28), |
| num_filter = 3, |
| kernel = (3,3), |
| stride = (1,1), |
| pad = (1,1) |
| ) |
| check_deconvolution_forward_backward( |
| input_shape = (10, 3, 403, 403), |
| num_filter = 3, |
| kernel = (7,7), |
| stride = (5,5), |
| pad = (2,2) |
| ) |
| check_deconvolution_gradient( |
| input_shape = (1,3,5,5), |
| num_filter = 3, |
| pad = (1,1) |
| ) |
| check_deconvolution_gradient( |
| input_shape = (5,3,100,100), |
| num_filter = 3, |
| pad = (3,3) |
| ) |
| # 1D |
| check_deconvolution_target_shape( |
| input_shape = (2,3,4), |
| kernel = (3,), |
| stride = (2,), |
| target_shape = (8,), |
| pad = (99,), # will be ignored |
| adj = (101,), # will be ignored |
| ) |
| check_deconvolution_target_shape( |
| input_shape = (2,3,4), |
| kernel = (3,), |
| stride = (2,), |
| pad = (1,), |
| adj = (1,), |
| ) |
| check_deconvolution_forward_backward( |
| input_shape = (1,1,5), |
| num_filter = 1, |
| kernel = (3,), |
| stride = (1,), |
| pad = (1,) |
| ) |
| check_deconvolution_forward_backward( |
| input_shape = (32,3,28), |
| num_filter = 3, |
| kernel = (3,), |
| stride = (1,), |
| pad = (1,) |
| ) |
| check_deconvolution_forward_backward( |
| input_shape = (10, 3, 403), |
| num_filter = 3, |
| kernel = (7,), |
| stride = (5,), |
| pad = (2,) |
| ) |
| check_deconvolution_gradient( |
| input_shape = (1,3,5), |
| num_filter = 3, |
| pad = (1,) |
| ) |
| check_deconvolution_gradient( |
| input_shape = (5,3,100), |
| num_filter = 3, |
| pad = (3,) |
| ) |
| |
| @pytest.mark.parametrize('shape,num_filter,num_group,kernel,pad', [ |
| ((1, 4, 15), 16, 2, (2,), (0,)), |
| ((8, 4, 16), 16, 1, (3,), (1,)), |
| |
| ((1, 4, 15, 16), 16, 2, (2, 2), (0, 0)), |
| ((8, 4, 16, 16), 16, 1, (3, 3), (1, 1)), |
| |
| ((1, 4, 3, 15, 16), 16, 2, (2, 2, 2), (0, 0, 0)), |
| ((8, 4, 3, 16, 16), 16, 1, (3, 3, 3), (1, 1, 1))]) |
| def test_deconvolution_forward_with_bias(shape, num_filter, num_group, kernel, pad): |
| """Check if deconvolution forward can work well with bias=True |
| """ |
| if len(kernel) == 3 and mx.current_context().device_type == 'gpu': |
| pytest.skip('Skipping Conv3DTranspose tests for GPU') |
| |
| x = mx.sym.Variable('x') |
| w = mx.sym.Variable('w') |
| b = mx.sym.Variable('b') |
| y_nb = mx.sym.Deconvolution(data=x, weight=w, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=True, pad=pad) |
| y_b = mx.sym.Deconvolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=False, pad=pad) |
| |
| exe_nb = y_nb._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') |
| exe_b = y_b._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') |
| |
| data = np.random.uniform(-5, 5, size=exe_b.arg_arrays[0].shape) |
| weights = np.random.normal(size=exe_b.arg_arrays[1].shape) |
| bias = np.random.normal(size=exe_b.arg_arrays[2].shape) |
| |
| def exe_forward(exe): |
| exe.arg_arrays[0][:] = data |
| exe.arg_arrays[1][:] = weights |
| if len(exe.arg_arrays) == 3: |
| exe.arg_arrays[2][:] = bias |
| return exe.forward(is_train=False)[0].asnumpy() |
| |
| out_nb = exe_forward(exe_nb) |
| out_b = exe_forward(exe_b) |
| bias = np.broadcast_to(bias, [np.prod(out_nb.shape[2:])] + [num_filter]).T |
| bias = np.broadcast_to(bias.reshape((num_filter, *out_nb.shape[2:])), out_b.shape) |
| assert_almost_equal(out_nb + bias, out_b) |
| |
| |
| def check_nearest_upsampling_with_shape(shapes, scale, root_scale): |
| arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_device()) for i, shape in zip(range(len(shapes)), shapes)} |
| arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)} |
| |
| up = mx.sym.UpSampling(*[mx.sym.Variable('arg_%d'%i) for i in range(len(shapes))], sample_type='nearest', scale=root_scale) |
| exe = up._bind(default_device(), args=arr, args_grad=arr_grad) |
| exe.forward(is_train=True) |
| exe.backward(exe.outputs) |
| for k in range(len(shapes)): |
| name = 'arg_%d'%k |
| assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) |
| |
| |
| def check_bilinear_upsampling_with_shape(data_shape, weight_shape, scale, root_scale, num_filter): |
| def _init_bilinear(arr, f): |
| weight = np.zeros(np.prod(arr.shape), dtype='float32') |
| shape = arr.shape |
| c = (2 * f - 1 - f % 2) / (2. * f) |
| for i in range(np.prod(shape)): |
| x = i % shape[3] |
| y = (i // shape[3]) % shape[2] |
| weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) |
| arr[:] = weight.reshape(shape) |
| return arr |
| |
| up = mx.sym.UpSampling(mx.sym.Variable("data"), |
| mx.sym.Variable('weight'), sample_type='bilinear', scale=root_scale, |
| num_filter=num_filter, num_args=2) |
| arg_shapes, out_shapes, _ = up.infer_shape(data=data_shape) |
| arr = {'data': mx.random.uniform(-5, 5, data_shape, ctx=mx.cpu()).copyto(default_device()), |
| 'weight': mx.nd.array(_init_bilinear(mx.ndarray.empty(arg_shapes[1]).asnumpy(), root_scale))} |
| |
| arr_grad = [mx.nd.empty(s) for s in arg_shapes] |
| exe = up._bind(default_device(), args=arr, args_grad=arr_grad) |
| exe.forward(is_train=True) |
| out = exe.outputs[0].asnumpy() |
| exe.backward(exe.outputs) |
| target_shape = (data_shape[2] * root_scale, data_shape[3] * root_scale) |
| assert out.shape == data_shape[:2] + target_shape |
| |
| |
| def test_nearest_upsampling(): |
| for root_scale in [1,2,3]: |
| for scale in [1,2,3]: |
| for num_shape in [1,2,3]: |
| for base in [1,2,3]: |
| shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] |
| check_nearest_upsampling_with_shape(shapes, scale, root_scale) |
| |
| |
| def test_bilinear_upsampling(): |
| rootscale = [2,3] |
| scales = [1,2,3] |
| filters = [1,2,3] |
| bases = [1,2,3] |
| for params in itertools.product(rootscale, scales, filters, bases): |
| root_scale, scale, num_filter, base = params |
| # bilinear upsampling takes only 1 data and 1 weight |
| # multi input mode is not applicable |
| dimension = base*root_scale*scale |
| kernel = 2 * root_scale - root_scale % 2 |
| data_shape = (1, num_filter, dimension, dimension) |
| weight_shape = (1, num_filter, kernel, kernel) |
| check_bilinear_upsampling_with_shape(data_shape, weight_shape, scale, root_scale, num_filter) |
| |
| def test_batchnorm_training(): |
| def check_batchnorm_training(stype): |
| for shape in [(2, 3), (2, 3, 2, 2), (2, 8, 2, 2)]: |
| data_tmp = np.random.normal(-0.1, 0.1, size=shape) |
| s = shape[1], |
| gamma = np.ones(s) |
| beta = np.ones(s) |
| gamma[1] = 3 |
| beta[0] = 3 |
| |
| rolling_mean = np.random.uniform(size=s) |
| rolling_std = np.random.uniform(size=s) |
| |
| data = mx.symbol.Variable('data', stype=stype) |
| in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype), |
| mx.nd.array(beta).tostype(stype)] |
| mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=True) |
| check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True) |
| check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=False) |
| check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True) |
| check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) |
| |
| # Test varying channel axis |
| dim = len(shape) |
| for chaxis in range(-dim, dim): |
| chaxis_true = chaxis |
| if chaxis < 0: |
| chaxis_true = dim + chaxis |
| |
| shapex = shape |
| |
| channel_count = shapex[chaxis_true] |
| data_tmp = np.random.normal(-0.1, 0.1, size=shapex) |
| |
| gamma = np.ones(channel_count) |
| beta = np.ones(channel_count) |
| if channel_count > 1: |
| gamma[1] = 3 |
| beta[0] = 3 |
| |
| in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype), |
| mx.nd.array(beta).tostype(stype)] |
| |
| xrolling_mean = np.random.uniform(size=channel_count) |
| xrolling_std = np.random.uniform(size=channel_count) |
| xmean_std = [mx.nd.array(xrolling_mean).tostype(stype), |
| mx.nd.array(xrolling_std).tostype(stype)] |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis) |
| check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis) |
| check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis) |
| check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) |
| |
| test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) |
| check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) |
| |
| check_batchnorm_training('default') |
| |
| |
| @xfail_when_nonstandard_decimal_separator |
| @pytest.mark.parametrize('op_name', ['BatchNorm', 'SyncBatchNorm']) |
| @pytest.mark.parametrize('shape', [(4, 2), (4, 3, 4), |
| (4, 6, 4, 5), (4, 5, 6, 4, 5)]) |
| @pytest.mark.parametrize('fix_gamma', [False, True]) |
| @pytest.mark.parametrize('cudnn_off', [False, True]) |
| @pytest.mark.parametrize('output_mean_var', [False, True]) |
| def test_batchnorm(op_name, shape, fix_gamma, cudnn_off, output_mean_var): |
| if op_name == 'BatchNorm': |
| op = mx.nd.BatchNorm |
| elif op_name == 'SyncBatchNorm': |
| op = mx.nd.contrib.SyncBatchNorm |
| else: |
| raise ValueError(f'Not supported {op_name}') |
| momentum = 0.9 |
| epsilon = 1e-5 |
| |
| def _test_batchnorm_impl(axis, |
| data_grad_req, gamma_grad_req, beta_grad_req): |
| kwargs = dict(output_mean_var=output_mean_var) |
| if op_name == 'SyncBatchNorm': |
| if axis != 1: |
| return |
| key = str(op) + str(shape) + str(axis) |
| kwargs.update(dict(key=key)) |
| if cudnn_off: |
| return |
| else: |
| kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) |
| nch = shape[axis] |
| |
| if not fix_gamma: |
| bn_gamma = mx.nd.random.uniform(shape=(nch,)) |
| bn_gamma.attach_grad(grad_req=gamma_grad_req) |
| else: |
| bn_gamma = mx.nd.ones(shape=(nch,)) |
| |
| bn_beta = mx.nd.random.uniform(shape=(nch,)) |
| bn_beta.attach_grad(grad_req=beta_grad_req) |
| |
| bn_running_mean = mx.nd.zeros(nch) |
| bn_running_var = mx.nd.ones(nch) |
| |
| running_mean = mx.nd.zeros(nch) |
| running_var = mx.nd.ones(nch) |
| num_iters = 10 |
| expand_shape = [1] * len(shape) |
| expand_shape[axis] = shape[axis] |
| data = mx.nd.random.uniform(shape=shape) |
| data.attach_grad(grad_req=data_grad_req) |
| adX, adW, adb = 0, 0, 0 |
| is_train = data_grad_req != 'null' or \ |
| (not fix_gamma and gamma_grad_req != 'null') or \ |
| beta_grad_req != 'null' |
| for _ in range(num_iters): |
| if data_grad_req != 'add': |
| data = mx.nd.random.uniform(shape=shape) |
| data.attach_grad(grad_req=data_grad_req) |
| ograd = mx.nd.random.uniform(shape=shape) |
| with mx.autograd.record(): |
| output = op(data, bn_gamma, bn_beta, |
| bn_running_mean, bn_running_var, |
| momentum=momentum, eps=epsilon, |
| fix_gamma=fix_gamma, **kwargs) |
| if output_mean_var: |
| output, output_mean, output_std = output |
| if is_train: |
| output.backward(ograd) |
| mx.nd.waitall() |
| |
| data_mean = data.mean( |
| axis=axis, exclude=True, keepdims=True) |
| data_var = (data - data_mean).square().mean(axis=axis, |
| exclude=True, |
| keepdims=True) |
| |
| target_output = (data - data_mean) / \ |
| (data_var + epsilon).sqrt() * \ |
| bn_gamma.reshape(expand_shape) + \ |
| bn_beta.reshape(expand_shape) |
| |
| # squeeze data_mean and data_var |
| data_mean_flat = data_mean.squeeze() |
| data_var_flat = data_var.squeeze() |
| |
| running_mean = running_mean * momentum + \ |
| data_mean_flat * (1 - momentum) |
| |
| m = np.prod(shape) / shape[axis] |
| # cudnn uses m-1 in the denominator of its sample variance calculation, not m |
| sample_var_adjust = 1.0 if cudnn_off or fix_gamma else m / (m-1) |
| running_var = running_var * momentum + \ |
| data_var_flat * sample_var_adjust * (1 - momentum) |
| |
| W = bn_gamma.reshape(expand_shape) |
| dnx = ograd * W |
| xsm = data - data_mean |
| nd = 1.0 / mx.nd.sqrt(data_var + epsilon) |
| nx = xsm * nd |
| dvar = (dnx * xsm).sum(axis=axis, keepdims=True, |
| exclude=True) * (-0.5) * mx.nd.power(nd, 3) |
| dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \ |
| dvar * xsm.mean(axis=axis, keepdims=True, |
| exclude=True) * 2.0 |
| dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) |
| dW = (ograd * nx).sum(axis=axis, exclude=True) |
| db = ograd.sum(axis=axis, exclude=True) |
| adX = dX if data_grad_req != 'add' else adX + dX |
| adW = dW if gamma_grad_req != 'add' else adW + dW |
| adb = db if beta_grad_req != 'add' else adb + db |
| |
| atol, rtol = 5e-2, 5e-2 |
| |
| if output_mean_var: |
| assert_almost_equal(output_mean.asnumpy(), |
| data_mean_flat.asnumpy(), |
| atol=atol, rtol=rtol) |
| if op != mx.nd.contrib.SyncBatchNorm: |
| assert_almost_equal(output_std.asnumpy(), |
| (1.0 / (data_var_flat + |
| epsilon).sqrt()).asnumpy(), |
| atol=atol, rtol=rtol) |
| else: |
| assert_almost_equal(output_std.asnumpy(), |
| data_var_flat.asnumpy(), |
| atol=atol, rtol=rtol) |
| assert_almost_equal(output.asnumpy(), target_output.asnumpy(), |
| atol=atol, rtol=rtol) |
| if is_train: |
| assert_almost_equal(bn_running_mean.asnumpy( |
| ), running_mean.asnumpy(), atol=atol, rtol=rtol) |
| assert_almost_equal(bn_running_var.asnumpy( |
| ), running_var.asnumpy(), atol=atol, rtol=rtol) |
| |
| if data_grad_req != 'null': |
| assert_almost_equal(data.grad.asnumpy(), |
| adX.asnumpy(), atol=atol, rtol=rtol) |
| if not fix_gamma: |
| if gamma_grad_req != 'null': |
| assert_almost_equal( |
| bn_gamma.grad.asnumpy(), adW.asnumpy(), |
| atol=atol, rtol=rtol) |
| else: |
| assert((bn_gamma.asnumpy() == 1).all()) |
| if beta_grad_req != 'null': |
| assert_almost_equal( |
| bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) |
| |
| grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] |
| for data_grad_req in grad_reqs: |
| for gamma_grad_req in grad_reqs: |
| if fix_gamma and gamma_grad_req != 'null': |
| continue |
| for beta_grad_req in grad_reqs: |
| for axis in range(len(shape)): |
| _test_batchnorm_impl(axis, |
| data_grad_req, gamma_grad_req, beta_grad_req) |
| |
| |
| def test_groupnorm(): |
| acc_types = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64'} |
| def x_hat_helper(x, num_groups, eps): |
| dtype = x.dtype |
| dshape = x.shape |
| assert len(dshape) == 4 |
| acc_type = acc_types[str(dtype)] |
| new_shape = (dshape[0], num_groups, int(dshape[1] / num_groups), dshape[2], dshape[3]) |
| new_moments_shape = (dshape[0], num_groups, 1, 1, 1) |
| data = x.reshape(new_shape) |
| mean = np.mean(data, axis=(2, 3, 4), keepdims=False, dtype=acc_type).astype(dtype) |
| std = np.sqrt(np.var(data, axis=(2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + eps) |
| x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape) |
| return x_hat, mean, std |
| |
| def np_groupnorm(data, gamma, beta, num_groups, eps): |
| new_param_shape = (1, dshape[1], 1, 1) |
| x_hat, mean, std = x_hat_helper(data, num_groups, eps) |
| out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) |
| return out, mean, std |
| |
| def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): |
| x_hat, mean, std = x_hat_helper(data, num_groups, eps) |
| new_shape = x_hat.shape |
| dshape = data.shape |
| dtype = data.dtype |
| new_moments_shape = (new_shape[0], num_groups, 1, 1, 1) |
| new_param_shape = (1, dshape[1], 1, 1) |
| acc_type = acc_types[str(dtype)] |
| ograd = ograd.reshape(new_shape) |
| data = data.reshape(new_shape) |
| gamma = gamma.reshape(new_param_shape) |
| beta = beta.reshape(new_param_shape) |
| mean = mean.reshape(new_moments_shape) |
| std = std.reshape(new_moments_shape) |
| beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten() |
| gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten() |
| x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1) |
| ograd_mult = x_hat_grad / std |
| red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype) |
| data_grad = ograd_mult - red_out |
| red_out = np.mean(ograd_mult * x_hat, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype) |
| data_grad = data_grad - x_hat * red_out |
| return data_grad.reshape(dshape), gamma_grad, beta_grad |
| |
| |
| batch_size = random.randint(1, 8) |
| num_groups = random.randint(2, 3) |
| num_channels = random.randint(2, 3) * num_groups |
| height = random.randint(1, 5) |
| width = random.randint(1, 5) |
| dshape = (batch_size, num_channels, height, width) |
| param_shape = (num_channels,) |
| temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width) |
| np_data = np.random.uniform(0.2, 1.0, dshape) |
| np_gamma = np.random.uniform(-1.0, 1.0, param_shape) |
| np_beta = np.random.uniform(-1.0, 1.0, param_shape) |
| data_sym = mx.sym.Variable("data") |
| gamma_sym = mx.sym.Variable("gamma") |
| beta_sym = mx.sym.Variable("beta") |
| for dtype in [np.float16, np.float32, np.float64]: |
| eps = 1e-2 if dtype == np.float16 else 1e-5 |
| mx_data = mx.nd.array(np_data, dtype=dtype) |
| mx_gamma = mx.nd.array(np_gamma, dtype=dtype) |
| mx_beta = mx.nd.array(np_beta, dtype=dtype) |
| np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype), |
| np_gamma.astype(dtype), |
| np_beta.astype(dtype), |
| num_groups=num_groups, |
| eps=eps) |
| mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, |
| num_groups=num_groups, eps=eps, output_mean_var=True) |
| check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std], |
| rtol=1e-2 if dtype == np.float16 else 1e-3, |
| atol=5e-3 if dtype == np.float16 else 1e-4, dtype=dtype) |
| mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, |
| num_groups=num_groups, eps=eps, output_mean_var=False) |
| np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype) |
| np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd, |
| np_data.astype(dtype), |
| np_gamma.astype(dtype), |
| np_beta.astype(dtype), |
| np_mean, np_std, |
| num_groups, eps) |
| check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd, dtype=np_ograd.dtype)], |
| [np_data_grad, np_gamma_grad, np_beta_grad], |
| rtol=1e-2 if dtype == np.float16 else 1e-3, |
| atol=5e-2 if dtype == np.float16 else 1e-4, dtype=dtype) |
| |
| |
| @pytest.mark.serial |
| def test_convolution_grouping(): |
| for dim in [1, 2, 3]: |
| num_filter = 4 |
| for num_group in [1, 2]: |
| kernel = (3,) * dim |
| shape = (1, 4) + (9,) * dim |
| |
| x = mx.sym.Variable('x') |
| w = mx.sym.Variable('w') |
| b = mx.sym.Variable('b') |
| y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel) |
| xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1) |
| wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0) |
| bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0) |
| y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i], |
| num_filter=num_filter//num_group, kernel=kernel) |
| for i in range(num_group)]) |
| |
| exe1 = y1._simple_bind(default_device(), x=shape) |
| exe2 = y2._simple_bind(default_device(), x=shape, w=(num_filter, shape[1]//num_group) + kernel, b=(num_filter,)) |
| for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays): |
| arr1[:] = np.random.normal(size=arr1.shape).astype(effective_dtype(mx.nd.array([1.,]))) |
| arr2[:] = arr1 |
| exe1.forward(is_train=True) |
| exe1.backward(exe1.outputs[0]) |
| exe2.forward(is_train=True) |
| exe2.backward(exe2.outputs[0]) |
| |
| for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): |
| assert_almost_equal(arr1, arr2) |
| |
| |
| @pytest.mark.skip(reason="Flaky test https://github.com/apache/incubator-mxnet/issues/14052") |
| def test_depthwise_convolution(): |
| for dim in [1,2]: |
| for num_base in [1, 4, 16, 32, 64]: |
| for kernel_x in [3, 5]: |
| for stride_x in [1, 2]: |
| for pad_x in [0, 1]: |
| for in_size in [7, 32]: |
| kernel = (kernel_x,) * dim |
| stride = (stride_x,) * dim |
| pad = (pad_x,) * dim |
| num_filter = num_base |
| num_group = num_base |
| shape = (2, num_base) + (in_size,) * dim |
| |
| x = mx.sym.Variable('x') |
| w = mx.sym.Variable('w') |
| b = mx.sym.Variable('b') |
| y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, |
| kernel=kernel, stride=stride, pad=pad) |
| xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1) |
| wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0) |
| bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0) |
| y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i], |
| num_filter=num_filter//num_group, kernel=kernel, |
| stride=stride, pad=pad) |
| for i in range(num_group)]) |
| |
| dev = default_device() |
| exe1 = y1._simple_bind(dev, x=shape) |
| exe2 = y2._simple_bind(dev, x=shape, w=(num_filter, shape[1]//num_group)+kernel, |
| b=(num_filter,)) |
| for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays): |
| arr1[:] = np.random.normal(size=arr1.shape) |
| arr2[:] = arr1 |
| exe1.forward(is_train=True) |
| exe1.backward(exe1.outputs[0]) |
| exe2.forward(is_train=True) |
| exe2.backward(exe2.outputs[0]) |
| |
| for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): |
| assert_allclose(arr1, arr2, rtol=1e-3, atol=1e-3) |
| |
| |
| def test_convolution_independent_gradients(): |
| # NOTE(zixuanweeei): Flaky test tracked by https://github.com/apache/incubator-mxnet/issues/15603. |
| # GPU context will be enabled after figuring out the possible issue tracked at |
| # https://github.com/apache/incubator-mxnet/issues/15638. |
| ctx = mx.cpu() |
| atol = 1.0e-3 |
| rtol = 1.0e-3 |
| reqs = ["null", "write", "add"] |
| var_names = ["x", "w", "b"] |
| dims = [1, 2] |
| num_bases = [1, 8] |
| kernel_xs = [3, 5] |
| stride_xs = [1, 2] |
| pad_xs = [0, 1] |
| in_sizes = [7, 32] |
| no_biases = [True, False] |
| for dim, num_base, kernel_x, stride_x, pad_x , in_size, no_bias in \ |
| itertools.product(dims, num_bases, kernel_xs, stride_xs, pad_xs, in_sizes, no_biases): |
| # Prepare params shape |
| kernel = (kernel_x,) * dim |
| stride = (stride_x,) * dim |
| pad = (pad_x,) * dim |
| num_filter = num_base |
| x_shape = (2, num_base) + (in_size,) * dim |
| w_shape = (num_filter, num_base) + kernel |
| |
| # Symbols definition |
| x = mx.sym.Variable('x') |
| w = mx.sym.Variable('w') |
| b = mx.sym.Variable('b') if not no_bias else None |
| conv = mx.sym.Convolution(x, w, b, num_filter=num_filter, |
| kernel=kernel, stride=stride, pad=pad, no_bias=no_bias) |
| |
| for req_kind in reqs: |
| # Binding args for conv with possible dependent gradients |
| base_args = { |
| 'x': mx.nd.random.normal(shape=x_shape, ctx=ctx), |
| 'w': mx.nd.random.normal(shape=w_shape, ctx=ctx), |
| 'b': mx.nd.random.normal(shape=(num_filter, ), ctx=ctx) if not no_bias else None} |
| args1 = copy.deepcopy(base_args) |
| grad1 = { |
| 'x': mx.nd.zeros(shape=x_shape, ctx=ctx), |
| 'w': mx.nd.zeros(shape=w_shape, ctx=ctx), |
| 'b': mx.nd.zeros(shape=(num_filter, ), ctx=ctx) if not no_bias else None} |
| |
| grad_req1 = [req_kind] * 3 |
| grad_req1 = dict(zip(var_names, grad_req1)) |
| |
| exe1 = conv._bind(ctx, args1, args_grad=grad1, grad_req=grad_req1) |
| exe1.forward(is_train=True) |
| exe1.backward(exe1.outputs[0]) |
| |
| for x_req, w_req, b_req in itertools.product(reqs, repeat=3): |
| # Binding args for conv with independent gradients |
| args2 = copy.deepcopy(base_args) # Deepcopy the same params of `exe1` |
| grad2 = { |
| 'x': mx.nd.zeros(shape=x_shape, ctx=ctx), |
| 'w': mx.nd.zeros(shape=w_shape, ctx=ctx), |
| 'b': mx.nd.zeros(shape=(num_filter, ), ctx=ctx) if not no_bias else None} |
| grad_req2 = {"x": x_req, "w": w_req, "b": b_req} |
| exe2 = conv._bind(ctx, args2, args_grad=grad2, grad_req=grad_req2) |
| |
| exe2.forward(is_train=True) |
| np.testing.assert_allclose(exe1.outputs[0].asnumpy(), |
| exe2.outputs[0].asnumpy(), rtol=rtol, atol=atol) |
| |
| exe2.backward(exe2.outputs[0]) |
| for var_name in var_names: |
| if var_name == "b" and no_bias: |
| continue |
| if grad_req2[var_name] == "null": |
| exe2_var_grad = grad2[var_name].asnumpy() |
| np.testing.assert_allclose(exe2_var_grad, |
| np.zeros_like(exe2_var_grad), rtol=rtol, atol=atol) |
| if grad_req2[var_name] != grad_req1[var_name]: |
| continue |
| np.testing.assert_allclose(args1[var_name].asnumpy(), |
| args2[var_name].asnumpy(), rtol=rtol, atol=atol) |
| np.testing.assert_allclose(grad1[var_name].asnumpy(), |
| grad2[var_name].asnumpy(), rtol=rtol, atol=atol) |
| |
| |
| def gen_broadcast_data(idx): |
| # Manually set test cases |
| binary_op_data_shape = np.array( |
| [[[2, 5, 1, 30, 7], [1, 5, 448, 30, 1]], |
| [[10, 49, 1, 77, 17], [10, 1, 2, 1, 17]], |
| [[13, 2, 65, 2, 1], [13, 1, 65, 1, 225]], |
| [[9, 434, 4, 2, 37], [9, 1, 4, 1, 37]], |
| [[2, 52, 1, 4, 1], [1, 52, 60, 1, 37]], |
| [[1, 23, 7, 122, 50], [2, 1, 7, 1, 50]], |
| [[1, 17, 1, 5, 1], [22, 1, 2, 1, 28]], |
| [[29, 1, 2, 1, 8], [29, 22, 1, 130, 1]], |
| [[2, 36, 1, 427, 3], [1, 36, 11, 427, 1]], |
| [[1, 2, 1, 100, 7], [1, 2, 448, 100, 1]], |
| [[1, 2, 495, 77, 7], [1, 2, 1, 1, 7]], |
| [[1, 43, 65, 2, 1], [1, 43, 65, 1, 225]], |
| [[1, 92, 434, 2, 2], [1, 92, 1, 2, 2]], |
| [[1, 92, 1, 4, 1], [1, 92, 134, 1, 17]], |
| [[1, 53, 2, 122, 143], [1, 1, 2, 1, 143]], |
| [[1, 179, 1, 87, 17], [1, 179, 1, 1, 17]], |
| [[1, 1, 17, 5, 1], [1, 22, 1, 1, 28]], |
| [[1, 2, 1, 1, 8], [1, 2, 52, 430, 1]], |
| [[1, 163, 1, 22, 3], [1, 163, 116, 22, 1]], |
| [[1, 1, 44, 30, 7], [1, 1, 44, 30, 1]], |
| [[1, 1, 1, 1, 28], [1, 127, 1, 5, 28]], |
| [[1, 2, 394, 38, 1], [1, 2, 394, 38, 16]], |
| [[1, 10, 49, 77, 17], [1, 1, 1, 1, 17]], |
| [[1, 431, 6, 2, 225], [1, 1, 6, 2, 225]], |
| [[1, 15, 1, 28, 1], [1, 15, 1, 28, 463]], |
| [[1, 129, 2, 48, 96], [1, 129, 2, 1, 1]], |
| [[1, 1, 403, 17, 2], [1, 44, 403, 17, 2]], |
| [[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]], |
| [[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]], |
| [[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]], |
| [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]], |
| [[8, 1, 6, 1], [7, 1, 5]], [[5, 4], [1]], |
| [[256, 256, 3], [3]], [[5, 4], [4]], |
| [[15, 3, 5], [3, 5]], [[15, 3, 5], [1, 5]], |
| [[15, 3, 5], [3, 1]], [[1,1,1,1], [1,1]], |
| [[15,3], [4, 1, 3]], [[7, 1, 5], [8, 1, 6, 1]]]) |
| if idx < binary_op_data_shape.shape[0]: |
| l_shape = binary_op_data_shape[idx][0] |
| r_shape = binary_op_data_shape[idx][1] |
| else: |
| # Generate random data that has ndim between 1-7 and all the shape dims between 1-5 |
| ndim = np.random.randint(1, 6) |
| shape = np.random.randint(1, 6, size=(ndim,)) |
| l_same_dim = np.random.randint(0, 5) |
| r_same_dim = np.random.randint(0, 5) |
| l_axis_flags = np.random.randint(0, 2, size=ndim) |
| r_axis_flags = np.random.randint(0, 2, size=ndim) |
| if l_same_dim == 4: |
| l_axis_flags = np.ones(ndim) |
| if r_same_dim == 4: |
| r_axis_flags = np.ones(ndim) |
| l_shape = shape.copy() |
| r_shape = shape.copy() |
| l_shape[np.where(l_axis_flags == 0)] = 1 |
| r_shape[np.where(r_axis_flags == 0)] = 1 |
| return [np.random.random(l_shape), np.random.random(r_shape)] |
| |
| |
| def gen_broadcast_data_int(idx): |
| d = gen_broadcast_data(idx) |
| return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] |
| |
| |
| def gen_binary_data(dummy): |
| ndim = np.random.randint(1, 6) |
| shape = np.random.randint(1, 6, size=(ndim,)) |
| #print("gen shape {}".format(shape)) |
| return [np.random.random(shape), np.random.random(shape)] |
| |
| |
| def gen_binary_data_int(dummy): |
| d = gen_binary_data(dummy) |
| return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] |
| |
| |
| def check_binary_op_forward(symbol, baseline, gen_data, rtol=1e-3, atol=1e-5, mx_nd_func=None): |
| sample_num = 200 |
| for i in range(sample_num): |
| d = gen_data(i) |
| y = symbol._bind(default_device(), args={'a': mx.nd.array(d[0]), 'b': mx.nd.array(d[1])}) |
| y.forward(is_train=True) |
| y = y.outputs[0].asnumpy() |
| x = baseline(d[0], d[1]).astype(y.dtype) |
| |
| #np.set_printoptions(precision=20) |
| |
| a = d[0] |
| b = d[1] |
| #print("a: {} {}".format(a.dtype, a)) |
| #print("a: {} {}".format(b.dtype, b)) |
| |
| #print("x: {} {}".format(x.dtype, x)) |
| #print("y: {} {}".format(y.dtype, y)) |
| if mx_nd_func is not None: |
| d0 = mx.nd.array(d[0], dtype=d[0].dtype) |
| d1 = mx.nd.array(d[1], dtype=d[1].dtype) |
| assert_almost_equal(y, mx_nd_func(d0, d1).asnumpy(), rtol=rtol, atol=atol) |
| idx = np.abs(x-y) > atol+rtol*np.abs(x) |
| if idx.any(): |
| import binascii |
| np.set_printoptions(precision=20) |
| logging.error('found precision problem:') |
| d[0] = np.broadcast_to(d[0], x.shape) |
| d[1] = np.broadcast_to(d[1], x.shape) |
| logging.error('input a: {}'.format(d[0][idx])) |
| logging.error('input b: {}'.format(d[1][idx])) |
| logging.error("output x: {} {}".format(x.dtype, x)) |
| logging.error("output y: {} {}".format(y.dtype, y)) |
| def ftohex(xs): |
| import struct |
| return list(map(lambda x: binascii.hexlify(struct.pack('d', x)), xs.flatten())) |
| logging.error('output x in baseline(a, b): {}'.format(x[idx])) |
| logging.error('output y in symbol(a, b): {}'.format(y[idx])) |
| logging.error('output x in baseline(a,b) hex: {}'.format(ftohex(x[idx]))) |
| logging.error('output y in symbol(a,b) hex: {}'.format(ftohex(y[idx]))) |
| logging.error('input a hex: {}'.format(ftohex(d[0][idx]))) |
| logging.error('input a hex: {}'.format(ftohex(d[1][idx]))) |
| |
| logging.error('diff: {}'.format(np.abs(x-y)[idx] - atol-rtol*np.abs(x)[idx])) |
| assert_allclose(y, x, rtol=rtol, atol=atol) |
| |
| |
| def check_binary_op_backward(symbol, baseline, gen_data, rtol=1e-3, atol=1e-5): |
| sample_num = 200 |
| for i in range(sample_num): |
| d = gen_data(i) |
| out = np.random.random((d[0] + d[1]).shape) |
| |
| def reduce_op(shape, x): |
| if shape == x.shape: |
| return x |
| keepdims_shape = list(x.shape) |
| # calculate difference between output and input ndims |
| # to include cases where inputs' ndims are not equal |
| ndim_diff = len(x.shape) - len(shape) |
| for i in range(ndim_diff): |
| keepdims_shape[i] = 1 |
| x = np.sum(x, axis=i).reshape(keepdims_shape) |
| for i in range(len(shape)): |
| if x.shape[ndim_diff + i] != shape[i]: |
| keepdims_shape[ndim_diff + i] = 1 |
| x = np.sum(x, axis=ndim_diff + i).reshape(keepdims_shape) |
| return x |
| |
| baseline_grad1, baseline_grad2 = baseline(out, d[0], d[1]) |
| x_1 = reduce_op(d[0].shape, baseline_grad1) |
| x_2 = reduce_op(d[1].shape, baseline_grad2) |
| y_1 = mx.nd.empty(d[0].shape) |
| y_2 = mx.nd.empty(d[1].shape) |
| y = symbol._bind(default_device(), args={'a': mx.nd.array(d[0]), 'b': mx.nd.array(d[1])}, |
| args_grad=[y_1, y_2]) |
| o = y.forward(is_train=True) |
| y.backward([mx.nd.array(out, dtype=o[0].dtype)]) |
| assert_allclose(y_1.asnumpy(), x_1, rtol=rtol, atol=atol) |
| assert_allclose(y_2.asnumpy(), x_2, rtol=rtol, atol=atol) |
| |
| |
| def test_binary_op(): |
| a = mx.sym.Variable('a') |
| b = mx.sym.Variable('b') |
| |
| def test_bplus(a, b): |
| c = a + b |
| check_binary_op_forward(c, lambda a, b: a + b, gen_binary_data) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out, g_out), gen_binary_data) |
| |
| def test_bminus(a, b): |
| c = a - b |
| check_binary_op_forward(c, lambda a, b: a - b, gen_binary_data) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out), gen_binary_data) |
| |
| def test_bmul(a, b): |
| c = a * b |
| check_binary_op_forward(c, lambda a, b: a * b, gen_binary_data) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out * b, g_out * a), gen_binary_data) |
| |
| def test_bdiv(a, b): |
| c = a / b |
| check_binary_op_forward(c, lambda a, b: a / b, gen_binary_data) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_binary_data) |
| |
| def test_bmod(a, b): |
| # Python and numpy operate only in double so to avoid numerical errors we have to use |
| # doubles as well. This was a flaky test before when using float32. seed 1688524483, 1768433044 |
| #c = a % b |
| c = mx.sym.cast(a, dtype='float64') % mx.sym.cast(b, dtype='float64') |
| # '%' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32. |
| check_binary_op_forward(c, lambda a, b: np.float32(a) % np.float32(b), gen_binary_data, rtol=0, atol=0) |
| check_binary_op_backward(c, |
| lambda g_out, a, b: (g_out, - g_out * (np.float32(a) // np.float32(b))), gen_binary_data) |
| |
| def test_bmod_int(a, b): |
| c = mx.sym.cast(a, dtype='int32') % mx.sym.cast(b, dtype='int32') |
| check_binary_op_forward(c, lambda a, b: a % b, gen_binary_data_int) |
| check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_binary_data_int) |
| |
| def test_bpow(a, b): |
| c = a ** b |
| check_binary_op_forward(c, lambda a, b: a ** b, gen_binary_data) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out * a **(b - 1) * b, |
| g_out * a ** b * np.log(a)), gen_binary_data) |
| |
| def test_bneq(a, b): |
| c = a != b |
| # '!=' is sensitive to the precision of the comparison. Force numpy to match mxnet's float32. |
| # Issue exposed with seed 1644387363 |
| check_binary_op_forward(c, lambda a, b: (np.float32(a) != np.float32(b)).astype(a.dtype), gen_binary_data) |
| check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_binary_data) |
| |
| test_bplus(a, b) |
| test_bminus(a, b) |
| test_bmul(a, b) |
| test_bdiv(a, b) |
| test_bmod(a, b) |
| test_bmod_int(a, b) |
| test_bpow(a, b) |
| test_bneq(a, b) |
| |
| def test_broadcast_binary_op(): |
| def check_bmaxmin_gradient(test_sym, x, y, delta, rtol, atol): |
| """This function ensures that checking the numerical gradient of |
| broadcast_max/min is not crossing the boundary y=x where there |
| is no gradient definition at those sigularities.""" |
| x_max = np.max(x) |
| y = x_max + 2 * delta + np.random.random(y.shape) |
| check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol) |
| |
| x_min = np.min(x) |
| y = x_min - 2 * delta - np.random.random(y.shape) |
| check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol) |
| |
| a = mx.sym.Variable('a') |
| b = mx.sym.Variable('b') |
| |
| def test_bplus(a, b): |
| c = mx.sym.broadcast_plus(a, b) |
| check_binary_op_forward(c, lambda a, b: a + b, gen_broadcast_data, mx_nd_func=mx.nd.add) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out, g_out), gen_broadcast_data) |
| |
| def test_bminus(a, b): |
| c = mx.sym.broadcast_minus(a, b) |
| check_binary_op_forward(c, lambda a, b: a - b, gen_broadcast_data, mx_nd_func=mx.nd.subtract) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out), gen_broadcast_data) |
| |
| def test_bmul(a, b): |
| c = mx.sym.broadcast_mul(a, b) |
| check_binary_op_forward(c, lambda a, b: a * b, gen_broadcast_data, mx_nd_func=mx.nd.multiply) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out * b, g_out * a), gen_broadcast_data) |
| |
| def test_bdiv(a, b): |
| c = mx.sym.broadcast_div(a, b) |
| check_binary_op_forward(c, lambda a, b: a / b, gen_broadcast_data, mx_nd_func=mx.nd.divide) |
| check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_broadcast_data) |
| |
| def test_bmod(a_, b_): |
| # Python and numpy operate only in double so to avoid numerical errors we have to use |
| # doubles as well. This was a flaky test before when using float32. seed 1688524483, 1768433044 |
| a = mx.sym.cast(a_, dtype='float64') |
| b = mx.sym.cast(b_, dtype='float64') |
| # '%' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32. |
| c = mx.sym.broadcast_mod(a, b) |
| check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data, atol=1, mx_nd_func=mx.nd.modulo) |
| check_binary_op_backward(c, |
| lambda g_out, a, b: (g_out, - g_out * (np.float32(a) // np.float32(b))), gen_binary_data) |
| |
| def test_bmod_int(a, b): |
| c = mx.sym.broadcast_mod(mx.sym.cast(a, dtype='int32'), mx.sym.cast(b, dtype='int32')) |
| check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data_int, mx_nd_func=mx.nd.modulo) |
| check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_broadcast_data_int) |
| |
| def test_bpow(<
|