blob: b8297f0d83d2d06f243c0487faafc0f969002219 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>RNN Cell API — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="index.html" rel="up" title="MXNet - Python API">
<link href="kvstore.html" rel="next" title="KVStore API"/>
<link href="module.html" rel="prev" title="Module API"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Python Documents</a><ul class="current">
<li class="toctree-l2 current"><a class="reference internal" href="index.html#table-of-contents">Table of contents</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="ndarray.html">NDArray API</a></li>
<li class="toctree-l3"><a class="reference internal" href="symbol.html">Symbol API</a></li>
<li class="toctree-l3"><a class="reference internal" href="module.html">Module API</a></li>
<li class="toctree-l3 current"><a class="current reference internal" href="">RNN Cell API</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#overview">Overview</a></li>
<li class="toctree-l4"><a class="reference internal" href="#the-rnn-module">The <code class="docutils literal"><span class="pre">rnn</span></code> module</a></li>
<li class="toctree-l4"><a class="reference internal" href="#api-reference">API Reference</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="kvstore.html">KVStore API</a></li>
<li class="toctree-l3"><a class="reference internal" href="io.html">Data Loading API</a></li>
<li class="toctree-l3"><a class="reference internal" href="optimization.html">Optimization: initialize and update weights</a></li>
<li class="toctree-l3"><a class="reference internal" href="callback.html">Callback API</a></li>
<li class="toctree-l3"><a class="reference internal" href="metric.html">Evaluation Metric API</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="rnn-cell-api">
<span id="rnn-cell-api"></span><h1>RNN Cell API<a class="headerlink" href="#rnn-cell-api" title="Permalink to this headline"></a></h1>
<div class="admonition warning">
<p class="first admonition-title">Warning</p>
<p class="last">This package is currently experimental and may change in the near future.</p>
</div>
<div class="section" id="overview">
<span id="overview"></span><h2>Overview<a class="headerlink" href="#overview" title="Permalink to this headline"></a></h2>
<p>The <code class="docutils literal"><span class="pre">rnn</span></code> module includes the recurrent neural network (RNN) cell APIs, a suite of tools for building an RNN’s symbolic graph.</p>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">The <cite>rnn</cite> module offers higher-level interface while <cite>symbol.RNN</cite> is a lower-level interface. The cell APIs in <cite>rnn</cite> module are easier to use in most cases.</p>
</div>
</div>
<div class="section" id="the-rnn-module">
<span id="the-rnn-module"></span><h2>The <code class="docutils literal"><span class="pre">rnn</span></code> module<a class="headerlink" href="#the-rnn-module" title="Permalink to this headline"></a></h2>
<div class="section" id="cell-interfaces">
<span id="cell-interfaces"></span><h3>Cell interfaces<a class="headerlink" href="#cell-interfaces" title="Permalink to this headline"></a></h3>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.__call__" title="mxnet.rnn.BaseRNNCell.__call__"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.__call__</span></code></a></td>
<td>Unroll the RNN for one time step.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unroll" title="mxnet.rnn.BaseRNNCell.unroll"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.unroll</span></code></a></td>
<td>Unroll an RNN cell across time steps.</td>
</tr>
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.reset" title="mxnet.rnn.BaseRNNCell.reset"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.reset</span></code></a></td>
<td>Reset before re-using the cell for another graph.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.begin_state" title="mxnet.rnn.BaseRNNCell.begin_state"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.begin_state</span></code></a></td>
<td>Initial state for this cell.</td>
</tr>
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unpack_weights" title="mxnet.rnn.BaseRNNCell.unpack_weights"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.unpack_weights</span></code></a></td>
<td>Unpack fused weight matrices into separate weight matrices.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.pack_weights" title="mxnet.rnn.BaseRNNCell.pack_weights"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.pack_weights</span></code></a></td>
<td>Pack separate weight matrices into a single packed weight.</td>
</tr>
</tbody>
</table>
<p>When working with the cell API, the precise input and output symbols
depend on the type of RNN you are using. Take Long Short-Term Memory (LSTM) for example:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="c1"># Shape of 'step_data' is (batch_size,).</span>
<span class="n">step_input</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'step_data'</span><span class="p">)</span>
<span class="c1"># First we embed our raw input data to be used as LSTM's input.</span>
<span class="n">embedded_step</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">step_input</span><span class="p">,</span> \
<span class="n">input_dim</span><span class="o">=</span><span class="n">input_dim</span><span class="p">,</span> \
<span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>
<span class="c1"># Then we create an LSTM cell.</span>
<span class="n">lstm_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="c1"># Initialize its hidden and memory states.</span>
<span class="c1"># 'begin_state' method takes an initialization function, and uses 'zeros' by default.</span>
<span class="n">begin_state</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="o">.</span><span class="n">begin_state</span><span class="p">()</span>
</pre></div>
</div>
<p>The LSTM cell and other non-fused RNN cells are callable. Calling the cell updates it’s state once. This transformation depends on both the current input and the previous states. See this <a class="reference external" href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">blog post</a> for a great introduction to LSTM and other RNN.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Call the cell to get the output of one time step for a batch.</span>
<span class="n">output</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="p">(</span><span class="n">embedded_step</span><span class="p">,</span> <span class="n">begin_state</span><span class="p">)</span>
<span class="c1"># 'output' is lstm_t0_out_output of shape (batch_size, hidden_dim).</span>
<span class="c1"># 'states' has the recurrent states that will be carried over to the next step,</span>
<span class="c1"># which includes both the "hidden state" and the "cell state":</span>
<span class="c1"># Both 'lstm_t0_out_output' and 'lstm_t0_state_output' have shape (batch_size, hidden_dim).</span>
</pre></div>
</div>
<p>Most of the time our goal is to process a sequence of many steps. For this, we need to unroll the LSTM according to the sequence length.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Embed a sequence. 'seq_data' has the shape of (batch_size, sequence_length).</span>
<span class="n">seq_input</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'seq_data'</span><span class="p">)</span>
<span class="n">embedded_seq</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">seq_input</span><span class="p">,</span> \
<span class="n">input_dim</span><span class="o">=</span><span class="n">input_dim</span><span class="p">,</span> \
<span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>
</pre></div>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">Remember to reset the cell when unrolling/stepping for a new sequence by calling <cite>lstm_cell.reset()</cite>.</p>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Note that when unrolling, if 'merge_outputs' is set to True, the 'outputs' is merged into a single symbol</span>
<span class="c1"># In the layout, 'N' represents batch size, 'T' represents sequence length, and 'C' represents the</span>
<span class="c1"># number of dimensions in hidden states.</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">layout</span><span class="o">=</span><span class="s1">'NTC'</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># 'outputs' is concat0_output of shape (batch_size, sequence_length, hidden_dim).</span>
<span class="c1"># The hidden state and cell state from the final time step is returned:</span>
<span class="c1"># Both 'lstm_t4_out_output' and 'lstm_t4_state_output' have shape (batch_size, hidden_dim).</span>
<span class="c1"># If merge_outputs is set to False, a list of symbols for each of the time steps is returned.</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">layout</span><span class="o">=</span><span class="s1">'NTC'</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="c1"># In this case, 'outputs' is a list of symbols. Each symbol is of shape (batch_size, hidden_dim).</span>
</pre></div>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">Loading and saving models that are built with RNN cells API requires using
<cite>mx.rnn.load_rnn_checkpoint</cite>, <cite>mx.rnn.save_rnn_checkpoint</cite>, and <cite>mx.rnn.do_rnn_checkpoint</cite>.
The list of all the used cells should be provided as the first argument to those functions.</p>
</div>
</div>
<div class="section" id="basic-rnn-cells">
<span id="basic-rnn-cells"></span><h3>Basic RNN cells<a class="headerlink" href="#basic-rnn-cells" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">rnn</span></code> module supports the following RNN cell types.</p>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.LSTMCell" title="mxnet.rnn.LSTMCell"><code class="xref py py-obj docutils literal"><span class="pre">LSTMCell</span></code></a></td>
<td>Long-Short Term Memory (LSTM) network cell.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.GRUCell" title="mxnet.rnn.GRUCell"><code class="xref py py-obj docutils literal"><span class="pre">GRUCell</span></code></a></td>
<td>Gated Rectified Unit (GRU) network cell.</td>
</tr>
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.RNNCell" title="mxnet.rnn.RNNCell"><code class="xref py py-obj docutils literal"><span class="pre">RNNCell</span></code></a></td>
<td>Simple recurrent neural network cell.</td>
</tr>
</tbody>
</table>
</div>
<div class="section" id="modifier-cells">
<span id="modifier-cells"></span><h3>Modifier cells<a class="headerlink" href="#modifier-cells" title="Permalink to this headline"></a></h3>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BidirectionalCell" title="mxnet.rnn.BidirectionalCell"><code class="xref py py-obj docutils literal"><span class="pre">BidirectionalCell</span></code></a></td>
<td>Bidirectional RNN cell.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.DropoutCell" title="mxnet.rnn.DropoutCell"><code class="xref py py-obj docutils literal"><span class="pre">DropoutCell</span></code></a></td>
<td>Apply dropout on input.</td>
</tr>
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.ZoneoutCell" title="mxnet.rnn.ZoneoutCell"><code class="xref py py-obj docutils literal"><span class="pre">ZoneoutCell</span></code></a></td>
<td>Apply Zoneout on base cell.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.ResidualCell" title="mxnet.rnn.ResidualCell"><code class="xref py py-obj docutils literal"><span class="pre">ResidualCell</span></code></a></td>
<td>Adds residual connection as described in Wu et al, 2016 (<a class="reference external" href="https://arxiv.org/abs/1609.08144">https://arxiv.org/abs/1609.08144</a>).</td>
</tr>
</tbody>
</table>
<p>A modifier cell takes in one or more cells and transforms the output of those cells.
<code class="docutils literal"><span class="pre">BidirectionalCell</span></code> is one example. It takes two cells for forward unroll and backward unroll
respectively. After unrolling, the outputs of the forward and backward pass are concatenated.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Bidirectional cell takes two RNN cells, for forward and backward pass respectively.</span>
<span class="c1"># Having different types of cells for forward and backward unrolling is allowed.</span>
<span class="n">bi_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">BidirectionalCell</span><span class="p">(</span>
<span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">),</span>
<span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">GRUCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">75</span><span class="p">))</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">bi_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># The output feature is the concatenation of the forward and backward pass.</span>
<span class="c1"># Thus, the number of output dimensions is the sum of the dimensions of the two cells.</span>
<span class="c1"># 'outputs' is the symbol 'bi_out_output' of shape (batch_size, sequence_length, 125L)</span>
<span class="c1"># The states of the BidirectionalCell is a list of two lists, corresponding to the</span>
<span class="c1"># states of the forward and backward cells respectively.</span>
</pre></div>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">BidirectionalCell cannot be called or stepped, because the backward unroll requires the output of
future steps, and thus the whole sequence is required.</p>
</div>
<p>Dropout and zoneout are popular regularization techniques that can be applied to RNN. <code class="docutils literal"><span class="pre">rnn</span></code>
module provides <code class="docutils literal"><span class="pre">DropoutCell</span></code> and <code class="docutils literal"><span class="pre">ZoneoutCell</span></code> for regularization on the output and recurrent
states of RNN. <code class="docutils literal"><span class="pre">ZoneoutCell</span></code> takes one RNN cell in the constructor, and supports unrolling like
other cells.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">zoneout_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">ZoneoutCell</span><span class="p">(</span><span class="n">lstm_cell</span><span class="p">,</span> <span class="n">zoneout_states</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">zoneout_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">DropoutCell</span></code> performs dropout on the input sequence. It can be used in a stacked
multi-layer RNN setting, which we will cover next.</p>
<p>Residual connection is a useful technique for training deep neural models because it helps the
propagation of gradients by shortening the paths. <code class="docutils literal"><span class="pre">ResidualCell</span></code> provides such functionality for
RNN models.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">residual_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">ResidualCell</span><span class="p">(</span><span class="n">lstm_cell</span><span class="p">)</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">residual_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">outputs</span></code> are the element-wise sum of both the input and the output of the LSTM cell.</p>
</div>
<div class="section" id="multi-layer-cells">
<span id="multi-layer-cells"></span><h3>Multi-layer cells<a class="headerlink" href="#multi-layer-cells" title="Permalink to this headline"></a></h3>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.SequentialRNNCell" title="mxnet.rnn.SequentialRNNCell"><code class="xref py py-obj docutils literal"><span class="pre">SequentialRNNCell</span></code></a></td>
<td>Sequantially stacking multiple RNN cells.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.SequentialRNNCell.add" title="mxnet.rnn.SequentialRNNCell.add"><code class="xref py py-obj docutils literal"><span class="pre">SequentialRNNCell.add</span></code></a></td>
<td>Append a cell into the stack.</td>
</tr>
</tbody>
</table>
<p>The <code class="docutils literal"><span class="pre">SequentialRNNCell</span></code> allows stacking multiple layers of RNN cells to improve the expressiveness
and performance of the model. Cells can be added to a <code class="docutils literal"><span class="pre">SequentialRNNCell</span></code> in order, from bottom to
top. When unrolling, the output of a lower-level cell is automatically passed to the cell above.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">stacked_rnn_cells</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">SequentialRNNCell</span><span class="p">()</span>
<span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">BidirectionalCell</span><span class="p">(</span>
<span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">),</span>
<span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">)))</span>
<span class="c1"># Dropout the output of the bottom layer BidirectionalCell with a retention probability of 0.5.</span>
<span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">DropoutCell</span><span class="p">(</span><span class="mf">0.5</span><span class="p">))</span>
<span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">))</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># The output of SequentialRNNCell is the same as that of the last layer.</span>
<span class="c1"># In this case 'outputs' is the symbol 'concat6_output' of shape (batch_size, sequence_length, hidden_dim)</span>
<span class="c1"># The states of the SequentialRNNCell is a list of lists, with each list</span>
<span class="c1"># corresponding to the states of each of the added cells respectively.</span>
</pre></div>
</div>
</div>
<div class="section" id="fused-rnn-cell">
<span id="fused-rnn-cell"></span><h3>Fused RNN cell<a class="headerlink" href="#fused-rnn-cell" title="Permalink to this headline"></a></h3>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.FusedRNNCell" title="mxnet.rnn.FusedRNNCell"><code class="xref py py-obj docutils literal"><span class="pre">FusedRNNCell</span></code></a></td>
<td>Fusing RNN layers across time step into one kernel.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.FusedRNNCell.unfuse" title="mxnet.rnn.FusedRNNCell.unfuse"><code class="xref py py-obj docutils literal"><span class="pre">FusedRNNCell.unfuse</span></code></a></td>
<td>Unfuse the fused RNN in to a stack of rnn cells.</td>
</tr>
</tbody>
</table>
<p>The computation of an RNN for an input sequence consists of many GEMM and point-wise operations with
temporal dependencies dependencies. This could make the computation memory-bound especially on GPU,
resulting in longer wall-time. By combining the computation of many small matrices into that of
larger ones and streaming the computation whenever possible, the ratio of computation to memory I/O
can be increased, which results in better performance on GPU. Such optimization technique is called
“fusing”.
<a class="reference external" href="https://devblogs.nvidia.com/parallelforall/optimizing-recurrent-neural-networks-cudnn-5/">This post</a>
talks in greater detail.</p>
<p>The <code class="docutils literal"><span class="pre">rnn</span></code> module includes a <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>, which provides the optimized fused implementation.
The FusedRNNCell supports bidirectional RNNs and dropout.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">fused_lstm_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">FusedRNNCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> \
<span class="n">num_layers</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> \
<span class="n">mode</span><span class="o">=</span><span class="s1">'lstm'</span><span class="p">,</span> \
<span class="n">bidirectional</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> \
<span class="n">dropout</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">fused_lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># The 'outputs' is the symbol 'lstm_rnn_output' that has the shape</span>
<span class="c1"># (batch_size, sequence_length, forward_backward_concat_dim)</span>
</pre></div>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last"><cite>FusedRNNCell</cite> supports GPU-only. It cannot be called or stepped.</p>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">When <cite>dropout</cite> is set to non-zero in <cite>FusedRNNCell</cite>, the dropout is applied to the
output of all layers except the last layer. If there is only one layer in the <cite>FusedRNNCell</cite>, the
dropout rate is ignored.</p>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">Similar to <cite>BidirectionalCell</cite>, when <cite>bidirectional</cite> flag is set to <cite>True</cite>, the output
of <cite>FusedRNNCell</cite> is twice the size specified by <cite>num_hidden</cite>.</p>
</div>
<p>When training a deep, complex model <em>on multiple GPUs</em> it’s recommended to stack
fused RNN cells (one layer per cell) together instead of one with all layers.
The reason is that fused RNN cells don’t set gradients to be ready until the
computation for the entire layer is completed. Breaking a multi-layer fused RNN
cell into several one-layer ones allows gradients to be processed ealier. This
reduces communication overhead, especially with multiple GPUs.</p>
<p>The <code class="docutils literal"><span class="pre">unfuse()</span></code> method can be used to convert the <code class="docutils literal"><span class="pre">FusedRNNCell</span></code> into an equivalent
and CPU-compatible <code class="docutils literal"><span class="pre">SequentialRNNCell</span></code> that mirrors the settings of the <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">unfused_lstm_cell</span> <span class="o">=</span> <span class="n">fused_lstm_cell</span><span class="o">.</span><span class="n">unfuse</span><span class="p">()</span>
<span class="n">unfused_outputs</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">unfused_lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \
<span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \
<span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># The 'outputs' is the symbol 'lstm_bi_l2_out_output' that has the shape</span>
<span class="c1"># (batch_size, sequence_length, forward_backward_concat_dim)</span>
</pre></div>
</div>
</div>
<div class="section" id="rnn-checkpoint-methods-and-parameters">
<span id="rnn-checkpoint-methods-and-parameters"></span><h3>RNN checkpoint methods and parameters<a class="headerlink" href="#rnn-checkpoint-methods-and-parameters" title="Permalink to this headline"></a></h3>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.save_rnn_checkpoint" title="mxnet.rnn.save_rnn_checkpoint"><code class="xref py py-obj docutils literal"><span class="pre">save_rnn_checkpoint</span></code></a></td>
<td>Save checkpoint for model using RNN cells.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.load_rnn_checkpoint" title="mxnet.rnn.load_rnn_checkpoint"><code class="xref py py-obj docutils literal"><span class="pre">load_rnn_checkpoint</span></code></a></td>
<td>Load model checkpoint from file.</td>
</tr>
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.do_rnn_checkpoint" title="mxnet.rnn.do_rnn_checkpoint"><code class="xref py py-obj docutils literal"><span class="pre">do_rnn_checkpoint</span></code></a></td>
<td>Make a callback to checkpoint Module to prefix every epoch.</td>
</tr>
</tbody>
</table>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><code class="xref py py-obj docutils literal"><span class="pre">RNNParams</span></code></a></td>
<td>Container for holding variables.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.RNNParams.get" title="mxnet.rnn.RNNParams.get"><code class="xref py py-obj docutils literal"><span class="pre">RNNParams.get</span></code></a></td>
<td>Get the variable given a name if one exists or create a new one if missing.</td>
</tr>
</tbody>
</table>
<p>The model parameters from the training with fused cell can be used for inference with unfused cell,
and vice versa. As the parameters of fused and unfused cells are organized differently, they need to
be converted first. <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>‘s parameters are merged and flattened. In the fused example above,
the mode has <code class="docutils literal"><span class="pre">lstm_parameters</span></code> of shape <code class="docutils literal"><span class="pre">(total_num_params,)</span></code>, whereas the
equivalent SequentialRNNCell’s parameters are separate:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="s1">'lstm_l0_i2h_weight'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span>
<span class="s1">'lstm_l0_i2h_bias'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,)</span>
<span class="s1">'lstm_l0_h2h_weight'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
<span class="s1">'lstm_l0_h2h_bias'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,)</span>
<span class="s1">'lstm_r0_i2h_weight'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span>
<span class="o">...</span>
</pre></div>
</div>
<p>All cells in the <code class="docutils literal"><span class="pre">rnn</span></code> module support the method <code class="docutils literal"><span class="pre">unpack_weights()</span></code> for converting <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>
parameters to the unfused format and <code class="docutils literal"><span class="pre">pack_weights()</span></code> for fusing the parameters. The RNN-specific
checkpointing methods (<code class="docutils literal"><span class="pre">load_rnn_checkpoint,</span> <span class="pre">save_rnn_checkpoint,</span> <span class="pre">do_rnn_checkpoint</span></code>) handle the
conversion transparently based on the provided cells.</p>
</div>
<div class="section" id="i-o-utilities">
<span id="i-o-utilities"></span><h3>I/O utilities<a class="headerlink" href="#i-o-utilities" title="Permalink to this headline"></a></h3>
<table border="1" class="longtable docutils">
<colgroup>
<col width="10%"/>
<col width="90%"/>
</colgroup>
<tbody valign="top">
<tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BucketSentenceIter" title="mxnet.rnn.BucketSentenceIter"><code class="xref py py-obj docutils literal"><span class="pre">BucketSentenceIter</span></code></a></td>
<td>Simple bucketing iterator for language model.</td>
</tr>
<tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.encode_sentences" title="mxnet.rnn.encode_sentences"><code class="xref py py-obj docutils literal"><span class="pre">encode_sentences</span></code></a></td>
<td>Encode sentences and (optionally) build a mapping from string tokens to integer indices.</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="section" id="api-reference">
<span id="api-reference"></span><h2>API Reference<a class="headerlink" href="#api-reference" title="Permalink to this headline"></a></h2>
<script src="../../_static/js/auto_module_index.js" type="text/javascript"></script><dl class="class">
<dt id="mxnet.rnn.BaseRNNCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">BaseRNNCell</code><span class="sig-paren">(</span><em>prefix=''</em>, <em>params=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell" title="Permalink to this definition"></a></dt>
<dd><p>Abstract base class for RNN cells</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>prefix</strong> (<em>str, optional</em>) – Prefix for names of layers
(this prefix is also used for names of weights if <cite>params</cite> is None
i.e. if <cite>params</cite> are being created and not reused)</li>
<li><strong>params</strong> (<em>RNNParams, default None.</em>) – Container for weight sharing between cells.
A new RNNParams container is created if <cite>params</cite> is None.</li>
</ul>
</td>
</tr>
</tbody>
</table>
<dl class="method">
<dt id="mxnet.rnn.BaseRNNCell.__call__">
<code class="descname">__call__</code><span class="sig-paren">(</span><em>inputs</em>, <em>states</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.__call__" title="Permalink to this definition"></a></dt>
<dd><p>Unroll the RNN for one time step.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>inputs</strong> (<em>sym.Variable</em>) – input symbol, 2D, batch * num_units</li>
<li><strong>states</strong> (<em>list of sym.Variable</em>) – RNN state from previous step or the output of begin_state().</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple">
<li><strong>output</strong> (<em>Symbol</em>) –
Symbol corresponding to the output from the RNN when unrolling
for a single time step.</li>
<li><strong>states</strong> (<em>nested list of Symbol</em>) –
The new state of this RNN after this unrolling.
The type of this symbol is same as the output of begin_state().
This can be used as input state to the next time step
of this RNN.</li>
</ul>
</p>
</td>
</tr>
</tbody>
</table>
<div class="admonition seealso">
<p class="first admonition-title">See also</p>
<dl class="last docutils">
<dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.begin_state" title="mxnet.rnn.BaseRNNCell.begin_state"><code class="xref py py-meth docutils literal"><span class="pre">begin_state()</span></code></a></dt>
<dd>This function can provide the states for the first time step.</dd>
<dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unroll" title="mxnet.rnn.BaseRNNCell.unroll"><code class="xref py py-meth docutils literal"><span class="pre">unroll()</span></code></a></dt>
<dd>This function unrolls an RNN for a given number of (>=1) time steps.</dd>
</dl>
</div>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.BaseRNNCell.reset">
<code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.reset" title="Permalink to this definition"></a></dt>
<dd><p>Reset before re-using the cell for another graph.</p>
</dd></dl>
<dl class="method">
<dt>
<code class="descname">__call__</code><span class="sig-paren">(</span><em>inputs</em>, <em>states</em><span class="sig-paren">)</span></dt>
<dd><p>Unroll the RNN for one time step.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>inputs</strong> (<em>sym.Variable</em>) – input symbol, 2D, batch * num_units</li>
<li><strong>states</strong> (<em>list of sym.Variable</em>) – RNN state from previous step or the output of begin_state().</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple">
<li><strong>output</strong> (<em>Symbol</em>) –
Symbol corresponding to the output from the RNN when unrolling
for a single time step.</li>
<li><strong>states</strong> (<em>nested list of Symbol</em>) –
The new state of this RNN after this unrolling.
The type of this symbol is same as the output of begin_state().
This can be used as input state to the next time step
of this RNN.</li>
</ul>
</p>
</td>
</tr>
</tbody>
</table>
<div class="admonition seealso">
<p class="first admonition-title">See also</p>
<dl class="last docutils">
<dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.begin_state" title="mxnet.rnn.BaseRNNCell.begin_state"><code class="xref py py-meth docutils literal"><span class="pre">begin_state()</span></code></a></dt>
<dd>This function can provide the states for the first time step.</dd>
<dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unroll" title="mxnet.rnn.BaseRNNCell.unroll"><code class="xref py py-meth docutils literal"><span class="pre">unroll()</span></code></a></dt>
<dd>This function unrolls an RNN for a given number of (>=1) time steps.</dd>
</dl>
</div>
</dd></dl>
<dl class="attribute">
<dt id="mxnet.rnn.BaseRNNCell.params">
<code class="descname">params</code><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.params" title="Permalink to this definition"></a></dt>
<dd><p>Parameters of this cell</p>
</dd></dl>
<dl class="attribute">
<dt id="mxnet.rnn.BaseRNNCell.state_info">
<code class="descname">state_info</code><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.state_info" title="Permalink to this definition"></a></dt>
<dd><p>shape and layout information of states</p>
</dd></dl>
<dl class="attribute">
<dt id="mxnet.rnn.BaseRNNCell.state_shape">
<code class="descname">state_shape</code><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.state_shape" title="Permalink to this definition"></a></dt>
<dd><p>shape(s) of states</p>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.BaseRNNCell.begin_state">
<code class="descname">begin_state</code><span class="sig-paren">(</span><em>func=<function zeros></em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.begin_state" title="Permalink to this definition"></a></dt>
<dd><p>Initial state for this cell.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>func</strong> (<em>callable, default symbol.zeros</em>) – Function for creating initial state. Can be symbol.zeros,
symbol.uniform, symbol.Variable etc.
Use symbol.Variable if you want to directly
feed input as states.</li>
<li><strong>**kwargs</strong><p>more keyword arguments passed to func. For example
mean, std, dtype, etc.</p>
</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>states</strong>
Starting states for the first RNN step.</p>
</td>
</tr>
<tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">nested list of Symbol</p>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.BaseRNNCell.unpack_weights">
<code class="descname">unpack_weights</code><span class="sig-paren">(</span><em>args</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.unpack_weights" title="Permalink to this definition"></a></dt>
<dd><p>Unpack fused weight matrices into separate
weight matrices.</p>
<p>For example, say you use a module object <cite>mod</cite> to run a network that has an lstm cell.
In <cite>mod.get_params()[0]</cite>, the lstm parameters are all represented as a single big vector.
<cite>cell.unpack_weights(mod.get_params()[0])</cite> will unpack this vector into a dictionary of
more readable lstm parameters - c, f, i, o gates for i2h (input to hidden) and
h2h (hidden to hidden) weights.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args</strong> (<em>dict of str -> NDArray</em>) – Dictionary containing packed weights.
usually from <cite>Module.get_params()[0]</cite>.</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>args</strong>
Dictionary with unpacked weights associated with
this cell.</td>
</tr>
<tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">dict of str -> NDArray</td>
</tr>
</tbody>
</table>
<div class="admonition seealso">
<p class="first admonition-title">See also</p>
<dl class="last docutils">
<dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.pack_weights" title="mxnet.rnn.BaseRNNCell.pack_weights"><code class="xref py py-meth docutils literal"><span class="pre">pack_weights()</span></code></a></dt>
<dd>Performs the reverse operation of this function.</dd>
</dl>
</div>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.BaseRNNCell.pack_weights">
<code class="descname">pack_weights</code><span class="sig-paren">(</span><em>args</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.pack_weights" title="Permalink to this definition"></a></dt>
<dd><p>Pack separate weight matrices into a single packed
weight.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args</strong> (<em>dict of str -> NDArray</em>) – Dictionary containing unpacked weights.</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>args</strong>
Dictionary with packed weights associated with
this cell.</td>
</tr>
<tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">dict of str -> NDArray</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.BaseRNNCell.unroll">
<code class="descname">unroll</code><span class="sig-paren">(</span><em>length</em>, <em>inputs</em>, <em>begin_state=None</em>, <em>layout='NTC'</em>, <em>merge_outputs=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.unroll" title="Permalink to this definition"></a></dt>
<dd><p>Unroll an RNN cell across time steps.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>length</strong> (<em>int</em>) – Number of steps to unroll.</li>
<li><strong>inputs</strong> (<em>Symbol, list of Symbol, or None</em>) – <p>If <cite>inputs</cite> is a single Symbol (usually the output
of Embedding symbol), it should have shape
(batch_size, length, ...) if layout == ‘NTC’,
or (length, batch_size, ...) if layout == ‘TNC’.</p>
<p>If <cite>inputs</cite> is a list of symbols (usually output of
previous unroll), they should all have shape
(batch_size, ...).</p>
</li>
<li><strong>begin_state</strong> (<em>nested list of Symbol, default None</em>) – Input states created by <cite>begin_state()</cite>
or output state of another cell.
Created from <cite>begin_state()</cite> if None.</li>
<li><strong>layout</strong> (<em>str, optional</em>) – <cite>layout</cite> of input symbol. Only used if inputs
is a single Symbol.</li>
<li><strong>merge_outputs</strong> (<em>bool, optional</em>) – If False, return outputs as a list of Symbols.
If True, concatenate output across time steps
and return a single symbol with shape
(batch_size, length, ...) if layout == ‘NTC’,
or (length, batch_size, ...) if layout == ‘TNC’.
If None, output whatever is faster.</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple">
<li><strong>outputs</strong> (<em>list of Symbol or Symbol</em>) –
Symbol (if <cite>merge_outputs</cite> is True) or list of Symbols
(if <cite>merge_outputs</cite> is False) corresponding to the output from
the RNN from this unrolling.</li>
<li><strong>states</strong> (<em>nested list of Symbol</em>) –
The new state of this RNN after this unrolling.
The type of this symbol is same as the output of begin_state().</li>
</ul>
</p>
</td>
</tr>
</tbody>
</table>
</dd></dl>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.LSTMCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">LSTMCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>prefix='lstm_'</em>, <em>params=None</em>, <em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.LSTMCell" title="Permalink to this definition"></a></dt>
<dd><p>Long-Short Term Memory (LSTM) network cell.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li>
<li><strong>prefix</strong> (str, default ‘<a href="#id5"><span class="problematic" id="id6">lstm_</span></a>‘) – Prefix for name of layers (and name of weight if params is None).</li>
<li><strong>params</strong> (<em>RNNParams, default None</em>) – Container for weight sharing between cells. Created if None.</li>
<li><strong>forget_bias</strong> (<em>bias added to forget gate, default 1.0.</em>) – Jozefowicz et al. 2015 recommends setting this to 1.0</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.GRUCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">GRUCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>prefix='gru_'</em>, <em>params=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.GRUCell" title="Permalink to this definition"></a></dt>
<dd><p>Gated Rectified Unit (GRU) network cell.
Note: this is an implementation of the cuDNN version of GRUs
(slight modification compared to Cho et al. 2014).</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li>
<li><strong>prefix</strong> (str, default ‘<a href="#id7"><span class="problematic" id="id8">gru_</span></a>‘) – Prefix for name of layers (and name of weight if params is None).</li>
<li><strong>params</strong> (<em>RNNParams, default None</em>) – Container for weight sharing between cells. Created if None.</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.RNNCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">RNNCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>activation='tanh'</em>, <em>prefix='rnn_'</em>, <em>params=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.RNNCell" title="Permalink to this definition"></a></dt>
<dd><p>Simple recurrent neural network cell.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li>
<li><strong>activation</strong> (<em>str or Symbol, default 'tanh'</em>) – Type of activation function. Options are ‘relu’ and ‘tanh’.</li>
<li><strong>prefix</strong> (str, default ‘<a href="#id9"><span class="problematic" id="id10">rnn_</span></a>‘) – Prefix for name of layers (and name of weight if params is None).</li>
<li><strong>params</strong> (<em>RNNParams, default None</em>) – Container for weight sharing between cells. Created if None.</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.FusedRNNCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">FusedRNNCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>num_layers=1</em>, <em>mode='lstm'</em>, <em>bidirectional=False</em>, <em>dropout=0.0</em>, <em>get_next_state=False</em>, <em>forget_bias=1.0</em>, <em>prefix=None</em>, <em>params=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.FusedRNNCell" title="Permalink to this definition"></a></dt>
<dd><p>Fusing RNN layers across time step into one kernel.
Improves speed but is less flexible. Currently only
supported if using cuDNN on GPU.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li>
<li><strong>num_layers</strong> (<em>int, default 1</em>) – Number of layers in the cell.</li>
<li><strong>mode</strong> (<em>str, default 'lstm'</em>) – Type of RNN. options are ‘rnn_relu’, ‘rnn_tanh’, ‘lstm’, ‘gru’.</li>
<li><strong>bidirectional</strong> (<em>bool, default False</em>) – Whether to use bidirectional unroll. The output dimension size is doubled if bidrectional.</li>
<li><strong>dropout</strong> (<em>float, default 0.</em>) – Fraction of the input that gets dropped out during training time.</li>
<li><strong>get_next_state</strong> (<em>bool, default False</em>) – Whether to return the states that can be used as starting states next time.</li>
<li><strong>forget_bias</strong> (<em>bias added to forget gate, default 1.0.</em>) – Jozefowicz et al. 2015 recommends setting this to 1.0</li>
<li><strong>prefix</strong> (str, default ‘$mode_’ such as ‘<a href="#id11"><span class="problematic" id="id12">lstm_</span></a>‘) – Prefix for names of layers
(this prefix is also used for names of weights if <cite>params</cite> is None
i.e. if <cite>params</cite> are being created and not reused)</li>
<li><strong>params</strong> (<em>RNNParams, default None</em>) – Container for weight sharing between cells. Created if None.</li>
</ul>
</td>
</tr>
</tbody>
</table>
<dl class="method">
<dt id="mxnet.rnn.FusedRNNCell.unfuse">
<code class="descname">unfuse</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.FusedRNNCell.unfuse" title="Permalink to this definition"></a></dt>
<dd><p>Unfuse the fused RNN in to a stack of rnn cells.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body"><strong>cell</strong>
unfused cell that can be used for stepping, and can run on CPU.</td>
</tr>
<tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body"><a class="reference internal" href="#mxnet.rnn.SequentialRNNCell" title="mxnet.rnn.SequentialRNNCell">SequentialRNNCell</a></td>
</tr>
</tbody>
</table>
</dd></dl>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.SequentialRNNCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">SequentialRNNCell</code><span class="sig-paren">(</span><em>params=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.SequentialRNNCell" title="Permalink to this definition"></a></dt>
<dd><p>Sequantially stacking multiple RNN cells.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>params</strong> (<em>RNNParams, default None</em>) – Container for weight sharing between cells. Created if None.</td>
</tr>
</tbody>
</table>
<dl class="method">
<dt id="mxnet.rnn.SequentialRNNCell.add">
<code class="descname">add</code><span class="sig-paren">(</span><em>cell</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.SequentialRNNCell.add" title="Permalink to this definition"></a></dt>
<dd><p>Append a cell into the stack.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – The cell to be appended. During unroll, previous cell’s output (or raw inputs if
no previous cell) is used as the input to this cell.</td>
</tr>
</tbody>
</table>
</dd></dl>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.BidirectionalCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">BidirectionalCell</code><span class="sig-paren">(</span><em>l_cell</em>, <em>r_cell</em>, <em>params=None</em>, <em>output_prefix='bi_'</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BidirectionalCell" title="Permalink to this definition"></a></dt>
<dd><p>Bidirectional RNN cell.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>l_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – cell for forward unrolling</li>
<li><strong>r_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – cell for backward unrolling</li>
<li><strong>params</strong> (<em>RNNParams, default None.</em>) – Container for weight sharing between cells.
A new RNNParams container is created if <cite>params</cite> is None.</li>
<li><strong>output_prefix</strong> (str, default ‘<a href="#id13"><span class="problematic" id="id14">bi_</span></a>‘) – prefix for name of output</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.DropoutCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">DropoutCell</code><span class="sig-paren">(</span><em>dropout</em>, <em>prefix='dropout_'</em>, <em>params=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.DropoutCell" title="Permalink to this definition"></a></dt>
<dd><p>Apply dropout on input.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>dropout</strong> (<em>float</em>) – Percentage of elements to drop out, which
is 1 - percentage to retain.</li>
<li><strong>prefix</strong> (str, default ‘<a href="#id15"><span class="problematic" id="id16">dropout_</span></a>‘) – Prefix for names of layers
(this prefix is also used for names of weights if <cite>params</cite> is None
i.e. if <cite>params</cite> are being created and not reused)</li>
<li><strong>params</strong> (<em>RNNParams, default None</em>) – Container for weight sharing between cells. Created if None.</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.ZoneoutCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">ZoneoutCell</code><span class="sig-paren">(</span><em>base_cell</em>, <em>zoneout_outputs=0.0</em>, <em>zoneout_states=0.0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.ZoneoutCell" title="Permalink to this definition"></a></dt>
<dd><p>Apply Zoneout on base cell.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>base_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – Cell on whose states to perform zoneout.</li>
<li><strong>zoneout_outputs</strong> (<em>float, default 0.</em>) – Fraction of the output that gets dropped out during training time.</li>
<li><strong>zoneout_states</strong> (<em>float, default 0.</em>) – Fraction of the states that gets dropped out during training time.</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.ResidualCell">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">ResidualCell</code><span class="sig-paren">(</span><em>base_cell</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.ResidualCell" title="Permalink to this definition"></a></dt>
<dd><p>Adds residual connection as described in Wu et al, 2016
(<a class="reference external" href="https://arxiv.org/abs/1609.08144">https://arxiv.org/abs/1609.08144</a>).</p>
<p>Output of the cell is output of the base cell plus input.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>base_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – Cell on whose outputs to add residual connection.</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.RNNParams">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">RNNParams</code><span class="sig-paren">(</span><em>prefix=''</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.RNNParams" title="Permalink to this definition"></a></dt>
<dd><p>Container for holding variables.
Used by RNN cells for parameter sharing between cells.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>prefix</strong> (<em>str</em>) – Names of all variables created by this container will
be prepended with prefix.</td>
</tr>
</tbody>
</table>
<dl class="method">
<dt id="mxnet.rnn.RNNParams.get">
<code class="descname">get</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.RNNParams.get" title="Permalink to this definition"></a></dt>
<dd><p>Get the variable given a name if one exists or create a new one if missing.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>name</strong> (<em>str</em>) – name of the variable</li>
<li><strong>**kwargs</strong><p>more arguments that’s passed to symbol.Variable</p>
</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
</dd></dl>
<dl class="class">
<dt id="mxnet.rnn.BucketSentenceIter">
<em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">BucketSentenceIter</code><span class="sig-paren">(</span><em>sentences</em>, <em>batch_size</em>, <em>buckets=None</em>, <em>invalid_label=-1</em>, <em>data_name='data'</em>, <em>label_name='softmax_label'</em>, <em>dtype='float32'</em>, <em>layout='NT'</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BucketSentenceIter" title="Permalink to this definition"></a></dt>
<dd><p>Simple bucketing iterator for language model.
The label at each sequence step is the following token
in the sequence.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>sentences</strong> (<em>list of list of int</em>) – Encoded sentences.</li>
<li><strong>batch_size</strong> (<em>int</em>) – Batch size of the data.</li>
<li><strong>invalid_label</strong> (<em>int, optional</em>) – Key for invalid label, e.g. <end-of-sentence>. The default is -1.</li>
<li><strong>dtype</strong> (<em>str, optional</em>) – Data type of the encoding. The default data type is ‘float32’.</li>
<li><strong>buckets</strong> (<em>list of int, optional</em>) – Size of the data buckets. Automatically generated if None.</li>
<li><strong>data_name</strong> (<em>str, optional</em>) – Name of the data. The default name is ‘data’.</li>
<li><strong>label_name</strong> (<em>str, optional</em>) – Name of the label. The default name is ‘softmax_label’.</li>
<li><strong>layout</strong> (<em>str, optional</em>) – Format of data and label. ‘NT’ means (batch_size, length)
and ‘TN’ means (length, batch_size).</li>
</ul>
</td>
</tr>
</tbody>
</table>
<dl class="method">
<dt id="mxnet.rnn.BucketSentenceIter.reset">
<code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BucketSentenceIter.reset" title="Permalink to this definition"></a></dt>
<dd><p>Resets the iterator to the beginning of the data.</p>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.BucketSentenceIter.next">
<code class="descname">next</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.BucketSentenceIter.next" title="Permalink to this definition"></a></dt>
<dd><p>Returns the next batch of data.</p>
</dd></dl>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.encode_sentences">
<code class="descclassname">rnn.</code><code class="descname">encode_sentences</code><span class="sig-paren">(</span><em>sentences</em>, <em>vocab=None</em>, <em>invalid_label=-1</em>, <em>invalid_key='\n'</em>, <em>start_label=0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.encode_sentences" title="Permalink to this definition"></a></dt>
<dd><p>Encode sentences and (optionally) build a mapping
from string tokens to integer indices. Unknown keys
will be added to vocabulary.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>sentences</strong> (<em>list of list of str</em>) – A list of sentences to encode. Each sentence
should be a list of string tokens.</li>
<li><strong>vocab</strong> (<em>None or dict of str -> int</em>) – Optional input Vocabulary</li>
<li><strong>invalid_label</strong> (<em>int, default -1</em>) – Index for invalid token, like <end-of-sentence></li>
<li><strong>invalid_key</strong> (<em>str, default 'n'</em>) – Key for invalid token. Use ‘n’ for end
of sentence by default.</li>
<li><strong>start_label</strong> (<em>int</em>) – lowest index.</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple">
<li><strong>result</strong> (<em>list of list of int</em>) –
encoded sentences</li>
<li><strong>vocab</strong> (<em>dict of str -> int</em>) –
result vocabulary</li>
</ul>
</p>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.save_rnn_checkpoint">
<code class="descclassname">rnn.</code><code class="descname">save_rnn_checkpoint</code><span class="sig-paren">(</span><em>cells</em>, <em>prefix</em>, <em>epoch</em>, <em>symbol</em>, <em>arg_params</em>, <em>aux_params</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.save_rnn_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Save checkpoint for model using RNN cells.
Unpacks weight before saving.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>cells</strong> (<em>RNNCell or list of RNNCells</em>) – The RNN cells used by this symbol.</li>
<li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li>
<li><strong>epoch</strong> (<em>int</em>) – The epoch number of the model.</li>
<li><strong>symbol</strong> (<a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The input symbol</li>
<li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s weights.</li>
<li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li>
</ul>
</td>
</tr>
</tbody>
</table>
<p class="rubric">Notes</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li>
<li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li>
</ul>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.load_rnn_checkpoint">
<code class="descclassname">rnn.</code><code class="descname">load_rnn_checkpoint</code><span class="sig-paren">(</span><em>cells</em>, <em>prefix</em>, <em>epoch</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.load_rnn_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Load model checkpoint from file.
Pack weights after loading.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>cells</strong> (<em>RNNCell or list of RNNCells</em>) – The RNN cells used by this symbol.</li>
<li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li>
<li><strong>epoch</strong> (<em>int</em>) – Epoch number of model we would like to load.</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple">
<li><strong>symbol</strong> (<em>Symbol</em>) –
The symbol configuration of computation network.</li>
<li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) –
Model parameter, dict of name to NDArray of net’s weights.</li>
<li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) –
Model parameter, dict of name to NDArray of net’s auxiliary states.</li>
</ul>
</p>
</td>
</tr>
</tbody>
</table>
<p class="rubric">Notes</p>
<ul class="simple">
<li>symbol will be loaded from <code class="docutils literal"><span class="pre">prefix-symbol.json</span></code>.</li>
<li>parameters will be loaded from <code class="docutils literal"><span class="pre">prefix-epoch.params</span></code>.</li>
</ul>
</dd></dl>
<dl class="method">
<dt id="mxnet.rnn.do_rnn_checkpoint">
<code class="descclassname">rnn.</code><code class="descname">do_rnn_checkpoint</code><span class="sig-paren">(</span><em>cells</em>, <em>prefix</em>, <em>period=1</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.do_rnn_checkpoint" title="Permalink to this definition"></a></dt>
<dd><p>Make a callback to checkpoint Module to prefix every epoch.
unpacks weights used by cells before saving.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name"/>
<col class="field-body"/>
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>cells</strong> (<em>RNNCell or list of RNNCells</em>) – The RNN cells used by this symbol.</li>
<li><strong>prefix</strong> (<em>str</em>) – The file prefix to checkpoint to</li>
<li><strong>period</strong> (<em>int</em>) – How many epochs to wait before checkpointing. Default is 1.</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>callback</strong>
The callback function that can be passed as iter_end_callback to fit.</p>
</td>
</tr>
<tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">function</p>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<script>auto_index("api-reference");</script></div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">RNN Cell API</a><ul>
<li><a class="reference internal" href="#overview">Overview</a></li>
<li><a class="reference internal" href="#the-rnn-module">The <code class="docutils literal"><span class="pre">rnn</span></code> module</a><ul>
<li><a class="reference internal" href="#cell-interfaces">Cell interfaces</a></li>
<li><a class="reference internal" href="#basic-rnn-cells">Basic RNN cells</a></li>
<li><a class="reference internal" href="#modifier-cells">Modifier cells</a></li>
<li><a class="reference internal" href="#multi-layer-cells">Multi-layer cells</a></li>
<li><a class="reference internal" href="#fused-rnn-cell">Fused RNN cell</a></li>
<li><a class="reference internal" href="#rnn-checkpoint-methods-and-parameters">RNN checkpoint methods and parameters</a></li>
<li><a class="reference internal" href="#i-o-utilities">I/O utilities</a></li>
</ul>
</li>
<li><a class="reference internal" href="#api-reference">API Reference</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>