blob: f55f1de5e3055f911488ee11ed12991bddcff6cb [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Basics — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="basics">
<span id="basics"></span><h1>Basics<a class="headerlink" href="#basics" title="Permalink to this headline"></a></h1>
<p>This tutorial provides basic usages of the C++ package through the classical handwritten digits
identification database–<a class="reference external" href="http://yann.lecun.com/exdb/mnist/">MNIST</a>.</p>
<p>The following contents assume that the working directory is <code class="docutils literal"><span class="pre">/path/to/mxnet/cpp-package/example</span></code>.</p>
<div class="section" id="load-data">
<span id="load-data"></span><h2>Load Data<a class="headerlink" href="#load-data" title="Permalink to this headline"></a></h2>
<p>Before going into codes, we need to fetch MNIST data. You can either use the script <code class="docutils literal"><span class="pre">get_mnist.sh</span></code>,
or download mnist data by yourself from Lecun’s <a class="reference external" href="http://yann.lecun.com/exdb/mnist/">website</a>
and decompress them into <code class="docutils literal"><span class="pre">mnist_data</span></code> folder.</p>
<p>Except linking the MXNet shared library, the C++ package itself is a header-only package,
which means all you need to do is to include the header files. Among the header files,
<code class="docutils literal"><span class="pre">op.h</span></code> is special since it is generated dynamically. The generation should be done when
<a class="reference external" href="http://mxnet.io/get_started/build_from_source.html#build-the-c++-package">building the C++ package</a>.
After that, you also need to copy the shared library (<code class="docutils literal"><span class="pre">libmxnet.so</span></code> in linux,
<code class="docutils literal"><span class="pre">libmxnet.dll</span></code> in windows) from <code class="docutils literal"><span class="pre">/path/to/mxnet/lib</span></code> to the working directory.
We do not recommend you to use pre-built binaries because MXNet is under heavy development,
the operator definitions in <code class="docutils literal"><span class="pre">op.h</span></code> may be incompatible with the pre-built version.</p>
<p>In order to use functionalities provides by the C++ package, first we include the general
header file <code class="docutils literal"><span class="pre">MxNetCpp.h</span></code> and specify the namespaces.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="cp">#include</span> <span class="cpf">"mxnet-cpp/MxNetCpp.h"</span><span class="cp"></span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">std</span><span class="p">;</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">cpp</span><span class="p">;</span>
</pre></div>
</div>
<p>Next we can use the data iter to load MNIST data (separated to training sets and validation sets).
The digits in MNIST are 2-dimension arrays, so we should set <code class="docutils literal"><span class="pre">flat</span></code> to true to flatten the data.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">auto</span> <span class="n">train_iter</span> <span class="o">=</span> <span class="n">MXDataIter</span><span class="p">(</span><span class="s">"MNISTIter"</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"image"</span><span class="p">,</span> <span class="s">"./mnist_data/train-images-idx3-ubyte"</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"./mnist_data/train-labels-idx1-ubyte"</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"batch_size"</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"flat"</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">CreateDataIter</span><span class="p">();</span>
<span class="k">auto</span> <span class="n">val_iter</span> <span class="o">=</span> <span class="n">MXDataIter</span><span class="p">(</span><span class="s">"MNISTIter"</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"image"</span><span class="p">,</span> <span class="s">"./mnist_data/t10k-images-idx3-ubyte"</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"label"</span><span class="p">,</span> <span class="s">"./mnist_data/t10k-labels-idx1-ubyte"</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"batch_size"</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="p">.</span><span class="n">SetParam</span><span class="p">(</span><span class="s">"flat"</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">CreateDataIter</span><span class="p">();</span>
</pre></div>
</div>
<p>The data have been successfully loaded, we can now easily construct various models to identify
the digits with the help of C++ package.</p>
</div>
<div class="section" id="multilayer-perceptron">
<span id="multilayer-perceptron"></span><h2>Multilayer Perceptron<a class="headerlink" href="#multilayer-perceptron" title="Permalink to this headline"></a></h2>
<p>If you are not familiar with multilayer perceptron, you can get some basic information
<a class="reference external" href="http://mxnet.io/tutorials/python/mnist.html#multilayer-perceptron">here</a>. We only focus on
the implementation in this tutorial.</p>
<p>Constructing multilayer perceptron model is straightforward, assume we store the hidden size
for each layer in <code class="docutils literal"><span class="pre">layers</span></code>, and each layer uses
<a class="reference external" href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">ReLu</a> function as activation.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">Symbol</span> <span class="nf">mlp</span><span class="p">(</span><span class="k">const</span> <span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">layers</span><span class="p">)</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">x</span> <span class="o">=</span> <span class="n">Symbol</span><span class="o">::</span><span class="n">Variable</span><span class="p">(</span><span class="s">"X"</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">label</span> <span class="o">=</span> <span class="n">Symbol</span><span class="o">::</span><span class="n">Variable</span><span class="p">(</span><span class="s">"label"</span><span class="p">);</span>
<span class="n">vector</span><span class="o"><</span><span class="n">Symbol</span><span class="o">></span> <span class="n">weights</span><span class="p">(</span><span class="n">layers</span><span class="p">.</span><span class="n">size</span><span class="p">());</span>
<span class="n">vector</span><span class="o"><</span><span class="n">Symbol</span><span class="o">></span> <span class="n">biases</span><span class="p">(</span><span class="n">layers</span><span class="p">.</span><span class="n">size</span><span class="p">());</span>
<span class="n">vector</span><span class="o"><</span><span class="n">Symbol</span><span class="o">></span> <span class="n">outputs</span><span class="p">(</span><span class="n">layers</span><span class="p">.</span><span class="n">size</span><span class="p">());</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">i</span><span class="o"><</span><span class="n">layers</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
<span class="n">weights</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">Symbol</span><span class="o">::</span><span class="n">Variable</span><span class="p">(</span><span class="s">"w"</span> <span class="o">+</span> <span class="n">to_string</span><span class="p">(</span><span class="n">i</span><span class="p">));</span>
<span class="n">biases</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">Symbol</span><span class="o">::</span><span class="n">Variable</span><span class="p">(</span><span class="s">"b"</span> <span class="o">+</span> <span class="n">to_string</span><span class="p">(</span><span class="n">i</span><span class="p">));</span>
<span class="n">Symbol</span> <span class="n">fc</span> <span class="o">=</span> <span class="n">FullyConnected</span><span class="p">(</span>
<span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="o">?</span> <span class="nl">x</span> <span class="p">:</span> <span class="n">outputs</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">weights</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">biases</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">layers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="p">);</span>
<span class="n">outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span> <span class="o">==</span> <span class="n">layers</span><span class="p">.</span><span class="n">size</span><span class="p">()</span><span class="o">-</span><span class="mi">1</span> <span class="o">?</span> <span class="nl">fc</span> <span class="p">:</span> <span class="n">Activation</span><span class="p">(</span><span class="n">fc</span><span class="p">,</span> <span class="n">ActivationActType</span><span class="o">::</span><span class="n">relu</span><span class="p">);</span>
<span class="p">}</span>
<span class="k">return</span> <span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">back</span><span class="p">(),</span> <span class="n">label</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>
<p>The above function defines a multilayer perceptron model where hidden sizes are specified
by <code class="docutils literal"><span class="pre">layers</span></code>.</p>
<p>We now create and initialize the parameters after the model is constructed. MXNet can help
you to infer shapes of most of the parameters. Basically only the shape of data and label
is needed.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">std</span><span class="o">::</span><span class="n">map</span><span class="o"><</span><span class="n">string</span><span class="p">,</span> <span class="n">NDArray</span><span class="o">></span> <span class="n">args</span><span class="p">;</span>
<span class="n">args</span><span class="p">[</span><span class="s">"X"</span><span class="p">]</span> <span class="o">=</span> <span class="n">NDArray</span><span class="p">(</span><span class="n">Shape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="o">*</span><span class="n">image_size</span><span class="p">),</span> <span class="n">ctx</span><span class="p">);</span>
<span class="n">args</span><span class="p">[</span><span class="s">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">NDArray</span><span class="p">(</span><span class="n">Shape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">ctx</span><span class="p">);</span>
<span class="c1">// Let MXNet infer shapes other parameters such as weights</span>
<span class="n">net</span><span class="p">.</span><span class="n">InferArgsMap</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">args</span><span class="p">,</span> <span class="n">args</span><span class="p">);</span>
<span class="c1">// Initialize all parameters with uniform distribution U(-0.01, 0.01)</span>
<span class="k">auto</span> <span class="n">initializer</span> <span class="o">=</span> <span class="n">Uniform</span><span class="p">(</span><span class="mf">0.01</span><span class="p">);</span>
<span class="k">for</span> <span class="p">(</span><span class="k">auto</span><span class="o">&amp;</span> <span class="nl">arg</span> <span class="p">:</span> <span class="n">args</span><span class="p">)</span> <span class="p">{</span>
<span class="c1">// arg.first is parameter name, and arg.second is the value</span>
<span class="n">initializer</span><span class="p">(</span><span class="n">arg</span><span class="p">.</span><span class="n">first</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">arg</span><span class="p">.</span><span class="n">second</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>
<p>The rest is to train the model with an optimizer.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="c1">// Create sgd optimizer</span>
<span class="n">Optimizer</span><span class="o">*</span> <span class="n">opt</span> <span class="o">=</span> <span class="n">OptimizerRegistry</span><span class="o">::</span><span class="n">Find</span><span class="p">(</span><span class="s">"sgd"</span><span class="p">);</span>
<span class="n">opt</span><span class="o">-></span><span class="n">SetParam</span><span class="p">(</span><span class="s">"rescale_grad"</span><span class="p">,</span> <span class="mf">1.0</span><span class="o">/</span><span class="n">batch_size</span><span class="p">);</span>
<span class="c1">// Start training</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">iter</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">iter</span> <span class="o"><</span> <span class="n">max_epoch</span><span class="p">;</span> <span class="o">++</span><span class="n">iter</span><span class="p">)</span> <span class="p">{</span>
<span class="n">train_iter</span><span class="p">.</span><span class="n">Reset</span><span class="p">();</span>
<span class="k">while</span> <span class="p">(</span><span class="n">train_iter</span><span class="p">.</span><span class="n">Next</span><span class="p">())</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">data_batch</span> <span class="o">=</span> <span class="n">train_iter</span><span class="p">.</span><span class="n">GetDataBatch</span><span class="p">();</span>
<span class="c1">// Set data and label</span>
<span class="n">args</span><span class="p">[</span><span class="s">"X"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_batch</span><span class="p">.</span><span class="n">data</span><span class="p">;</span>
<span class="n">args</span><span class="p">[</span><span class="s">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_batch</span><span class="p">.</span><span class="n">label</span><span class="p">;</span>
<span class="c1">// Create executor by binding parameters to the model</span>
<span class="k">auto</span> <span class="o">*</span><span class="n">exec</span> <span class="o">=</span> <span class="n">net</span><span class="p">.</span><span class="n">SimpleBind</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">args</span><span class="p">);</span>
<span class="c1">// Compute gradients</span>
<span class="n">exec</span><span class="o">-></span><span class="n">Forward</span><span class="p">(</span><span class="nb">true</span><span class="p">);</span>
<span class="n">exec</span><span class="o">-></span><span class="n">Backward</span><span class="p">();</span>
<span class="c1">// Update parameters</span>
<span class="n">exec</span><span class="o">-></span><span class="n">UpdateAll</span><span class="p">(</span><span class="n">opt</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">);</span>
<span class="c1">// Remember to free the memory</span>
<span class="k">delete</span> <span class="n">exec</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>We also want to see how our model performs. The C++ package provides convenient APIs for
evaluating. Here we use accuracy as metric. The inference is almost the same as training,
except that we don’t need gradients.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">Accuracy</span> <span class="n">acc</span><span class="p">;</span>
<span class="n">val_iter</span><span class="p">.</span><span class="n">Reset</span><span class="p">();</span>
<span class="k">while</span> <span class="p">(</span><span class="n">val_iter</span><span class="p">.</span><span class="n">Next</span><span class="p">())</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">data_batch</span> <span class="o">=</span> <span class="n">val_iter</span><span class="p">.</span><span class="n">GetDataBatch</span><span class="p">();</span>
<span class="n">args</span><span class="p">[</span><span class="s">"X"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_batch</span><span class="p">.</span><span class="n">data</span><span class="p">;</span>
<span class="n">args</span><span class="p">[</span><span class="s">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_batch</span><span class="p">.</span><span class="n">label</span><span class="p">;</span>
<span class="k">auto</span> <span class="o">*</span><span class="n">exec</span> <span class="o">=</span> <span class="n">net</span><span class="p">.</span><span class="n">SimpleBind</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">args</span><span class="p">);</span>
<span class="c1">// Forward pass is enough as no gradient is needed when evaluating</span>
<span class="n">exec</span><span class="o">-></span><span class="n">Forward</span><span class="p">(</span><span class="nb">false</span><span class="p">);</span>
<span class="n">acc</span><span class="p">.</span><span class="n">Update</span><span class="p">(</span><span class="n">data_batch</span><span class="p">.</span><span class="n">label</span><span class="p">,</span> <span class="n">exec</span><span class="o">-></span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span>
<span class="k">delete</span> <span class="n">exec</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>You can find the complete code in <code class="docutils literal"><span class="pre">mlp_cpu.cpp</span></code>. Use <code class="docutils literal"><span class="pre">make</span> <span class="pre">mlp_cpu</span></code> to compile it,
and <code class="docutils literal"><span class="pre">./mlp_cpu</span></code> to run it.</p>
</div>
<div class="section" id="gpu-support">
<span id="gpu-support"></span><h2>GPU Support<a class="headerlink" href="#gpu-support" title="Permalink to this headline"></a></h2>
<p>It’s worth noting that changing context from <code class="docutils literal"><span class="pre">Context::cpu()</span></code> to <code class="docutils literal"><span class="pre">Context::gpu()</span></code> is not enough,
because the data read by data iter are stored in memory, we cannot assign it directly to the
parameters. To bridge this gap, NDArray provides data synchronization functionalities between
GPU and CPU. We will illustrate it by making the mlp code run on GPU.</p>
<p>In the previous code, data are used like</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">args</span><span class="p">[</span><span class="s">"X"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_batch</span><span class="p">.</span><span class="n">data</span><span class="p">;</span>
<span class="n">args</span><span class="p">[</span><span class="s">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_batch</span><span class="p">.</span><span class="n">label</span><span class="p">;</span>
</pre></div>
</div>
<p>It will be problematic if other parameters are created in the context of GPU. We can use
<code class="docutils literal"><span class="pre">NDArray::CopyTo</span></code> to solve this problem.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="c1">// Data provided by DataIter are stored in memory, should be copied to GPU first.</span>
<span class="n">data_batch</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">CopyTo</span><span class="p">(</span><span class="o">&amp;</span><span class="n">args</span><span class="p">[</span><span class="s">"X"</span><span class="p">]);</span>
<span class="n">data_batch</span><span class="p">.</span><span class="n">label</span><span class="p">.</span><span class="n">CopyTo</span><span class="p">(</span><span class="o">&amp;</span><span class="n">args</span><span class="p">[</span><span class="s">"label"</span><span class="p">]);</span>
<span class="c1">// CopyTo is imperative, need to wait for it to complete.</span>
<span class="n">NDArray</span><span class="o">::</span><span class="n">WaitAll</span><span class="p">();</span>
</pre></div>
</div>
<p>By replacing the former code to the latter one, we successfully port the code to GPU. You can find the complete code in <code class="docutils literal"><span class="pre">mlp_gpu.cpp</span></code>. Compilation is similar to the cpu version. (Note: The shared library should be built with GPU support on)</p>
</div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Basics</a><ul>
<li><a class="reference internal" href="#load-data">Load Data</a></li>
<li><a class="reference internal" href="#multilayer-perceptron">Multilayer Perceptron</a></li>
<li><a class="reference internal" href="#gpu-support">GPU Support</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>