blob: 3c28e24c0158ebb7aa7d64d28dd9d737c2fc109e [file] [log] [blame]
import mxnet as mx
def test_module_layout():
sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')
dshape = (3, 8, 7)
mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)])
mod.bind(data_shapes=[mx.io.DataDesc('data', dshape, layout='TNC')])
mod.init_params()
mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)],
label=None))
mod.backward([mx.nd.ones(dshape)])
assert mod.get_outputs()[0].shape == dshape
hdshape = (3, 4, 7)
for x in mod.get_outputs(merge_multi_context=False)[0]:
assert x.shape == hdshape
def test_save_load():
def dict_equ(a, b):
assert set(a) == set(b)
for k in a:
assert (a[k].asnumpy() == b[k].asnumpy()).all()
sym = mx.sym.Variable('data')
sym = mx.sym.FullyConnected(sym, num_hidden=100)
# single device
mod = mx.mod.Module(sym, ('data',))
mod.bind(data_shapes=[('data', (10, 10))])
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
mod.update()
mod.save_checkpoint('test', 0, save_optimizer_states=True)
mod2 = mx.mod.Module.load('test', 0, load_optimizer_states=True, data_names=('data',))
mod2.bind(data_shapes=[('data', (10, 10))])
mod2.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
assert mod._symbol.tojson() == mod2._symbol.tojson()
dict_equ(mod.get_params()[0], mod2.get_params()[0])
dict_equ(mod._updater.states, mod2._updater.states)
# multi device
mod = mx.mod.Module(sym, ('data',), context=[mx.cpu(0), mx.cpu(1)])
mod.bind(data_shapes=[('data', (10, 10))])
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
mod.update()
mod.save_checkpoint('test', 0, save_optimizer_states=True)
mod2 = mx.mod.Module.load('test', 0, load_optimizer_states=True, data_names=('data',))
mod2.bind(data_shapes=[('data', (10, 10))])
mod2.init_optimizer(optimizer_params={'learning_rate':0.1, 'momentum':0.9})
assert mod._symbol.tojson() == mod2._symbol.tojson()
dict_equ(mod.get_params()[0], mod2.get_params()[0])
dict_equ(mod._kvstore._updater.states, mod2._updater.states)
if __name__ == '__main__':
test_save_load()
test_module_layout()