blob: 0e63beaf402feed4add653328e1dc0decbbda485 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Linear Regression — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../index.html" rel="up" title="Tutorials">
<link href="mnist.html" rel="next" title="Handwritten Digit Recognition"/>
<link href="../basic/data.html" rel="prev" title="Iterators - Loading data"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="reference internal" href="../index.html#python">Python</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../index.html#basics">Basics</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html#training-and-inference">Training and Inference</a><ul class="current">
<li class="toctree-l4 current"><a class="current reference internal" href="">Linear Regression</a></li>
<li class="toctree-l4"><a class="reference internal" href="mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l4"><a class="reference internal" href="predict_image.html">Predict with pre-trained models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../vision/large_scale_classification.html">Large Scale Image Classification</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../index.html#contributing-tutorials">Contributing Tutorials</a></li>
</ul>
</li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="linear-regression">
<span id="linear-regression"></span><h1>Linear Regression<a class="headerlink" href="#linear-regression" title="Permalink to this headline"></a></h1>
<p>In this tutorial we’ll walk through how one can implement <em>linear regression</em> using MXNet APIs.</p>
<p>The function we are trying to learn is: <em>y = x<sub>1</sub> + 2x<sub>2</sub></em>, where <em>(x<sub>1</sub>,x<sub>2</sub>)</em> are input features and <em>y</em> is the corresponding label.</p>
<div class="section" id="prerequisites">
<span id="prerequisites"></span><h2>Prerequisites<a class="headerlink" href="#prerequisites" title="Permalink to this headline"></a></h2>
<p>To complete this tutorial, we need:</p>
<ul class="simple">
<li>MXNet. See the instructions for your operating system in <a class="reference external" href="http://mxnet.io/get_started/install.html">Setup and Installation</a>.</li>
<li><a class="reference external" href="http://jupyter.org/index.html">Jupyter Notebook</a>.</li>
</ul>
<div class="highlight-python"><div class="highlight"><pre><span></span>$ pip install jupyter
</pre></div>
</div>
<p>To begin, the following code imports the necessary packages we’ll need for this exercise.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
</pre></div>
</div>
</div>
<div class="section" id="preparing-the-data">
<span id="preparing-the-data"></span><h2>Preparing the Data<a class="headerlink" href="#preparing-the-data" title="Permalink to this headline"></a></h2>
<p>In MXNet, data is input via <strong>Data Iterators</strong>. Here we will illustrate
how to encode a dataset into an iterator that MXNet can use. The data used in the example is made up of 2D data points with corresponding integer labels.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1">#Training data</span>
<span class="n">train_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">[</span><span class="mi">100</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="n">train_label</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">train_data</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">train_data</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">)])</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span>
<span class="c1">#Evaluation Data</span>
<span class="n">eval_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">7</span><span class="p">,</span><span class="mi">2</span><span class="p">],[</span><span class="mi">6</span><span class="p">,</span><span class="mi">10</span><span class="p">],[</span><span class="mi">12</span><span class="p">,</span><span class="mi">2</span><span class="p">]])</span>
<span class="n">eval_label</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">11</span><span class="p">,</span><span class="mi">26</span><span class="p">,</span><span class="mi">16</span><span class="p">])</span>
</pre></div>
</div>
<p>Once we have the data ready, we need to put it into an iterator and specify
parameters such as <code class="docutils literal"><span class="pre">batch_size</span></code> and <code class="docutils literal"><span class="pre">shuffle</span></code>. <code class="docutils literal"><span class="pre">batch_size</span></code> specifies the number
of examples shown to the model each time we update its parameters and <code class="docutils literal"><span class="pre">shuffle</span></code>
tells the iterator to randomize the order in which examples are shown to the model.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">train_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span><span class="n">train_label</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span><span class="n">label_name</span><span class="o">=</span><span class="s1">'lin_reg_label'</span><span class="p">)</span>
<span class="n">eval_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">eval_data</span><span class="p">,</span> <span class="n">eval_label</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</pre></div>
</div>
<p>In the above example, we have made use of <code class="docutils literal"><span class="pre">NDArrayIter</span></code>, which is useful for iterating
over both numpy ndarrays and MXNet NDArrays. In general, there are different types of iterators in
MXNet and you can use one based on the type of data you are processing.
Documentation for iterators can be found <a class="reference external" href="http://mxnet.io/api/python/io.html">here</a>.</p>
</div>
<div class="section" id="mxnet-classes">
<span id="mxnet-classes"></span><h2>MXNet Classes<a class="headerlink" href="#mxnet-classes" title="Permalink to this headline"></a></h2>
<ol class="simple">
<li><strong>IO:</strong> The IO class as we already saw works on the data and carries out
operations such as feeding data in batches and shuffling.</li>
<li><strong>Symbol:</strong> The actual MXNet neural network is composed using symbols. MXNet has
different types of symbols, including variable placeholders for input data,
neural network layers, and operators that manipulate NDArrays.</li>
<li><strong>Module:</strong> The module class in MXNet is used to define the overall computation.
It is initialized with the model we want to train, the training inputs (data and labels)
and some additional parameters such as learning rate and the optimization
algorithm to use.</li>
</ol>
</div>
<div class="section" id="defining-the-model">
<span id="defining-the-model"></span><h2>Defining the Model<a class="headerlink" href="#defining-the-model" title="Permalink to this headline"></a></h2>
<p>MXNet uses <strong>Symbols</strong> for defining a model. Symbols are the building blocks
and make up various components of the model. Symbols are used to define:</p>
<ol class="simple">
<li><strong>Variables:</strong> A variable is a placeholder for future data. This symbol is used
to define a spot which will be filled with training data/labels in the future
when we commence training.</li>
<li><strong>Neural Network Layers:</strong> The layers of a network or any other type of model are
also defined by Symbols. Such a symbol takes one or more previous symbols as
inputs, performs some transformations on them, and creates one or more outputs.
One such example is the <code class="docutils literal"><span class="pre">FullyConnected</span></code> symbol which specifies a fully connected
layer of a neural network.</li>
<li><strong>Outputs:</strong> Output symbols are MXNet’s way of defining a loss. They are
suffixed with the word “Output” (eg. the <code class="docutils literal"><span class="pre">SoftmaxOutput</span></code> layer). You can also
<a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/docs/tutorials/r/CustomLossFunction.md#how-to-use-your-own-loss-function">create your own loss function</a>.
Some examples of existing losses are: <code class="docutils literal"><span class="pre">LinearRegressionOutput</span></code>, which computes
the l2-loss between it’s input symbol and the labels provided to it;
<code class="docutils literal"><span class="pre">SoftmaxOutput</span></code>, which computes the categorical cross-entropy.</li>
</ol>
<p>The ones described above and other symbols are chained together with the output of
one symbol serving as input to the next to build the network topology. More information
about the different types of symbols can be found <a class="reference external" href="http://mxnet.io/api/python/symbol.html">here</a>.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">X</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'lin_reg_label'</span><span class="p">)</span>
<span class="n">fully_connected_layer</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">X</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc1'</span><span class="p">,</span> <span class="n">num_hidden</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">lro</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">LinearRegressionOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fully_connected_layer</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">Y</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"lro"</span><span class="p">)</span>
</pre></div>
</div>
<p>The above network uses the following layers:</p>
<ol class="simple">
<li><code class="docutils literal"><span class="pre">FullyConnected</span></code>: The fully connected symbol represents a fully connected layer
of a neural network (without any activation being applied), which in essence,
is just a linear regression on the input attributes. It takes the following
parameters:<ul>
<li><code class="docutils literal"><span class="pre">data</span></code>: Input to the layer (specifies the symbol whose output should be fed here)</li>
<li><code class="docutils literal"><span class="pre">num_hidden</span></code>: Number of hidden neurons in the layer, which is same as the dimensionality
of the layer’s output</li>
</ul>
</li>
<li><code class="docutils literal"><span class="pre">LinearRegressionOutput</span></code>: Output layers in MXNet compute training loss, which is
the measure of inaccuracy in the model’s predictions. The goal of training is to minimize the
training loss. In our example, the <code class="docutils literal"><span class="pre">LinearRegressionOutput</span></code> layer computes the <em>l2</em> loss against
its input and the labels provided to it. The parameters to this layer are:<ul>
<li><code class="docutils literal"><span class="pre">data</span></code>: Input to this layer (specifies the symbol whose output should be fed here)</li>
<li><code class="docutils literal"><span class="pre">label</span></code>: The training labels against which we will compare the input to the layer for calculation of l2 loss</li>
</ul>
</li>
</ol>
<p><strong>Note on naming convention:</strong> the label variable’s name should be the same as the
<code class="docutils literal"><span class="pre">label_name</span></code> parameter passed to your training data iterator. The default value of
this is <code class="docutils literal"><span class="pre">softmax_label</span></code>, but we have updated it to <code class="docutils literal"><span class="pre">lin_reg_label</span></code> in this
tutorial as you can see in <code class="docutils literal"><span class="pre">Y</span> <span class="pre">=</span> <span class="pre">mx.symbol.Variable('lin_reg_label')</span></code> and
<code class="docutils literal"><span class="pre">train_iter</span> <span class="pre">=</span> <span class="pre">mx.io.NDArrayIter(...,</span> <span class="pre">label_name='lin_reg_label')</span></code>.</p>
<p>Finally, the network is input to a <em>Module</em>, where we specify the symbol
whose output needs to be minimized (in our case, <code class="docutils literal"><span class="pre">lro</span></code> or the <code class="docutils literal"><span class="pre">lin_reg_output</span></code>), the
learning rate to be used while optimization and the number of epochs we want to
train our model for.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span>
<span class="n">symbol</span> <span class="o">=</span> <span class="n">lro</span> <span class="p">,</span>
<span class="n">data_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'data'</span><span class="p">],</span>
<span class="n">label_names</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'lin_reg_label'</span><span class="p">]</span><span class="c1"># network structure</span>
<span class="p">)</span>
</pre></div>
</div>
<p>We can visualize the network we created by plotting it:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">mx</span><span class="o">.</span><span class="n">viz</span><span class="o">.</span><span class="n">plot_network</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">lro</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="training-the-model">
<span id="training-the-model"></span><h2>Training the model<a class="headerlink" href="#training-the-model" title="Permalink to this headline"></a></h2>
<p>Once we have defined the model structure, the next step is to train the
parameters of the model to fit the training data. This is accomplished using the
<code class="docutils literal"><span class="pre">fit()</span></code> function of the <code class="docutils literal"><span class="pre">Module</span></code> class.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_iter</span><span class="p">,</span> <span class="n">eval_iter</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span><span class="s1">'learning_rate'</span><span class="p">:</span><span class="mf">0.005</span><span class="p">,</span> <span class="s1">'momentum'</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">},</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span>
<span class="n">batch_end_callback</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">Speedometer</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
</pre></div>
</div>
</div>
<div class="section" id="using-a-trained-model-testing-and-inference">
<span id="using-a-trained-model-testing-and-inference"></span><h2>Using a trained model: (Testing and Inference)<a class="headerlink" href="#using-a-trained-model-testing-and-inference" title="Permalink to this headline"></a></h2>
<p>Once we have a trained model, we can do a couple of things with it - we can either
use it for inference or we can evaluate the trained model on test data. The latter is shown below:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">eval_iter</span><span class="p">)</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
</pre></div>
</div>
<p>We can also evaluate our model according to some metric. In this example, we are
evaluating our model’s mean squared error (MSE) on the evaluation data.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">metric</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">MSE</span><span class="p">()</span>
<span class="n">model</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">eval_iter</span><span class="p">,</span> <span class="n">metric</span><span class="p">)</span>
</pre></div>
</div>
<p>Let us try and add some noise to the evaluation data and see how the MSE changes:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">eval_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">7</span><span class="p">,</span><span class="mi">2</span><span class="p">],[</span><span class="mi">6</span><span class="p">,</span><span class="mi">10</span><span class="p">],[</span><span class="mi">12</span><span class="p">,</span><span class="mi">2</span><span class="p">]])</span>
<span class="n">eval_label</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">11.1</span><span class="p">,</span><span class="mf">26.1</span><span class="p">,</span><span class="mf">16.1</span><span class="p">])</span> <span class="c1">#Adding 0.1 to each of the values</span>
<span class="n">eval_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">eval_data</span><span class="p">,</span> <span class="n">eval_label</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">eval_iter</span><span class="p">,</span> <span class="n">metric</span><span class="p">)</span>
</pre></div>
</div>
<p>We can also create a custom metric and use it to evaluate a model. More
information on metrics can be found in the <a class="reference external" href="http://mxnet.io/api/python/model.html#evaluation-metric-api-reference">API documentation</a>.</p>
<div class="btn-group" role="group">
<div class="download_btn"><a download="linear-regression_python.ipynb" href="linear-regression_python.ipynb"><span class="glyphicon glyphicon-download-alt"></span> linear-regression_python.ipynb</a></div></div></div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Linear Regression</a><ul>
<li><a class="reference internal" href="#prerequisites">Prerequisites</a></li>
<li><a class="reference internal" href="#preparing-the-data">Preparing the Data</a></li>
<li><a class="reference internal" href="#mxnet-classes">MXNet Classes</a></li>
<li><a class="reference internal" href="#defining-the-model">Defining the Model</a></li>
<li><a class="reference internal" href="#training-the-model">Training the model</a></li>
<li><a class="reference internal" href="#using-a-trained-model-testing-and-inference">Using a trained model: (Testing and Inference)</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>