blob: 12d91296890359a9dc599d66b4010b555866634e [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Text Classification Using a Convolutional Neural Network on MXNet — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="text-classification-using-a-convolutional-neural-network-on-mxnet">
<span id="text-classification-using-a-convolutional-neural-network-on-mxnet"></span><h1>Text Classification Using a Convolutional Neural Network on MXNet<a class="headerlink" href="#text-classification-using-a-convolutional-neural-network-on-mxnet" title="Permalink to this headline"></a></h1>
<p>This tutorial is based of Yoon Kim’s <a class="reference external" href="https://arxiv.org/abs/1408.5882">paper</a> on using convolutional neural networks for sentence sentiment classification.</p>
<p>For this tutorial, we will train a convolutional deep network model on movie review sentences from Rotten Tomatoes labeled with their sentiment. The result will be a model that can classify a sentence based on its sentiment (with 1 being a purely positive sentiment, 0 being a purely negative sentiment and 0.5 being neutral).</p>
<p>Our first step will be to fetch the labeled training data of positive and negative sentiment sentences and process it into sets of vectors that are then randomly split into train and test sets.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">urllib2</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">re</span>
<span class="kn">import</span> <span class="nn">itertools</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">Counter</span>
<span class="k">def</span> <span class="nf">clean_str</span><span class="p">(</span><span class="n">string</span><span class="p">):</span>
<span class="sd">"""</span>
<span class="sd"> Tokenization/string cleaning for all datasets except for SST.</span>
<span class="sd"> Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py</span>
<span class="sd"> """</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"[^A-Za-z0-9(),!?\'\`]"</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\'s"</span><span class="p">,</span> <span class="s2">" </span><span class="se">\'</span><span class="s2">s"</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\'ve"</span><span class="p">,</span> <span class="s2">" </span><span class="se">\'</span><span class="s2">ve"</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"n\'t"</span><span class="p">,</span> <span class="s2">" n</span><span class="se">\'</span><span class="s2">t"</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\'re"</span><span class="p">,</span> <span class="s2">" </span><span class="se">\'</span><span class="s2">re"</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\'d"</span><span class="p">,</span> <span class="s2">" </span><span class="se">\'</span><span class="s2">d"</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\'ll"</span><span class="p">,</span> <span class="s2">" </span><span class="se">\'</span><span class="s2">ll"</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">","</span><span class="p">,</span> <span class="s2">" , "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"!"</span><span class="p">,</span> <span class="s2">" ! "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\("</span><span class="p">,</span> <span class="s2">" \( "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\)"</span><span class="p">,</span> <span class="s2">" \) "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\?"</span><span class="p">,</span> <span class="s2">" \? "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="n">string</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\s{2,}"</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="n">string</span><span class="p">)</span>
<span class="k">return</span> <span class="n">string</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">load_data_and_labels</span><span class="p">():</span>
<span class="sd">"""</span>
<span class="sd"> Loads MR polarity data from files, splits the data into words and generates labels.</span>
<span class="sd"> Returns split sentences and labels.</span>
<span class="sd"> """</span>
<span class="c1"># Pull sentences with positive sentiment</span>
<span class="n">pos_file</span> <span class="o">=</span> <span class="n">urllib2</span><span class="o">.</span><span class="n">urlopen</span><span class="p">(</span><span class="s1">'https://raw.githubusercontent.com/yoonkim/CNN_sentence/master/rt-polarity.pos'</span><span class="p">)</span>
<span class="c1"># Pull sentences with negative sentiment</span>
<span class="n">neg_file</span> <span class="o">=</span> <span class="n">urllib2</span><span class="o">.</span><span class="n">urlopen</span><span class="p">(</span><span class="s1">'https://raw.githubusercontent.com/yoonkim/CNN_sentence/master/rt-polarity.neg'</span><span class="p">)</span>
<span class="c1"># Load data from files</span>
<span class="n">positive_examples</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">pos_file</span><span class="o">.</span><span class="n">readlines</span><span class="p">())</span>
<span class="n">positive_examples</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">positive_examples</span><span class="p">]</span>
<span class="n">negative_examples</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">neg_file</span><span class="o">.</span><span class="n">readlines</span><span class="p">())</span>
<span class="n">negative_examples</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">negative_examples</span><span class="p">]</span>
<span class="c1"># Split by words</span>
<span class="n">x_text</span> <span class="o">=</span> <span class="n">positive_examples</span> <span class="o">+</span> <span class="n">negative_examples</span>
<span class="n">x_text</span> <span class="o">=</span> <span class="p">[</span><span class="n">clean_str</span><span class="p">(</span><span class="n">sent</span><span class="p">)</span> <span class="k">for</span> <span class="n">sent</span> <span class="ow">in</span> <span class="n">x_text</span><span class="p">]</span>
<span class="n">x_text</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">" "</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">x_text</span><span class="p">]</span>
<span class="c1"># Generate labels</span>
<span class="n">positive_labels</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">positive_examples</span><span class="p">]</span>
<span class="n">negative_labels</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">negative_examples</span><span class="p">]</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">positive_labels</span><span class="p">,</span> <span class="n">negative_labels</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="p">[</span><span class="n">x_text</span><span class="p">,</span> <span class="n">y</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">pad_sentences</span><span class="p">(</span><span class="n">sentences</span><span class="p">,</span> <span class="n">padding_word</span><span class="o">=</span><span class="s2">"</s>"</span><span class="p">):</span>
<span class="sd">"""</span>
<span class="sd"> Pads all sentences to the same length. The length is defined by the longest sentence.</span>
<span class="sd"> Returns padded sentences.</span>
<span class="sd"> """</span>
<span class="n">sequence_length</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">sentences</span><span class="p">)</span>
<span class="n">padded_sentences</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sentences</span><span class="p">)):</span>
<span class="n">sentence</span> <span class="o">=</span> <span class="n">sentences</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">num_padding</span> <span class="o">=</span> <span class="n">sequence_length</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">sentence</span><span class="p">)</span>
<span class="n">new_sentence</span> <span class="o">=</span> <span class="n">sentence</span> <span class="o">+</span> <span class="p">[</span><span class="n">padding_word</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_padding</span>
<span class="n">padded_sentences</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">new_sentence</span><span class="p">)</span>
<span class="k">return</span> <span class="n">padded_sentences</span>
<span class="k">def</span> <span class="nf">build_vocab</span><span class="p">(</span><span class="n">sentences</span><span class="p">):</span>
<span class="sd">"""</span>
<span class="sd"> Builds a vocabulary mapping from word to index based on the sentences.</span>
<span class="sd"> Returns vocabulary mapping and inverse vocabulary mapping.</span>
<span class="sd"> """</span>
<span class="c1"># Build vocabulary</span>
<span class="n">word_counts</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">(</span><span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span><span class="o">*</span><span class="n">sentences</span><span class="p">))</span>
<span class="c1"># Mapping from index to word</span>
<span class="n">vocabulary_inv</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">word_counts</span><span class="o">.</span><span class="n">most_common</span><span class="p">()]</span>
<span class="c1"># Mapping from word to index</span>
<span class="n">vocabulary</span> <span class="o">=</span> <span class="p">{</span><span class="n">x</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">vocabulary_inv</span><span class="p">)}</span>
<span class="k">return</span> <span class="p">[</span><span class="n">vocabulary</span><span class="p">,</span> <span class="n">vocabulary_inv</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">build_input_data</span><span class="p">(</span><span class="n">sentences</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">vocabulary</span><span class="p">):</span>
<span class="sd">"""</span>
<span class="sd"> Maps sentences and labels to vectors based on a vocabulary.</span>
<span class="sd"> """</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">vocabulary</span><span class="p">[</span><span class="n">word</span><span class="p">]</span> <span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">sentence</span><span class="p">]</span> <span class="k">for</span> <span class="n">sentence</span> <span class="ow">in</span> <span class="n">sentences</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
<span class="k">return</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">]</span>
<span class="sd">"""</span>
<span class="sd">Loads and preprocessed data for the MR dataset.</span>
<span class="sd">Returns input vectors, labels, vocabulary, and inverse vocabulary.</span>
<span class="sd">"""</span>
<span class="c1"># Load and preprocess data</span>
<span class="n">sentences</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">load_data_and_labels</span><span class="p">()</span>
<span class="n">sentences_padded</span> <span class="o">=</span> <span class="n">pad_sentences</span><span class="p">(</span><span class="n">sentences</span><span class="p">)</span>
<span class="n">vocabulary</span><span class="p">,</span> <span class="n">vocabulary_inv</span> <span class="o">=</span> <span class="n">build_vocab</span><span class="p">(</span><span class="n">sentences_padded</span><span class="p">)</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">build_input_data</span><span class="p">(</span><span class="n">sentences_padded</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">vocabulary</span><span class="p">)</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">vocabulary</span><span class="p">)</span>
<span class="c1"># randomly shuffle data</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span>
<span class="n">shuffle_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">)))</span>
<span class="n">x_shuffled</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">shuffle_indices</span><span class="p">]</span>
<span class="n">y_shuffled</span> <span class="o">=</span> <span class="n">y</span><span class="p">[</span><span class="n">shuffle_indices</span><span class="p">]</span>
<span class="c1"># split train/dev set</span>
<span class="c1"># there are a total of 10662 labeled examples to train on</span>
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_dev</span> <span class="o">=</span> <span class="n">x_shuffled</span><span class="p">[:</span><span class="o">-</span><span class="mi">1000</span><span class="p">],</span> <span class="n">x_shuffled</span><span class="p">[</span><span class="o">-</span><span class="mi">1000</span><span class="p">:]</span>
<span class="n">y_train</span><span class="p">,</span> <span class="n">y_dev</span> <span class="o">=</span> <span class="n">y_shuffled</span><span class="p">[:</span><span class="o">-</span><span class="mi">1000</span><span class="p">],</span> <span class="n">y_shuffled</span><span class="p">[</span><span class="o">-</span><span class="mi">1000</span><span class="p">:]</span>
<span class="n">sentence_size</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span> <span class="s1">'Train/Dev split: </span><span class="si">%d</span><span class="s1">/</span><span class="si">%d</span><span class="s1">'</span> <span class="o">%</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">y_train</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">y_dev</span><span class="p">))</span>
<span class="k">print</span> <span class="s1">'train shape:'</span><span class="p">,</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span>
<span class="k">print</span> <span class="s1">'dev shape:'</span><span class="p">,</span> <span class="n">x_dev</span><span class="o">.</span><span class="n">shape</span>
<span class="k">print</span> <span class="s1">'vocab_size'</span><span class="p">,</span> <span class="n">vocab_size</span>
<span class="k">print</span> <span class="s1">'sentence max words'</span><span class="p">,</span> <span class="n">sentence_size</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span>Train/Dev split: 9662/1000
train shape: (9662, 56)
dev shape: (1000, 56)
vocab_size 18766
sentence max words 56
</pre></div>
</div>
<p>Now that we prepared the training and test data by loading, vectorizing and shuffling it we can go on to defining the network architecture we want to train with the data.</p>
<p>We will first set up some placeholders for the input and output of the network then define the first layer, an embedding layer, which learns to map word vectors into a lower dimensional vector space where distances between words correspond to how related they are (with respect to sentiment they convey).</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">sys</span><span class="o">,</span><span class="nn">os</span>
<span class="sd">'''</span>
<span class="sd">Define batch size and the place holders for network inputs and outputs</span>
<span class="sd">'''</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">50</span> <span class="c1"># the size of batches to train network with</span>
<span class="k">print</span> <span class="s1">'batch size'</span><span class="p">,</span> <span class="n">batch_size</span>
<span class="n">input_x</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span> <span class="c1"># placeholder for input data</span>
<span class="n">input_y</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'softmax_label'</span><span class="p">)</span> <span class="c1"># placeholder for output label</span>
<span class="sd">'''</span>
<span class="sd">Define the first network layer (embedding)</span>
<span class="sd">'''</span>
<span class="c1"># create embedding layer to learn representation of words in a lower dimensional subspace (much like word2vec)</span>
<span class="n">num_embed</span> <span class="o">=</span> <span class="mi">300</span> <span class="c1"># dimensions to embed words into</span>
<span class="k">print</span> <span class="s1">'embedding dimensions'</span><span class="p">,</span> <span class="n">num_embed</span>
<span class="n">embed_layer</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">input_x</span><span class="p">,</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">num_embed</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'vocab_embed'</span><span class="p">)</span>
<span class="c1"># reshape embedded data for next layer</span>
<span class="n">conv_input</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">embed_layer</span><span class="p">,</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">sentence_size</span><span class="p">,</span> <span class="n">num_embed</span><span class="p">))</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span>batch size 50
embedding dimensions 300
</pre></div>
</div>
<p>The next layer in the network performs convolutions over the ordered embedded word vectors in a sentence using multiple filter sizes, sliding over 3, 4 or 5 words at a time. This is the equivalent of looking at all 3-grams, 4-grams and 5-grams in a sentence and will allow us to understand how words contribute to sentiment in the context of those around them.</p>
<p>After each convolution, we add a max-pool layer to extract the most significant elements in each convolution and turn them into a feature vector.</p>
<p>Because each convolution+pool filter produces tensors of different shapes we need to create a layer for each of them, and then concatenate the results of these layers into one big feature vector.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># create convolution + (max) pooling layer for each filter operation</span>
<span class="n">filter_list</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">]</span> <span class="c1"># the size of filters to use</span>
<span class="k">print</span> <span class="s1">'convolution filters'</span><span class="p">,</span> <span class="n">filter_list</span>
<span class="n">num_filter</span><span class="o">=</span><span class="mi">100</span>
<span class="n">pooled_outputs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">filter_size</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">filter_list</span><span class="p">):</span>
<span class="n">convi</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">conv_input</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="n">filter_size</span><span class="p">,</span> <span class="n">num_embed</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="n">num_filter</span><span class="p">)</span>
<span class="n">relui</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">convi</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
<span class="n">pooli</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Pooling</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">relui</span><span class="p">,</span> <span class="n">pool_type</span><span class="o">=</span><span class="s1">'max'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="n">sentence_size</span> <span class="o">-</span> <span class="n">filter_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span>
<span class="n">pooled_outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pooli</span><span class="p">)</span>
<span class="c1"># combine all pooled outputs</span>
<span class="n">total_filters</span> <span class="o">=</span> <span class="n">num_filter</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">filter_list</span><span class="p">)</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Concat</span><span class="p">(</span><span class="o">*</span><span class="n">pooled_outputs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># reshape for next layer</span>
<span class="n">h_pool</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">concat</span><span class="p">,</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">total_filters</span><span class="p">))</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span>convolution filters [3, 4, 5]
</pre></div>
</div>
<p>Next, we add dropout regularization, which will randomly disable a fraction of neurons in the layer (set to 50% here) to ensure that that model does not overfit. This works by preventing neurons from co-adapting and forcing them to learn individually useful features.</p>
<p>This is necessary for our model because the dataset has a vocabulary of size around 20k and only around 10k examples so since this data set is pretty small we’re likely to overfit with a powerful model (like this neural net).</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># dropout layer</span>
<span class="n">dropout</span><span class="o">=</span><span class="mf">0.5</span>
<span class="k">print</span> <span class="s1">'dropout probability'</span><span class="p">,</span> <span class="n">dropout</span>
<span class="k">if</span> <span class="n">dropout</span> <span class="o">></span> <span class="mf">0.0</span><span class="p">:</span>
<span class="n">h_drop</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">h_pool</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">h_drop</span> <span class="o">=</span> <span class="n">h_pool</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span>dropout probability 0.5
</pre></div>
</div>
<p>Finally, we add a fully connected layer to add non-linearity to the model. We then classify the resulting output of this layer using a softmax function, yielding a result between 0 (negative sentiment) and 1 (positive).</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># fully connected layer</span>
<span class="n">num_label</span><span class="o">=</span><span class="mi">2</span>
<span class="n">cls_weight</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'cls_weight'</span><span class="p">)</span>
<span class="n">cls_bias</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'cls_bias'</span><span class="p">)</span>
<span class="n">fc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">h_drop</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="n">cls_weight</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">cls_bias</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="n">num_label</span><span class="p">)</span>
<span class="c1"># softmax output</span>
<span class="n">sm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">input_y</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span>
<span class="c1"># set CNN pointer to the "back" of the network</span>
<span class="n">cnn</span> <span class="o">=</span> <span class="n">sm</span>
</pre></div>
</div>
<p>Now that we have defined our CNN model we will define the device on our machine that we will train and execute this model on, as well as the datasets to train and test this model with.</p>
<p><em>If you are running this code be sure that you have a GPU on your machine if your ctx is set to mx.gpu(0) otherwise you can set your ctx to mx.cpu(0) which will run the training much slower</em></p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">namedtuple</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="c1"># Define the structure of our CNN Model (as a named tuple)</span>
<span class="n">CNNModel</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span><span class="s2">"CNNModel"</span><span class="p">,</span> <span class="p">[</span><span class="s1">'cnn_exec'</span><span class="p">,</span> <span class="s1">'symbol'</span><span class="p">,</span> <span class="s1">'data'</span><span class="p">,</span> <span class="s1">'label'</span><span class="p">,</span> <span class="s1">'param_blocks'</span><span class="p">])</span>
<span class="c1"># Define what device to train/test on</span>
<span class="n">ctx</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># If you have no GPU on your machine change this to</span>
<span class="c1"># ctx=mx.cpu(0)</span>
<span class="n">arg_names</span> <span class="o">=</span> <span class="n">cnn</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()</span>
<span class="n">input_shapes</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">input_shapes</span><span class="p">[</span><span class="s1">'data'</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">sentence_size</span><span class="p">)</span>
<span class="n">arg_shape</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">,</span> <span class="n">aux_shape</span> <span class="o">=</span> <span class="n">cnn</span><span class="o">.</span><span class="n">infer_shape</span><span class="p">(</span><span class="o">**</span><span class="n">input_shapes</span><span class="p">)</span>
<span class="n">arg_arrays</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">arg_shape</span><span class="p">]</span>
<span class="n">args_grad</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">shape</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">arg_shape</span><span class="p">,</span> <span class="n">arg_names</span><span class="p">):</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'softmax_label'</span><span class="p">,</span> <span class="s1">'data'</span><span class="p">]:</span> <span class="c1"># input, output</span>
<span class="k">continue</span>
<span class="n">args_grad</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span>
<span class="n">cnn_exec</span> <span class="o">=</span> <span class="n">cnn</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="n">arg_arrays</span><span class="p">,</span> <span class="n">args_grad</span><span class="o">=</span><span class="n">args_grad</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="s1">'add'</span><span class="p">)</span>
<span class="n">param_blocks</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">arg_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">arg_names</span><span class="p">,</span> <span class="n">cnn_exec</span><span class="o">.</span><span class="n">arg_arrays</span><span class="p">))</span>
<span class="n">initializer</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">arg_names</span><span class="p">):</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'softmax_label'</span><span class="p">,</span> <span class="s1">'data'</span><span class="p">]:</span> <span class="c1"># input, output</span>
<span class="k">continue</span>
<span class="n">initializer</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="n">param_blocks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">args_grad</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">name</span><span class="p">)</span> <span class="p">)</span>
<span class="n">out_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">cnn</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">(),</span> <span class="n">cnn_exec</span><span class="o">.</span><span class="n">outputs</span><span class="p">))</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">cnn_exec</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="s1">'data'</span><span class="p">]</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">cnn_exec</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="s1">'softmax_label'</span><span class="p">]</span>
<span class="n">cnn_model</span><span class="o">=</span> <span class="n">CNNModel</span><span class="p">(</span><span class="n">cnn_exec</span><span class="o">=</span><span class="n">cnn_exec</span><span class="p">,</span> <span class="n">symbol</span><span class="o">=</span><span class="n">cnn</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">,</span> <span class="n">param_blocks</span><span class="o">=</span><span class="n">param_blocks</span><span class="p">)</span>
</pre></div>
</div>
<p>We can now execute the training and testing of our network, which in-part mxnet automatically does for us with its forward and backward propagation methods, along with its automatic gradient calculations.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="sd">'''</span>
<span class="sd">Train the cnn_model using back prop</span>
<span class="sd">'''</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'rmsprop'</span>
<span class="n">max_grad_norm</span><span class="o">=</span><span class="mf">5.0</span>
<span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.0005</span>
<span class="n">epoch</span><span class="o">=</span><span class="mi">50</span>
<span class="k">print</span> <span class="s1">'optimizer'</span><span class="p">,</span> <span class="n">optimizer</span>
<span class="k">print</span> <span class="s1">'maximum gradient'</span><span class="p">,</span> <span class="n">max_grad_norm</span>
<span class="k">print</span> <span class="s1">'learning rate (step size)'</span><span class="p">,</span> <span class="n">learning_rate</span>
<span class="k">print</span> <span class="s1">'epochs to train for'</span><span class="p">,</span> <span class="n">epoch</span>
<span class="c1"># create optimizer</span>
<span class="n">opt</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
<span class="n">opt</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">learning_rate</span>
<span class="n">updater</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">get_updater</span><span class="p">(</span><span class="n">opt</span><span class="p">)</span>
<span class="c1"># create logging output</span>
<span class="n">logs</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
<span class="c1"># For each training epoch</span>
<span class="k">for</span> <span class="n">iteration</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epoch</span><span class="p">):</span>
<span class="n">tic</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">num_correct</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">num_total</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Over each batch of training data</span>
<span class="k">for</span> <span class="n">begin</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">):</span>
<span class="n">batchX</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">begin</span><span class="p">:</span><span class="n">begin</span><span class="o">+</span><span class="n">batch_size</span><span class="p">]</span>
<span class="n">batchY</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">[</span><span class="n">begin</span><span class="p">:</span><span class="n">begin</span><span class="o">+</span><span class="n">batch_size</span><span class="p">]</span>
<span class="k">if</span> <span class="n">batchX</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">batch_size</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">batchX</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">label</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">batchY</span>
<span class="c1"># forward</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># backward</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="c1"># eval on training data</span>
<span class="n">num_correct</span> <span class="o">+=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">batchY</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">num_total</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batchY</span><span class="p">)</span>
<span class="c1"># update weights</span>
<span class="n">norm</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">cnn_model</span><span class="o">.</span><span class="n">param_blocks</span><span class="p">:</span>
<span class="n">grad</span> <span class="o">/=</span> <span class="n">batch_size</span>
<span class="n">l2_norm</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
<span class="n">norm</span> <span class="o">+=</span> <span class="n">l2_norm</span> <span class="o">*</span> <span class="n">l2_norm</span>
<span class="n">norm</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">norm</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">cnn_model</span><span class="o">.</span><span class="n">param_blocks</span><span class="p">:</span>
<span class="k">if</span> <span class="n">norm</span> <span class="o">></span> <span class="n">max_grad_norm</span><span class="p">:</span>
<span class="n">grad</span> <span class="o">*=</span> <span class="p">(</span><span class="n">max_grad_norm</span> <span class="o">/</span> <span class="n">norm</span><span class="p">)</span>
<span class="n">updater</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">weight</span><span class="p">)</span>
<span class="c1"># reset gradient to zero</span>
<span class="n">grad</span><span class="p">[:]</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="c1"># Decay learning rate for this epoch to ensure we are not "overshooting" optima</span>
<span class="k">if</span> <span class="n">iteration</span> <span class="o">%</span> <span class="mi">50</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">iteration</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
<span class="n">opt</span><span class="o">.</span><span class="n">lr</span> <span class="o">*=</span> <span class="mf">0.5</span>
<span class="k">print</span> <span class="o">>></span> <span class="n">logs</span><span class="p">,</span> <span class="s1">'reset learning rate to </span><span class="si">%g</span><span class="s1">'</span> <span class="o">%</span> <span class="n">opt</span><span class="o">.</span><span class="n">lr</span>
<span class="c1"># End of training loop for this epoch</span>
<span class="n">toc</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">train_time</span> <span class="o">=</span> <span class="n">toc</span> <span class="o">-</span> <span class="n">tic</span>
<span class="n">train_acc</span> <span class="o">=</span> <span class="n">num_correct</span> <span class="o">*</span> <span class="mi">100</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">num_total</span><span class="p">)</span>
<span class="c1"># Saving checkpoint to disk</span>
<span class="k">if</span> <span class="p">(</span><span class="n">iteration</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">prefix</span> <span class="o">=</span> <span class="s1">'cnn'</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'./</span><span class="si">%s</span><span class="s1">-symbol.json'</span> <span class="o">%</span> <span class="n">prefix</span><span class="p">)</span>
<span class="n">save_dict</span> <span class="o">=</span> <span class="p">{(</span><span class="s1">'arg:</span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">k</span><span class="p">)</span> <span class="p">:</span><span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">arg_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">save_dict</span><span class="o">.</span><span class="n">update</span><span class="p">({(</span><span class="s1">'aux:</span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">k</span><span class="p">)</span> <span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">aux_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
<span class="n">param_name</span> <span class="o">=</span> <span class="s1">'./</span><span class="si">%s</span><span class="s1">-</span><span class="si">%04d</span><span class="s1">.params'</span> <span class="o">%</span> <span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">iteration</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">param_name</span><span class="p">,</span> <span class="n">save_dict</span><span class="p">)</span>
<span class="k">print</span> <span class="o">>></span> <span class="n">logs</span><span class="p">,</span> <span class="s1">'Saved checkpoint to </span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">param_name</span>
<span class="c1"># Evaluate model after this epoch on dev (test) set</span>
<span class="n">num_correct</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">num_total</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># For each test batch</span>
<span class="k">for</span> <span class="n">begin</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">x_dev</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">):</span>
<span class="n">batchX</span> <span class="o">=</span> <span class="n">x_dev</span><span class="p">[</span><span class="n">begin</span><span class="p">:</span><span class="n">begin</span><span class="o">+</span><span class="n">batch_size</span><span class="p">]</span>
<span class="n">batchY</span> <span class="o">=</span> <span class="n">y_dev</span><span class="p">[</span><span class="n">begin</span><span class="p">:</span><span class="n">begin</span><span class="o">+</span><span class="n">batch_size</span><span class="p">]</span>
<span class="k">if</span> <span class="n">batchX</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">batch_size</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">batchX</span>
<span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">num_correct</span> <span class="o">+=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">batchY</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">cnn_model</span><span class="o">.</span><span class="n">cnn_exec</span><span class="o">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">num_total</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batchY</span><span class="p">)</span>
<span class="n">dev_acc</span> <span class="o">=</span> <span class="n">num_correct</span> <span class="o">*</span> <span class="mi">100</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">num_total</span><span class="p">)</span>
<span class="k">print</span> <span class="o">>></span> <span class="n">logs</span><span class="p">,</span> <span class="s1">'Iter [</span><span class="si">%d</span><span class="s1">] Train: Time: </span><span class="si">%.3f</span><span class="s1">s, Training Accuracy: </span><span class="si">%.3f</span><span class="s1"> </span><span class="se">\</span>
<span class="s1"> --- Dev Accuracy thus far: </span><span class="si">%.3f</span><span class="s1">'</span> <span class="o">%</span> <span class="p">(</span><span class="n">iteration</span><span class="p">,</span> <span class="n">train_time</span><span class="p">,</span> <span class="n">train_acc</span><span class="p">,</span> <span class="n">dev_acc</span><span class="p">)</span>
</pre></div>
</div>
<p>Now that we have gone through the trouble of training the model, we have stored the learned parameters in the .params file in our local directory. We can now load this file whenever we want and predict the sentiment of new sentences by running them through a forward pass of the trained model.</p>
<div class="section" id="references">
<span id="references"></span><h2>References<a class="headerlink" href="#references" title="Permalink to this headline"></a></h2>
<div class="toctree-wrapper compound">
<ul>
<li class="toctree-l1"><a class="reference external" href="http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/">“Implementing a CNN for Text Classification in TensorFlow” blog post</a></li>
<li class="toctree-l1"><a class="reference external" href="https://arxiv.org/abs/1408.5882">Convolutional Neural Networks for Sentence Classification</a></li>
</ul>
</div>
</div>
<div class="section" id="next-steps">
<span id="next-steps"></span><h2>Next Steps<a class="headerlink" href="#next-steps" title="Permalink to this headline"></a></h2>
<div class="toctree-wrapper compound">
<ul>
<li class="toctree-l1"><a class="reference external" href="http://mxnet.io/tutorials/index.html">MXNet tutorials index</a></li>
</ul>
</div>
</div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Text Classification Using a Convolutional Neural Network on MXNet</a><ul>
<li><a class="reference internal" href="#references">References</a></li>
<li><a class="reference internal" href="#next-steps">Next Steps</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>