blob: 372ad3553c06600ee70339818acf3724eeee1617 [file] [log] [blame]
import mxnet as mx
import numpy as np
def test_default_init():
data = mx.sym.Variable('data')
sym = mx.sym.LeakyReLU(data=data, act_type='prelu')
mod = mx.mod.Module(sym)
mod.bind(data_shapes=[('data', (10,10))])
mod.init_params()
assert (list(mod.get_params()[0].values())[0].asnumpy() == 0.25).all()
def test_variable_init():
data = mx.sym.Variable('data')
gamma = mx.sym.Variable('gamma', init=mx.init.One())
sym = mx.sym.LeakyReLU(data=data, gamma=gamma, act_type='prelu')
mod = mx.mod.Module(sym)
mod.bind(data_shapes=[('data', (10,10))])
mod.init_params()
assert (list(mod.get_params()[0].values())[0].asnumpy() == 1).all()
def test_aux_init():
data = mx.sym.Variable('data')
sym = mx.sym.BatchNorm(data=data, name='bn')
mod = mx.mod.Module(sym)
mod.bind(data_shapes=[('data', (10, 10, 3, 3))])
mod.init_params()
assert (mod.get_params()[1]['bn_moving_var'].asnumpy() == 1).all()
assert (mod.get_params()[1]['bn_moving_mean'].asnumpy() == 0).all()
if __name__ == '__main__':
test_default_init()
test_variable_init()
test_aux_init()