blob: 1cf778368fb40d1e8dc40ea528bb12b2fb4bd265 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Bucketing in MXNet — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../_static/basic.css" rel="stylesheet" type="text/css"/>
<link href="../_static/pygments.css" rel="stylesheet" type="text/css"/>
<link href="../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../_static/underscore.js" type="text/javascript"></script>
<script src="../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../_static/doctools.js" type="text/javascript"></script>
<script src="../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../install/index.html">Install</a>
<a class="main-nav-link" href="../tutorials/index.html">Learn</a>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="../community/index.html">Community</a></li>
<li><a class="main-nav-link" href="../community/collaborator.html">Collaborator</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../api/perl/index.html">Perl</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="../tutorials/index.html">Tutorials</a></li>
<li><a class="main-nav-link" href="../how_to/index.html">How To</a></li>
<li><a class="main-nav-link" href="../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/master/example">Examples</a></li>
<li><a class="main-nav-link" href="../model_zoo/index.html">Model Zoo</a></li>
</ul>
</span>
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/>1.0.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="../install/index.html">Install</a></li>
<li><a class="main-nav-link" href="../tutorials/index.html">Learn</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="../community/index.html" tabindex="-1">Community</a></li>
<li><a href="../community/collaborator.html" tabindex="-1">Collaborator</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="../tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="../how_to/index.html" tabindex="-1">How To</a></li>
<li><a href="../architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/master/example" tabindex="-1">Examples</a></li>
<li><a href="../model_zoo/index.html" tabindex="-1">Model Zoo</a></li>
</ul>
</li>
<li><a href="../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/>1.0.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes"/>
<input name="area" type="hidden" value="default"/>
</form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorials/index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/index.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="bucketing-in-mxnet">
<span id="bucketing-in-mxnet"></span><h1>Bucketing in MXNet<a class="headerlink" href="#bucketing-in-mxnet" title="Permalink to this headline"></a></h1>
<p>When we train recurrent neural networks (RNNs), we <em>unroll</em> the network in time.
For a single example of length T, we would unroll the network T steps.
In the unrolled view, the network is acyclic and thus we can think of it
as a feedforward neural network in which weights are tied across adjacent time steps.
The unrolled view allows us to train the network via the traditional backpropagation algorithm -
in this case, we call the algorithm <em>back-propagation through time</em>.</p>
<p>Things get complicated when we work with datasets where the example sequences have varying lengths.
In the <em>recurrent</em> view, each example can pass through the same architecture.
But in the <em>unrolled</em> view, each example requires a different number of unrollings, and thus corresponds to a <em>slightly</em> different feedforward network.
If we train on one example at a time, we can simply unroll the desired amount on each training iteration.
But things get more complicated when we try to perform mini-batch training.
A naive approach might be to pad all the sequences so that they appear to have the length of the longest example.
However, this could be wasteful because on shorter sequences, most of the computations are done on padded data.</p>
<p>Borrowed from <a class="reference external" href="https://www.tensorflow.org/versions/r0.7/tutorials/seq2seq/index.html">TensorFlow’s sequence training code</a>,
<em>bucketing</em> offers an effective solution to make minibatches out of varying-length sequences.
Instead of unrolling the network to the maximum possible sequence length,
we unroll multiple instances of different lengths (e.g., length 5, 10, 20, 30).
During training, we use the most appropriate unrolled model
for each mini-batch of data.
Although the models <em>with different numbers of unrollings</em>
have different feedforward architectures,
their parameters are shared in time.
<em>MXNet</em> reuses the internal memory buffers among all executors.</p>
<p>For simple RNNs, you can use a for loop to explicitly
go over the input sequences and perform a backpropagation-through-time
by maintaining the connection of the states and gradients through time.
However, this implementation approach results in slow processing.
This approach works with variable length sequences. For more complicated models (e.g., translation that uses a sequence-to-sequence model), explicitly unrolling is the easiest way. In this example, we introduce the MXNet APIs that allows us to implement bucketing.</p>
<div class="section" id="variable-length-sequence-training-for-ptb">
<span id="variable-length-sequence-training-for-ptb"></span><h2>Variable-length Sequence Training for PTB<a class="headerlink" href="#variable-length-sequence-training-for-ptb" title="Permalink to this headline"></a></h2>
<p>We use the <a class="reference external" href="https://github.com/dmlc/mxnet/tree/master/example/rnn">PennTreeBank language model example</a> for this example. If you are not familiar with this example, see <a class="reference external" href="http://dmlc.ml/mxnet/2015/11/15/char-lstm-in-julia.html">this tutorial (in Julia)</a> first.</p>
<p>In this example, we use a simple architecture
consisting of a word-embedding layer
followed by two LSTM layers.
In the original example,
the model is unrolled explicitly in time for a fixed length of 32.</p>
<!-- In this tutorial, we show how to use bucketing to implement variable-length sequence training. -->
To enable bucketing, MXNet needs to know how to construct a new unrolled symbolic architecture for a different sequence length. To achieve this, instead of constructing a model with a fixed `Symbol`, we use a callback function that generates a new `Symbol` on a *bucket key*.
```python
model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = data_train.default_bucket_key,
context = contexts)
```
`sym_gen` must be a function that takes one argument, `bucket_key`, and returns a `Symbol` for this bucket. We'll use the sequence length as the bucket key. A bucket key could be anything. For example, in neural translation, because different combinations of input-output sequence lengths correspond to different unrolling, the bucket key could be a pair of lengths.
```python
def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))
```
The data iterator needs to report the `default_bucket_key`, which allows MXNet to do some parameter initialization before reading the data. Now the model is capable of training with different buckets by sharing the parameters and intermediate computation buffers between bucket executors.<p>To train, we still need to add some extra bits to our <code class="docutils literal"><span class="pre">DataIter</span></code>. Apart from reporting the <code class="docutils literal"><span class="pre">default_bucket_key</span></code> as mentioned previously, we also need to report the current <code class="docutils literal"><span class="pre">bucket_key</span></code> for each mini-batch. More specifically, the <code class="docutils literal"><span class="pre">DataBatch</span></code> object returned in each mini-batch by the <code class="docutils literal"><span class="pre">DataIter</span></code> should contain the following additional properties:</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">bucket_key</span></code>: The bucket key that corresponds to this batch of data.
In our example, it is the sequence length for this batch of data.
If the executors corresponding to this bucket key have not yet been created,
they will be constructed according to the symbolic model returned by <code class="docutils literal"><span class="pre">gen_sym</span></code> on this bucket key.
The executors will be cached for future use.
Note that generated <code class="docutils literal"><span class="pre">Symbol</span></code>s could be arbitrary,
but they should all have the same trainable parameters and auxiliary states.</li>
<li><code class="docutils literal"><span class="pre">provide_data</span></code>: The same information reported by the <code class="docutils literal"><span class="pre">DataIter</span></code> object.
Because each bucket now corresponds to a different architecture,
they could have different input data.
Also, make sure that the <code class="docutils literal"><span class="pre">provide_data</span></code> information returned by the <code class="docutils literal"><span class="pre">DataIter</span></code> object
is compatible with the architecture for <code class="docutils literal"><span class="pre">default_bucket_key</span></code>.</li>
<li><code class="docutils literal"><span class="pre">provide_label</span></code>: The same as <code class="docutils literal"><span class="pre">provide_data</span></code>.</li>
</ul>
<p>The <code class="docutils literal"><span class="pre">DataIter</span></code> is responsible for grouping the data into different buckets.
Assuming that randomization is enabled, at each iteration,
<code class="docutils literal"><span class="pre">DataIter</span></code> chooses a random bucket (according to a distribution balanced by the bucket sizes),
and then randomly chooses sequences from that bucket to form a mini-batch.
It also applies padding for sequences of different length within the mini-batch as necessary.</p>
<p>For a full, working implementation of a <code class="docutils literal"><span class="pre">DataIter</span></code>
that reads text sequences by as described above, see <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/example/rnn/lstm_bucketing.py">example/rnn/lstm_ptb_bucketing.py</a>.
In this example, you can use bucketing with a static configuration (e.g., <code class="docutils literal"><span class="pre">buckets</span> <span class="pre">=</span> <span class="pre">[10,</span> <span class="pre">20,</span> <span class="pre">30,</span> <span class="pre">40,</span> <span class="pre">50,</span> <span class="pre">60]</span></code>), or let MXNet generate buckets automatically according to the characteristics of the dataset (<code class="docutils literal"><span class="pre">buckets</span> <span class="pre">=</span> <span class="pre">[]</span></code>). The latter approach is implemented by adding a bucket as long as the number of sequences assigned to that bucket is exceeds some minimum count. For more information, see <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/example/rnn/old/bucket_io.py#L43">default_gen_buckets()</a>.</p>
</div>
<div class="section" id="beyond-sequence-training">
<span id="beyond-sequence-training"></span><h2>Beyond Sequence Training<a class="headerlink" href="#beyond-sequence-training" title="Permalink to this headline"></a></h2>
<p>In this example, we briefly explained how the bucketing API works.
However, the API is not limited to bucketing by sequence lengths.
The bucket key can be an arbitrary object, as long
as the architecture returned by <code class="docutils literal"><span class="pre">gen_sym</span></code>
is compatible with (has the same set of parameters) as the object.</p>
</div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Bucketing in MXNet</a><ul>
<li><a class="reference internal" href="#variable-length-sequence-training-for-ptb">Variable-length Sequence Training for PTB</a></li>
<li><a class="reference internal" href="#beyond-sequence-training">Beyond Sequence Training</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../_static/js/search.js" type="text/javascript"></script>
<script src="../_static/js/navbar.js" type="text/javascript"></script>
<script src="../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../_static/js/copycode.js" type="text/javascript"></script>
<script src="../_static/js/page.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>