blob: 0d7e67dced2d09c1fda5b42d9becdf22d71f7c33 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import mxnet as mx
from common import models
import pickle as pkl
def test_attr_basic():
with mx.AttrScope(group='4', data='great'):
data = mx.symbol.Variable('data',
attr={'dtype':'data',
'group': '1',
'force_mirroring': 'True'},
lr_mult=1)
gdata = mx.symbol.Variable('data2')
assert gdata.attr('group') == '4'
assert data.attr('group') == '1'
assert data.attr('lr_mult') == '1'
assert data.attr('__lr_mult__') == '1'
assert data.attr('force_mirroring') == 'True'
assert data.attr('__force_mirroring__') == 'True'
data2 = pkl.loads(pkl.dumps(data))
assert data.attr('dtype') == data2.attr('dtype')
def test_operator():
data = mx.symbol.Variable('data')
with mx.AttrScope(__group__='4', __data__='great'):
fc1 = mx.symbol.Activation(data, act_type='relu')
with mx.AttrScope(__init_bias__='0.0'):
fc2 = mx.symbol.FullyConnected(fc1, num_hidden=10, name='fc2')
assert fc1.attr('__data__') == 'great'
assert fc2.attr('__data__') == 'great'
assert fc2.attr('__init_bias__') == '0.0'
fc2copy = pkl.loads(pkl.dumps(fc2))
assert fc2copy.tojson() == fc2.tojson()
fc2weight = fc2.get_internals()['fc2_weight']
def contain(x, y):
for k, v in x.items():
if k not in y:
return False
if isinstance(y[k], dict):
if not isinstance(v, dict):
return False
if not contain(v, y[k]):
return False
elif y[k] != v:
return False
return True
def test_list_attr():
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'__mood__': 'so so', 'wd_mult': 'x'})
assert contain({'__mood__': 'so so', 'wd_mult': 'x', '__wd_mult__': 'x'}, op.list_attr())
def test_attr_dict():
data = mx.sym.Variable('data', attr={'mood': 'angry'})
op = mx.sym.Convolution(data=data, name='conv', kernel=(1, 1),
num_filter=1, attr={'__mood__': 'so so'}, lr_mult=1)
assert contain({
'data': {'mood': 'angry'},
'conv_weight': {'__mood__': 'so so'},
'conv': {'kernel': '(1, 1)', '__mood__': 'so so', 'num_filter': '1', 'lr_mult': '1', '__lr_mult__': '1'},
'conv_bias': {'__mood__': 'so so'}}, op.attr_dict())
if __name__ == '__main__':
test_attr_basic()
test_operator()
test_list_attr()
test_attr_dict()