blob: 07d3a3896edd844730bce03a1e00e91625e6ab38 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="Logistic regression using Gluon API explained" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="Logistic regression using Gluon API explained" property="og:description"/>
<title>Logistic regression using Gluon API explained — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../genindex.html" rel="index" title="Index">
<link href="../../search.html" rel="search" title="Search"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../gluon/index.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">Tutorials</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="../../faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/0.11.0/example">Examples</a></li>
<li><a class="main-nav-link" href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="../../api/python/gluon/model_zoo.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="../../api/python/contrib/onnx.html">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet">Github</a></li>
<li><a class="main-nav-link" href="../../community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="../../community/ecosystem.html">Ecosystem</a></li>
<li><a class="main-nav-link" href="../../community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(0.11.0)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/>master</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/1.2.1/index.html>1.2.1</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/1.1.0/index.html>1.1.0</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="../../install/index.html">Install</a></li>
<li><a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet" tabindex="-1">Github</a></li>
<li><a href="../../community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="../../community/ecosystem.html" tabindex="-1">Ecosystem</a></li>
<li><a href="../../community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/clojure/index.html" tabindex="-1">Clojure</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="../../tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="../../faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="../../architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/0.11.0/example" tabindex="-1">Examples</a></li>
<li><a href="../../api/python/gluon/model_zoo.html" tabindex="-1">Gluon Model Zoo</a></li>
</ul>
</li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(0.11.0)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/>master</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/1.2.1/index.html>1.2.1</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/1.1.0/index.html>1.1.0</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../community/contribute.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="logistic-regression-using-gluon-api-explained">
<span id="logistic-regression-using-gluon-api-explained"></span><h1>Logistic regression using Gluon API explained<a class="headerlink" href="#logistic-regression-using-gluon-api-explained" title="Permalink to this headline"></a></h1>
<p>Logistic Regression is one of the first models newcomers to Deep Learning are implementing. The focus of this tutorial is to show how to do logistic regression using Gluon API.</p>
<p>Before anything else, let’s import required packages for this tutorial.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mxnet</span> <span class="kn">import</span> <span class="n">nd</span><span class="p">,</span> <span class="n">autograd</span><span class="p">,</span> <span class="n">gluon</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon</span> <span class="kn">import</span> <span class="n">nn</span><span class="p">,</span> <span class="n">Trainer</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">ArrayDataset</span>
<span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">12345</span><span class="p">)</span> <span class="c1"># Added for reproducibility</span>
</pre></div>
</div>
<p>In this tutorial we will use fake dataset, which contains 10 features drawn from a normal distribution with mean equals to 0 and standard deviation equals to 1, and a class label, which can be either 0 or 1. The size of the dataset is an arbitrary value. The function below helps us to generate a dataset. Class label <code class="docutils literal"><span class="pre">y</span></code> is generated via a non-random logic, so the network would have a pattern to look for. Boundary of 3 is selected to make sure that number of positive examples smaller than negative, but not too small</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_random_data</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">ctx</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">></span> <span class="mi">3</span>
<span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span>
</pre></div>
</div>
<p>Also, let’s define a set of hyperparameters, that we are going to use later. Since our model is simple and dataset is small, we are going to use CPU for calculations. Feel free to change it to GPU for a more advanced scenario.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">ctx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="n">train_data_size</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="n">val_data_size</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">10</span>
</pre></div>
</div>
<div class="section" id="working-with-data">
<span id="working-with-data"></span><h2>Working with data<a class="headerlink" href="#working-with-data" title="Permalink to this headline"></a></h2>
<p>To work with data, Apache MXNet provides <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.Dataset">Dataset</a> and <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.DataLoader">DataLoader</a> classes. The former is used to provide an indexed access to the data, the latter is used to shuffle and batchify the data. To learn more about working with data in Gluon, please refer to <a class="reference external" href="https://mxnet.incubator.apache.org/tutorials/gluon/datasets.html">Gluon Datasets and Dataloaders</a> tutorial.</p>
<p>Below we define training and validation datasets, which we are going to use in the tutorial.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_ground_truth_class</span> <span class="o">=</span> <span class="n">get_random_data</span><span class="p">(</span><span class="n">train_data_size</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span>
<span class="n">train_dataset</span> <span class="o">=</span> <span class="n">ArrayDataset</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_ground_truth_class</span><span class="p">)</span>
<span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">val_x</span><span class="p">,</span> <span class="n">val_ground_truth_class</span> <span class="o">=</span> <span class="n">get_random_data</span><span class="p">(</span><span class="n">val_data_size</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span>
<span class="n">val_dataset</span> <span class="o">=</span> <span class="n">ArrayDataset</span><span class="p">(</span><span class="n">val_x</span><span class="p">,</span> <span class="n">val_ground_truth_class</span><span class="p">)</span>
<span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="defining-and-training-the-model">
<span id="defining-and-training-the-model"></span><h2>Defining and training the model<a class="headerlink" href="#defining-and-training-the-model" title="Permalink to this headline"></a></h2>
<p>The only requirement for the logistic regression is that the last layer of the network must be a single neuron. Apache MXNet allows us to do so by using <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/gluon/nn.html#mxnet.gluon.nn.Dense">Dense</a> layer and specifying the number of units to 1. The rest of the network can be arbitrarily complex.</p>
<p>Below, we define a model which has an input layer of 10 neurons, a couple of inner layers of 10 neurons each, and output layer of 1 neuron. We stack the layers using <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/gluon/gluon.html#mxnet.gluon.nn.HybridSequential">HybridSequential</a> block and initialize parameters of the network using <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/optimization/optimization.html#mxnet.initializer.Xavier">Xavier</a> initialization.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">()</span>
<span class="k">with</span> <span class="n">net</span><span class="o">.</span><span class="n">name_scope</span><span class="p">():</span>
<span class="n">net</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">))</span> <span class="c1"># input layer</span>
<span class="n">net</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">))</span> <span class="c1"># inner layer 1</span>
<span class="n">net</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">))</span> <span class="c1"># inner layer 2</span>
<span class="n">net</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># output layer: notice, it must have only 1 neuron</span>
<span class="n">net</span><span class="o">.</span><span class="n">initialize</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Xavier</span><span class="p">())</span>
</pre></div>
</div>
<p>After defining the model, we need to define a few more things: our loss, our trainer and our metric.</p>
<p>Loss function is used to calculate how the output of the network differs from the ground truth. Because classes of the logistic regression are either 0 or 1, we are using <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/gluon/loss.html#mxnet.gluon.loss.SigmoidBinaryCrossEntropyLoss">SigmoidBinaryCrossEntropyLoss</a>. Notice that we do not specify <code class="docutils literal"><span class="pre">from_sigmoid</span></code> attribute in the code, which means that the output of the neuron doesn’t need to go through sigmoid, but at inference we’d have to pass it through sigmoid. You can learn more about cross entropy on <a class="reference external" href="https://en.wikipedia.org/wiki/Cross_entropy">wikipedia</a>.</p>
<p>Trainer object allows to specify the method of training to be used. For our tutorial we use <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/optimization/optimization.html#mxnet.optimizer.SGD">Stochastic Gradient Descent (SGD)</a>. For more information on SGD refer to <a class="reference external" href="https://gluon.mxnet.io/chapter06_optimization/gd-sgd-scratch.html">the following tutorial</a>. We also need to parametrize it with learning rate value, which defines the weight updates, and weight decay, which is used for regularization.</p>
<p>Metric helps us to estimate how good our model is in terms of a problem we are trying to solve. Where loss function has more importance for the training process, a metric is usually the thing we are trying to improve and reach maximum value. We also can use more than one metric, to measure various aspects of our model. In our example, we are using <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/model.html#mxnet.metric.Accuracy">Accuracy</a> and <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/model.html#mxnet.metric.F1">F1 score</a> as measurements of success of our model.</p>
<p>Below we define these objects.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">loss</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">SigmoidBinaryCrossEntropyLoss</span><span class="p">()</span>
<span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">params</span><span class="o">=</span><span class="n">net</span><span class="o">.</span><span class="n">collect_params</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span><span class="s1">'learning_rate'</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">})</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span>
<span class="n">f1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">F1</span><span class="p">()</span>
</pre></div>
</div>
<p>The next step is to define the training function in which we iterate over all batches of training data, execute the forward pass on each batch and calculate training loss. On line 19, we sum losses of every batch per epoch into a single variable, because we calculate loss per single batch, but want to display it per epoch.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_model</span><span class="p">():</span>
<span class="n">cumulative_train_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_dataloader</span><span class="p">):</span>
<span class="k">with</span> <span class="n">autograd</span><span class="o">.</span><span class="n">record</span><span class="p">():</span>
<span class="c1"># Do forward pass on a batch of training data</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c1"># Calculate loss for the training data batch</span>
<span class="n">loss_result</span> <span class="o">=</span> <span class="n">loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span>
<span class="c1"># Calculate gradients </span>
<span class="n">loss_result</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="c1"># Update parameters of the network</span>
<span class="n">trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="c1"># sum losses of every batch</span>
<span class="n">cumulative_train_loss</span> <span class="o">+=</span> <span class="n">nd</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">loss_result</span><span class="p">)</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
<span class="k">return</span> <span class="n">cumulative_train_loss</span>
</pre></div>
</div>
</div>
<div class="section" id="validating-the-model">
<span id="validating-the-model"></span><h2>Validating the model<a class="headerlink" href="#validating-the-model" title="Permalink to this headline"></a></h2>
<p>Our validation function is very similar to the training one. The main difference is that we want to calculate accuracy of the model. We use <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/model.html#mxnet.metric.Accuracy">Accuracy metric</a> to do so.</p>
<p><code class="docutils literal"><span class="pre">Accuracy</span></code> metric requires 2 arguments: 1) a vector of ground-truth classes and 2) A vector or matrix of predictions. When predictions are of the same shape as the vector of ground-truth classes, <code class="docutils literal"><span class="pre">Accuracy</span></code> class assumes that prediction vector contains predicted classes. So, it converts the vector to <code class="docutils literal"><span class="pre">Int32</span></code> and compare each item of ground-truth classes to prediction vector.</p>
<p>Because of the behaviour above, you will get an unexpected result if you just apply <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.sigmoid">Sigmoid</a> function to the network result and pass it to <code class="docutils literal"><span class="pre">Accuracy</span></code> metric. As mentioned before, we need to apply <code class="docutils literal"><span class="pre">Sigmoid</span></code> function to the output of the neuron to get a probability of belonging to the class 1. But <code class="docutils literal"><span class="pre">Sigmoid</span></code> function produces output in range [0; 1], and all numbers in that range are going to be casted to 0, even if it is as high as 0.99. To avoid this we write a custom bit of code on line 12, that:</p>
<ol class="simple">
<li>Calculates sigmoid using <code class="docutils literal"><span class="pre">Sigmoid</span></code> function</li>
<li>Subtracts a threshold from the original sigmoid output. Usually, the threshold is equal to 0.5, but it can be higher, if you want to increase certainty of an item to belong to class 1.</li>
<li>Uses <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.ceil">mx.nd.ceil</a> function, which converts all negative values to 0 and all positive values to 1</li>
</ol>
<p>After these transformations we can pass the result to <code class="docutils literal"><span class="pre">Accuracy.update()</span></code> method and expect it to behave in a proper way.</p>
<p>For <code class="docutils literal"><span class="pre">F1</span></code> metric to work, instead of one number per class, we must pass probabilities of belonging to both classes. Because of that, on lines 21-22 we:</p>
<ol class="simple">
<li>Reshape predictions to a single vector</li>
<li>We stack together two vectors: probabilities of belonging to class 0 (1 - <code class="docutils literal"><span class="pre">prediction</span></code>) and probabilities of belonging to class 1.</li>
</ol>
<p>Then we pass this stacked matrix to <code class="docutils literal"><span class="pre">F1</span></code> score.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">validate_model</span><span class="p">(</span><span class="n">threshold</span><span class="p">):</span>
<span class="n">cumulative_val_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">val_data</span><span class="p">,</span> <span class="n">val_ground_truth_class</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">val_dataloader</span><span class="p">):</span>
<span class="c1"># Do forward pass on a batch of validation data</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">val_data</span><span class="p">)</span>
<span class="c1"># Similar to cumulative training loss, calculate cumulative validation loss</span>
<span class="n">cumulative_val_loss</span> <span class="o">+=</span> <span class="n">nd</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">val_ground_truth_class</span><span class="p">))</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
<span class="c1"># getting prediction as a sigmoid</span>
<span class="n">prediction</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">val_data</span><span class="p">)</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">()</span>
<span class="c1"># Converting neuron outputs to classes</span>
<span class="n">predicted_classes</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">prediction</span> <span class="o">-</span> <span class="n">threshold</span><span class="p">)</span>
<span class="c1"># Update validation accuracy</span>
<span class="n">accuracy</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">val_ground_truth_class</span><span class="p">,</span> <span class="n">predicted_classes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="c1"># calculate probabilities of belonging to different classes. F1 metric works only with this notation</span>
<span class="n">prediction</span> <span class="o">=</span> <span class="n">prediction</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">probabilities</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">prediction</span><span class="p">,</span> <span class="n">prediction</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">f1</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">val_ground_truth_class</span><span class="p">,</span> <span class="n">probabilities</span><span class="p">)</span>
<span class="k">return</span> <span class="n">cumulative_val_loss</span>
</pre></div>
</div>
</div>
<div class="section" id="putting-it-all-together">
<span id="putting-it-all-together"></span><h2>Putting it all together<a class="headerlink" href="#putting-it-all-together" title="Permalink to this headline"></a></h2>
<p>By using the defined above functions, we can finally write our main training loop.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">threshold</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
<span class="n">avg_train_loss</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">()</span> <span class="o">/</span> <span class="n">train_data_size</span>
<span class="n">avg_val_loss</span> <span class="o">=</span> <span class="n">validate_model</span><span class="p">(</span><span class="n">threshold</span><span class="p">)</span> <span class="o">/</span> <span class="n">val_data_size</span>
<span class="k">print</span><span class="p">(</span><span class="s2">"Epoch: </span><span class="si">%s</span><span class="s2">, Training loss: </span><span class="si">%.2f</span><span class="s2">, Validation loss: </span><span class="si">%.2f</span><span class="s2">, Validation accuracy: </span><span class="si">%.2f</span><span class="s2">, F1 score: </span><span class="si">%.2f</span><span class="s2">"</span> <span class="o">%</span>
<span class="p">(</span><span class="n">e</span><span class="p">,</span> <span class="n">avg_train_loss</span><span class="p">,</span> <span class="n">avg_val_loss</span><span class="p">,</span> <span class="n">accuracy</span><span class="o">.</span><span class="n">get</span><span class="p">()[</span><span class="mi">1</span><span class="p">],</span> <span class="n">f1</span><span class="o">.</span><span class="n">get</span><span class="p">()[</span><span class="mi">1</span><span class="p">]))</span>
<span class="c1"># we reset accuracy, so the new epoch's accuracy would be calculated from the blank state</span>
<span class="n">accuracy</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span>Epoch: 0, Training loss: 0.43, Validation loss: 0.36, Validation accuracy: 0.85, F1 score: 0.00 <!--notebook-skip-line-->
Epoch: 1, Training loss: 0.22, Validation loss: 0.14, Validation accuracy: 0.96, F1 score: 0.35 <!--notebook-skip-line-->
Epoch: 2, Training loss: 0.09, Validation loss: 0.11, Validation accuracy: 0.97, F1 score: 0.48 <!--notebook-skip-line-->
Epoch: 3, Training loss: 0.07, Validation loss: 0.09, Validation accuracy: 0.96, F1 score: 0.53 <!--notebook-skip-line-->
Epoch: 4, Training loss: 0.06, Validation loss: 0.09, Validation accuracy: 0.97, F1 score: 0.58 <!--notebook-skip-line-->
Epoch: 5, Training loss: 0.04, Validation loss: 0.12, Validation accuracy: 0.97, F1 score: 0.59 <!--notebook-skip-line-->
Epoch: 6, Training loss: 0.05, Validation loss: 0.09, Validation accuracy: 0.99, F1 score: 0.62 <!--notebook-skip-line-->
Epoch: 7, Training loss: 0.05, Validation loss: 0.10, Validation accuracy: 0.97, F1 score: 0.62 <!--notebook-skip-line-->
Epoch: 8, Training loss: 0.05, Validation loss: 0.12, Validation accuracy: 0.95, F1 score: 0.63 <!--notebook-skip-line-->
Epoch: 9, Training loss: 0.04, Validation loss: 0.09, Validation accuracy: 0.98, F1 score: 0.65 <!--notebook-skip-line-->
</pre></div>
</div>
<p>In our case we hit the accuracy of 0.98 and F1 score of 0.65.</p>
</div>
<div class="section" id="tip-1-use-only-one-neuron-in-the-output-layer">
<span id="tip-1-use-only-one-neuron-in-the-output-layer"></span><h2>Tip 1: Use only one neuron in the output layer<a class="headerlink" href="#tip-1-use-only-one-neuron-in-the-output-layer" title="Permalink to this headline"></a></h2>
<p>Despite that there are 2 classes, there should be only one output neuron, because <code class="docutils literal"><span class="pre">SigmoidBinaryCrossEntropyLoss</span></code> accepts only one feature as an input.</p>
</div>
<div class="section" id="tip-2-encode-classes-as-0-and-1">
<span id="tip-2-encode-classes-as-0-and-1"></span><h2>Tip 2: Encode classes as 0 and 1<a class="headerlink" href="#tip-2-encode-classes-as-0-and-1" title="Permalink to this headline"></a></h2>
<p>For <code class="docutils literal"><span class="pre">SigmoidBinaryCrossEntropyLoss</span></code> to work it is required that classes were encoded as 0 and 1. In some datasets the class encoding might be different, like -1 and 1 or 1 and 2. If this is how your dataset looks like, then you need to re-encode the data before using <code class="docutils literal"><span class="pre">SigmoidBinaryCrossEntropyLoss</span></code>.</p>
</div>
<div class="section" id="tip-3-use-sigmoidbinarycrossentropyloss-instead-of-logisticregressionoutput">
<span id="tip-3-use-sigmoidbinarycrossentropyloss-instead-of-logisticregressionoutput"></span><h2>Tip 3: Use SigmoidBinaryCrossEntropyLoss instead of LogisticRegressionOutput<a class="headerlink" href="#tip-3-use-sigmoidbinarycrossentropyloss-instead-of-logisticregressionoutput" title="Permalink to this headline"></a></h2>
<p>NDArray API has two options to calculate logistic regression loss: <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/gluon/loss.html#mxnet.gluon.loss.SigmoidBinaryCrossEntropyLoss">SigmoidBinaryCrossEntropyLoss</a> and <a class="reference external" href="https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.LogisticRegressionOutput">LogisticRegressionOutput</a>. <code class="docutils literal"><span class="pre">LogisticRegressionOutput</span></code> is designed to be an output layer when using the Module API, and is not supposed to be used when using Gluon API.</p>
</div>
<div class="section" id="conclusion">
<span id="conclusion"></span><h2>Conclusion<a class="headerlink" href="#conclusion" title="Permalink to this headline"></a></h2>
<p>In this tutorial I explained some potential pitfalls to be aware of. When doing logistic regression using Gluon API remember to:</p>
<ol class="simple">
<li>Use only one neuron in the output layer</li>
<li>Encode class labels as 0 or 1</li>
<li>Use <code class="docutils literal"><span class="pre">SigmoidBinaryCrossEntropyLoss</span></code></li>
<li>Convert probabilities to classes before calculating Accuracy</li>
</ol>
<div class="btn-group" role="group">
<div class="download-btn"><a download="logistic_regression_explained.ipynb" href="logistic_regression_explained.ipynb"><span class="glyphicon glyphicon-download-alt"></span> logistic_regression_explained.ipynb</a></div></div></div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Logistic regression using Gluon API explained</a><ul>
<li><a class="reference internal" href="#working-with-data">Working with data</a></li>
<li><a class="reference internal" href="#defining-and-training-the-model">Defining and training the model</a></li>
<li><a class="reference internal" href="#validating-the-model">Validating the model</a></li>
<li><a class="reference internal" href="#putting-it-all-together">Putting it all together</a></li>
<li><a class="reference internal" href="#tip-1-use-only-one-neuron-in-the-output-layer">Tip 1: Use only one neuron in the output layer</a></li>
<li><a class="reference internal" href="#tip-2-encode-classes-as-0-and-1">Tip 2: Encode classes as 0 and 1</a></li>
<li><a class="reference internal" href="#tip-3-use-sigmoidbinarycrossentropyloss-instead-of-logisticregressionoutput">Tip 3: Use SigmoidBinaryCrossEntropyLoss instead of LogisticRegressionOutput</a></li>
<li><a class="reference internal" href="#conclusion">Conclusion</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>