| import mxnet as mx |
| |
| def test_module_layout(): |
| sym = mx.sym.Variable('data') |
| sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') |
| |
| dshape = (3, 8, 7) |
| mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)]) |
| mod.bind(data_shapes=[mx.io.DataDesc('data', dshape, layout='TNC')]) |
| mod.init_params() |
| mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)], |
| label=None)) |
| mod.backward([mx.nd.ones(dshape)]) |
| assert mod.get_outputs()[0].shape == dshape |
| |
| hdshape = (3, 4, 7) |
| for x in mod.get_outputs(merge_multi_context=False)[0]: |
| assert x.shape == hdshape |
| |
| def test_save_load(): |
| def dict_equ(a, b): |
| assert set(a) == set(b) |
| for k in a: |
| assert (a[k].asnumpy() == b[k].asnumpy()).all() |
| |
| sym = mx.sym.Variable('data') |
| sym = mx.sym.FullyConnected(sym, num_hidden=100) |
| |
| # single device |
| mod = mx.mod.Module(sym, ('data',)) |
| mod.bind(data_shapes=[('data', (10, 10))]) |
| mod.init_params() |
| mod.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9}) |
| mod.update() |
| mod.save_checkpoint('test', 0, save_optimizer_states=True) |
| |
| mod2 = mx.mod.Module.load('test', 0, load_optimizer_states=True, data_names=('data',)) |
| mod2.bind(data_shapes=[('data', (10, 10))]) |
| mod2.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9}) |
| assert mod._symbol.tojson() == mod2._symbol.tojson() |
| dict_equ(mod.get_params()[0], mod2.get_params()[0]) |
| dict_equ(mod._updater.states, mod2._updater.states) |
| |
| # multi device |
| mod = mx.mod.Module(sym, ('data',), context=[mx.cpu(0), mx.cpu(1)]) |
| mod.bind(data_shapes=[('data', (10, 10))]) |
| mod.init_params() |
| mod.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9}) |
| mod.update() |
| mod.save_checkpoint('test', 0, save_optimizer_states=True) |
| |
| mod2 = mx.mod.Module.load('test', 0, load_optimizer_states=True, data_names=('data',)) |
| mod2.bind(data_shapes=[('data', (10, 10))]) |
| mod2.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9}) |
| assert mod._symbol.tojson() == mod2._symbol.tojson() |
| dict_equ(mod.get_params()[0], mod2.get_params()[0]) |
| dict_equ(mod._kvstore._updater.states, mod2._updater.states) |
| |
| if __name__ == '__main__': |
| test_save_load() |
| test_module_layout() |