blob: b290ff344227cfebfd973c800e7eb011a4143f90 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import os
import re
import mxnet as mx
import numpy as np
from common import assertRaises, models
from mxnet.base import NotImplementedForSymbol
from mxnet.test_utils import discard_stderr
import pickle as pkl
def test_symbol_basic():
mlist = []
mlist.append(models.mlp2())
for m in mlist:
m.list_arguments()
m.list_outputs()
def test_symbol_bool():
x = mx.symbol.Variable('x')
assertRaises(NotImplementedForSymbol, bool, x)
def test_symbol_compose():
data = mx.symbol.Variable('data')
net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10)
net1 = mx.symbol.FullyConnected(data=net1, name='fc2', num_hidden=100)
net1.list_arguments() == ['data',
'fc1_weight', 'fc1_bias',
'fc2_weight', 'fc2_bias']
net2 = mx.symbol.FullyConnected(name='fc3', num_hidden=10)
net2 = mx.symbol.Activation(data=net2, act_type='relu')
net2 = mx.symbol.FullyConnected(data=net2, name='fc4', num_hidden=20)
composed = net2(fc3_data=net1, name='composed')
multi_out = mx.symbol.Group([composed, net1])
assert len(multi_out.list_outputs()) == 2
assert len(multi_out) == 2
def test_symbol_copy():
data = mx.symbol.Variable('data')
data_2 = copy.deepcopy(data)
data_3 = copy.copy(data)
assert data.tojson() == data_2.tojson()
assert data.tojson() == data_3.tojson()
def test_symbol_internal():
data = mx.symbol.Variable('data')
oldfc = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10)
net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100)
assert net1.list_arguments() == ['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias']
internal = net1.get_internals()
fc1 = internal['fc1_output']
assert fc1.list_arguments() == oldfc.list_arguments()
def test_symbol_children():
data = mx.symbol.Variable('data')
oldfc = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10)
net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100)
assert net1.get_children().list_outputs() == ['fc1_output', 'fc2_weight', 'fc2_bias']
assert len(net1.get_children()) == 3
assert net1.get_children().get_children().list_outputs() == ['data', 'fc1_weight', 'fc1_bias']
assert len(net1.get_children().get_children()) == 3
assert net1.get_children()['fc2_weight'].list_arguments() == ['fc2_weight']
assert net1.get_children()['fc2_weight'].get_children() is None
data = mx.sym.Variable('data')
sliced = mx.sym.SliceChannel(data, num_outputs=3, name='slice')
concat = mx.sym.Concat(*list(sliced))
assert concat.get_children().list_outputs() == \
['slice_output0', 'slice_output1', 'slice_output2']
assert sliced.get_children().list_outputs() == ['data']
def test_symbol_pickle():
mlist = [models.mlp2(), models.conv()]
data = pkl.dumps(mlist)
mlist2 = pkl.loads(data)
for x, y in zip(mlist, mlist2):
assert x.tojson() == y.tojson()
def test_symbol_saveload():
sym = models.mlp2()
fname = 'tmp_sym.json'
sym.save(fname)
data2 = mx.symbol.load(fname)
# save because of order
assert sym.tojson() == data2.tojson()
os.remove(fname)
def test_symbol_infer_type():
data = mx.symbol.Variable('data')
f32data = mx.symbol.Cast(data=data, dtype='float32')
fc1 = mx.symbol.FullyConnected(data = f32data, name='fc1', num_hidden=128)
mlp = mx.symbol.SoftmaxOutput(data = fc1, name = 'softmax')
arg, out, aux = mlp.infer_type(data=np.float16)
assert arg == [np.float16, np.float32, np.float32, np.float32]
assert out == [np.float32]
assert aux == []
# partial infer type
arg, out, aux = mlp.infer_type_partial()
assert arg == [None, np.float32, np.float32, np.float32]
assert out == [np.float32]
assert aux == []
def test_symbol_infer_shape():
num_hidden = 128
num_dim = 64
num_sample = 10
data = mx.symbol.Variable('data')
prev = mx.symbol.Variable('prevstate')
x2h = mx.symbol.FullyConnected(data=data, name='x2h', num_hidden=num_hidden)
h2h = mx.symbol.FullyConnected(data=prev, name='h2h', num_hidden=num_hidden)
out = mx.symbol.Activation(data=mx.sym.elemwise_add(x2h, h2h), name='out', act_type='relu')
# shape inference will fail because information is not available for h2h
ret = out.infer_shape(data=(num_sample, num_dim))
assert ret == (None, None, None)
arg, out_shapes, aux_shapes = out.infer_shape_partial(data=(num_sample, num_dim))
arg_shapes = dict(zip(out.list_arguments(), arg))
assert arg_shapes['data'] == (num_sample, num_dim)
assert arg_shapes['x2h_weight'] == (num_hidden, num_dim)
assert arg_shapes['h2h_weight'] == ()
# now we can do full shape inference
state_shape = out_shapes[0]
arg, out_shapes, aux_shapes = out.infer_shape(data=(num_sample, num_dim), prevstate=state_shape)
arg_shapes = dict(zip(out.list_arguments(), arg))
assert arg_shapes['data'] == (num_sample, num_dim)
assert arg_shapes['x2h_weight'] == (num_hidden, num_dim)
assert arg_shapes['h2h_weight'] == (num_hidden, num_hidden)
# Partial shape inference with some unknown dimensions
data_shape = (1, 0, 0, 0)
data = mx.sym.Variable('data', shape=data_shape)
weight = mx.sym.Variable('weight')
cdata = mx.sym.cast(data, dtype='float16')
cweight = mx.sym.cast(weight, dtype='float16')
test = mx.sym.Convolution(data=cdata, weight=cweight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7))
arg, _, _ = test.infer_shape_partial()
arg_shapes = dict(zip(test.list_arguments(), arg))
assert arg_shapes['data'] == data_shape
assert arg_shapes['weight'] == (64, 0, 7, 7)
def test_symbol_infer_shape_var():
"Test specifying shape information when constructing a variable"
shape = (2, 3)
a = mx.symbol.Variable('a', shape=shape)
b = mx.symbol.Variable('b')
c = mx.symbol.elemwise_add(a, b)
arg_shapes, out_shapes, aux_shapes = c.infer_shape()
assert arg_shapes[0] == shape
assert arg_shapes[1] == shape
assert out_shapes[0] == shape
overwrite_shape = (5, 6)
arg_shapes, out_shapes, aux_shapes = c.infer_shape(a=overwrite_shape)
assert arg_shapes[0] == overwrite_shape
assert arg_shapes[1] == overwrite_shape
assert out_shapes[0] == overwrite_shape
def test_symbol_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
'broadcast_axes', 'broadcast_like', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 'argmin',
'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
'exp', 'expm1', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt',
'square', 'reciprocal' 'reshape_like', 'cbrt', 'rcbrt', 'relu', 'sigmoid',
'softmax', 'log_softmax', 'softmin', 'rint', 'ceil', 'floor', 'trunc', 'fix'])
def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
with mx.name.NameManager():
data = mx.symbol.Variable('data')
regular = getattr(mx.symbol, func)(data, name=func+'0', **kwargs)
fluent = getattr(data, func)(**kwargs)
check_symbol_consistency(regular, fluent, {'ctx': mx.context.current_context(),
'data': shape},
skip_grad=func not in has_grad,
equal_nan=equal_nan)
for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
check_fluent_regular(func, {})
for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
'arcsinh', 'arctanh', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt',
'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax', 'softmin']:
check_fluent_regular(func, {}, equal_nan=True)
for func in ['expand_dims', 'flip', 'sort', 'topk', 'argsort', 'argmax', 'argmin']:
check_fluent_regular(func, {'axis': 1})
check_fluent_regular('one_hot', {'depth': 15})
check_fluent_regular('tile', {'reps': (1,2)})
check_fluent_regular('repeat', {'repeats': 3})
check_fluent_regular('transpose', {'axes': (1,0,2)})
check_fluent_regular('split', {'axis': 2, 'num_outputs': 3}, shape=(5, 17, 6))
check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, shape=(5, 17, 6))
check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7})
check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': mx.sym.zeros((3, 3))})
check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)})
check_fluent_regular('broadcast_like', {'rhs': mx.sym.ones((1, 5)), 'lhs_axes': (0,), 'rhs_axes': (1,)}, shape=(1,9))
check_fluent_regular('pad', {'mode': 'constant', 'pad_width': (0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3))
check_fluent_regular('reshape_like', {'rhs': mx.sym.ones((30, 17))}, shape=(5, 17, 2, 3))
for func in ['sum', 'nansum', 'prod', 'nanprod', 'mean', 'max', 'min', 'norm']:
check_fluent_regular(func, {'axis': (1, 2)})
check_fluent_regular('reshape', {'shape': (17, 1, 5)})
check_fluent_regular('broadcast_to', {'shape': (5, 17, 47)})
check_fluent_regular('squeeze', {'axis': (1, 3)}, shape=(2, 1, 3, 1, 4))
def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False):
assert sym1.list_arguments() == sym2.list_arguments()
assert sym1.list_auxiliary_states() == sym2.list_auxiliary_states()
assert sym1.list_outputs() == sym2.list_outputs()
mx.test_utils.check_consistency([sym1, sym2], ctx_list=[ctx, ctx],
grad_req='null' if skip_grad else 'write',
equal_nan=equal_nan)
def test_load_000800():
with mx.AttrScope(ctx_group='stage1'):
data = mx.symbol.Variable('data', lr_mult=0.2)
weight = mx.sym.Variable(name='fc1_weight', lr_mult=1.2)
fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128, wd_mult=0.3)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
set_stage1 = set(act1.list_arguments())
with mx.AttrScope(ctx_group='stage2'):
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, lr_mult=0.01)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0')
sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sym2 = mx.sym.load(os.path.join(curr_path, 'save_000800.json'))
attr1 = sym1.attr_dict()
attr2 = sym2.attr_dict()
for k, v1 in attr1.items():
assert k in attr2, k
v2 = attr2[k]
for kk, vv1 in v1.items():
if kk.startswith('__') and kk.endswith('__'):
assert kk in v2 and v2[kk] == vv1, k + str(v1) + str(v2)
check_symbol_consistency(sym1, sym2,
{'ctx': mx.cpu(0), 'group2ctx': {'stage1' : mx.cpu(1), 'stage2' : mx.cpu(2)}, 'data': (1,200)})
def test_blockgrad():
a = mx.sym.Variable('a')
b = mx.sym.BlockGrad(2*a)
exe = b.simple_bind(ctx=mx.cpu(), a=(10,10))
def test_zero_prop():
data = mx.symbol.Variable('data')
for i in range(10):
data = data * data
exe = data.simple_bind(ctx=mx.cpu(), data=(10, 3, 256, 256))
big = int(re.search('Total (\d+) MB allocated', exe.debug_str()).group(1))
exe = data.simple_bind(ctx=mx.cpu(), data=(10, 3, 256, 256), grad_req='null')
small1 = int(re.search('Total (\d+) MB allocated', exe.debug_str()).group(1))
data = mx.sym.stop_gradient(data)
exe = data.simple_bind(ctx=mx.cpu(), data=(10, 3, 256, 256))
small2 = int(re.search('Total (\d+) MB allocated', exe.debug_str()).group(1))
assert big > small2
assert small1 == small2
def test_zero_prop2():
x = mx.sym.Variable('x')
idx = mx.sym.Variable('idx')
y = mx.sym.batch_take(x, idx)
z = mx.sym.stop_gradient(y)
exe = z.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10,),
type_dict={'x': np.float32, 'idx': np.int32})
exe.forward()
exe.backward()
# The following bind() should throw an exception. We discard the expected stderr
# output for this operation only in order to keep the test logs clean.
with discard_stderr():
try:
y.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10,),
type_dict={'x': np.float32, 'idx': np.int32})
except:
return
assert False
def test_simple_bind_incomplete_shape_inference_in_one_forward_pass():
"""This is a special case that results in shape inference
failure after moving simple_bind logic from frontend to backend.
Added here for testing against the network similar to the following one.
Network diagram:
weight --> abs_op --> sum_op --
\ |--> add_op
data --> fc_op --> sum_op --
Given data's shape, if the shape inference starts from weight node,
then the node entries of negative_op and sum_op are unknown in the
forward pass. Therefore, there are several unknown shapes after the
first forward pass is done. Now the backward inference pass starts with
the assumption that there are no unknown-shape node entries in the forward
pass, and consequently, leads to CHECK_EQ failure.
"""
data_shape = (5, 13)
data = mx.sym.Variable('data')
fc = mx.sym.FullyConnected(data=data, num_hidden=1, no_bias=True, name='fc')
modified_weight = mx.sym.abs(fc.get_internals()['fc_weight'])
net = mx.sym.sum(modified_weight) + mx.sym.sum(fc)
net.simple_bind(ctx=mx.cpu(), data=data_shape)
def test_simple_bind_gradient_graph_possible_with_cycle():
"""This is a special case that results in a cycle in the gradient graph
before this bug was fixed. With the following symbol, the node entries
passed into function AggregateGradient(std::vector<nnvm::NodeEntry>&& v)
are the outputs of the same node. Therefore, adding a node to the
control_deps of itself must be skipped.
See GitHub issue:
https://github.com/apache/incubator-mxnet/issues/8029
for more details."""
data = mx.symbol.Variable('data')
res = data + data + data + data + data + data + data + data
res.simple_bind(ctx=mx.cpu(), data=(1,))
if __name__ == '__main__':
import nose
nose.runmodule()