blob: 9b65985784cc8002aef77133a4c799d8045507e4 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Handwritten Digit Recognition — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../index.html" rel="up" title="Tutorials">
<link href="predict_image.html" rel="next" title="Predict with pre-trained models"/>
<link href="linear-regression.html" rel="prev" title="Linear Regression"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="reference internal" href="../index.html#python">Python</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../index.html#basics">Basics</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html#training-and-inference">Training and Inference</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="linear-regression.html">Linear Regression</a></li>
<li class="toctree-l4 current"><a class="current reference internal" href="">Handwritten Digit Recognition</a></li>
<li class="toctree-l4"><a class="reference internal" href="predict_image.html">Predict with pre-trained models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../vision/large_scale_classification.html">Large Scale Image Classification</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../index.html#contributing-tutorials">Contributing Tutorials</a></li>
</ul>
</li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="handwritten-digit-recognition">
<span id="handwritten-digit-recognition"></span><h1>Handwritten Digit Recognition<a class="headerlink" href="#handwritten-digit-recognition" title="Permalink to this headline"></a></h1>
<p>In this tutorial, we’ll give you a step by step walk-through of how to build a hand-written digit classifier using the <a class="reference external" href="https://en.wikipedia.org/wiki/MNIST_database">MNIST</a> dataset. For someone new to deep learning, this exercise is arguably the “Hello World” equivalent.</p>
<p>MNIST is a widely used dataset for the hand-written digit classification task. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images.</p>
<p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/mnist.png"/></p>
<p><strong>Figure 1:</strong> Sample images from the MNIST dataset.</p>
<div class="section" id="prerequisites">
<span id="prerequisites"></span><h2>Prerequisites<a class="headerlink" href="#prerequisites" title="Permalink to this headline"></a></h2>
<p>To complete this tutorial, we need:</p>
<ul class="simple">
<li>MXNet. See the instructions for your operating system in <a class="reference external" href="http://mxnet.io/get_started/install.html">Setup and Installation</a>.</li>
<li><a class="reference external" href="http://docs.python-requests.org/en/master/">Python Requests</a> and <a class="reference external" href="http://jupyter.org/index.html">Jupyter Notebook</a>.</li>
</ul>
<div class="highlight-python"><div class="highlight"><pre><span></span>$ pip install requests jupyter
</pre></div>
</div>
</div>
<div class="section" id="loading-data">
<span id="loading-data"></span><h2>Loading Data<a class="headerlink" href="#loading-data" title="Permalink to this headline"></a></h2>
<p>Before we define the model, let’s first fetch the <a class="reference external" href="http://yann.lecun.com/exdb/mnist/">MNIST</a> dataset.</p>
<p>The following source code downloads and loads the images and the corresponding labels into memory.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="n">mnist</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">get_mnist</span><span class="p">()</span>
</pre></div>
</div>
<p>After running the above source code, the entire MNIST dataset should be fully loaded into memory. Note that for large datasets it is not feasible to pre-load the entire dataset first like we did here. What is needed is a mechanism by which we can quickly and efficiently stream data directly from the source. MXNet Data iterators come to the rescue here by providing exactly that. Data iterator is the mechanism by which we feed input data into an MXNet training algorithm and they are very simple to initialize and use and are optimized for speed. During training, we typically process training samples in small batches and over the entire training lifetime will end up processing each training example multiple times. In this tutorial, we’ll configure the data iterator to feed examples in batches of 100. Keep in mind that each example is a 28x28 grayscale image and the corresponding label.</p>
<p>Image batches are commonly represented by a 4-D array with shape <code class="docutils literal"><span class="pre">(batch_size,</span> <span class="pre">num_channels,</span> <span class="pre">width,</span> <span class="pre">height)</span></code>. For the MNIST dataset, since the images are grayscale, there is only one color channel. Also, the images are 28x28 pixels, and so each image has width and height equal to 28. Therefore, the shape of input is <code class="docutils literal"><span class="pre">(batch_size,</span> <span class="pre">1,</span> <span class="pre">28,</span> <span class="pre">28)</span></code>. Another important consideration is the order of input samples. When feeding training examples, it is critical that we don’t feed samples with the same label in succession. Doing so can slow down training.
Data iterators take care of this by randomly shuffling the inputs. Note that we only need to shuffle the training data. The order does not matter for test data.</p>
<p>The following source code initializes the data iterators for the MNIST dataset. Note that we initialize two iterators: one for train data and one for test data.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">train_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">mnist</span><span class="p">[</span><span class="s1">'train_data'</span><span class="p">],</span> <span class="n">mnist</span><span class="p">[</span><span class="s1">'train_label'</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">val_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">mnist</span><span class="p">[</span><span class="s1">'test_data'</span><span class="p">],</span> <span class="n">mnist</span><span class="p">[</span><span class="s1">'test_label'</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="training">
<span id="training"></span><h2>Training<a class="headerlink" href="#training" title="Permalink to this headline"></a></h2>
<p>We will cover a couple of approaches for performing the hand written digit recognition task. The first approach makes use of a traditional deep neural network architecture called Multilayer Percepton (MLP). We’ll discuss its drawbacks and use that as a motivation to introduce a second more advanced approach called Convolution Neural Network (CNN) that has proven to work very well for image classification tasks.</p>
<div class="section" id="multilayer-perceptron">
<span id="multilayer-perceptron"></span><h3>Multilayer Perceptron<a class="headerlink" href="#multilayer-perceptron" title="Permalink to this headline"></a></h3>
<p>The first approach makes use of a <a class="reference external" href="https://en.wikipedia.org/wiki/Multilayer_perceptron">Multilayer Perceptron</a> to solve this problem. We’ll define the MLP using MXNet’s symbolic interface. We begin by creating a place holder variable for the input data. When working with an MLP, we need to flatten our 28x28 images into a flat 1-D structure of 784 (28 * 28) raw pixel values. The order of pixel values in the flattened vector does not matter as long as we are being consistent about how we do this across all images.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="c1"># Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span>
</pre></div>
</div>
<p>One might wonder if we are discarding valuable information by flattening. That is indeed true and we’ll cover this more when we talk about convolutional neural networks where we preserve the input shape. For now, we’ll go ahead and work with flattened images.</p>
<p>MLPs contains several fully connected layers. A fully connected layer or FC layer for short, is one where each neuron in the layer is connected to every neuron in its preceding layer. From a linear algebra perspective, an FC layer applies an <a class="reference external" href="https://en.wikipedia.org/wiki/Affine_transformation">affine transform</a> to the <em>n x m</em> input matrix <em>X</em> and outputs a matrix <em>Y</em> of size <em>n x k</em>, where <em>k</em> is the number of neurons in the FC layer. <em>k</em> is also referred to as the hidden size. The output <em>Y</em> is computed according to the equation <em>Y = W X + b</em>. The FC layer has two learnable parameters, the <em>m x k</em> weight matrix <em>W</em> and the <em>m x 1</em> bias vector <em>b</em>.</p>
<p>In an MLP, the outputs of most FC layers are fed into an activation function, which applies an element-wise non-linearity. This step is critical and it gives neural networks the ability to classify inputs that are not linearly separable. Common choices for activation functions are sigmoid, tanh, and <a class="reference external" href="https://en.wikipedia.org/wiki/Rectifier_%28neural_networks%29">rectified linear unit</a> (ReLU). In this example, we’ll use the ReLU activation function which has several desirable properties and is typically considered a default choice.</p>
<p>The following code declares two fully connected layers with 128 and 64 neurons each. Furthermore, these FC layers are sandwiched between ReLU activation layers each one responsible for performing an element-wise ReLU transformation on the FC layer output.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># The first fully-connected layer and the corresponding activation function</span>
<span class="n">fc1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
<span class="n">act1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc1</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span>
<span class="c1"># The second fully-connected layer and the corresponding activation function</span>
<span class="n">fc2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">act1</span><span class="p">,</span> <span class="n">num_hidden</span> <span class="o">=</span> <span class="mi">64</span><span class="p">)</span>
<span class="n">act2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc2</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span>
</pre></div>
</div>
<p>The last fully connected layer often has its hidden size equal to the number of output classes in the dataset. The activation function for this layer will be the softmax function. The Softmax layer maps its input to a probability score for each class of output. During the training stage, a loss function computes the <a class="reference external" href="https://en.wikipedia.org/wiki/Cross_entropy">cross entropy</a> between the probability distribution (softmax output) predicted by the network and the true probability distribution given by the label.</p>
<p>The following source code declares the final fully connected layer of size 10. 10 incidentally is the total number of digits. The output from this layer is fed into a <code class="docutils literal"><span class="pre">SoftMaxOutput</span></code> layer that performs softmax and cross-entropy loss computation in one go. Note that loss computation only happens during training.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># MNIST has 10 classes</span>
<span class="n">fc3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">act2</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c1"># Softmax with cross entropy loss</span>
<span class="n">mlp</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span>
</pre></div>
</div>
<p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mlp_mnist.png"/></p>
<p><strong>Figure 2:</strong> MLP network architecture for MNIST.</p>
<p>Now that both the data iterator and neural network are defined, we can commence training. Here we’ll employ the <code class="docutils literal"><span class="pre">module</span></code> feature in MXNet which provides a high-level abstraction for running training and inference on predefined networks. The module API allows the user to specify appropriate parameters that control how the training proceeds.</p>
<p>The following source code initializes a module to train the MLP network we defined above. For our training, we will make use of the stochastic gradient descent (SGD) optimizer. In particular, we’ll be using mini-batch SGD. Standard SGD processes train data one example at a time. In practice, this is very slow and one can speed up the process by processing examples in small batches. In this case, our batch size will be 100, which is a reasonable choice. Another parameter we select here is the learning rate, which controls the step size the optimizer takes in search of a solution. We’ll pick a learning rate of 0.1, again a reasonable choice. Settings such as batch size and learning rate are what are usually referred to as hyper-parameters. What values we give them can have a great impact on training performance. For the purpose of this tutorial, we’ll start with some reasonable and safe values. In other tutorials, we’ll discuss how one might go about finding a combination of hyper-parameters for optimal model performance.</p>
<p>Typically, one runs the training until convergence, which means that we have learned a good set of model parameters (weights + biases) from the train data. For the purpose of this tutorial, we’ll run training for 10 epochs and stop. An epoch is one full pass over the entire train data.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">logging</span>
<span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">()</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span><span class="p">)</span> <span class="c1"># logging to stdout</span>
<span class="c1"># create a trainable module on CPU</span>
<span class="n">mlp_model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">mlp</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span>
<span class="n">mlp_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span> <span class="c1"># train data</span>
<span class="n">eval_data</span><span class="o">=</span><span class="n">val_iter</span><span class="p">,</span> <span class="c1"># validation data</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span> <span class="c1"># use SGD to train</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span><span class="s1">'learning_rate'</span><span class="p">:</span><span class="mf">0.1</span><span class="p">},</span> <span class="c1"># use fixed learning rate</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span> <span class="c1"># report accuracy during training</span>
<span class="n">batch_end_callback</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">Speedometer</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span> <span class="c1"># output progress for each 100 data batches</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="c1"># train for at most 10 dataset passes</span>
</pre></div>
</div>
</div>
<div class="section" id="prediction">
<span id="prediction"></span><h3>Prediction<a class="headerlink" href="#prediction" title="Permalink to this headline"></a></h3>
<p>After the above training completes, we can evaluate the trained model by running predictions on test data. The following source code computes the prediction probability scores for each test image. <em>prob[i][j]</em> is the probability that the <em>i</em>-th test image contains the <em>j</em>-th output class.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">test_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">mnist</span><span class="p">[</span><span class="s1">'test_data'</span><span class="p">],</span> <span class="bp">None</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="n">prob</span> <span class="o">=</span> <span class="n">mlp_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_iter</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">prob</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="p">(</span><span class="mi">10000</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
</pre></div>
</div>
<p>Since the dataset also has labels for all test images, we can compute the accuracy metric as follows:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">test_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">mnist</span><span class="p">[</span><span class="s1">'test_data'</span><span class="p">],</span> <span class="n">mnist</span><span class="p">[</span><span class="s1">'test_label'</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="c1"># predict accuracy of mlp</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span>
<span class="n">mlp_model</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">test_iter</span><span class="p">,</span> <span class="n">acc</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">acc</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> <span class="o">></span> <span class="mf">0.96</span>
</pre></div>
</div>
<p>If everything went well, we should see an accuracy value that is around 0.96, which means that we are able to accurately predict the digit in 96% of test images. This is a pretty good result. But as we will see in the next part of this tutorial, we can do a lot better than that.</p>
</div>
<div class="section" id="convolutional-neural-network">
<span id="convolutional-neural-network"></span><h3>Convolutional Neural Network<a class="headerlink" href="#convolutional-neural-network" title="Permalink to this headline"></a></h3>
<p>Earlier, we briefly touched on a drawback of MLP when we said we need to discard the input image’s original shape and flatten it as a vector before we can feed it as input to the MLP’s first fully connected layer. Turns out this is an important issue because we don’t take advantage of the fact that pixels in the image have natural spatial correlation along the horizontal and vertical axes. A convolutional neural network (CNN) aims to address this problem by using a more structured weight representation. Instead of flattening the image and doing a simple matrix-matrix multiplication, it employs one or more convolutional layers that each performs a 2-D convolution on the input image.</p>
<p>A single convolution layer consists of one or more filters that each play the role of a feature detector. During training, a CNN learns appropriate representations (parameters) for these filters. Similar to MLP, the output from the convolutional layer is transformed by applying a non-linearity. Besides the convolutional layer, another key aspect of a CNN is the pooling layer. A pooling layer serves to make the CNN translation invariant: a digit remains the same even when it is shifted left/right/up/down by a few pixels. A pooling layer reduces a <em>n x m</em> patch into a single value to make the network less sensitive to the spatial location. Pooling layer is always included after each conv (+ activation) layer in the CNN.</p>
<p>The following source code defines a convolutional neural network architecture called LeNet. LeNet is a popular network known to work well on digit classification tasks. We will use a slightly different version from the original LeNet implementation, replacing the sigmoid activations with tanh activations for the neurons</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="c1"># first conv layer</span>
<span class="n">conv1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span><span class="mi">5</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
<span class="n">tanh1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">conv1</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span>
<span class="n">pool1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Pooling</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">tanh1</span><span class="p">,</span> <span class="n">pool_type</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
<span class="c1"># second conv layer</span>
<span class="n">conv2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">pool1</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span><span class="mi">5</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="n">tanh2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">conv2</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span>
<span class="n">pool2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Pooling</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">tanh2</span><span class="p">,</span> <span class="n">pool_type</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
<span class="c1"># first fullc layer</span>
<span class="n">flatten</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">pool2</span><span class="p">)</span>
<span class="n">fc1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">flatten</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">tanh3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc1</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span>
<span class="c1"># second fullc</span>
<span class="n">fc2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">tanh3</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c1"># softmax loss</span>
<span class="n">lenet</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span>
</pre></div>
</div>
<p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/conv_mnist.png"/></p>
<p><strong>Figure 3:</strong> First conv + pooling layer in LeNet.</p>
<p>Now we train LeNet with the same hyper-parameters as before. Note that, if a GPU is available, we recommend using it. This greatly speeds up computation given that LeNet is more complex and compute-intensive than the previous multilayer perceptron. To do so, we only need to change <code class="docutils literal"><span class="pre">mx.cpu()</span></code> to <code class="docutils literal"><span class="pre">mx.gpu()</span></code> and MXNet takes care of the rest. Just like before, we’ll stop training after 10 epochs.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># create a trainable module on GPU 0</span>
<span class="n">lenet_model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">lenet</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span>
<span class="c1"># train with the same</span>
<span class="n">lenet_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span>
<span class="n">eval_data</span><span class="o">=</span><span class="n">val_iter</span><span class="p">,</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span><span class="s1">'learning_rate'</span><span class="p">:</span><span class="mf">0.1</span><span class="p">},</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span>
<span class="n">batch_end_callback</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">Speedometer</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="prediction">
<span id="id1"></span><h3>Prediction<a class="headerlink" href="#prediction" title="Permalink to this headline"></a></h3>
<p>Finally, we’ll use the trained LeNet model to generate predictions for the test data.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">test_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">mnist</span><span class="p">[</span><span class="s1">'test_data'</span><span class="p">],</span> <span class="bp">None</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="n">prob</span> <span class="o">=</span> <span class="n">lenet_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_iter</span><span class="p">)</span>
<span class="n">test_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">mnist</span><span class="p">[</span><span class="s1">'test_data'</span><span class="p">],</span> <span class="n">mnist</span><span class="p">[</span><span class="s1">'test_label'</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="c1"># predict accuracy for lenet</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span>
<span class="n">lenet_model</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">test_iter</span><span class="p">,</span> <span class="n">acc</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">acc</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> <span class="o">></span> <span class="mf">0.98</span>
</pre></div>
</div>
<p>If all went well, we should see a higher accuracy metric for predictions made using LeNet. With CNN we should be able to correctly predict around 98% of all test images.</p>
</div>
</div>
<div class="section" id="summary">
<span id="summary"></span><h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline"></a></h2>
<p>In this tutorial, we have learned how to use MXNet to solve a standard computer vision problem: classifying images of hand written digits. You have seen how to quickly and easily build, train and evaluate models such as MLP and CNN with MXNet.</p>
<div class="btn-group" role="group">
<div class="download_btn"><a download="mnist_python.ipynb" href="mnist_python.ipynb"><span class="glyphicon glyphicon-download-alt"></span> mnist_python.ipynb</a></div></div></div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Handwritten Digit Recognition</a><ul>
<li><a class="reference internal" href="#prerequisites">Prerequisites</a></li>
<li><a class="reference internal" href="#loading-data">Loading Data</a></li>
<li><a class="reference internal" href="#training">Training</a><ul>
<li><a class="reference internal" href="#multilayer-perceptron">Multilayer Perceptron</a></li>
<li><a class="reference internal" href="#prediction">Prediction</a></li>
<li><a class="reference internal" href="#convolutional-neural-network">Convolutional Neural Network</a></li>
<li><a class="reference internal" href="#prediction">Prediction</a></li>
</ul>
</li>
<li><a class="reference internal" href="#summary">Summary</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>