blob: e9c054fbe0737a48541a6a704c14e8027f500a69 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Module - Neural network training and inference — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../index.html" rel="up" title="Tutorials">
<link href="data.html" rel="next" title="Iterators - Loading data"/>
<link href="symbol.html" rel="prev" title="Symbol - Neural network graphs and auto-differentiation"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="reference internal" href="../index.html#python">Python</a><ul class="current">
<li class="toctree-l3 current"><a class="reference internal" href="../index.html#basics">Basics</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="ndarray.html">NDArray - Imperative tensor operations on CPU/GPU</a></li>
<li class="toctree-l4"><a class="reference internal" href="symbol.html">Symbol - Neural network graphs and auto-differentiation</a></li>
<li class="toctree-l4 current"><a class="current reference internal" href="">Module - Neural network training and inference</a></li>
<li class="toctree-l4"><a class="reference internal" href="data.html">Iterators - Loading data</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../index.html#training-and-inference">Training and Inference</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../index.html#contributing-tutorials">Contributing Tutorials</a></li>
</ul>
</li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="module-neural-network-training-and-inference">
<span id="module-neural-network-training-and-inference"></span><h1>Module - Neural network training and inference<a class="headerlink" href="#module-neural-network-training-and-inference" title="Permalink to this headline"></a></h1>
<p>Training a neural network involves quite a few steps. One need to specify how
to feed input training data, initialize model parameters, perform forward and
backward passes through the network, update weights based on computed gradients, do
model checkpoints, etc. During prediction, one ends up repeating most of these
steps. All this can be quite daunting to both newcomers as well as experienced
developers.</p>
<p>Luckily, MXNet modularizes commonly used code for training and inference in
the <code class="docutils literal"><span class="pre">module</span></code> (<code class="docutils literal"><span class="pre">mod</span></code> for short) package. <code class="docutils literal"><span class="pre">Module</span></code> provides both high-level and
intermediate-level interfaces for executing predefined networks. One can use
both interfaces interchangeably. We will show the usage of both interfaces in
this tutorial.</p>
<div class="section" id="prerequisites">
<span id="prerequisites"></span><h2>Prerequisites<a class="headerlink" href="#prerequisites" title="Permalink to this headline"></a></h2>
<p>To complete this tutorial, we need:</p>
<ul class="simple">
<li>MXNet. See the instructions for your operating system in <a class="reference external" href="http://mxnet.io/get_started/install.html">Setup and Installation</a>.</li>
<li><a class="reference external" href="http://jupyter.org/index.html">Jupyter Notebook</a> and <a class="reference external" href="http://docs.python-requests.org/en/master/">Python Requests</a> packages.</li>
</ul>
<div class="highlight-python"><div class="highlight"><pre><span></span>pip install jupyter requests
</pre></div>
</div>
</div>
<div class="section" id="preliminary">
<span id="preliminary"></span><h2>Preliminary<a class="headerlink" href="#preliminary" title="Permalink to this headline"></a></h2>
<p>In this tutorial we will demonstrate <code class="docutils literal"><span class="pre">module</span></code> usage by training a
<a class="reference external" href="https://en.wikipedia.org/wiki/Multilayer_perceptron">Multilayer Perceptron</a> (MLP)
on the <a class="reference external" href="https://archive.ics.uci.edu/ml/datasets/letter+recognition">UCI letter recognition</a>
dataset.</p>
<p>The following code downloads the dataset and creates an 80:20 train:test
split. It also initializes a training data iterator to return a batch of 32
training examples each time. A separate iterator is also created for test data.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">logging</span>
<span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">()</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">logging</span><span class="o">.</span><span class="n">INFO</span><span class="p">)</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="s1">'http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data'</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">genfromtxt</span><span class="p">(</span><span class="n">fname</span><span class="p">,</span> <span class="n">delimiter</span><span class="o">=</span><span class="s1">','</span><span class="p">)[:,</span><span class="mi">1</span><span class="p">:]</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">ord</span><span class="p">(</span><span class="n">l</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">','</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span><span class="o">-</span><span class="nb">ord</span><span class="p">(</span><span class="s1">'A'</span><span class="p">)</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">open</span><span class="p">(</span><span class="n">fname</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)])</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">ntrain</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">*</span><span class="mf">0.8</span><span class="p">)</span>
<span class="n">train_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">data</span><span class="p">[:</span><span class="n">ntrain</span><span class="p">,</span> <span class="p">:],</span> <span class="n">label</span><span class="p">[:</span><span class="n">ntrain</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">val_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="n">ntrain</span><span class="p">:,</span> <span class="p">:],</span> <span class="n">label</span><span class="p">[</span><span class="n">ntrain</span><span class="p">:],</span> <span class="n">batch_size</span><span class="p">)</span>
</pre></div>
</div>
<p>Next, we define the network.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc1'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'relu1'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc2'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">26</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">viz</span><span class="o">.</span><span class="n">plot_network</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="creating-a-module">
<span id="creating-a-module"></span><h2>Creating a Module<a class="headerlink" href="#creating-a-module" title="Permalink to this headline"></a></h2>
<p>Now we are ready to introduce module. The commonly used module class is
<code class="docutils literal"><span class="pre">Module</span></code>. We can construct a module by specifying the following parameters:</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">symbol</span></code>: the network definition</li>
<li><code class="docutils literal"><span class="pre">context</span></code>: the device (or a list of devices) to use for execution</li>
<li><code class="docutils literal"><span class="pre">data_names</span></code> : the list of input data variable names</li>
<li><code class="docutils literal"><span class="pre">label_names</span></code> : the list of input label variable names</li>
</ul>
<p>For <code class="docutils literal"><span class="pre">net</span></code>, we have only one data named <code class="docutils literal"><span class="pre">data</span></code>, and one label named <code class="docutils literal"><span class="pre">softmax_label</span></code>,
which is automatically named for us following the name <code class="docutils literal"><span class="pre">softmax</span></code> we specified for the <code class="docutils literal"><span class="pre">SoftmaxOutput</span></code> operator.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">net</span><span class="p">,</span>
<span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span>
<span class="n">data_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'data'</span><span class="p">],</span>
<span class="n">label_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'softmax_label'</span><span class="p">])</span>
</pre></div>
</div>
</div>
<div class="section" id="intermediate-level-interface">
<span id="intermediate-level-interface"></span><h2>Intermediate-level Interface<a class="headerlink" href="#intermediate-level-interface" title="Permalink to this headline"></a></h2>
<p>We have created module. Now let us see how to run training and inference using module’s intermediate-level APIs. These APIs give developers flexibility to do step-by-step
computation by running <code class="docutils literal"><span class="pre">forward</span></code> and <code class="docutils literal"><span class="pre">backward</span></code> passes. It’s also useful for debugging.</p>
<p>To train a module, we need to perform following steps:</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">bind</span></code> : Prepares environment for the computation by allocating memory.</li>
<li><code class="docutils literal"><span class="pre">init_params</span></code> : Assigns and initializes parameters.</li>
<li><code class="docutils literal"><span class="pre">init_optimizer</span></code> : Initializes optimizers. Defaults to <code class="docutils literal"><span class="pre">sgd</span></code>.</li>
<li><code class="docutils literal"><span class="pre">metric.create</span></code> : Creates evaluation metric from input metric name.</li>
<li><code class="docutils literal"><span class="pre">forward</span></code> : Forward computation.</li>
<li><code class="docutils literal"><span class="pre">update_metric</span></code> : Evaluates and accumulates evaluation metric on outputs of the last forward computation.</li>
<li><code class="docutils literal"><span class="pre">backward</span></code> : Backward computation.</li>
<li><code class="docutils literal"><span class="pre">update</span></code> : Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch.</li>
</ul>
<p>This can be used as follows:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># allocate memory given the input data and label shapes</span>
<span class="n">mod</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">data_shapes</span><span class="o">=</span><span class="n">train_iter</span><span class="o">.</span><span class="n">provide_data</span><span class="p">,</span> <span class="n">label_shapes</span><span class="o">=</span><span class="n">train_iter</span><span class="o">.</span><span class="n">provide_label</span><span class="p">)</span>
<span class="c1"># initialize parameters by uniform random numbers</span>
<span class="n">mod</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">initializer</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="n">scale</span><span class="o">=.</span><span class="mi">1</span><span class="p">))</span>
<span class="c1"># use SGD with learning rate 0.1 to train</span>
<span class="n">mod</span><span class="o">.</span><span class="n">init_optimizer</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span> <span class="n">optimizer_params</span><span class="o">=</span><span class="p">((</span><span class="s1">'learning_rate'</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">),</span> <span class="p">))</span>
<span class="c1"># use accuracy as the metric</span>
<span class="n">metric</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'acc'</span><span class="p">)</span>
<span class="c1"># train 5 epochs, i.e. going over the data iter one pass</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
<span class="n">train_iter</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">metric</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">train_iter</span><span class="p">:</span>
<span class="n">mod</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="c1"># compute predictions</span>
<span class="n">mod</span><span class="o">.</span><span class="n">update_metric</span><span class="p">(</span><span class="n">metric</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">label</span><span class="p">)</span> <span class="c1"># accumulate prediction accuracy</span>
<span class="n">mod</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># compute gradients</span>
<span class="n">mod</span><span class="o">.</span><span class="n">update</span><span class="p">()</span> <span class="c1"># update parameters</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'Epoch </span><span class="si">%d</span><span class="s1">, Training </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">metric</span><span class="o">.</span><span class="n">get</span><span class="p">()))</span>
</pre></div>
</div>
<p>To learn more about these APIs, visit <a class="reference external" href="http://mxnet.io/api/python/module.html">Module API</a>.</p>
</div>
<div class="section" id="high-level-interface">
<span id="high-level-interface"></span><h2>High-level Interface<a class="headerlink" href="#high-level-interface" title="Permalink to this headline"></a></h2>
<div class="section" id="train">
<span id="train"></span><h3>Train<a class="headerlink" href="#train" title="Permalink to this headline"></a></h3>
<p>Module also provides high-level APIs for training, predicting and evaluating for
user convenience. Instead of doing all the steps mentioned in the above section,
one can simply call <a class="reference external" href="http://mxnet.io/api/python/module.html#mxnet.module.BaseModule.fit">fit API</a>
and it internally executes the same steps.</p>
<p>To fit a module, call the <code class="docutils literal"><span class="pre">fit</span></code> function as follows:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># reset train_iter to the beginning</span>
<span class="n">train_iter</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="c1"># create a module</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">net</span><span class="p">,</span>
<span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span>
<span class="n">data_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'data'</span><span class="p">],</span>
<span class="n">label_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'softmax_label'</span><span class="p">])</span>
<span class="c1"># fit the module</span>
<span class="n">mod</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span>
<span class="n">eval_data</span><span class="o">=</span><span class="n">val_iter</span><span class="p">,</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span><span class="s1">'learning_rate'</span><span class="p">:</span><span class="mf">0.1</span><span class="p">},</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
</pre></div>
</div>
<p>By default, <code class="docutils literal"><span class="pre">fit</span></code> function has <code class="docutils literal"><span class="pre">eval_metric</span></code> set to <code class="docutils literal"><span class="pre">accuracy</span></code>, <code class="docutils literal"><span class="pre">optimizer</span></code> to <code class="docutils literal"><span class="pre">sgd</span></code>
and optimizer_params to <code class="docutils literal"><span class="pre">(('learning_rate',</span> <span class="pre">0.01),)</span></code>.</p>
</div>
<div class="section" id="predict-and-evaluate">
<span id="predict-and-evaluate"></span><h3>Predict and Evaluate<a class="headerlink" href="#predict-and-evaluate" title="Permalink to this headline"></a></h3>
<p>To predict with module, we can call <code class="docutils literal"><span class="pre">predict()</span></code>. It will collect and
return all the prediction results.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">y</span> <span class="o">=</span> <span class="n">mod</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">val_iter</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="p">(</span><span class="mi">4000</span><span class="p">,</span> <span class="mi">26</span><span class="p">)</span>
</pre></div>
</div>
<p>If we do not need the prediction outputs, but just need to evaluate on a test
set, we can call the <code class="docutils literal"><span class="pre">score()</span></code> function. It runs prediction in the input validation
dataset and evaluates the performance according to the given input metric.</p>
<p>It can be used as follows:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">score</span> <span class="o">=</span> <span class="n">mod</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">val_iter</span><span class="p">,</span> <span class="p">[</span><span class="s1">'mse'</span><span class="p">,</span> <span class="s1">'acc'</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="s2">"Accuracy score is </span><span class="si">%f</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">score</span><span class="p">))</span>
</pre></div>
</div>
<p>Some of the other metrics which can be used are <code class="docutils literal"><span class="pre">top_k_acc</span></code>(top-k-accuracy),
<code class="docutils literal"><span class="pre">F1</span></code>, <code class="docutils literal"><span class="pre">RMSE</span></code>, <code class="docutils literal"><span class="pre">MSE</span></code>, <code class="docutils literal"><span class="pre">MAE</span></code>, <code class="docutils literal"><span class="pre">ce</span></code>(CrossEntropy). To learn more about the metrics,
visit <a class="reference external" href="http://mxnet.io/api/python/metric.html">Evaluation metric</a>.</p>
<p>One can vary number of epochs, learning_rate, optimizer parameters to change the score
and tune these parameters to get best score.</p>
</div>
<div class="section" id="save-and-load">
<span id="save-and-load"></span><h3>Save and Load<a class="headerlink" href="#save-and-load" title="Permalink to this headline"></a></h3>
<p>We can save the module parameters after each training epoch by using a checkpoint callback.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># construct a callback function to save checkpoints</span>
<span class="n">model_prefix</span> <span class="o">=</span> <span class="s1">'mx_mlp'</span>
<span class="n">checkpoint</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">do_checkpoint</span><span class="p">(</span><span class="n">model_prefix</span><span class="p">)</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">net</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span> <span class="n">num_epoch</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">epoch_end_callback</span><span class="o">=</span><span class="n">checkpoint</span><span class="p">)</span>
</pre></div>
</div>
<p>To load the saved module parameters, call the <code class="docutils literal"><span class="pre">load_checkpoint</span></code> function. It
loads the Symbol and the associated parameters. We can then set the loaded
parameters into the module.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="n">model_prefix</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">sym</span><span class="o">.</span><span class="n">tojson</span><span class="p">()</span> <span class="o">==</span> <span class="n">net</span><span class="o">.</span><span class="n">tojson</span><span class="p">()</span>
<span class="c1"># assign the loaded parameters to the module</span>
<span class="n">mod</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span>
</pre></div>
</div>
<p>Or if we just want to resume training from a saved checkpoint, instead of
calling <code class="docutils literal"><span class="pre">set_params()</span></code>, we can directly call <code class="docutils literal"><span class="pre">fit()</span></code>, passing the loaded
parameters, so that <code class="docutils literal"><span class="pre">fit()</span></code> knows to start from those parameters instead of
initializing randomly from scratch. We also set the <code class="docutils literal"><span class="pre">begin_epoch</span></code> parameter so that
<code class="docutils literal"><span class="pre">fit()</span></code> knows we are resuming from a previously saved epoch.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">sym</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="n">arg_params</span><span class="o">=</span><span class="n">arg_params</span><span class="p">,</span>
<span class="n">aux_params</span><span class="o">=</span><span class="n">aux_params</span><span class="p">,</span>
<span class="n">begin_epoch</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
</pre></div>
</div>
<div class="btn-group" role="group">
<div class="download_btn"><a download="module_python.ipynb" href="module_python.ipynb"><span class="glyphicon glyphicon-download-alt"></span> module_python.ipynb</a></div></div></div>
</div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Module - Neural network training and inference</a><ul>
<li><a class="reference internal" href="#prerequisites">Prerequisites</a></li>
<li><a class="reference internal" href="#preliminary">Preliminary</a></li>
<li><a class="reference internal" href="#creating-a-module">Creating a Module</a></li>
<li><a class="reference internal" href="#intermediate-level-interface">Intermediate-level Interface</a></li>
<li><a class="reference internal" href="#high-level-interface">High-level Interface</a><ul>
<li><a class="reference internal" href="#train">Train</a></li>
<li><a class="reference internal" href="#predict-and-evaluate">Predict and Evaluate</a></li>
<li><a class="reference internal" href="#save-and-load">Save and Load</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>