blob: e91426f6b0ac252b7e59efd4ba52564740c7d2b3 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Bucketing in MXNet — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../_static/underscore.js" type="text/javascript"></script>
<script src="../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../_static/doctools.js" type="text/javascript"></script>
<script src="../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../get_started/index.html">Get Started</a> </li>
<li> <a href="../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../packages/python/index.html">
Python
</a></li>
<li><a href="../packages/r/index.html">
R
</a></li>
<li><a href="../packages/julia/index.html">
Julia
</a></li>
<li><a href="../packages/c++/index.html">
C++
</a></li>
<li><a href="../packages/scala/index.html">
Scala
</a></li>
<li><a href="../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../get_started/install.html">Install</a>
<a class="main-nav-link" href="../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../get_started/install.html">Install</a></li>
<li><a href="../tutorials/index.html">Tutorials</a></li>
<li><a href="../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorials/index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="bucketing-in-mxnet">
<span id="bucketing-in-mxnet"></span><h1>Bucketing in MXNet<a class="headerlink" href="#bucketing-in-mxnet" title="Permalink to this headline"></a></h1>
<p>When we train recurrent neural networks (RNNs), we <em>unroll</em> the network in time.
For a single example of length T, we would unroll the network T steps.
In the unrolled view, the network is acyclic and thus we can think of it
as a feedforward neural network in which weights are tied across adjacent time steps.
The unrolled view allows us to train the network via the traditional backpropagation algorithm -
in this case, we call the algorithm <em>back-propagation through time</em>.</p>
<p>Things get complicated when we work with datasets where the example sequences have varying lengths.
In the <em>recurrent</em> view, each example can pass through the same architecture.
But in the <em>unrolled</em> view, each example requires a different number of unrollings, and thus corresponds to a <em>slightly</em> different feedforward network.
If we train on one example at a time, we can simply unroll the desired amount on each training iteration.
But things get more complicated when we try to perform mini-batch training.
A naive approach might be to pad all the sequences so that they appear to have the length of the longest example.
However, this could be wasteful because on shorter sequences, most of the computations are done on padded data.</p>
<p>Borrowed from <a class="reference external" href="https://www.tensorflow.org/versions/r0.7/tutorials/seq2seq/index.html">TensorFlow’s sequence training code</a>,
<em>bucketing</em> offers an effective solution to make minibatches out of varying-length sequences.
Instead of unrolling the network to the maximum possible sequence length,
we unroll multiple instances of different lengths (e.g., length 5, 10, 20, 30).
During training, we use the most appropriate unrolled model
for each mini-batch of data.
Although the models <em>with different numbers of unrollings</em>
have different feedforward architectures,
their parameters are shared in time.
<em>MXNet</em> reuses the internal memory buffers among all executors.</p>
<p>For simple RNNs, you can use a for loop to explicitly
go over the input sequences and perform a backpropagation-through-time
by maintaining the connection of the states and gradients through time.
However, this implementation approach results in slow processing.
This approach works with variable length sequences. For more complicated models (e.g., translation that uses a sequence-to-sequence model), explicitly unrolling is the easiest way. In this example, we introduce the MXNet APIs that allows us to implement bucketing.</p>
<div class="section" id="variable-length-sequence-training-for-ptb">
<span id="variable-length-sequence-training-for-ptb"></span><h2>Variable-length Sequence Training for PTB<a class="headerlink" href="#variable-length-sequence-training-for-ptb" title="Permalink to this headline"></a></h2>
<p>We use the <a class="reference external" href="https://github.com/dmlc/mxnet/tree/master/example/rnn">PennTreeBank language model example</a> for this example. If you are not familiar with this example, see <a class="reference external" href="http://dmlc.ml/mxnet/2015/11/15/char-lstm-in-julia.html">this tutorial (in Julia)</a> first.</p>
<p>In this example, we use a simple architecture
consisting of a word-embedding layer
followed by two LSTM layers.
In the original example,
the model is unrolled explicitly in time for a fixed length of 32.</p>
<!-- In this tutorial, we show how to use bucketing to implement variable-length sequence training. -->
To enable bucketing, MXNet needs to know how to construct a new unrolled symbolic architecture for a different sequence length. To achieve this, instead of constructing a model with a fixed `Symbol`, we use a callback function that generates a new `Symbol` on a *bucket key*.<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">BucketingModule</span><span class="p">(</span>
<span class="n">sym_gen</span> <span class="o">=</span> <span class="n">sym_gen</span><span class="p">,</span>
<span class="n">default_bucket_key</span> <span class="o">=</span> <span class="n">data_train</span><span class="o">.</span><span class="n">default_bucket_key</span><span class="p">,</span>
<span class="n">context</span> <span class="o">=</span> <span class="n">contexts</span><span class="p">)</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">sym_gen</span></code> must be a function that takes one argument, <code class="docutils literal"><span class="pre">bucket_key</span></code>, and returns a <code class="docutils literal"><span class="pre">Symbol</span></code> for this bucket. We’ll use the sequence length as the bucket key. A bucket key could be anything. For example, in neural translation, because different combinations of input-output sequence lengths correspond to different unrolling, the bucket key could be a pair of lengths.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sym_gen</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
<span class="k">return</span> <span class="n">lstm_unroll</span><span class="p">(</span><span class="n">num_lstm_layer</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">vocab</span><span class="p">),</span>
<span class="n">num_hidden</span><span class="o">=</span><span class="n">num_hidden</span><span class="p">,</span> <span class="n">num_embed</span><span class="o">=</span><span class="n">num_embed</span><span class="p">,</span>
<span class="n">num_label</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">vocab</span><span class="p">))</span>
</pre></div>
</div>
<p>The data iterator needs to report the <code class="docutils literal"><span class="pre">default_bucket_key</span></code>, which allows MXNet to do some parameter initialization before reading the data. Now the model is capable of training with different buckets by sharing the parameters and intermediate computation buffers between bucket executors.</p>
<p>To train, we still need to add some extra bits to our <code class="docutils literal"><span class="pre">DataIter</span></code>. Apart from reporting the <code class="docutils literal"><span class="pre">default_bucket_key</span></code> as mentioned previously, we also need to report the current <code class="docutils literal"><span class="pre">bucket_key</span></code> for each mini-batch. More specifically, the <code class="docutils literal"><span class="pre">DataBatch</span></code> object returned in each mini-batch by the <code class="docutils literal"><span class="pre">DataIter</span></code> should contain the following additional properties:</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">bucket_key</span></code>: The bucket key that corresponds to this batch of data.
In our example, it is the sequence length for this batch of data.
If the executors corresponding to this bucket key have not yet been created,
they will be constructed according to the symbolic model returned by <code class="docutils literal"><span class="pre">gen_sym</span></code> on this bucket key.
The executors will be cached for future use.
Note that generated <code class="docutils literal"><span class="pre">Symbol</span></code>s could be arbitrary,
but they should all have the same trainable parameters and auxiliary states.</li>
<li><code class="docutils literal"><span class="pre">provide_data</span></code>: The same information reported by the <code class="docutils literal"><span class="pre">DataIter</span></code> object.
Because each bucket now corresponds to a different architecture,
they could have different input data.
Also, make sure that the <code class="docutils literal"><span class="pre">provide_data</span></code> information returned by the <code class="docutils literal"><span class="pre">DataIter</span></code> object
is compatible with the architecture for <code class="docutils literal"><span class="pre">default_bucket_key</span></code>.</li>
<li><code class="docutils literal"><span class="pre">provide_label</span></code>: The same as <code class="docutils literal"><span class="pre">provide_data</span></code>.</li>
</ul>
<p>The <code class="docutils literal"><span class="pre">DataIter</span></code> is responsible for grouping the data into different buckets.
Assuming that randomization is enabled, at each iteration,
<code class="docutils literal"><span class="pre">DataIter</span></code> chooses a random bucket (according to a distribution balanced by the bucket sizes),
and then randomly chooses sequences from that bucket to form a mini-batch.
It also applies padding for sequences of different length within the mini-batch as necessary.</p>
<p>For a full, working implementation of a <code class="docutils literal"><span class="pre">DataIter</span></code>
that reads text sequences by as described above, see <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/example/rnn/lstm_bucketing.py">example/rnn/lstm_ptb_bucketing.py</a>.
In this example, you can use bucketing with a static configuration (e.g., <code class="docutils literal"><span class="pre">buckets</span> <span class="pre">=</span> <span class="pre">[10,</span> <span class="pre">20,</span> <span class="pre">30,</span> <span class="pre">40,</span> <span class="pre">50,</span> <span class="pre">60]</span></code>), or let MXNet generate buckets automatically according to the characteristics of the dataset (<code class="docutils literal"><span class="pre">buckets</span> <span class="pre">=</span> <span class="pre">[]</span></code>). The latter approach is implemented by adding a bucket as long as the number of sequences assigned to that bucket is exceeds some minimum count. For more information, see <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/example/rnn/old/bucket_io.py#L43">default_gen_buckets()</a>.</p>
</div>
<div class="section" id="beyond-sequence-training">
<span id="beyond-sequence-training"></span><h2>Beyond Sequence Training<a class="headerlink" href="#beyond-sequence-training" title="Permalink to this headline"></a></h2>
<p>In this example, we briefly explained how the bucketing API works.
However, the API is not limited to bucketing by sequence lengths.
The bucket key can be an arbitrary object, as long
as the architecture returned by <code class="docutils literal"><span class="pre">gen_sym</span></code>
is compatible with (has the same set of parameters) as the object.</p>
</div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Bucketing in MXNet</a><ul>
<li><a class="reference internal" href="#variable-length-sequence-training-for-ptb">Variable-length Sequence Training for PTB</a></li>
<li><a class="reference internal" href="#beyond-sequence-training">Beyond Sequence Training</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../_static/js/search.js" type="text/javascript"></script>
<script src="../_static/js/navbar.js" type="text/javascript"></script>
<script src="../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>