| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import functools |
| import mxnet.ndarray as nd |
| from mxnet.ndarray import zeros_like |
| from mxnet.autograd import * |
| from mxnet.test_utils import * |
| |
| from common import xfail_when_nonstandard_decimal_separator |
| from mxnet.test_utils import environment |
| |
| import pytest |
| |
| |
| def grad_and_loss(func, argnum=None): |
| """Return function that computes both gradient of arguments and loss value. |
| |
| Parameters |
| ---------- |
| func: a python function |
| The forward (loss) function. |
| argnum: an int or a list of int |
| The index of argument to calculate gradient for. |
| |
| Returns |
| ------- |
| grad_and_loss_func: a python function |
| A function that would compute both the gradient of arguments and loss value. |
| """ |
| @functools.wraps(func) |
| def wrapped(*args): |
| """Wrapped function.""" |
| variables = args |
| if argnum is not None: |
| argnum_ = argnum if isinstance(argnum, list) else [argnum] |
| variables = [args[i] for i in argnum_] |
| for x in variables: |
| assert isinstance(x, NDArray), "type of autograd input should NDArray." |
| grads = [zeros_like(x) for x in variables] |
| mark_variables(variables, grads) |
| with record(): |
| outputs = func(*args) |
| backward([outputs] if isinstance(outputs, NDArray) else outputs) |
| return grads, outputs |
| return wrapped |
| |
| def grad(func, argnum=None): |
| """Return function that computes gradient of arguments. |
| |
| Parameters |
| ---------- |
| func: a python function |
| The forward (loss) function. |
| argnum: an int or a list of int |
| The index of argument to calculate gradient for. |
| |
| Returns |
| ------- |
| grad_func: a python function |
| A function that would compute the gradient of arguments. |
| |
| Examples |
| -------- |
| >>> # autograd supports dynamic graph which is changed |
| >>> # every instance |
| >>> def func(x): |
| >>> r = random.randint(0, 1) |
| >>> if r % 2: |
| >>> return x**2 |
| >>> else: |
| >>> return x/3 |
| >>> # use `grad(func)` to get the gradient function |
| >>> for x in range(10): |
| >>> grad_func = grad(func) |
| >>> inputs = nd.array([[1, 2, 3], [4, 5, 6]]) |
| >>> grad_vals = grad_func(inputs) |
| """ |
| grad_with_loss_func = grad_and_loss(func, argnum) |
| @functools.wraps(grad_with_loss_func) |
| def wrapped(*args): |
| return grad_with_loss_func(*args)[0] |
| return wrapped |
| |
| def autograd_assert(*args, **kwargs): |
| func = kwargs["func"] |
| grad_f = kwargs["grad_func"] |
| argnum = kwargs["argnum"] if 'argnum' in kwargs else None |
| |
| grad_func = grad_and_loss(func, argnum) |
| grad_vals, output = grad_func(*args) |
| res = func(*args) |
| assert same(output.asnumpy(), res.asnumpy()) |
| grad_res = grad_f(*args) |
| assert len(grad_vals) == len(grad_res) |
| for a, b in zip(grad_vals, grad_res): |
| assert same(a.asnumpy(), b.asnumpy()) |
| |
| @xfail_when_nonstandard_decimal_separator |
| def test_unary_func(): |
| def check_unary_func(x): |
| f_exp = lambda x: nd.exp(x) |
| f_exp_grad = lambda x: [nd.exp(x)] |
| autograd_assert(x, func=f_exp, grad_func=f_exp_grad) |
| f_half = lambda x: x/2 |
| f_half_grad = lambda x: [nd.ones(x.shape) * 0.5] |
| autograd_assert(x, func=f_half, grad_func=f_half_grad) |
| f_square = lambda x: x**2 |
| f_square_grad = lambda x: [2*x] |
| autograd_assert(x, func=f_square, grad_func=f_square_grad) |
| uniform = nd.uniform(shape=(4, 5)) |
| stypes = ['default', 'row_sparse', 'csr'] |
| with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): |
| for stype in stypes: |
| check_unary_func(uniform.tostype(stype)) |
| |
| def test_binary_func(): |
| def check_binary_func(x, y): |
| f_add = lambda x, y: x+y |
| f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)] |
| autograd_assert(x, y, func=f_add, grad_func=f_add_grad) |
| f_mul = lambda x, y: x*y |
| f_mul_grad = lambda x, y: [y, x] |
| autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad) |
| f_compose = lambda x, y: x+x*y |
| f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x] |
| autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad) |
| uniform_x = nd.uniform(shape=(4, 5)) |
| uniform_y = nd.uniform(shape=(4, 5)) |
| stypes = ['default', 'row_sparse', 'csr'] |
| with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): |
| for stype_x in stypes: |
| for stype_y in stypes: |
| x = uniform_x.tostype(stype_x) |
| y = uniform_y.tostype(stype_y) |
| check_binary_func(x, y) |
| |
| |
| def test_operator_with_state(): |
| def f_fc(a, b, weight, bias): |
| x = a*b |
| fc = nd.FullyConnected( |
| x, weight, bias, num_hidden=32) |
| return fc |
| |
| a = nd.uniform(shape=(64, 50)) |
| b = nd.uniform(shape=(64, 50)) |
| weight = nd.uniform(shape=(32, 50)) |
| bias = nd.uniform(shape=(32, )) |
| |
| grad_func = grad_and_loss(f_fc) |
| grad_vals, outputs = grad_func(a, b, weight, bias) |
| # (TODO) assert |
| |
| def test_argnum(): |
| def f_with_mode(a, b, mode): |
| if mode: |
| return a+b |
| else: |
| return a*b |
| |
| a = nd.uniform(shape=(3, 2)) |
| b = nd.uniform(shape=(3, 2)) |
| f_add_grad = lambda x, y, mode: [nd.ones(x.shape), nd.ones(y.shape)] |
| f_mul_grad = lambda x, y, mode: [y, x] |
| autograd_assert(a, b, True, |
| argnum=[0, 1], func=f_with_mode, grad_func=f_add_grad) |
| autograd_assert(a, b, False, |
| argnum=[0, 1], func=f_with_mode, grad_func=f_mul_grad) |
| |
| |
| def test_training(): |
| x = nd.ones((10, 10)) |
| with record(): |
| y = nd.Dropout(x, p=0.5) |
| assert not (y.asnumpy() == x.asnumpy()).all() |
| with pause(): |
| y = nd.Dropout(x, p=0.5) |
| assert (y.asnumpy() == x.asnumpy()).all() |
| |
| |
| def test_out_grads(): |
| x = nd.ones((3, 5)) |
| dx = nd.zeros_like(x) |
| mark_variables([x], [dx]) |
| da = None |
| db = nd.array([1,2,3,4,5]) |
| dc = nd.array([5,4,3,2,1]) |
| |
| with record(): |
| a, b, c = nd.split(x, axis=0, num_outputs=3, squeeze_axis=True) |
| backward([a, b, c], [da, db, dc]) |
| |
| assert (dx.asnumpy() == np.array( |
| [[1,1,1,1,1], |
| [1,2,3,4,5], |
| [5,4,3,2,1]])).all() |
| |
| |
| def test_detach_updated_grad(): |
| x = nd.ones((2, 2)) |
| dx = nd.zeros_like(x) |
| y = nd.ones_like(x) |
| dy = nd.zeros_like(x) |
| mark_variables([x, y], [dx, dy]) |
| assert x._fresh_grad == False |
| assert y._fresh_grad == False |
| |
| with record(): |
| x2 = x + 2 |
| y2 = x2 + y |
| y2.backward() |
| assert (dx.asnumpy() == 1).all() |
| assert x._fresh_grad == True |
| assert y._fresh_grad == True |
| |
| dx[:] = 0 |
| x._fresh_grad = False |
| y._fresh_grad = False |
| assert x._fresh_grad == False |
| assert y._fresh_grad == False |
| with record(): |
| x2 = x + 2 |
| x2 = x2.detach() |
| y2 = x2 + y |
| y2.backward() |
| assert (dx.asnumpy() == 0).all() |
| assert y._fresh_grad == True |
| assert x._fresh_grad == False |
| |
| |
| def test_retain_graph(): |
| x = mx.nd.ones((2, 2)) |
| dx = mx.nd.zeros((2, 2)) |
| mark_variables([x], [dx], grad_reqs='add') |
| with record(): |
| y = x + 1 |
| y.backward(retain_graph=False) |
| assert (dx.asnumpy() == 1).all() |
| |
| dx[:] = 0 |
| with record(): |
| y = x + 1 |
| y.backward(retain_graph=True) |
| y.backward(retain_graph=False) |
| assert (dx.asnumpy() == 2).all() |
| |
| # The following sequence should throw an exception. We discard the expected |
| # stderr stack trace output for this operation to keep the test logs clean. |
| with discard_stderr(): |
| try: |
| with record(): |
| y = x + 1 |
| y.backward() |
| y.backward() |
| except Exception: |
| return |
| |
| raise AssertionError( |
| "differentiating the same graph twice without retain_graph should fail") |
| |
| |
| def test_attach_grad(): |
| def check_attach_grad(x): |
| assert x.grad is None |
| x.attach_grad() |
| with record(): |
| y = x * 2 |
| assert y.grad is None |
| y.backward(out_grad=mx.nd.ones_like(y).tostype(x.stype)) |
| assert (x.grad.asnumpy() == 2).all() |
| zeros = mx.nd.zeros((10, 10)) |
| stypes = ['default', 'row_sparse', 'csr'] |
| for stype in stypes: |
| x = zeros.tostype(stype) |
| check_attach_grad(x) |
| |
| |
| def test_is_train(): |
| x = mx.nd.ones((10, 10)) |
| x.attach_grad() |
| with record(train_mode=True): |
| assert is_recording() |
| assert is_training() |
| y = mx.nd.Dropout(x, p=0.5) |
| assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 |
| y.backward() |
| assert (x.grad.asnumpy() == y.asnumpy()).all() |
| |
| with predict_mode(): |
| assert is_recording() |
| assert not is_training() |
| y = mx.nd.Dropout(x, p=0.5) |
| assert (y.asnumpy() == x.asnumpy()).all() |
| y.backward(train_mode=False) |
| assert (x.grad.asnumpy() == x.asnumpy()).all() |
| |
| with record(train_mode=False): |
| assert is_recording() |
| assert not is_training() |
| y = mx.nd.Dropout(x, p=0.5) |
| assert (y.asnumpy() == x.asnumpy()).all() |
| y.backward(train_mode=False) |
| assert (x.grad.asnumpy() == x.asnumpy()).all() |
| |
| with train_mode(): |
| assert is_recording() |
| assert is_training() |
| y = mx.nd.Dropout(x, p=0.5) |
| assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 |
| y.backward() |
| assert (x.grad.asnumpy() == y.asnumpy()).all() |
| |
| assert not is_recording() |
| assert not is_training() |
| y = mx.nd.Dropout(x, p=0.5) |
| assert (y.asnumpy() == x.asnumpy()).all() |
| |
| with train_mode(): |
| assert not is_recording() |
| assert is_training() |
| y = mx.nd.Dropout(x, p=0.5) |
| assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 |
| |
| @pytest.mark.garbage_expected |
| def test_function(): |
| class func(Function): |
| def forward(self, x, y): |
| m = x / y |
| n = x * y |
| self.save_for_backward(x, y) |
| return m, n |
| |
| def backward(self, dm, dn): |
| x, y = self.saved_tensors |
| dx = dm/y + dn*y |
| dy = dn*x - dm * x / y / y |
| return dx, dy |
| |
| f = func() |
| x = mx.nd.random.uniform(shape=(10,)) |
| x.attach_grad() |
| y = mx.nd.random.uniform(shape=(10,)) |
| y.attach_grad() |
| with record(): |
| m, n = f(x, y) |
| backward([m, n]) |
| |
| dx1 = x.grad.asnumpy() |
| dy1 = y.grad.asnumpy() |
| |
| with record(): |
| backward([x/y, x*y]) |
| |
| # Non-zero atol required, as exposed by seed 630179191 |
| atol = 1e-6 |
| assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol) |
| assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol) |
| |
| |
| @pytest.mark.garbage_expected |
| def test_function1(): |
| class Foo(mx.autograd.Function): |
| def __init__(self): |
| super(Foo, self).__init__() |
| |
| def forward(self, X): |
| return X + 1; |
| |
| def backward(self, dY): |
| return dY |
| |
| with mx.autograd.record(): |
| X = mx.nd.zeros((3, 4)) |
| #X.attach_grad() # uncommenting this line works |
| for _ in range(5): |
| f = Foo() |
| X = f(X) |
| X.wait_to_read() |
| |
| |
| @pytest.mark.garbage_expected |
| @use_np |
| def test_np_function(): |
| class func(Function): |
| def forward(self, x, y): |
| m = x / y |
| n = x * y |
| self.save_for_backward(x, y) |
| return m, n |
| |
| def backward(self, dm, dn): |
| x, y = self.saved_tensors |
| dx = dm/y + dn*y |
| dy = dn*x - dm * x / y / y |
| return dx, dy |
| |
| f = func() |
| x = mx.np.random.uniform(size=(10,)) |
| x.attach_grad() |
| y = mx.np.random.uniform(size=(10,)) |
| y.attach_grad() |
| with record(): |
| m, n = f(x, y) |
| backward([m, n]) |
| |
| dx1 = x.grad.asnumpy() |
| dy1 = y.grad.asnumpy() |
| |
| with record(): |
| backward([x/y, x*y]) |
| |
| # Non-zero atol required, as exposed by seed 630179191 |
| atol = 1e-6 |
| assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol) |
| assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol) |
| |
| |
| @pytest.mark.garbage_expected |
| @use_np |
| def test_np_function1(): |
| class Foo(mx.autograd.Function): |
| def __init__(self): |
| super(Foo, self).__init__() |
| |
| def forward(self, X): |
| return X + 1; |
| |
| def backward(self, dY): |
| return dY |
| |
| with mx.autograd.record(): |
| X = mx.np.zeros((3, 4)) |
| #X.attach_grad() # uncommenting this line works |
| for _ in range(5): |
| f = Foo() |
| X = f(X) |
| X.wait_to_read() |
| |
| |
| @pytest.mark.garbage_expected |
| def test_get_symbol(): |
| x = mx.nd.ones((1,)) |
| x.attach_grad() |
| with record(): |
| y = x*x + 2*x - 1 |
| assert len(get_symbol(y).list_arguments()) == 1 |
| |
| z = mx.nd.ones((1,)) |
| z.attach_grad() |
| with record(): |
| y = x*x + 2*z - 1 |
| assert len(get_symbol(y).list_arguments()) == 2 |
| |
| @pytest.mark.garbage_expected |
| def test_grad_with_stype(): |
| def check_grad_with_stype(array_stype, grad_stype, expected_stype): |
| x = mx.nd.zeros((1, 1), stype=array_stype) |
| x.attach_grad(stype=grad_stype) |
| # check grad attached |
| assert x.grad.stype == expected_stype |
| y = x.detach() |
| # check array detached |
| assert y.stype == array_stype |
| |
| stypes = ['default', 'csr', 'row_sparse'] |
| for stype in stypes: |
| # check the default stype of the gradient (same as the array stype) |
| check_grad_with_stype(stype, None, stype) |
| for grad_stype in stypes: |
| # check the stype of the gradient when provided |
| check_grad_with_stype(stype, grad_stype, grad_stype) |
| |
| @pytest.mark.garbage_expected |
| def test_sparse_dot_grad(): |
| def check_sparse_dot_grad(rhs): |
| lhs = rand_ndarray((2, 8), 'csr') |
| with mx.autograd.record(): |
| y = mx.nd.dot(lhs, rhs) |
| y.backward() |
| grad = rhs.grad |
| grad_np = np.dot(lhs.asnumpy().T, np.ones((lhs.shape[0], rhs.shape[1]))) |
| assert grad.stype == 'row_sparse' |
| assert_almost_equal(grad.asnumpy(), grad_np) |
| |
| # check grad with row_sparse weight |
| shape = (8, 3) |
| rsp = mx.nd.ones(shape).tostype('row_sparse') |
| rsp.attach_grad() |
| check_sparse_dot_grad(rsp) |
| |
| # check grad with dense weight |
| dns = mx.nd.ones(shape) |
| dns.attach_grad(stype='row_sparse') |
| check_sparse_dot_grad(dns) |
| |
| def test_gradient(): |
| x = mx.nd.ones((1,)) |
| x.attach_grad() |
| |
| with mx.autograd.record(): |
| z = mx.nd.elemwise_add(mx.nd.exp(x), x) |
| dx, = mx.autograd.grad(z, [x], create_graph=True) |
| assert abs(dx.asscalar() - 3.71828175) < 1e-7 |
| dx.backward() |
| assert abs(x.grad.asscalar() - 2.71828175) < 1e-7 |
| |
| def test_retain_grad_drop_grad(): |
| x = nd.array([1,2,3,4]) |
| x.attach_grad() |
| y = nd.array([5,6,7,8]) |
| y.attach_grad() |
| |
| with mx.autograd.record(): |
| u = x * y |
| z = u * x |
| |
| u.attach_grad() |
| z.attach_grad() |
| out_grad = nd.array([10, 10, 10, 10]) |
| z.backward(out_grad, retain_graph=True) |
| |
| assert (u.grad == out_grad * x).asnumpy().all() |
| assert (z.grad == out_grad).asnumpy().all() |
| assert (x.grad == out_grad * 2 * x * y).asnumpy().all() |
| assert (y.grad == out_grad * x*x).asnumpy().all() |
| |
| u.drop_grad() |
| z.drop_grad() |
| y.drop_grad() |
| out_grad = nd.array([0.1, 0.1, 0.1, 0.1]) |
| z.backward(out_grad) |
| |
| assert u.grad is None and z.grad is None and y.grad is None |
| assert (x.grad == out_grad * 2 * x * y).asnumpy().all() |
| |
| def test_retain_grad_drop_grad_gluon(): |
| class CompBlock(mx.gluon.HybridBlock): |
| def __init__(self): |
| super().__init__() |
| self.marked_var = None |
| def forward(self, a, b): |
| out1 = a*b |
| out2 = out1 * a |
| self.marked_var = out1 |
| return out2 |
| x = mx.np.array([1,2,3,4]) |
| y = mx.np.array([5,6,7,8]) |
| x.attach_grad() |
| y.attach_grad() |
| block2 = CompBlock() |
| block2.initialize() |
| # block2.hybridize() |
| with mx.autograd.record(): |
| z = block2(x, y) |
| u = block2.marked_var |
| u.attach_grad() |
| z.attach_grad() |
| z.backward(retain_graph=True) |
| |
| assert (u.grad == x).all() |
| assert (z.grad == mx.np.array([1,1,1,1])).all() |
| assert (x.grad == 2 * x * y).all() |
| assert (y.grad == x*x).all() |
| |
| u.drop_grad() |
| z.drop_grad() |
| y.drop_grad() |
| z.backward() |
| |
| assert u.grad is None and z.grad is None and y.grad is None |
| assert (x.grad == 2 * x * y).all() |