blob: 8e8ebfa9c153cb97fef24d60497a81bde6382704 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import warnings
import mxnet as mx
import numpy as np
import pytest
def test_print_summary():
data = mx.sym.Variable('data')
bias = mx.sym.Variable('fc1_bias', lr_mult=1.0)
emb1= mx.symbol.Embedding(data = data, name='emb1', input_dim=100, output_dim=28)
conv1= mx.symbol.Convolution(data = emb1, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1")
act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu")
mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')
fc1 = mx.sym.FullyConnected(data=mp1, bias=bias, name='fc1', num_hidden=10, lr_mult=0)
fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=10, wd_mult=0.5)
sc1 = mx.symbol.SliceChannel(data=fc2, num_outputs=10, name="slice_1", squeeze_axis=0)
mx.viz.print_summary(sc1)
shape = {}
shape["data"]=(1,3,28)
mx.viz.print_summary(sc1, shape)
def graphviz_exists():
try:
import graphviz
except ImportError:
return False
else:
return True
@pytest.mark.skipif(not graphviz_exists(), reason="Skip test_plot_network as Graphviz could not be imported")
def test_plot_network():
# Test warnings for cyclic graph
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10)
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
dtype={'data': np.float32},
node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)
assert "fc" in str(w[-1].message)