This topic demonstrates how to use mx.viz.plot_network in MXNet for visualizing your Neural Networks built on MXNet. mx.viz.plot_network helps to represent the Neural Network as a computation graph of nodes; with input nodes, where the computation starts, and output nodes, where the result can be read.
You need Jupyter Notebook and Graphviz library to visualize the network. Please Make sure you have followed installation instructions in setting up above dependencies along with setting up MXNet.
mx.viz.plot_network takes Symbol, with your Network definition, and optional node_attrs, parameters for the shape of the node in the graph, as input and generates a computation graph.
We will now try to visualize a sample Neural Network for linear matrix factorization:
$ jupyter notebook
import mxnet as mx user = mx.symbol.Variable('user') item = mx.symbol.Variable('item') score = mx.symbol.Variable('score') # Set dummy dimensions k = 64 max_user = 100 max_item = 50 # user feature lookup user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k) # item feature lookup item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k) # predict by the inner product, which is elementwise product and then sum net = user * item net = mx.symbol.sum_axis(data = net, axis = 1) net = mx.symbol.Flatten(data = net) # loss layer net = mx.symbol.LinearRegressionOutput(data = net, label = score) # Visualize your network mx.viz.plot_network(net)
You should be able to see computation graph something like below: