blob: 4cdecaf9146f190457b295516eac4ef1c00462b6 [file] [log] [blame]
import os
import mxnet as mx
from common import models
import pickle as pkl
def test_attr_basic():
with mx.AttrScope(group='4', data='great'):
data = mx.symbol.Variable('data',
attr={'dtype':'data',
'group': '1',
'force_mirroring': 'True'},
lr_mult=1)
gdata = mx.symbol.Variable('data2')
assert gdata.attr('group') == '4'
assert data.attr('group') == '1'
assert data.attr('lr_mult') == '1'
assert data.attr('__lr_mult__') == '1'
assert data.attr('force_mirroring') == 'True'
assert data.attr('__force_mirroring__') == 'True'
data2 = pkl.loads(pkl.dumps(data))
assert data.attr('dtype') == data2.attr('dtype')
def test_operator():
data = mx.symbol.Variable('data')
with mx.AttrScope(__group__='4', __data__='great'):
fc1 = mx.symbol.Activation(data, act_type='relu')
with mx.AttrScope(__init_bias__='0.0'):
fc2 = mx.symbol.FullyConnected(fc1, num_hidden=10, name='fc2')
assert fc1.attr('__data__') == 'great'
assert fc2.attr('__data__') == 'great'
assert fc2.attr('__init_bias__') == '0.0'
fc2copy = pkl.loads(pkl.dumps(fc2))
assert fc2copy.tojson() == fc2.tojson()
fc2weight = fc2.get_internals()['fc2_weight']
def contain(x, y):
for k, v in x.items():
if k not in y:
return False
if isinstance(y[k], dict):
if not isinstance(v, dict):
return False
if not contain(v, y[k]):
return False
elif y[k] != v:
return False
return True
def test_list_attr():
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'__mood__': 'so so', 'wd_mult': 'x'})
assert contain({'__mood__': 'so so', 'wd_mult': 'x', '__wd_mult__': 'x'}, op.list_attr())
def test_attr_dict():
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'__mood__': 'so so'}, lr_mult=1)
assert contain({
'data': {'mood': 'angry'},
'conv_weight': {'__mood__': 'so so'},
'conv': {'kernel': '(1, 1)', '__mood__': 'so so', 'num_filter': '1', 'lr_mult': '1', '__lr_mult__': '1'},
'conv_bias': {'__mood__': 'so so'}}, op.attr_dict())
if __name__ == '__main__':
test_attr_basic()
test_operator()
test_list_attr()
test_attr_dict()