| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import mxnet as mx |
| from mxnet import gluon |
| from mxnet.gluon import nn |
| from mxnet.test_utils import assert_almost_equal |
| from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID |
| from common import setup_module, with_seed, assertRaises, teardown, assert_raises_cudnn_disabled |
| import numpy as np |
| from numpy.testing import assert_array_equal |
| from nose.tools import raises, assert_raises |
| from copy import deepcopy |
| import warnings |
| import json |
| import unittest |
| |
| @with_seed() |
| def test_parameter(): |
| p = gluon.Parameter('weight', shape=(10, 10)) |
| p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) |
| assert len(p.list_data()) == 2 |
| assert len(p.list_grad()) == 2 |
| assert p.data(mx.cpu(1)).context == mx.cpu(1) |
| assert p.data(mx.cpu(0)).shape == (10, 10) |
| assert p.var().name == 'weight' |
| assert p.grad(mx.cpu(0)).stype == 'default' |
| assert p.data(mx.cpu(0)).stype == 'default' |
| |
| p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)]) |
| assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)] |
| |
| @with_seed() |
| @raises(AssertionError) |
| def test_invalid_parameter_stype(): |
| p = gluon.Parameter('weight', shape=(10, 10), stype='invalid') |
| |
| @with_seed() |
| @raises(AssertionError) |
| def test_invalid_parameter_grad_stype(): |
| p = gluon.Parameter('weight', shape=(10, 10), grad_stype='invalid') |
| |
| @with_seed() |
| def test_sparse_parameter(): |
| p = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse') |
| p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) |
| row_id = mx.nd.arange(0, 10, ctx=mx.cpu(1)) |
| assert len(p.list_grad()) == 2 |
| # getting row_sparse data without trainer throws an exception |
| assertRaises(RuntimeError, p.list_row_sparse_data, row_id) |
| trainer = mx.gluon.Trainer([p], 'sgd') |
| assert len(p.list_row_sparse_data(row_id)) == 2 |
| weight = p.row_sparse_data(row_id) |
| assert weight.context == mx.cpu(1) |
| assert weight.shape == (10, 10) |
| assert weight.stype == 'row_sparse' |
| assert p.var().name == 'weight' |
| assert p.var().attr('__storage_type__') == str(_STORAGE_TYPE_STR_TO_ID['row_sparse']) |
| assert p.grad(mx.cpu(0)).stype == 'row_sparse' |
| |
| p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)]) |
| assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)] |
| |
| @with_seed() |
| def test_parameter_invalid_access(): |
| # cannot call data on row_sparse parameters |
| p0 = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse') |
| p0.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) |
| assertRaises(RuntimeError, p0.data) |
| assertRaises(RuntimeError, p0.list_data) |
| row_id = mx.nd.arange(0, 10) |
| # cannot call row_sparse_data on dense parameters |
| p1 = gluon.Parameter('weight', shape=(10, 10)) |
| p1.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) |
| assertRaises(RuntimeError, p1.row_sparse_data, row_id.copyto(mx.cpu(0))) |
| assertRaises(RuntimeError, p1.list_row_sparse_data, row_id) |
| |
| @with_seed() |
| def test_paramdict(): |
| ctx = mx.cpu(1) |
| params0 = gluon.ParameterDict('net_') |
| params0.get('w0', shape=(10, 10)) |
| params0.get('w1', shape=(10, 10), stype='row_sparse') |
| all_row_ids = mx.nd.arange(0, 10, ctx=ctx) |
| # check param names |
| assert list(params0.keys()) == ['net_w0', 'net_w1'] |
| params0.initialize(ctx=ctx) |
| trainer0 = mx.gluon.Trainer(params0, 'sgd') |
| prev_w0 = params0.get('w0').data(ctx) |
| prev_w1 = params0.get('w1').row_sparse_data(all_row_ids) |
| # save params |
| params0.save('test_paramdict.params') |
| |
| # load params |
| params1 = gluon.ParameterDict('net_') |
| params1.get('w0', shape=(10, 10)) |
| params1.get('w1', shape=(10, 10), stype='row_sparse') |
| params1.load('test_paramdict.params', ctx) |
| trainer1 = mx.gluon.Trainer(params1, 'sgd') |
| |
| # compare the values before and after save/load |
| cur_w0 = params1.get('w0').data(ctx) |
| cur_w1 = params1.get('w1').row_sparse_data(all_row_ids) |
| mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) |
| mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) |
| |
| # create a new param dict with dense params, and load from the checkpoint |
| # of sparse & dense params |
| params2 = gluon.ParameterDict('net_') |
| params2.get('w0', shape=(10, 10)) |
| params2.get('w1', shape=(10, 10)) |
| params2.load('test_paramdict.params', ctx) |
| |
| # compare the values before and after save/load |
| cur_w0 = params2.get('w0').data(ctx) |
| cur_w1 = params2.get('w1').data(ctx) |
| mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) |
| mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) |
| |
| |
| @with_seed() |
| def test_parameter_row_sparse_data(): |
| ctx0 = mx.cpu(1) |
| ctx1 = mx.cpu(2) |
| dim0 = 4 |
| x = gluon.Parameter('x', shape=(dim0, 2), stype='row_sparse') |
| x.initialize(init='xavier', ctx=[ctx0, ctx1]) |
| trainer = gluon.Trainer([x], 'sgd') |
| x_param = x._data[0].copy() |
| assert x_param.stype == 'row_sparse' |
| row_id_0 = mx.nd.array([0,1], ctx=ctx0) |
| retained_0 = x.row_sparse_data(row_id_0) |
| retained_target_0 = mx.nd.sparse.retain(x_param, row_id_0.as_in_context(ctx0)) |
| mx.test_utils.assert_almost_equal(retained_0.asnumpy(), retained_target_0.asnumpy()) |
| assert retained_0.context == ctx0 |
| row_id_1 = mx.nd.arange(0, dim0, ctx=ctx1) |
| retained_1 = x.row_sparse_data(row_id_1) |
| retained_target_1 = x_param |
| mx.test_utils.assert_almost_equal(retained_1.asnumpy(), retained_target_1.asnumpy()) |
| assert retained_1.context == ctx1 |
| row_id_2 = mx.nd.array([0,1,2]) |
| retained_2 = x.list_row_sparse_data(row_id_2) |
| retained_target_2 = mx.nd.sparse.retain(x_param, row_id_2.as_in_context(ctx0)) |
| mx.test_utils.assert_almost_equal(retained_2[0].asnumpy(), retained_target_2.asnumpy()) |
| |
| |
| @with_seed() |
| def test_constant(): |
| class Test(gluon.HybridBlock): |
| def __init__(self, **kwargs): |
| super(Test, self).__init__(**kwargs) |
| self.value = np.asarray([[1,2], [3,4]]) |
| self.const = self.params.get_constant('const', self.value) |
| |
| def hybrid_forward(self, F, x, const): |
| return x + const |
| |
| test = Test() |
| test.initialize() |
| trainer = gluon.Trainer(test.collect_params(), 'sgd', |
| {'learning_rate': 1.0, 'momentum': 0.5}) |
| |
| with mx.autograd.record(): |
| x = mx.nd.ones((2,2)) |
| x.attach_grad() |
| y = test(x) |
| y.backward() |
| |
| trainer.step(1) |
| |
| assert (test.const.data().asnumpy() == test.value).all() |
| assert (x.grad.asnumpy() == 1).all() |
| |
| |
| @with_seed() |
| def test_parameter_sharing(): |
| class Net(gluon.Block): |
| def __init__(self, in_units=0, **kwargs): |
| super(Net, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.dense0 = nn.Dense(5, in_units=in_units) |
| self.dense1 = nn.Dense(5, in_units=in_units) |
| |
| def forward(self, x): |
| return self.dense1(self.dense0(x)) |
| |
| net1 = Net(prefix='net1_', in_units=5) |
| net2 = Net(prefix='net2_', params=net1.collect_params()) |
| net1.collect_params().initialize() |
| net2(mx.nd.zeros((3, 5))) |
| |
| net1.save_parameters('net1.params') |
| |
| net3 = Net(prefix='net3_') |
| net3.load_parameters('net1.params', mx.cpu()) |
| |
| net4 = Net(prefix='net4_') |
| net5 = Net(prefix='net5_', in_units=5, params=net4.collect_params()) |
| net4.collect_params().initialize() |
| net5(mx.nd.zeros((3, 5))) |
| |
| net4.save_parameters('net4.params') |
| |
| net6 = Net(prefix='net6_') |
| net6.load_parameters('net4.params', mx.cpu()) |
| |
| |
| @with_seed() |
| def test_parameter_str(): |
| class Net(gluon.Block): |
| def __init__(self, **kwargs): |
| super(Net, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.dense0 = nn.Dense(10, in_units=5, use_bias=False) |
| |
| net = Net(prefix='net1_') |
| lines = str(net.collect_params()).splitlines() |
| |
| assert lines[0] == 'net1_ (' |
| assert 'net1_dense0_weight' in lines[1] |
| assert '(10, 5)' in lines[1] |
| assert 'float32' in lines[1] |
| assert lines[2] == ')' |
| |
| |
| @with_seed() |
| def test_collect_paramters(): |
| net = nn.HybridSequential(prefix="test_") |
| with net.name_scope(): |
| net.add(nn.Conv2D(10, 3)) |
| net.add(nn.Dense(10, activation='relu')) |
| assert set(net.collect_params().keys()) == \ |
| set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias']) |
| assert set(net.collect_params('.*weight').keys()) == \ |
| set(['test_conv0_weight', 'test_dense0_weight']) |
| assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \ |
| set(['test_conv0_bias', 'test_dense0_bias']) |
| |
| @with_seed() |
| def test_basic(): |
| model = nn.Sequential() |
| model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False)) |
| model.add(nn.Dropout(0.5)) |
| model.add(nn.Dense(64, activation='tanh', in_units=256), |
| nn.Dense(32, in_units=64)) |
| model.add(nn.Activation('relu')) |
| |
| # symbol |
| x = mx.sym.var('data') |
| y = model(x) |
| assert len(y.list_arguments()) == 7 |
| |
| # ndarray |
| model.collect_params().initialize(mx.init.Xavier(magnitude=2.24)) |
| x = model(mx.nd.zeros((32, 2, 10))) |
| assert x.shape == (32, 32) |
| x.wait_to_read() |
| |
| model.collect_params().setattr('grad_req', 'null') |
| assert list(model.collect_params().values())[0]._grad is None |
| model.collect_params().setattr('grad_req', 'write') |
| assert list(model.collect_params().values())[0]._grad is not None |
| |
| |
| @with_seed() |
| def test_dense(): |
| model = nn.Dense(128, activation='tanh', in_units=10, flatten=False, prefix='test_') |
| inputs = mx.sym.Variable('data') |
| outputs = model(inputs) |
| assert set(model.collect_params().keys()) == set(['test_weight', 'test_bias']) |
| assert outputs.list_outputs() == ['test_tanh_fwd_output'] |
| args, outs, auxs = outputs.infer_shape(data=(2, 3, 10)) |
| assert outs == [(2, 3, 128)] |
| |
| model = nn.Dense(128, activation='relu', in_units=30, flatten=True, prefix='test2_') |
| inputs = mx.sym.Variable('data') |
| outputs = model(inputs) |
| assert set(model.collect_params().keys()) == set(['test2_weight', 'test2_bias']) |
| assert outputs.list_outputs() == ['test2_relu_fwd_output'] |
| args, outs, auxs = outputs.infer_shape(data=(17, 2, 5, 3)) |
| assert outs == [(17, 128)] |
| |
| |
| @with_seed() |
| def test_symbol_block(): |
| model = nn.HybridSequential() |
| model.add(nn.Dense(128, activation='tanh')) |
| model.add(nn.Dropout(0.5)) |
| model.add(nn.Dense(64, activation='tanh'), |
| nn.Dense(32, in_units=64)) |
| model.add(nn.Activation('relu')) |
| |
| model.initialize() |
| |
| inputs = mx.sym.var('data') |
| outputs = model(inputs).get_internals() |
| |
| smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params()) |
| |
| assert len(smodel(mx.nd.zeros((16, 10)))) == 14 |
| |
| out = smodel(mx.sym.var('in')) |
| assert len(out) == len(outputs.list_outputs()) |
| |
| class Net(nn.HybridBlock): |
| def __init__(self, model): |
| super(Net, self).__init__() |
| self.model = model |
| |
| def hybrid_forward(self, F, x): |
| out = self.model(x) |
| return F.add_n(*[i.sum() for i in out]) |
| |
| net = Net(smodel) |
| net.hybridize() |
| assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray) |
| |
| inputs = mx.sym.var('data') |
| outputs = model(inputs) |
| smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params()) |
| net = Net(smodel) |
| net.hybridize() |
| assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray) |
| |
| @with_seed() |
| @raises(AssertionError) |
| def test_sparse_symbol_block(): |
| data = mx.sym.var('data') |
| weight = mx.sym.var('weight', stype='row_sparse') |
| bias = mx.sym.var('bias') |
| out = mx.sym.broadcast_add(mx.sym.dot(data, weight), bias) |
| # an exception is expected when creating a SparseBlock w/ sparse param |
| net = gluon.SymbolBlock(out, data) |
| |
| @with_seed() |
| @raises(RuntimeError) |
| def test_sparse_hybrid_block(): |
| params = gluon.ParameterDict('net_') |
| params.get('weight', shape=(5,5), stype='row_sparse', dtype='float32') |
| params.get('bias', shape=(5), dtype='float32') |
| net = gluon.nn.Dense(5, params=params) |
| net.initialize() |
| x = mx.nd.ones((2,5)) |
| # an exception is expected when forwarding a HybridBlock w/ sparse param |
| y = net(x) |
| |
| @with_seed() |
| def check_layer_forward(layer, dshape): |
| print("checking layer {}\nshape: {}.".format(layer, dshape)) |
| layer.collect_params().initialize() |
| x = mx.nd.ones(shape=dshape) |
| x.attach_grad() |
| with mx.autograd.record(): |
| out = layer(x) |
| out.backward() |
| |
| np_out = out.asnumpy() |
| np_dx = x.grad.asnumpy() |
| |
| layer.hybridize() |
| |
| x = mx.nd.ones(shape=dshape) |
| x.attach_grad() |
| with mx.autograd.record(): |
| out = layer(x) |
| out.backward() |
| |
| mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-5, atol=1e-6) |
| mx.test_utils.assert_almost_equal(np_dx, x.grad.asnumpy(), rtol=1e-5, atol=1e-6) |
| |
| @unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11506") |
| @with_seed() |
| def test_conv(): |
| layers1d = [ |
| nn.Conv1D(16, 3, in_channels=4), |
| nn.Conv1D(16, 3, groups=2, in_channels=4), |
| nn.Conv1D(16, 3, strides=3, groups=2, in_channels=4), |
| ] |
| for layer in layers1d: |
| check_layer_forward(layer, (1, 4, 10)) |
| |
| |
| layers2d = [ |
| nn.Conv2D(16, (3, 4), in_channels=4), |
| nn.Conv2D(16, (5, 4), in_channels=4), |
| nn.Conv2D(16, (3, 4), groups=2, in_channels=4), |
| nn.Conv2D(16, (3, 4), strides=4, in_channels=4), |
| nn.Conv2D(16, (3, 4), dilation=4, in_channels=4), |
| nn.Conv2D(16, (3, 4), padding=4, in_channels=4), |
| ] |
| for layer in layers2d: |
| check_layer_forward(layer, (1, 4, 20, 20)) |
| |
| |
| layers3d = [ |
| nn.Conv3D(16, (1, 8, 4), in_channels=4, activation='relu'), |
| nn.Conv3D(16, (5, 4, 3), in_channels=4), |
| nn.Conv3D(16, (3, 3, 3), groups=2, in_channels=4), |
| nn.Conv3D(16, 4, strides=4, in_channels=4), |
| nn.Conv3D(16, (3, 3, 3), padding=4, in_channels=4), |
| ] |
| for layer in layers3d: |
| check_layer_forward(layer, (1, 4, 10, 10, 10)) |
| |
| |
| layer = nn.Conv2D(16, (3, 3), layout='NHWC', in_channels=4) |
| # check_layer_forward(layer, (1, 10, 10, 4)) |
| |
| layer = nn.Conv3D(16, (3, 3, 3), layout='NDHWC', in_channels=4) |
| # check_layer_forward(layer, (1, 10, 10, 10, 4)) |
| |
| |
| @with_seed() |
| def test_deconv(): |
| # layers1d = [ |
| # nn.Conv1DTranspose(16, 3, in_channels=4), |
| # nn.Conv1DTranspose(16, 3, groups=2, in_channels=4), |
| # nn.Conv1DTranspose(16, 3, strides=3, groups=2, in_channels=4), |
| # ] |
| # for layer in layers1d: |
| # check_layer_forward(layer, (1, 4, 10)) |
| |
| |
| layers2d = [ |
| nn.Conv2DTranspose(16, (3, 4), in_channels=4), |
| nn.Conv2DTranspose(16, (5, 4), in_channels=4), |
| nn.Conv2DTranspose(16, (3, 4), groups=2, in_channels=4), |
| nn.Conv2DTranspose(16, (3, 4), strides=4, in_channels=4), |
| nn.Conv2DTranspose(16, (3, 4), dilation=4, in_channels=4), |
| # nn.Conv2DTranspose(16, (3, 4), padding=4, in_channels=4), |
| nn.Conv2DTranspose(16, (3, 4), strides=4, output_padding=3, in_channels=4), |
| ] |
| for layer in layers2d: |
| check_layer_forward(layer, (1, 4, 20, 20)) |
| |
| |
| # layers3d = [ |
| # nn.Conv3DTranspose(16, (1, 8, 4), in_channels=4), |
| # nn.Conv3DTranspose(16, (5, 4, 3), in_channels=4), |
| # nn.Conv3DTranspose(16, (3, 3, 3), groups=2, in_channels=4), |
| # nn.Conv3DTranspose(16, 4, strides=4, in_channels=4), |
| # nn.Conv3DTranspose(16, (3, 3, 3), padding=4, in_channels=4), |
| # ] |
| # for layer in layers3d: |
| # check_layer_forward(layer, (1, 4, 10, 10, 10)) |
| # |
| # |
| # layer = nn.Conv2DTranspose(16, (3, 3), layout='NHWC', in_channels=4) |
| # # check_layer_forward(layer, (1, 10, 10, 4)) |
| # |
| # layer = nn.Conv3DTranspose(16, (3, 3, 3), layout='NDHWC', in_channels=4) |
| # # check_layer_forward(layer, (1, 10, 10, 10, 4)) |
| |
| |
| @with_seed() |
| def test_pool(): |
| layers1d = [ |
| nn.MaxPool1D(), |
| nn.MaxPool1D(3), |
| nn.MaxPool1D(3, 2), |
| nn.AvgPool1D(), |
| nn.AvgPool1D(count_include_pad=False), |
| nn.GlobalAvgPool1D(), |
| ] |
| for layer in layers1d: |
| check_layer_forward(layer, (1, 2, 10)) |
| |
| |
| layers2d = [ |
| nn.MaxPool2D(), |
| nn.MaxPool2D((3, 3)), |
| nn.MaxPool2D(3, 2), |
| nn.AvgPool2D(), |
| nn.AvgPool2D(count_include_pad=False), |
| nn.GlobalAvgPool2D(), |
| ] |
| for layer in layers2d: |
| check_layer_forward(layer, (1, 2, 10, 10)) |
| |
| layers3d = [ |
| nn.MaxPool3D(), |
| nn.MaxPool3D((3, 3, 3)), |
| nn.MaxPool3D(3, 2), |
| nn.AvgPool3D(), |
| nn.AvgPool3D(count_include_pad=False), |
| nn.GlobalAvgPool3D(), |
| ] |
| for layer in layers3d: |
| check_layer_forward(layer, (1, 2, 10, 10, 10)) |
| |
| # test ceil_mode |
| x = mx.nd.zeros((2, 2, 10, 10)) |
| |
| layer = nn.MaxPool2D(3, ceil_mode=False) |
| layer.collect_params().initialize() |
| assert (layer(x).shape==(2, 2, 3, 3)) |
| |
| layer = nn.MaxPool2D(3, ceil_mode=True) |
| layer.collect_params().initialize() |
| assert (layer(x).shape==(2, 2, 4, 4)) |
| |
| |
| @with_seed() |
| def test_batchnorm(): |
| layer = nn.BatchNorm(in_channels=10) |
| check_layer_forward(layer, (2, 10, 10, 10)) |
| |
| |
| @with_seed() |
| def test_instancenorm(): |
| layer = nn.InstanceNorm(in_channels=10) |
| check_layer_forward(layer, (2, 10, 10, 10)) |
| |
| @with_seed() |
| def test_layernorm(): |
| layer = nn.LayerNorm(in_channels=10) |
| check_layer_forward(layer, (2, 10, 10, 10)) |
| |
| |
| @with_seed() |
| def test_reflectionpad(): |
| layer = nn.ReflectionPad2D(3) |
| check_layer_forward(layer, (2, 3, 24, 24)) |
| |
| |
| @with_seed() |
| def test_reshape(): |
| x = mx.nd.ones((2, 4, 10, 10)) |
| layer = nn.Conv2D(10, 2, in_channels=4) |
| layer.collect_params().initialize() |
| with mx.autograd.record(): |
| x = layer(x) |
| x = x.reshape((-1,)) |
| x = x + 10 |
| x.backward() |
| |
| |
| @with_seed() |
| def test_slice(): |
| x = mx.nd.ones((5, 4, 10, 10)) |
| layer = nn.Conv2D(10, 2, in_channels=4) |
| layer.collect_params().initialize() |
| with mx.autograd.record(): |
| x = layer(x) |
| x = x[1:3] |
| x = x + 10 |
| x.backward() |
| |
| |
| @with_seed() |
| def test_at(): |
| x = mx.nd.ones((5, 4, 10, 10)) |
| layer = nn.Conv2D(10, 2, in_channels=4) |
| layer.collect_params().initialize() |
| with mx.autograd.record(): |
| x = layer(x) |
| x = x[1] |
| x = x + 10 |
| x.backward() |
| |
| |
| @with_seed() |
| def test_deferred_init(): |
| x = mx.nd.ones((5, 4, 10, 10)) |
| layer = nn.Conv2D(10, 2) |
| layer.collect_params().initialize() |
| layer(x) |
| |
| |
| def check_split_data(x, num_slice, batch_axis, **kwargs): |
| res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs) |
| assert len(res) == num_slice |
| mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(), |
| x.asnumpy()) |
| |
| |
| @with_seed() |
| def test_split_data(): |
| x = mx.nd.random.uniform(shape=(128, 33, 64)) |
| |
| check_split_data(x, 8, 0) |
| check_split_data(x, 3, 1) |
| check_split_data(x, 4, 1, even_split=False) |
| check_split_data(x, 15, 1, even_split=False) |
| try: |
| check_split_data(x, 4, 1) |
| except ValueError: |
| return |
| assert False, "Should have failed" |
| |
| |
| @with_seed() |
| def test_flatten(): |
| flatten = nn.Flatten() |
| x = mx.nd.zeros((3,4,5,6)) |
| assert flatten(x).shape == (3, 4*5*6) |
| x = mx.nd.zeros((3,6)) |
| assert flatten(x).shape == (3, 6) |
| x = mx.nd.zeros((3,)) |
| assert flatten(x).shape == (3, 1) |
| |
| @with_seed() |
| def test_block_attr_hidden(): |
| b = gluon.Block() |
| |
| # regular attributes can change types |
| b.a = None |
| b.a = 1 |
| |
| |
| @raises(TypeError) |
| @with_seed() |
| def test_block_attr_block(): |
| b = gluon.Block() |
| |
| # regular variables can't change types |
| b.b = gluon.Block() |
| b.b = (2,) |
| |
| |
| @raises(TypeError) |
| @with_seed() |
| def test_block_attr_param(): |
| b = gluon.Block() |
| |
| # regular variables can't change types |
| b.b = gluon.Parameter() |
| b.b = (2,) |
| |
| |
| @with_seed() |
| def test_block_attr_regular(): |
| b = gluon.Block() |
| |
| # set block attribute also sets _children |
| b.c = gluon.Block() |
| c2 = gluon.Block() |
| b.c = c2 |
| assert b.c is c2 and list(b._children.values())[0] is c2 |
| |
| |
| @with_seed() |
| def test_block_attr_list_of_block(): |
| class Model1(gluon.Block): |
| def __init__(self, **kwargs): |
| super(Model1, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.layers = [nn.Dense(i * 10) for i in range(6)] |
| |
| class Model2(gluon.Block): |
| def __init__(self, **kwargs): |
| super(Model2, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.layers = dict() |
| self.layers['a'] = [nn.Dense(10), nn.Dense(10)] |
| |
| class Model3(gluon.Block): |
| def __init__(self, **kwargs): |
| super(Model3, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.layers = nn.Sequential() |
| self.layers.add(*[nn.Dense(i * 10) for i in range(6)]) |
| |
| class Model4(gluon.Block): |
| def __init__(self, **kwargs): |
| super(Model4, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.data = {'a': '4', 'b': 123} |
| |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter('always') |
| model = Model1() |
| model.collect_params() |
| assert len(w) > 0 |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter('always') |
| model = Model2() |
| model.collect_params() |
| assert len(w) > 0 |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter('always') |
| model = Model3() |
| model.collect_params() |
| assert len(w) == 0 |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter('always') |
| model = Model4() |
| model.collect_params() |
| assert len(w) == 0 |
| |
| def check_sequential(net): |
| dense1 = gluon.nn.Dense(10) |
| net.add(dense1) |
| dense2 = gluon.nn.Dense(10) |
| net.add(dense2) |
| dense3 = gluon.nn.Dense(10) |
| net.add(dense3) |
| |
| assert net[1] is dense2 |
| assert net[-1] is dense3 |
| slc = net[1:3] |
| assert len(slc) == 2 and slc[0] is dense2 and slc[1] is dense3 |
| assert isinstance(slc, type(net)) |
| |
| @with_seed() |
| def test_sequential(): |
| check_sequential(gluon.nn.Sequential()) |
| check_sequential(gluon.nn.HybridSequential()) |
| |
| @with_seed() |
| def test_sequential_warning(): |
| with warnings.catch_warnings(record=True) as w: |
| # The following line permits the test to pass if run multiple times |
| warnings.simplefilter('always') |
| b = gluon.nn.Sequential() |
| b.add(gluon.nn.Dense(20)) |
| b.hybridize() |
| assert len(w) == 1 |
| |
| |
| @with_seed() |
| def test_global_norm_clip(): |
| stypes = ['default', 'row_sparse'] |
| def check_global_norm_clip(stype): |
| x1 = mx.nd.ones((3,3)).tostype(stype) |
| x2 = mx.nd.ones((4,4)).tostype(stype) |
| norm = gluon.utils.clip_global_norm([x1, x2], 1.0) |
| assert norm == 5.0 |
| assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) |
| assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) |
| |
| x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype) |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter("always") |
| gluon.utils.clip_global_norm([x1, x3], 2.0) |
| assert len(w) == 1 |
| |
| for stype in stypes: |
| check_global_norm_clip(stype) |
| |
| @with_seed() |
| def test_embedding(): |
| def check_embedding(sparse_grad): |
| layer = gluon.nn.Embedding(10, 100, sparse_grad=sparse_grad) |
| layer.initialize() |
| x = mx.nd.array([3,4,2,0,1]) |
| with mx.autograd.record(): |
| y = layer(x) |
| y.backward() |
| assert (layer.weight.grad().asnumpy()[:5] == 1).all() |
| assert (layer.weight.grad().asnumpy()[5:] == 0).all() |
| |
| def check_embedding_large_input(sparse_grad): |
| embedding = mx.gluon.nn.Embedding(10, 1, sparse_grad=True) |
| embedding.initialize() |
| embedding.hybridize() |
| shape = (20481,) |
| with mx.autograd.record(): |
| emb_in = embedding(mx.nd.ones(shape)) |
| loss = emb_in.sum() |
| loss.backward() |
| assert embedding.weight.grad().data.sum().asscalar() == 20481 |
| |
| check_embedding(True) |
| check_embedding(False) |
| check_embedding_large_input(True) |
| check_embedding_large_input(False) |
| |
| @unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11616") |
| @with_seed() |
| def test_export(): |
| ctx = mx.context.current_context() |
| model = gluon.model_zoo.vision.resnet18_v1( |
| prefix='resnet', ctx=ctx, pretrained=True) |
| model.hybridize() |
| data = mx.nd.random.normal(shape=(1, 3, 32, 32)) |
| out = model(data) |
| |
| model.export('gluon') |
| |
| module = mx.mod.Module.load('gluon', 0, label_names=None, context=ctx) |
| module.bind(data_shapes=[('data', data.shape)]) |
| module.forward(mx.io.DataBatch([data], None), is_train=False) |
| mod_out, = module.get_outputs() |
| |
| assert_almost_equal(out.asnumpy(), mod_out.asnumpy()) |
| |
| model2 = gluon.model_zoo.vision.resnet18_v1(prefix='resnet', ctx=ctx) |
| model2.collect_params().load('gluon-0000.params', ctx) |
| out2 = model2(data) |
| |
| assert_almost_equal(out.asnumpy(), out2.asnumpy()) |
| |
| @with_seed() |
| def test_import(): |
| ctx = mx.context.current_context() |
| net1 = gluon.model_zoo.vision.resnet18_v1( |
| prefix='resnet', ctx=ctx, pretrained=True) |
| net1.hybridize() |
| data = mx.nd.random.normal(shape=(1, 3, 32, 32)) |
| out1 = net1(data) |
| |
| net1.export('net1', epoch=1) |
| |
| net2 = gluon.SymbolBlock.imports( |
| 'net1-symbol.json', ['data'], 'net1-0001.params', ctx) |
| out2 = net2(data) |
| |
| assert_almost_equal(out1.asnumpy(), out2.asnumpy()) |
| |
| @with_seed() |
| def test_hybrid_stale_cache(): |
| net = mx.gluon.nn.HybridSequential() |
| with net.name_scope(): |
| net.add(mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False)) |
| |
| net.hybridize() |
| net.initialize() |
| net(mx.nd.ones((2,3,5))) |
| |
| net.add(mx.gluon.nn.Flatten()) |
| assert net(mx.nd.ones((2,3,5))).shape == (2, 30) |
| |
| net = mx.gluon.nn.HybridSequential() |
| with net.name_scope(): |
| net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros', |
| bias_initializer='ones', flatten=False) |
| net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', |
| bias_initializer='ones', flatten=False) |
| net.hybridize() |
| net.initialize() |
| net(mx.nd.ones((2,3,5))) |
| |
| net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', |
| bias_initializer='ones', flatten=True) |
| net.initialize() |
| assert net(mx.nd.ones((2,3,5))).shape == (2, 10) |
| |
| |
| @with_seed() |
| def test_lambda(): |
| net1 = mx.gluon.nn.HybridSequential() |
| net1.add(nn.Activation('tanh'), |
| nn.LeakyReLU(0.1)) |
| |
| net2 = mx.gluon.nn.HybridSequential() |
| op3 = lambda F, x, *args: F.LeakyReLU(x, *args, slope=0.1) |
| net2.add(nn.HybridLambda('tanh'), |
| nn.HybridLambda(op3)) |
| |
| op4 = lambda x: mx.nd.LeakyReLU(x, slope=0.1) |
| net3 = mx.gluon.nn.Sequential() |
| net3.add(nn.Lambda('tanh'), |
| nn.Lambda(op4)) |
| |
| input_data = mx.nd.random.uniform(shape=(2, 3, 5, 7)) |
| out1, out2, out3 = net1(input_data), net2(input_data), net3(input_data) |
| assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3, atol=1e-3) |
| assert_almost_equal(out1.asnumpy(), out3.asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| |
| @with_seed() |
| def test_fill_shape_deferred(): |
| net = nn.HybridSequential() |
| with net.name_scope(): |
| net.add(nn.Conv2D(64, kernel_size=2, padding=1), |
| nn.BatchNorm(), |
| nn.Dense(10)) |
| net.hybridize() |
| net.initialize() |
| net(mx.nd.ones((2,3,5,7))) |
| assert net[0].weight.shape[1] == 3, net[0].weight.shape[1] |
| assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0] |
| assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1] |
| |
| |
| @with_seed() |
| def test_dtype(): |
| net = mx.gluon.model_zoo.vision.resnet18_v1() |
| net.initialize() |
| net.cast('float64') |
| with mx.autograd.record(): |
| y = net(mx.nd.ones((16, 3, 32, 32), dtype='float64')) |
| y.backward() |
| |
| net = mx.gluon.model_zoo.vision.resnet18_v1() |
| net.initialize() |
| net.hybridize() |
| net(mx.nd.ones((16, 3, 32, 32), dtype='float32')) |
| |
| net.cast('float64') |
| net(mx.nd.ones((16, 3, 32, 32), dtype='float64')) |
| |
| mx.nd.waitall() |
| |
| class Net(gluon.Block): |
| def __init__(self, in_dim, output_dim): |
| super(Net, self).__init__() |
| with self.name_scope(): |
| self.embed = gluon.nn.Embedding(input_dim=in_dim, output_dim=output_dim,dtype=np.float64) |
| self.dense = gluon.nn.Dense(2, dtype=np.float64) |
| |
| def forward(self, x): |
| e = self.embed(x) |
| assert(e.dtype == np.float64) |
| y = self.dense(e) |
| assert(y.dtype == np.float64) |
| return y |
| |
| net = Net(5, 10) |
| net.initialize() |
| out = net(mx.nd.ones((3,), dtype=np.float64)) |
| mx.nd.waitall() |
| |
| @with_seed() |
| def test_fill_shape_load(): |
| ctx = mx.context.current_context() |
| net1 = nn.HybridSequential() |
| with net1.name_scope(): |
| net1.add(nn.Conv2D(64, kernel_size=2, padding=1), |
| nn.BatchNorm(), |
| nn.Dense(10)) |
| net1.hybridize() |
| net1.initialize(ctx=ctx) |
| net1(mx.nd.ones((2,3,5,7), ctx)) |
| net1.save_parameters('net_fill.params') |
| |
| net2 = nn.HybridSequential() |
| with net2.name_scope(): |
| net2.add(nn.Conv2D(64, kernel_size=2, padding=1), |
| nn.BatchNorm(), |
| nn.Dense(10)) |
| net2.hybridize() |
| net2.initialize() |
| net2.load_parameters('net_fill.params', ctx) |
| assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1] |
| assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0] |
| assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1] |
| |
| |
| @with_seed() |
| def test_inline(): |
| net = mx.gluon.nn.HybridSequential() |
| with net.name_scope(): |
| net.add(mx.gluon.nn.Dense(10)) |
| net.add(mx.gluon.nn.Dense(10)) |
| net.add(mx.gluon.nn.Dense(10)) |
| |
| net.initialize() |
| net.hybridize(inline_limit=3) |
| with mx.autograd.record(): |
| y = net(mx.nd.zeros((1,10))) |
| |
| len_1 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes']) |
| y.backward() |
| |
| net.hybridize(inline_limit=0) |
| with mx.autograd.record(): |
| y = net(mx.nd.zeros((1,10))) |
| |
| len_2 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes']) |
| y.backward() |
| |
| assert len_1 == len_2 + 2 |
| |
| |
| @with_seed() |
| def test_activations(): |
| point_to_validate = mx.nd.array([-0.1, 0.1] * 3) |
| |
| swish = mx.gluon.nn.Swish() |
| def swish_test(x): |
| return x * mx.nd.sigmoid(x) |
| |
| for test_point, ref_point in zip(swish_test(point_to_validate), swish(point_to_validate)): |
| assert test_point == ref_point |
| |
| elu = mx.gluon.nn.ELU() |
| def elu_test(x): |
| def elu(x): |
| return 1.0 * (mx.nd.exp(x) - 1) if x < 0 else x |
| return [elu(x_i) for x_i in x] |
| |
| for test_point, ref_point in zip(elu_test(point_to_validate), elu(point_to_validate)): |
| assert test_point == ref_point |
| |
| selu = mx.gluon.nn.SELU() |
| def selu_test(x): |
| def selu(x): |
| scale, alpha = 1.0507009873554804934193349852946, 1.6732632423543772848170429916717 |
| return scale * x if x >= 0 else alpha * mx.nd.exp(x) - alpha |
| return [selu(x_i) for x_i in x] |
| |
| for test_point, ref_point in zip(selu(point_to_validate), selu(point_to_validate)): |
| assert test_point == ref_point |
| |
| prelu = mx.gluon.nn.PReLU() |
| prelu.initialize() |
| x = point_to_validate.reshape((1, 3, 2)) |
| assert_almost_equal(prelu(x).asnumpy(), mx.nd.where(x >= 0, x, 0.25 * x).asnumpy()) |
| |
| @with_seed() |
| def test_dropout(): |
| def get_slice(x, axis, idx): |
| ix = () |
| for i in range(x.ndim): |
| if i == axis: |
| ix += (idx,) |
| else: |
| ix += (slice(None, None, None),) |
| return x[ix] |
| |
| def check_dropout_axes(ratio, shape, axes): |
| compactshape = list(shape) |
| for axis in axes: |
| compactshape[axis] = 1 |
| compactx = mx.random.uniform(shape=tuple(compactshape)) |
| broadcastx = compactx.broadcast_to(shape) |
| dropouty = mx.gluon.nn.Dropout(rate=ratio, axes=axes)(broadcastx) |
| for axis in axes: |
| target = get_slice(dropouty, axis, 0).asnumpy() |
| for i in range(1, shape[axis]): |
| assert(get_slice(dropouty, axis, i).asnumpy() == target).all() |
| |
| nshape = (10, 10, 10, 10) |
| with mx.autograd.train_mode(): |
| check_dropout_axes(0.25, nshape, axes = (0,)) |
| check_dropout_axes(0.25, nshape, axes = (1,)) |
| check_dropout_axes(0.25, nshape, axes = (2,)) |
| check_dropout_axes(0.25, nshape, axes = (3,)) |
| check_dropout_axes(0.25, nshape, axes = (0, 1)) |
| check_dropout_axes(0.25, nshape, axes = (0, 2)) |
| check_dropout_axes(0.25, nshape, axes = (0, 3)) |
| check_dropout_axes(0.25, nshape, axes = (1, 2)) |
| check_dropout_axes(0.25, nshape, axes = (1, 3)) |
| check_dropout_axes(0.25, nshape, axes = (2, 3)) |
| check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) |
| check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) |
| check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) |
| |
| @with_seed() |
| def test_req(): |
| data = mx.nd.random.uniform(shape=(1,3,224,224)) |
| label = mx.nd.random.uniform(shape=(1)) |
| label[:] = 1 |
| loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| |
| net = nn.HybridSequential() |
| net1 = nn.HybridSequential() |
| net1.add(nn.Dense(4)) |
| net2 = nn.HybridSequential() |
| net2.add(nn.Dense(3)) |
| net2.add(nn.Dense(2)) |
| net.add(net1) |
| net.add(net2) |
| net.initialize() |
| |
| net.hybridize() |
| |
| for v in net.collect_params().values(): |
| v.grad_req = 'add' |
| |
| net.collect_params().zero_grad() |
| with mx.autograd.record(): |
| pred = net(data) |
| l = loss(pred, label) |
| l.backward() |
| grad = net[0][0].weight.grad().mean().asnumpy() |
| # run twice to check req = add |
| pred = net(data) |
| l = loss(pred, label) |
| l.backward() |
| |
| grad_double = net[0][0].weight.grad().mean().asnumpy() |
| assert_almost_equal(grad * 2, grad_double) |
| |
| |
| @with_seed() |
| def test_save_load(): |
| net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True) |
| net.save_parameters('test_save_load.params') |
| |
| net = mx.gluon.model_zoo.vision.get_resnet(1, 18) |
| net.output = mx.gluon.nn.Dense(1000) |
| |
| net.load_parameters('test_save_load.params') |
| |
| class Network(gluon.Block): |
| def __init__(self, **kwargs): |
| super(Network, self).__init__(**kwargs) |
| with self.name_scope(): |
| self.encoders = gluon.nn.Sequential() |
| with self.encoders.name_scope(): |
| for _ in range(2): |
| lstm = mx.gluon.rnn.LSTM(200, 1, bidirectional=True) |
| self.encoders.add(lstm) |
| |
| def forward(self, x): |
| for i in range(2): |
| x = self.encoders[i](x) |
| return x |
| net = Network() |
| net.initialize(mx.init.Xavier(), ctx=mx.cpu()) |
| net.hybridize() |
| x = np.random.rand(32, 10, 10) |
| x = mx.nd.array(x).as_in_context(mx.cpu()) |
| net(x) |
| net.save_parameters('tmp.params') |
| net2 = Network() |
| net2.load_parameters('tmp.params') |
| |
| @with_seed() |
| def test_symbol_block_save_load(): |
| class Net(gluon.HybridBlock): |
| def __init__(self): |
| super(Net, self).__init__() |
| with self.name_scope(): |
| backbone = gluon.model_zoo.vision.resnet18_v1() |
| data = mx.sym.var('data') |
| featnames = ['stage1_activation0', 'stage2_activation0', 'stage3_activation0'] |
| out_names = ['_'.join([backbone.name, featname, 'output']) for featname in featnames] |
| internals = backbone(data).get_internals() |
| outs = [internals[out_name] for out_name in out_names] |
| self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params()) |
| self.body = nn.Conv2D(3, 1) |
| |
| def hybrid_forward(self, F, x): |
| x = self.body(x) |
| return self.backbone(x) |
| |
| net1 = Net() |
| net1.initialize(mx.init.Normal()) |
| net1.hybridize() |
| net1(mx.nd.random.normal(shape=(1, 3, 32, 32))) |
| net1.save_parameters('./test_symbol_block_save_load.params') |
| |
| net2 = Net() |
| net2.load_parameters('./test_symbol_block_save_load.params', ctx=mx.cpu()) |
| |
| |
| @with_seed() |
| def test_hybrid_multi_context(): |
| net = mx.gluon.model_zoo.vision.get_resnet(1, 18) |
| net.initialize(ctx=[mx.cpu(0), mx.cpu(1)]) |
| net.hybridize() |
| net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy() |
| |
| @with_seed() |
| def test_zero_grad(): |
| data = mx.nd.random.uniform(shape=(3,3)) |
| net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_') |
| net.initialize() |
| with mx.autograd.record(): |
| l = net(data) |
| l.backward() |
| net.collect_params().zero_grad() |
| grad = net.collect_params()['test_zero_grad_weight'].grad() |
| assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) |
| |
| def check_hybrid_static_memory(**kwargs): |
| x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) |
| x.attach_grad() |
| |
| net1 = gluon.model_zoo.vision.get_resnet( |
| 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) |
| net2 = gluon.model_zoo.vision.get_resnet( |
| 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) |
| net2.hybridize(**kwargs) |
| net1(x) |
| net2(x) |
| |
| def test(net, x): |
| with mx.autograd.record(): |
| y = net(x) + net(x) |
| y.backward() |
| |
| grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'} |
| |
| return y, grads |
| |
| y1, grads1 = test(net1, x) |
| y2, grads2 = test(net2, x) |
| |
| assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5) |
| for key in grads1: |
| assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5) |
| |
| @with_seed() |
| def test_hybrid_static_memory(): |
| check_hybrid_static_memory() |
| check_hybrid_static_memory(static_alloc=True) |
| check_hybrid_static_memory(static_alloc=True, static_shape=True) |
| |
| def check_hybrid_static_memory_switching(**kwargs): |
| net = gluon.model_zoo.vision.get_resnet( |
| 1, 18, pretrained=True, ctx=mx.context.current_context()) |
| net.hybridize(**kwargs) |
| |
| x = mx.nd.random.uniform(shape=(4, 3, 32, 32)) |
| net(x) |
| with mx.autograd.record(): |
| y = net(x) |
| y.backward() |
| x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) |
| net(x) |
| with mx.autograd.record(): |
| y = net(x) |
| y.backward() |
| mx.nd.waitall() |
| |
| @with_seed() |
| def test_hybrid_static_memory_switching(): |
| check_hybrid_static_memory_switching() |
| check_hybrid_static_memory_switching(static_alloc=True) |
| check_hybrid_static_memory_switching(static_alloc=True, static_shape=True) |
| |
| @with_seed() |
| def test_hook(): |
| global hook_call_count |
| hook_call_count = 0 |
| global pre_hook_call_count |
| pre_hook_call_count = 0 |
| |
| def call_hook(block, x, y): |
| global hook_call_count |
| hook_call_count += 1 |
| |
| def call_pre_hook(block, x): |
| global pre_hook_call_count |
| pre_hook_call_count += 1 |
| |
| block = nn.Dense(10) |
| block.initialize() |
| handle = block.register_forward_hook(call_hook) |
| pre_handle = block.register_forward_pre_hook(call_pre_hook) |
| block(mx.nd.ones((3, 5))) |
| |
| assert hook_call_count == 1 |
| assert pre_hook_call_count == 1 |
| |
| handle.detach() |
| block(mx.nd.ones((3, 5))) |
| |
| assert hook_call_count == 1 |
| assert pre_hook_call_count == 2 |
| |
| pre_handle.detach() |
| block(mx.nd.ones((3, 5))) |
| assert hook_call_count == 1 |
| assert pre_hook_call_count == 2 |
| |
| |
| @with_seed() |
| def test_apply(): |
| global called_blocks |
| called_blocks = [] |
| |
| def record_name(block): |
| global called_blocks |
| called_blocks.append(block.name) |
| |
| block = nn.HybridSequential(prefix='test_') |
| with block.name_scope(): |
| block.add(nn.Dense(10)) |
| block.add(nn.Dropout(0.5)) |
| block.apply(record_name) |
| |
| assert called_blocks == ['test_dense0', 'test_dropout0', 'test'] |
| |
| |
| @with_seed() |
| @assert_raises_cudnn_disabled() |
| def test_summary(): |
| net = gluon.model_zoo.vision.resnet50_v1() |
| net.initialize() |
| net.summary(mx.nd.ones((32, 3, 224, 224))) |
| |
| net2 = nn.Sequential() |
| with net2.name_scope(): |
| net2.add(nn.Embedding(40, 30)) |
| net2.add(gluon.rnn.LSTM(30)) |
| net2.add(nn.Dense(40, flatten=False, params=net2[0].params)) |
| net2.initialize() |
| net2.summary(mx.nd.ones((80, 32))) |
| |
| net3 = gluon.rnn.LSTM(30) |
| net3.initialize() |
| begin_state = net3.begin_state(32) |
| net3.summary(mx.nd.ones((80, 32, 5)), begin_state) |
| |
| net.hybridize() |
| assert_raises(AssertionError, net.summary, mx.nd.ones((32, 3, 224, 224))) |
| |
| |
| @with_seed() |
| def test_legacy_save_params(): |
| net = gluon.nn.HybridSequential(prefix='') |
| with net.name_scope(): |
| net.add(gluon.nn.Conv2D(10, (3, 3))) |
| net.add(gluon.nn.Dense(50)) |
| net.initialize() |
| net(mx.nd.ones((1,1,50,50))) |
| a = net(mx.sym.var('data')) |
| a.save('test.json') |
| net.save_params('test.params') |
| model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json', 'r').read()), |
| inputs=mx.sym.var('data')) |
| model.load_params('test.params', ctx=mx.cpu()) |
| |
| |
| @with_seed() |
| def test_sparse_hybrid_block_grad(): |
| class Embedding(mx.gluon.HybridBlock): |
| def __init__(self, num_tokens, embedding_size): |
| super(Embedding, self).__init__() |
| self.num_tokens = num_tokens |
| |
| with self.name_scope(): |
| self.embedding = mx.gluon.nn.Embedding( |
| num_tokens, embedding_size, sparse_grad=True) |
| |
| def hybrid_forward(self, F, words): |
| emb = self.embedding(words) |
| return emb + F.ones_like(emb) |
| |
| embedding = Embedding(20, 3) |
| embedding.initialize() |
| embedding.hybridize() |
| |
| with mx.autograd.record(): |
| emb0 = embedding(mx.nd.arange(10)).sum() |
| emb1 = embedding(mx.nd.arange(10)).sum() |
| loss = emb0 + emb1 |
| loss.backward() |
| grad = embedding.embedding.weight.grad().asnumpy() |
| assert (grad[:10] == 2).all() |
| assert (grad[10:] == 0).all() |
| |
| @with_seed() |
| def test_sparse_hybrid_block(): |
| class Linear(mx.gluon.HybridBlock): |
| def __init__(self, units): |
| super(Linear, self).__init__() |
| with self.name_scope(): |
| self.w = self.params.get('w', shape=(units, units)) |
| |
| def hybrid_forward(self, F, x, w): |
| return F.dot(x, w) |
| |
| class SparseBlock(mx.gluon.HybridBlock): |
| def __init__(self, units): |
| super(SparseBlock, self).__init__() |
| with self.name_scope(): |
| self.net = Linear(units) |
| |
| def hybrid_forward(self, F, x): |
| return self.net(x) * x |
| |
| block = SparseBlock(2) |
| block.initialize() |
| block.hybridize() |
| x = mx.nd.ones((2,2)).tostype('csr') |
| with mx.autograd.record(): |
| z = block(x) + block(x) |
| z.backward() |
| assert (block.net.w.grad().asnumpy() == 4).all() |
| |
| def test_hybrid_static_memory_recording(): |
| net = gluon.model_zoo.vision.get_resnet( |
| 1, 18, pretrained=True, ctx=mx.context.current_context()) |
| net.hybridize(static_alloc=True) |
| |
| x = mx.nd.random.uniform(shape=(1, 3, 32, 32)) |
| with mx.autograd.record(True): |
| net(x) |
| net(x) |
| |
| |
| def test_share_inputs_outputs(): |
| class TestIOBackward(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestIOBackward, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, in1, in2): |
| return in1 + in2 |
| |
| class TestIOForward(gluon.HybridBlock): |
| def __init__(self, prefix=None, params=None): |
| super(TestIOForward, self).__init__(prefix=prefix, params=params) |
| |
| def hybrid_forward(self, F, in1): |
| return in1 |
| |
| d1 = mx.nd.arange(10) |
| d2 = mx.nd.arange(10) |
| |
| params=[{'inline_limit':0}, |
| {'inline_limit':0, 'static_alloc':True}, |
| {'inline_limit':0, 'static_alloc':True, 'static_shape':True}] |
| # Test the case that inputs and outputs of a forward graph share NDArrays. |
| for param in params: |
| t = TestIOForward() |
| t.hybridize(**param) |
| for i in range(5): |
| d1.attach_grad() |
| out_grad = mx.nd.random.uniform(shape=(10)) |
| res = t(d1) |
| assert_almost_equal(res.asnumpy(), d1.asnumpy()) |
| |
| param = deepcopy(params[2]) |
| param['param_indices'] = (1) |
| param['data_indices'] = (0) |
| params.append(param) |
| # Test the case that inputs and outputs of a backward graph share NDArrays. |
| for param in params: |
| t = TestIOBackward() |
| t.hybridize(**param) |
| for i in range(5): |
| d1.attach_grad() |
| d2.attach_grad() |
| out_grad = mx.nd.random.uniform(shape=(10)) |
| with mx.autograd.record(): |
| res = t(d1, d2) |
| res.backward(out_grad=out_grad) |
| assert_almost_equal(out_grad.asnumpy(), d1.grad.asnumpy()) |
| assert_almost_equal(out_grad.asnumpy(), d2.grad.asnumpy()) |
| |
| |
| def test_grad_graph_change(): |
| class Model(mx.gluon.HybridBlock): |
| def hybrid_forward(self, F, array, index): |
| row = array.take(index) |
| return row, index |
| array = mx.nd.arange(3) |
| index = mx.nd.array([2]) |
| array.attach_grad() |
| model = Model() |
| model.hybridize(inline_limit=0) |
| with mx.autograd.record(train_mode=True): |
| row, _ = model(array, index) |
| row.backward() |
| |
| |
| if __name__ == '__main__': |
| import nose |
| nose.runmodule() |