| <!DOCTYPE html> |
| |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta content="IE=edge" http-equiv="X-UA-Compatible"/> |
| <meta content="width=device-width, initial-scale=1" name="viewport"/> |
| <meta content="RNN Cell API" property="og:title"> |
| <meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image"> |
| <meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url"> |
| <meta content="RNN Cell API" property="og:description"/> |
| <title>RNN Cell API — mxnet documentation</title> |
| <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> |
| <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> |
| <link href="../../_static/basic.css" rel="stylesheet" type="text/css"> |
| <link href="../../_static/pygments.css" rel="stylesheet" type="text/css"> |
| <link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/> |
| <script type="text/javascript"> |
| var DOCUMENTATION_OPTIONS = { |
| URL_ROOT: '../../', |
| VERSION: '', |
| COLLAPSE_INDEX: false, |
| FILE_SUFFIX: '.html', |
| HAS_SOURCE: true, |
| SOURCELINK_SUFFIX: '.txt' |
| }; |
| </script> |
| <script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script> |
| <script src="../../_static/underscore.js" type="text/javascript"></script> |
| <script src="../../_static/searchtools_custom.js" type="text/javascript"></script> |
| <script src="../../_static/doctools.js" type="text/javascript"></script> |
| <script src="../../_static/selectlang.js" type="text/javascript"></script> |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> |
| <script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/0.11.0/searchindex.js"); Search.init();}); </script> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new |
| Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| |
| </script> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/jquery.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/underscore.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/doctools.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> |
| <!-- --> |
| <link href="../../genindex.html" rel="index" title="Index"> |
| <link href="../../search.html" rel="search" title="Search"/> |
| <link href="index.html" rel="up" title="MXNet - Python API"/> |
| <link href="kvstore.html" rel="next" title="KVStore API"/> |
| <link href="gluon.html" rel="prev" title="Gluon Package"/> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/> |
| </link></link></link></meta></meta></meta></head> |
| <body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document"> |
| <div class="content-block"><div class="navbar navbar-fixed-top"> |
| <div class="container" id="navContainer"> |
| <div class="innder" id="header-inner"> |
| <h1 id="logo-wrap"> |
| <a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a> |
| </h1> |
| <nav class="nav-bar" id="main-nav"> |
| <a class="main-nav-link" href="/versions/0.11.0/get_started/install.html">Install</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/gluon/gluon.html">About</a></li> |
| <li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li> |
| <li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li> |
| <li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/perl/index.html">Perl</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/scala/index.html">Scala</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-docs"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs"> |
| <li><a class="main-nav-link" href="/versions/0.11.0/how_to/faq.html">FAQ</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/index.html">Tutorials</a> |
| <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/v0.11.0/example">Examples</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/model_zoo/index.html">Model Zoo</a></li> |
| <li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li> |
| </li></ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-community"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community"> |
| <li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li> |
| <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/v0.11.0">Github</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/community/contribute.html">Contribute</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">0.11.0<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7/">1.7</a></li><li><a href=/versions/1.6/>1.6</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav> |
| <script> function getRootPath(){ return "../../" } </script> |
| <div class="burgerIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> |
| <ul class="dropdown-menu" id="burgerMenu"> |
| <li><a href="/versions/0.11.0/get_started/install.html">Install</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/index.html">Tutorials</a></li> |
| <li class="dropdown-submenu dropdown"> |
| <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/gluon/gluon.html">About</a></li> |
| <li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li> |
| <li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li> |
| <li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu"> |
| <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a> |
| <ul class="dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/perl/index.html">Perl</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="/versions/0.11.0/api/scala/index.html">Scala</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu"> |
| <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a> |
| <ul class="dropdown-menu"> |
| <li><a href="/versions/0.11.0/how_to/faq.html" tabindex="-1">FAQ</a></li> |
| <li><a href="/versions/0.11.0/tutorials/index.html" tabindex="-1">Tutorials</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/tree/v0.11.0/example" tabindex="-1">Examples</a></li> |
| <li><a href="/versions/0.11.0/architecture/index.html" tabindex="-1">Architecture</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li> |
| <li><a href="/versions/0.11.0/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li> |
| <li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu dropdown"> |
| <a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a> |
| <ul class="dropdown-menu"> |
| <li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/tree/v0.11.0" tabindex="-1">Github</a></li> |
| <li><a href="/versions/0.11.0/community/contribute.html" tabindex="-1">Contribute</a></li> |
| </ul> |
| </li> |
| <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">0.11.0</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6/>1.6</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul> |
| </div> |
| <div class="plusIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> |
| <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> |
| </div> |
| <div id="search-input-wrap"> |
| <form action="../../search.html" autocomplete="off" class="" method="get" role="search"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input class="form-control" name="q" placeholder="Search" type="text"/> |
| </div> |
| <input name="check_keywords" type="hidden" value="yes"> |
| <input name="area" type="hidden" value="default"/> |
| </input></form> |
| <div id="search-preview"></div> |
| </div> |
| <div id="searchIcon"> |
| <span aria-hidden="true" class="glyphicon glyphicon-search"></span> |
| </div> |
| <!-- <div id="lang-select-wrap"> --> |
| <!-- <label id="lang-select-label"> --> |
| <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> |
| <!-- <span></span> --> |
| <!-- </label> --> |
| <!-- <select id="lang-select"> --> |
| <!-- <option value="en">Eng</option> --> |
| <!-- <option value="zh">中文</option> --> |
| <!-- </select> --> |
| <!-- </div> --> |
| <!-- <a id="mobile-nav-toggle"> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| </a> --> |
| </div> |
| </div> |
| </div> |
| <script type="text/javascript"> |
| $('body').css('background', 'white'); |
| </script> |
| <div class="container"> |
| <div class="row"> |
| <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <ul class="current"> |
| <li class="toctree-l1 current"><a class="reference internal" href="index.html">Python Documents</a><ul class="current"> |
| <li class="toctree-l2 current"><a class="reference internal" href="index.html#table-of-contents">Table of contents</a><ul class="current"> |
| <li class="toctree-l3"><a class="reference internal" href="ndarray.html">NDArray API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="symbol.html">Symbol API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="module.html">Module API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="autograd.html">Autograd Package</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="gluon.html">Gluon Package</a></li> |
| <li class="toctree-l3 current"><a class="current reference internal" href="#">RNN Cell API</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#overview">Overview</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#the-rnn-module">The <code class="docutils literal"><span class="pre">rnn</span></code> module</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#api-reference">API Reference</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="kvstore.html">KVStore API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="io.html">Data Loading API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="image.html">Image API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="optimization.html">Optimization: initialize and update weights</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="callback.html">Callback API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="metric.html">Evaluation Metric API</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| <li class="toctree-l1"><a class="reference internal" href="../r/index.html">R Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../julia/index.html">Julia Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../c++/index.html">C++ Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../scala/index.html">Scala Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../perl/index.html">Perl Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li> |
| </ul> |
| </div> |
| </div> |
| <div class="content"> |
| <div class="page-tracker"></div> |
| <div class="section" id="rnn-cell-api"> |
| <span id="rnn-cell-api"></span><h1>RNN Cell API<a class="headerlink" href="#rnn-cell-api" title="Permalink to this headline">¶</a></h1> |
| <div class="admonition warning"> |
| <p class="first admonition-title">Warning</p> |
| <p class="last">This package is currently experimental and may change in the near future.</p> |
| </div> |
| <div class="section" id="overview"> |
| <span id="overview"></span><h2>Overview<a class="headerlink" href="#overview" title="Permalink to this headline">¶</a></h2> |
| <p>The <code class="docutils literal"><span class="pre">rnn</span></code> module includes the recurrent neural network (RNN) cell APIs, a suite of tools for building an RNN’s symbolic graph.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The <cite>rnn</cite> module offers higher-level interface while <cite>symbol.RNN</cite> is a lower-level interface. The cell APIs in <cite>rnn</cite> module are easier to use in most cases.</p> |
| </div> |
| </div> |
| <div class="section" id="the-rnn-module"> |
| <span id="the-rnn-module"></span><h2>The <code class="docutils literal"><span class="pre">rnn</span></code> module<a class="headerlink" href="#the-rnn-module" title="Permalink to this headline">¶</a></h2> |
| <div class="section" id="cell-interfaces"> |
| <span id="cell-interfaces"></span><h3>Cell interfaces<a class="headerlink" href="#cell-interfaces" title="Permalink to this headline">¶</a></h3> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.__call__" title="mxnet.rnn.BaseRNNCell.__call__"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.__call__</span></code></a></td> |
| <td>Unroll the RNN for one time step.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unroll" title="mxnet.rnn.BaseRNNCell.unroll"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.unroll</span></code></a></td> |
| <td>Unroll an RNN cell across time steps.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.reset" title="mxnet.rnn.BaseRNNCell.reset"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.reset</span></code></a></td> |
| <td>Reset before re-using the cell for another graph.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.begin_state" title="mxnet.rnn.BaseRNNCell.begin_state"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.begin_state</span></code></a></td> |
| <td>Initial state for this cell.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unpack_weights" title="mxnet.rnn.BaseRNNCell.unpack_weights"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.unpack_weights</span></code></a></td> |
| <td>Unpack fused weight matrices into separate weight matrices.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.pack_weights" title="mxnet.rnn.BaseRNNCell.pack_weights"><code class="xref py py-obj docutils literal"><span class="pre">BaseRNNCell.pack_weights</span></code></a></td> |
| <td>Pack separate weight matrices into a single packed weight.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p>When working with the cell API, the precise input and output symbols |
| depend on the type of RNN you are using. Take Long Short-Term Memory (LSTM) for example:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span> |
| <span class="c1"># Shape of 'step_data' is (batch_size,).</span> |
| <span class="n">step_input</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'step_data'</span><span class="p">)</span> |
| |
| <span class="c1"># First we embed our raw input data to be used as LSTM's input.</span> |
| <span class="n">embedded_step</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">step_input</span><span class="p">,</span> \ |
| <span class="n">input_dim</span><span class="o">=</span><span class="n">input_dim</span><span class="p">,</span> \ |
| <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> |
| |
| <span class="c1"># Then we create an LSTM cell.</span> |
| <span class="n">lstm_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span> |
| <span class="c1"># Initialize its hidden and memory states.</span> |
| <span class="c1"># 'begin_state' method takes an initialization function, and uses 'zeros' by default.</span> |
| <span class="n">begin_state</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="o">.</span><span class="n">begin_state</span><span class="p">()</span> |
| </pre></div> |
| </div> |
| <p>The LSTM cell and other non-fused RNN cells are callable. Calling the cell updates it’s state once. This transformation depends on both the current input and the previous states. See this <a class="reference external" href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">blog post</a> for a great introduction to LSTM and other RNN.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Call the cell to get the output of one time step for a batch.</span> |
| <span class="n">output</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="p">(</span><span class="n">embedded_step</span><span class="p">,</span> <span class="n">begin_state</span><span class="p">)</span> |
| |
| <span class="c1"># 'output' is lstm_t0_out_output of shape (batch_size, hidden_dim).</span> |
| |
| <span class="c1"># 'states' has the recurrent states that will be carried over to the next step,</span> |
| <span class="c1"># which includes both the "hidden state" and the "cell state":</span> |
| <span class="c1"># Both 'lstm_t0_out_output' and 'lstm_t0_state_output' have shape (batch_size, hidden_dim).</span> |
| </pre></div> |
| </div> |
| <p>Most of the time our goal is to process a sequence of many steps. For this, we need to unroll the LSTM according to the sequence length.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Embed a sequence. 'seq_data' has the shape of (batch_size, sequence_length).</span> |
| <span class="n">seq_input</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'seq_data'</span><span class="p">)</span> |
| <span class="n">embedded_seq</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">seq_input</span><span class="p">,</span> \ |
| <span class="n">input_dim</span><span class="o">=</span><span class="n">input_dim</span><span class="p">,</span> \ |
| <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">Remember to reset the cell when unrolling/stepping for a new sequence by calling <cite>lstm_cell.reset()</cite>.</p> |
| </div> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Note that when unrolling, if 'merge_outputs' is set to True, the 'outputs' is merged into a single symbol</span> |
| <span class="c1"># In the layout, 'N' represents batch size, 'T' represents sequence length, and 'C' represents the</span> |
| <span class="c1"># number of dimensions in hidden states.</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">layout</span><span class="o">=</span><span class="s1">'NTC'</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| <span class="c1"># 'outputs' is concat0_output of shape (batch_size, sequence_length, hidden_dim).</span> |
| <span class="c1"># The hidden state and cell state from the final time step is returned:</span> |
| <span class="c1"># Both 'lstm_t4_out_output' and 'lstm_t4_state_output' have shape (batch_size, hidden_dim).</span> |
| |
| <span class="c1"># If merge_outputs is set to False, a list of symbols for each of the time steps is returned.</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">layout</span><span class="o">=</span><span class="s1">'NTC'</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> |
| <span class="c1"># In this case, 'outputs' is a list of symbols. Each symbol is of shape (batch_size, hidden_dim).</span> |
| </pre></div> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">Loading and saving models that are built with RNN cells API requires using |
| <cite>mx.rnn.load_rnn_checkpoint</cite>, <cite>mx.rnn.save_rnn_checkpoint</cite>, and <cite>mx.rnn.do_rnn_checkpoint</cite>. |
| The list of all the used cells should be provided as the first argument to those functions.</p> |
| </div> |
| </div> |
| <div class="section" id="basic-rnn-cells"> |
| <span id="basic-rnn-cells"></span><h3>Basic RNN cells<a class="headerlink" href="#basic-rnn-cells" title="Permalink to this headline">¶</a></h3> |
| <p><code class="docutils literal"><span class="pre">rnn</span></code> module supports the following RNN cell types.</p> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.LSTMCell" title="mxnet.rnn.LSTMCell"><code class="xref py py-obj docutils literal"><span class="pre">LSTMCell</span></code></a></td> |
| <td>Long-Short Term Memory (LSTM) network cell.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.GRUCell" title="mxnet.rnn.GRUCell"><code class="xref py py-obj docutils literal"><span class="pre">GRUCell</span></code></a></td> |
| <td>Gated Rectified Unit (GRU) network cell.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.RNNCell" title="mxnet.rnn.RNNCell"><code class="xref py py-obj docutils literal"><span class="pre">RNNCell</span></code></a></td> |
| <td>Simple recurrent neural network cell.</td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| <div class="section" id="modifier-cells"> |
| <span id="modifier-cells"></span><h3>Modifier cells<a class="headerlink" href="#modifier-cells" title="Permalink to this headline">¶</a></h3> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BidirectionalCell" title="mxnet.rnn.BidirectionalCell"><code class="xref py py-obj docutils literal"><span class="pre">BidirectionalCell</span></code></a></td> |
| <td>Bidirectional RNN cell.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.DropoutCell" title="mxnet.rnn.DropoutCell"><code class="xref py py-obj docutils literal"><span class="pre">DropoutCell</span></code></a></td> |
| <td>Apply dropout on input.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.ZoneoutCell" title="mxnet.rnn.ZoneoutCell"><code class="xref py py-obj docutils literal"><span class="pre">ZoneoutCell</span></code></a></td> |
| <td>Apply Zoneout on base cell.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.ResidualCell" title="mxnet.rnn.ResidualCell"><code class="xref py py-obj docutils literal"><span class="pre">ResidualCell</span></code></a></td> |
| <td>Adds residual connection as described in Wu et al, 2016 (<a class="reference external" href="https://arxiv.org/abs/1609.08144">https://arxiv.org/abs/1609.08144</a>).</td> |
| </tr> |
| </tbody> |
| </table> |
| <p>A modifier cell takes in one or more cells and transforms the output of those cells. |
| <code class="docutils literal"><span class="pre">BidirectionalCell</span></code> is one example. It takes two cells for forward unroll and backward unroll |
| respectively. After unrolling, the outputs of the forward and backward pass are concatenated.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Bidirectional cell takes two RNN cells, for forward and backward pass respectively.</span> |
| <span class="c1"># Having different types of cells for forward and backward unrolling is allowed.</span> |
| <span class="n">bi_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">BidirectionalCell</span><span class="p">(</span> |
| <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">),</span> |
| <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">GRUCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">75</span><span class="p">))</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">bi_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| <span class="c1"># The output feature is the concatenation of the forward and backward pass.</span> |
| <span class="c1"># Thus, the number of output dimensions is the sum of the dimensions of the two cells.</span> |
| <span class="c1"># 'outputs' is the symbol 'bi_out_output' of shape (batch_size, sequence_length, 125L)</span> |
| |
| <span class="c1"># The states of the BidirectionalCell is a list of two lists, corresponding to the</span> |
| <span class="c1"># states of the forward and backward cells respectively.</span> |
| </pre></div> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">BidirectionalCell cannot be called or stepped, because the backward unroll requires the output of |
| future steps, and thus the whole sequence is required.</p> |
| </div> |
| <p>Dropout and zoneout are popular regularization techniques that can be applied to RNN. <code class="docutils literal"><span class="pre">rnn</span></code> |
| module provides <code class="docutils literal"><span class="pre">DropoutCell</span></code> and <code class="docutils literal"><span class="pre">ZoneoutCell</span></code> for regularization on the output and recurrent |
| states of RNN. <code class="docutils literal"><span class="pre">ZoneoutCell</span></code> takes one RNN cell in the constructor, and supports unrolling like |
| other cells.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">zoneout_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">ZoneoutCell</span><span class="p">(</span><span class="n">lstm_cell</span><span class="p">,</span> <span class="n">zoneout_states</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">zoneout_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><code class="docutils literal"><span class="pre">DropoutCell</span></code> performs dropout on the input sequence. It can be used in a stacked |
| multi-layer RNN setting, which we will cover next.</p> |
| <p>Residual connection is a useful technique for training deep neural models because it helps the |
| propagation of gradients by shortening the paths. <code class="docutils literal"><span class="pre">ResidualCell</span></code> provides such functionality for |
| RNN models.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">residual_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">ResidualCell</span><span class="p">(</span><span class="n">lstm_cell</span><span class="p">)</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">residual_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The <code class="docutils literal"><span class="pre">outputs</span></code> are the element-wise sum of both the input and the output of the LSTM cell.</p> |
| </div> |
| <div class="section" id="multi-layer-cells"> |
| <span id="multi-layer-cells"></span><h3>Multi-layer cells<a class="headerlink" href="#multi-layer-cells" title="Permalink to this headline">¶</a></h3> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.SequentialRNNCell" title="mxnet.rnn.SequentialRNNCell"><code class="xref py py-obj docutils literal"><span class="pre">SequentialRNNCell</span></code></a></td> |
| <td>Sequantially stacking multiple RNN cells.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.SequentialRNNCell.add" title="mxnet.rnn.SequentialRNNCell.add"><code class="xref py py-obj docutils literal"><span class="pre">SequentialRNNCell.add</span></code></a></td> |
| <td>Append a cell into the stack.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p>The <code class="docutils literal"><span class="pre">SequentialRNNCell</span></code> allows stacking multiple layers of RNN cells to improve the expressiveness |
| and performance of the model. Cells can be added to a <code class="docutils literal"><span class="pre">SequentialRNNCell</span></code> in order, from bottom to |
| top. When unrolling, the output of a lower-level cell is automatically passed to the cell above.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">stacked_rnn_cells</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">SequentialRNNCell</span><span class="p">()</span> |
| <span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">BidirectionalCell</span><span class="p">(</span> |
| <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">),</span> |
| <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">)))</span> |
| |
| <span class="c1"># Dropout the output of the bottom layer BidirectionalCell with a retention probability of 0.5.</span> |
| <span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">DropoutCell</span><span class="p">(</span><span class="mf">0.5</span><span class="p">))</span> |
| |
| <span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">))</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">stacked_rnn_cells</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| |
| <span class="c1"># The output of SequentialRNNCell is the same as that of the last layer.</span> |
| <span class="c1"># In this case 'outputs' is the symbol 'concat6_output' of shape (batch_size, sequence_length, hidden_dim)</span> |
| <span class="c1"># The states of the SequentialRNNCell is a list of lists, with each list</span> |
| <span class="c1"># corresponding to the states of each of the added cells respectively.</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="fused-rnn-cell"> |
| <span id="fused-rnn-cell"></span><h3>Fused RNN cell<a class="headerlink" href="#fused-rnn-cell" title="Permalink to this headline">¶</a></h3> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.FusedRNNCell" title="mxnet.rnn.FusedRNNCell"><code class="xref py py-obj docutils literal"><span class="pre">FusedRNNCell</span></code></a></td> |
| <td>Fusing RNN layers across time step into one kernel.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.FusedRNNCell.unfuse" title="mxnet.rnn.FusedRNNCell.unfuse"><code class="xref py py-obj docutils literal"><span class="pre">FusedRNNCell.unfuse</span></code></a></td> |
| <td>Unfuse the fused RNN in to a stack of rnn cells.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p>The computation of an RNN for an input sequence consists of many GEMM and point-wise operations with |
| temporal dependencies dependencies. This could make the computation memory-bound especially on GPU, |
| resulting in longer wall-time. By combining the computation of many small matrices into that of |
| larger ones and streaming the computation whenever possible, the ratio of computation to memory I/O |
| can be increased, which results in better performance on GPU. Such optimization technique is called |
| “fusing”. |
| <a class="reference external" href="https://devblogs.nvidia.com/parallelforall/optimizing-recurrent-neural-networks-cudnn-5/">This post</a> |
| talks in greater detail.</p> |
| <p>The <code class="docutils literal"><span class="pre">rnn</span></code> module includes a <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>, which provides the optimized fused implementation. |
| The FusedRNNCell supports bidirectional RNNs and dropout.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">fused_lstm_cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">FusedRNNCell</span><span class="p">(</span><span class="n">num_hidden</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> \ |
| <span class="n">num_layers</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> \ |
| <span class="n">mode</span><span class="o">=</span><span class="s1">'lstm'</span><span class="p">,</span> \ |
| <span class="n">bidirectional</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> \ |
| <span class="n">dropout</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="n">outputs</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">fused_lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| <span class="c1"># The 'outputs' is the symbol 'lstm_rnn_output' that has the shape</span> |
| <span class="c1"># (batch_size, sequence_length, forward_backward_concat_dim)</span> |
| </pre></div> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last"><cite>FusedRNNCell</cite> supports GPU-only. It cannot be called or stepped.</p> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">When <cite>dropout</cite> is set to non-zero in <cite>FusedRNNCell</cite>, the dropout is applied to the |
| output of all layers except the last layer. If there is only one layer in the <cite>FusedRNNCell</cite>, the |
| dropout rate is ignored.</p> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">Similar to <cite>BidirectionalCell</cite>, when <cite>bidirectional</cite> flag is set to <cite>True</cite>, the output |
| of <cite>FusedRNNCell</cite> is twice the size specified by <cite>num_hidden</cite>.</p> |
| </div> |
| <p>When training a deep, complex model <em>on multiple GPUs</em> it’s recommended to stack |
| fused RNN cells (one layer per cell) together instead of one with all layers. |
| The reason is that fused RNN cells don’t set gradients to be ready until the |
| computation for the entire layer is completed. Breaking a multi-layer fused RNN |
| cell into several one-layer ones allows gradients to be processed ealier. This |
| reduces communication overhead, especially with multiple GPUs.</p> |
| <p>The <code class="docutils literal"><span class="pre">unfuse()</span></code> method can be used to convert the <code class="docutils literal"><span class="pre">FusedRNNCell</span></code> into an equivalent |
| and CPU-compatible <code class="docutils literal"><span class="pre">SequentialRNNCell</span></code> that mirrors the settings of the <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">unfused_lstm_cell</span> <span class="o">=</span> <span class="n">fused_lstm_cell</span><span class="o">.</span><span class="n">unfuse</span><span class="p">()</span> |
| <span class="n">unfused_outputs</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">unfused_lstm_cell</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> \ |
| <span class="n">inputs</span><span class="o">=</span><span class="n">embedded_seq</span><span class="p">,</span> \ |
| <span class="n">merge_outputs</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> |
| <span class="c1"># The 'outputs' is the symbol 'lstm_bi_l2_out_output' that has the shape</span> |
| <span class="c1"># (batch_size, sequence_length, forward_backward_concat_dim)</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="rnn-checkpoint-methods-and-parameters"> |
| <span id="rnn-checkpoint-methods-and-parameters"></span><h3>RNN checkpoint methods and parameters<a class="headerlink" href="#rnn-checkpoint-methods-and-parameters" title="Permalink to this headline">¶</a></h3> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.save_rnn_checkpoint" title="mxnet.rnn.save_rnn_checkpoint"><code class="xref py py-obj docutils literal"><span class="pre">save_rnn_checkpoint</span></code></a></td> |
| <td>Save checkpoint for model using RNN cells.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.load_rnn_checkpoint" title="mxnet.rnn.load_rnn_checkpoint"><code class="xref py py-obj docutils literal"><span class="pre">load_rnn_checkpoint</span></code></a></td> |
| <td>Load model checkpoint from file.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.do_rnn_checkpoint" title="mxnet.rnn.do_rnn_checkpoint"><code class="xref py py-obj docutils literal"><span class="pre">do_rnn_checkpoint</span></code></a></td> |
| <td>Make a callback to checkpoint Module to prefix every epoch.</td> |
| </tr> |
| </tbody> |
| </table> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><code class="xref py py-obj docutils literal"><span class="pre">RNNParams</span></code></a></td> |
| <td>Container for holding variables.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.RNNParams.get" title="mxnet.rnn.RNNParams.get"><code class="xref py py-obj docutils literal"><span class="pre">RNNParams.get</span></code></a></td> |
| <td>Get the variable given a name if one exists or create a new one if missing.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p>The model parameters from the training with fused cell can be used for inference with unfused cell, |
| and vice versa. As the parameters of fused and unfused cells are organized differently, they need to |
| be converted first. <code class="docutils literal"><span class="pre">FusedRNNCell</span></code>‘s parameters are merged and flattened. In the fused example above, |
| the mode has <code class="docutils literal"><span class="pre">lstm_parameters</span></code> of shape <code class="docutils literal"><span class="pre">(total_num_params,)</span></code>, whereas the |
| equivalent SequentialRNNCell’s parameters are separate:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="s1">'lstm_l0_i2h_weight'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span> |
| <span class="s1">'lstm_l0_i2h_bias'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,)</span> |
| <span class="s1">'lstm_l0_h2h_weight'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span> |
| <span class="s1">'lstm_l0_h2h_bias'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,)</span> |
| <span class="s1">'lstm_r0_i2h_weight'</span><span class="p">:</span> <span class="p">(</span><span class="n">out_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span> |
| <span class="o">...</span> |
| </pre></div> |
| </div> |
| <p>All cells in the <code class="docutils literal"><span class="pre">rnn</span></code> module support the method <code class="docutils literal"><span class="pre">unpack_weights()</span></code> for converting <code class="docutils literal"><span class="pre">FusedRNNCell</span></code> |
| parameters to the unfused format and <code class="docutils literal"><span class="pre">pack_weights()</span></code> for fusing the parameters. The RNN-specific |
| checkpointing methods (<code class="docutils literal"><span class="pre">load_rnn_checkpoint,</span> <span class="pre">save_rnn_checkpoint,</span> <span class="pre">do_rnn_checkpoint</span></code>) handle the |
| conversion transparently based on the provided cells.</p> |
| </div> |
| <div class="section" id="i-o-utilities"> |
| <span id="i-o-utilities"></span><h3>I/O utilities<a class="headerlink" href="#i-o-utilities" title="Permalink to this headline">¶</a></h3> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.rnn.BucketSentenceIter" title="mxnet.rnn.BucketSentenceIter"><code class="xref py py-obj docutils literal"><span class="pre">BucketSentenceIter</span></code></a></td> |
| <td>Simple bucketing iterator for language model.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.rnn.encode_sentences" title="mxnet.rnn.encode_sentences"><code class="xref py py-obj docutils literal"><span class="pre">encode_sentences</span></code></a></td> |
| <td>Encode sentences and (optionally) build a mapping from string tokens to integer indices.</td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| </div> |
| <div class="section" id="api-reference"> |
| <span id="api-reference"></span><h2>API Reference<a class="headerlink" href="#api-reference" title="Permalink to this headline">¶</a></h2> |
| <script src="../../_static/js/auto_module_index.js" type="text/javascript"></script><dl class="class"> |
| <dt id="mxnet.rnn.BaseRNNCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">BaseRNNCell</code><span class="sig-paren">(</span><em>prefix=''</em>, <em>params=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Abstract base class for RNN cells</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>prefix</strong> (<em>str</em><em>, </em><em>optional</em>) – Prefix for names of layers |
| (this prefix is also used for names of weights if <cite>params</cite> is None |
| i.e. if <cite>params</cite> are being created and not reused)</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None.</em>) – Container for weight sharing between cells. |
| A new RNNParams container is created if <cite>params</cite> is None.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BaseRNNCell.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>inputs</em>, <em>states</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell.__call__"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Unroll the RNN for one time step.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>inputs</strong> (<em>sym.Variable</em>) – input symbol, 2D, batch * num_units</li> |
| <li><strong>states</strong> (<em>list of sym.Variable</em>) – RNN state from previous step or the output of begin_state().</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple"> |
| <li><strong>output</strong> (<em>Symbol</em>) – Symbol corresponding to the output from the RNN when unrolling |
| for a single time step.</li> |
| <li><strong>states</strong> (<em>nested list of Symbol</em>) – The new state of this RNN after this unrolling. |
| The type of this symbol is same as the output of begin_state(). |
| This can be used as input state to the next time step |
| of this RNN.</li> |
| </ul> |
| </p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <div class="admonition seealso"> |
| <p class="first admonition-title">See also</p> |
| <dl class="last docutils"> |
| <dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.begin_state" title="mxnet.rnn.BaseRNNCell.begin_state"><code class="xref py py-meth docutils literal"><span class="pre">begin_state()</span></code></a></dt> |
| <dd>This function can provide the states for the first time step.</dd> |
| <dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.unroll" title="mxnet.rnn.BaseRNNCell.unroll"><code class="xref py py-meth docutils literal"><span class="pre">unroll()</span></code></a></dt> |
| <dd>This function unrolls an RNN for a given number of (>=1) time steps.</dd> |
| </dl> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BaseRNNCell.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Reset before re-using the cell for another graph.</p> |
| </dd></dl> |
| <dl class="attribute"> |
| <dt id="mxnet.rnn.BaseRNNCell.params"> |
| <code class="descname">params</code><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.params" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Parameters of this cell</p> |
| </dd></dl> |
| <dl class="attribute"> |
| <dt id="mxnet.rnn.BaseRNNCell.state_info"> |
| <code class="descname">state_info</code><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.state_info" title="Permalink to this definition">¶</a></dt> |
| <dd><p>shape and layout information of states</p> |
| </dd></dl> |
| <dl class="attribute"> |
| <dt id="mxnet.rnn.BaseRNNCell.state_shape"> |
| <code class="descname">state_shape</code><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.state_shape" title="Permalink to this definition">¶</a></dt> |
| <dd><p>shape(s) of states</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BaseRNNCell.begin_state"> |
| <code class="descname">begin_state</code><span class="sig-paren">(</span><em>func=<function zeros></em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell.begin_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.begin_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initial state for this cell.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>func</strong> (<em>callable</em><em>, </em><em>default symbol.zeros</em>) – Function for creating initial state. Can be symbol.zeros, |
| symbol.uniform, symbol.Variable etc. |
| Use symbol.Variable if you want to directly |
| feed input as states.</li> |
| <li><strong>**kwargs</strong> – more keyword arguments passed to func. For example |
| mean, std, dtype, etc.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>states</strong> – Starting states for the first RNN step.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">nested list of Symbol</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BaseRNNCell.unpack_weights"> |
| <code class="descname">unpack_weights</code><span class="sig-paren">(</span><em>args</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell.unpack_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.unpack_weights" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Unpack fused weight matrices into separate |
| weight matrices.</p> |
| <p>For example, say you use a module object <cite>mod</cite> to run a network that has an lstm cell. |
| In <cite>mod.get_params()[0]</cite>, the lstm parameters are all represented as a single big vector. |
| <cite>cell.unpack_weights(mod.get_params()[0])</cite> will unpack this vector into a dictionary of |
| more readable lstm parameters - c, f, i, o gates for i2h (input to hidden) and |
| h2h (hidden to hidden) weights.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args</strong> (<em>dict of str -> NDArray</em>) – Dictionary containing packed weights. |
| usually from <cite>Module.get_params()[0]</cite>.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>args</strong> – Dictionary with unpacked weights associated with |
| this cell.</td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">dict of str -> NDArray</td> |
| </tr> |
| </tbody> |
| </table> |
| <div class="admonition seealso"> |
| <p class="first admonition-title">See also</p> |
| <dl class="last docutils"> |
| <dt><a class="reference internal" href="#mxnet.rnn.BaseRNNCell.pack_weights" title="mxnet.rnn.BaseRNNCell.pack_weights"><code class="xref py py-meth docutils literal"><span class="pre">pack_weights()</span></code></a></dt> |
| <dd>Performs the reverse operation of this function.</dd> |
| </dl> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BaseRNNCell.pack_weights"> |
| <code class="descname">pack_weights</code><span class="sig-paren">(</span><em>args</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell.pack_weights"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.pack_weights" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Pack separate weight matrices into a single packed |
| weight.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args</strong> (<em>dict of str -> NDArray</em>) – Dictionary containing unpacked weights.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>args</strong> – Dictionary with packed weights associated with |
| this cell.</td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">dict of str -> NDArray</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BaseRNNCell.unroll"> |
| <code class="descname">unroll</code><span class="sig-paren">(</span><em>length</em>, <em>inputs</em>, <em>begin_state=None</em>, <em>layout='NTC'</em>, <em>merge_outputs=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BaseRNNCell.unroll"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BaseRNNCell.unroll" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Unroll an RNN cell across time steps.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>length</strong> (<em>int</em>) – Number of steps to unroll.</li> |
| <li><strong>inputs</strong> (<a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a><em>, </em><em>list of Symbol</em><em>, or </em><em>None</em>) – <p>If <cite>inputs</cite> is a single Symbol (usually the output |
| of Embedding symbol), it should have shape |
| (batch_size, length, ...) if layout == ‘NTC’, |
| or (length, batch_size, ...) if layout == ‘TNC’.</p> |
| <p>If <cite>inputs</cite> is a list of symbols (usually output of |
| previous unroll), they should all have shape |
| (batch_size, ...).</p> |
| </li> |
| <li><strong>begin_state</strong> (<em>nested list of Symbol</em><em>, </em><em>default None</em>) – Input states created by <cite>begin_state()</cite> |
| or output state of another cell. |
| Created from <cite>begin_state()</cite> if None.</li> |
| <li><strong>layout</strong> (<em>str</em><em>, </em><em>optional</em>) – <cite>layout</cite> of input symbol. Only used if inputs |
| is a single Symbol.</li> |
| <li><strong>merge_outputs</strong> (<em>bool</em><em>, </em><em>optional</em>) – If False, return outputs as a list of Symbols. |
| If True, concatenate output across time steps |
| and return a single symbol with shape |
| (batch_size, length, ...) if layout == ‘NTC’, |
| or (length, batch_size, ...) if layout == ‘TNC’. |
| If None, output whatever is faster.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple"> |
| <li><strong>outputs</strong> (<em>list of Symbol or Symbol</em>) – Symbol (if <cite>merge_outputs</cite> is True) or list of Symbols |
| (if <cite>merge_outputs</cite> is False) corresponding to the output from |
| the RNN from this unrolling.</li> |
| <li><strong>states</strong> (<em>nested list of Symbol</em>) – The new state of this RNN after this unrolling. |
| The type of this symbol is same as the output of begin_state().</li> |
| </ul> |
| </p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.LSTMCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">LSTMCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>prefix='lstm_'</em>, <em>params=None</em>, <em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#LSTMCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.LSTMCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Long-Short Term Memory (LSTM) network cell.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li> |
| <li><strong>prefix</strong> (str, default ‘<a href="#id1"><span class="problematic" id="id2">lstm_</span></a>‘) – Prefix for name of layers (and name of weight if params is None).</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None</em>) – Container for weight sharing between cells. Created if None.</li> |
| <li><strong>forget_bias</strong> (<em>bias added to forget gate</em><em>, </em><em>default 1.0.</em>) – Jozefowicz et al. 2015 recommends setting this to 1.0</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.GRUCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">GRUCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>prefix='gru_'</em>, <em>params=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#GRUCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.GRUCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Gated Rectified Unit (GRU) network cell. |
| Note: this is an implementation of the cuDNN version of GRUs |
| (slight modification compared to Cho et al. 2014).</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li> |
| <li><strong>prefix</strong> (str, default ‘<a href="#id3"><span class="problematic" id="id4">gru_</span></a>‘) – Prefix for name of layers (and name of weight if params is None).</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None</em>) – Container for weight sharing between cells. Created if None.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.RNNCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">RNNCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>activation='tanh'</em>, <em>prefix='rnn_'</em>, <em>params=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#RNNCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.RNNCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Simple recurrent neural network cell.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li> |
| <li><strong>activation</strong> (<em>str</em><em> or </em><a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a><em>, </em><em>default 'tanh'</em>) – Type of activation function. Options are ‘relu’ and ‘tanh’.</li> |
| <li><strong>prefix</strong> (str, default ‘<a href="#id5"><span class="problematic" id="id6">rnn_</span></a>‘) – Prefix for name of layers (and name of weight if params is None).</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None</em>) – Container for weight sharing between cells. Created if None.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.FusedRNNCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">FusedRNNCell</code><span class="sig-paren">(</span><em>num_hidden</em>, <em>num_layers=1</em>, <em>mode='lstm'</em>, <em>bidirectional=False</em>, <em>dropout=0.0</em>, <em>get_next_state=False</em>, <em>forget_bias=1.0</em>, <em>prefix=None</em>, <em>params=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#FusedRNNCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.FusedRNNCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Fusing RNN layers across time step into one kernel. |
| Improves speed but is less flexible. Currently only |
| supported if using cuDNN on GPU.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>num_hidden</strong> (<em>int</em>) – Number of units in output symbol.</li> |
| <li><strong>num_layers</strong> (<em>int</em><em>, </em><em>default 1</em>) – Number of layers in the cell.</li> |
| <li><strong>mode</strong> (<em>str</em><em>, </em><em>default 'lstm'</em>) – Type of RNN. options are ‘rnn_relu’, ‘rnn_tanh’, ‘lstm’, ‘gru’.</li> |
| <li><strong>bidirectional</strong> (<em>bool</em><em>, </em><em>default False</em>) – Whether to use bidirectional unroll. The output dimension size is doubled if bidrectional.</li> |
| <li><strong>dropout</strong> (<em>float</em><em>, </em><em>default 0.</em>) – Fraction of the input that gets dropped out during training time.</li> |
| <li><strong>get_next_state</strong> (<em>bool</em><em>, </em><em>default False</em>) – Whether to return the states that can be used as starting states next time.</li> |
| <li><strong>forget_bias</strong> (<em>bias added to forget gate</em><em>, </em><em>default 1.0.</em>) – Jozefowicz et al. 2015 recommends setting this to 1.0</li> |
| <li><strong>prefix</strong> (str, default ‘$mode_’ such as ‘<a href="#id7"><span class="problematic" id="id8">lstm_</span></a>‘) – Prefix for names of layers |
| (this prefix is also used for names of weights if <cite>params</cite> is None |
| i.e. if <cite>params</cite> are being created and not reused)</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None</em>) – Container for weight sharing between cells. Created if None.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.rnn.FusedRNNCell.unfuse"> |
| <code class="descname">unfuse</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#FusedRNNCell.unfuse"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.FusedRNNCell.unfuse" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Unfuse the fused RNN in to a stack of rnn cells.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body"><strong>cell</strong> – unfused cell that can be used for stepping, and can run on CPU.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body"><a class="reference internal" href="#mxnet.rnn.SequentialRNNCell" title="mxnet.rnn.SequentialRNNCell">SequentialRNNCell</a></td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.SequentialRNNCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">SequentialRNNCell</code><span class="sig-paren">(</span><em>params=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#SequentialRNNCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.SequentialRNNCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sequantially stacking multiple RNN cells.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None</em>) – Container for weight sharing between cells. Created if None.</td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.rnn.SequentialRNNCell.add"> |
| <code class="descname">add</code><span class="sig-paren">(</span><em>cell</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#SequentialRNNCell.add"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.SequentialRNNCell.add" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Append a cell into the stack.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – The cell to be appended. During unroll, previous cell’s output (or raw inputs if |
| no previous cell) is used as the input to this cell.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.BidirectionalCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">BidirectionalCell</code><span class="sig-paren">(</span><em>l_cell</em>, <em>r_cell</em>, <em>params=None</em>, <em>output_prefix='bi_'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#BidirectionalCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BidirectionalCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Bidirectional RNN cell.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>l_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – cell for forward unrolling</li> |
| <li><strong>r_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – cell for backward unrolling</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None.</em>) – Container for weight sharing between cells. |
| A new RNNParams container is created if <cite>params</cite> is None.</li> |
| <li><strong>output_prefix</strong> (str, default ‘<a href="#id9"><span class="problematic" id="id10">bi_</span></a>‘) – prefix for name of output</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.DropoutCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">DropoutCell</code><span class="sig-paren">(</span><em>dropout</em>, <em>prefix='dropout_'</em>, <em>params=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#DropoutCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.DropoutCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Apply dropout on input.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>dropout</strong> (<em>float</em>) – Percentage of elements to drop out, which |
| is 1 - percentage to retain.</li> |
| <li><strong>prefix</strong> (str, default ‘<a href="#id11"><span class="problematic" id="id12">dropout_</span></a>‘) – Prefix for names of layers |
| (this prefix is also used for names of weights if <cite>params</cite> is None |
| i.e. if <cite>params</cite> are being created and not reused)</li> |
| <li><strong>params</strong> (<a class="reference internal" href="#mxnet.rnn.RNNParams" title="mxnet.rnn.RNNParams"><em>RNNParams</em></a><em>, </em><em>default None</em>) – Container for weight sharing between cells. Created if None.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.ZoneoutCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">ZoneoutCell</code><span class="sig-paren">(</span><em>base_cell</em>, <em>zoneout_outputs=0.0</em>, <em>zoneout_states=0.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#ZoneoutCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.ZoneoutCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Apply Zoneout on base cell.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>base_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – Cell on whose states to perform zoneout.</li> |
| <li><strong>zoneout_outputs</strong> (<em>float</em><em>, </em><em>default 0.</em>) – Fraction of the output that gets dropped out during training time.</li> |
| <li><strong>zoneout_states</strong> (<em>float</em><em>, </em><em>default 0.</em>) – Fraction of the states that gets dropped out during training time.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.ResidualCell"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">ResidualCell</code><span class="sig-paren">(</span><em>base_cell</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#ResidualCell"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.ResidualCell" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Adds residual connection as described in Wu et al, 2016 |
| (<a class="reference external" href="https://arxiv.org/abs/1609.08144">https://arxiv.org/abs/1609.08144</a>).</p> |
| <p>Output of the cell is output of the base cell plus input.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>base_cell</strong> (<a class="reference internal" href="#mxnet.rnn.BaseRNNCell" title="mxnet.rnn.BaseRNNCell"><em>BaseRNNCell</em></a>) – Cell on whose outputs to add residual connection.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.RNNParams"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">RNNParams</code><span class="sig-paren">(</span><em>prefix=''</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#RNNParams"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.RNNParams" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Container for holding variables. |
| Used by RNN cells for parameter sharing between cells.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>prefix</strong> (<em>str</em>) – Names of all variables created by this container will |
| be prepended with prefix.</td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.rnn.RNNParams.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/rnn_cell.html#RNNParams.get"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.RNNParams.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Get the variable given a name if one exists or create a new one if missing.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – name of the variable</li> |
| <li><strong>**kwargs</strong> – more arguments that’s passed to symbol.Variable</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.rnn.BucketSentenceIter"> |
| <em class="property">class </em><code class="descclassname">mxnet.rnn.</code><code class="descname">BucketSentenceIter</code><span class="sig-paren">(</span><em>sentences</em>, <em>batch_size</em>, <em>buckets=None</em>, <em>invalid_label=-1</em>, <em>data_name='data'</em>, <em>label_name='softmax_label'</em>, <em>dtype='float32'</em>, <em>layout='NT'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/io.html#BucketSentenceIter"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BucketSentenceIter" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Simple bucketing iterator for language model. |
| The label at each sequence step is the following token |
| in the sequence.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>sentences</strong> (<em>list of list of int</em>) – Encoded sentences.</li> |
| <li><strong>batch_size</strong> (<em>int</em>) – Batch size of the data.</li> |
| <li><strong>invalid_label</strong> (<em>int</em><em>, </em><em>optional</em>) – Key for invalid label, e.g. <end-of-sentence>. The default is -1.</li> |
| <li><strong>dtype</strong> (<em>str</em><em>, </em><em>optional</em>) – Data type of the encoding. The default data type is ‘float32’.</li> |
| <li><strong>buckets</strong> (<em>list of int</em><em>, </em><em>optional</em>) – Size of the data buckets. Automatically generated if None.</li> |
| <li><strong>data_name</strong> (<em>str</em><em>, </em><em>optional</em>) – Name of the data. The default name is ‘data’.</li> |
| <li><strong>label_name</strong> (<em>str</em><em>, </em><em>optional</em>) – Name of the label. The default name is ‘softmax_label’.</li> |
| <li><strong>layout</strong> (<em>str</em><em>, </em><em>optional</em>) – Format of data and label. ‘NT’ means (batch_size, length) |
| and ‘TN’ means (length, batch_size).</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BucketSentenceIter.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/io.html#BucketSentenceIter.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BucketSentenceIter.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Resets the iterator to the beginning of the data.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.BucketSentenceIter.next"> |
| <code class="descname">next</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/rnn/io.html#BucketSentenceIter.next"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.rnn.BucketSentenceIter.next" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns the next batch of data.</p> |
| </dd></dl> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.encode_sentences"> |
| <code class="descclassname">rnn.</code><code class="descname">encode_sentences</code><span class="sig-paren">(</span><em>sentences</em>, <em>vocab=None</em>, <em>invalid_label=-1</em>, <em>invalid_key='\n'</em>, <em>start_label=0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.encode_sentences" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Encode sentences and (optionally) build a mapping |
| from string tokens to integer indices. Unknown keys |
| will be added to vocabulary.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>sentences</strong> (<em>list of list of str</em>) – A list of sentences to encode. Each sentence |
| should be a list of string tokens.</li> |
| <li><strong>vocab</strong> (<em>None</em><em> or </em><em>dict of str -> int</em>) – Optional input Vocabulary</li> |
| <li><strong>invalid_label</strong> (<em>int</em><em>, </em><em>default -1</em>) – Index for invalid token, like <end-of-sentence></li> |
| <li><strong>invalid_key</strong> (<em>str</em><em>, </em><em>default 'n'</em>) – Key for invalid token. Use ‘n’ for end |
| of sentence by default.</li> |
| <li><strong>start_label</strong> (<em>int</em>) – lowest index.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple"> |
| <li><strong>result</strong> (<em>list of list of int</em>) – encoded sentences</li> |
| <li><strong>vocab</strong> (<em>dict of str -> int</em>) – result vocabulary</li> |
| </ul> |
| </p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.save_rnn_checkpoint"> |
| <code class="descclassname">rnn.</code><code class="descname">save_rnn_checkpoint</code><span class="sig-paren">(</span><em>cells</em>, <em>prefix</em>, <em>epoch</em>, <em>symbol</em>, <em>arg_params</em>, <em>aux_params</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.save_rnn_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Save checkpoint for model using RNN cells. |
| Unpacks weight before saving.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>cells</strong> (<a class="reference internal" href="gluon.html#mxnet.gluon.rnn.RNNCell" title="mxnet.gluon.rnn.RNNCell"><em>RNNCell</em></a><em> or </em><em>list of RNNCells</em>) – The RNN cells used by this symbol.</li> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – The epoch number of the model.</li> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The input symbol</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.load_rnn_checkpoint"> |
| <code class="descclassname">rnn.</code><code class="descname">load_rnn_checkpoint</code><span class="sig-paren">(</span><em>cells</em>, <em>prefix</em>, <em>epoch</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.load_rnn_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Load model checkpoint from file. |
| Pack weights after loading.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>cells</strong> (<a class="reference internal" href="gluon.html#mxnet.gluon.rnn.RNNCell" title="mxnet.gluon.rnn.RNNCell"><em>RNNCell</em></a><em> or </em><em>list of RNNCells</em>) – The RNN cells used by this symbol.</li> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – Epoch number of model we would like to load.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple"> |
| <li><strong>symbol</strong> (<em>Symbol</em>) – The symbol configuration of computation network.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| </ul> |
| </p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li>symbol will be loaded from <code class="docutils literal"><span class="pre">prefix-symbol.json</span></code>.</li> |
| <li>parameters will be loaded from <code class="docutils literal"><span class="pre">prefix-epoch.params</span></code>.</li> |
| </ul> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.rnn.do_rnn_checkpoint"> |
| <code class="descclassname">rnn.</code><code class="descname">do_rnn_checkpoint</code><span class="sig-paren">(</span><em>cells</em>, <em>prefix</em>, <em>period=1</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.rnn.do_rnn_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Make a callback to checkpoint Module to prefix every epoch. |
| unpacks weights used by cells before saving.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>cells</strong> (<a class="reference internal" href="gluon.html#mxnet.gluon.rnn.RNNCell" title="mxnet.gluon.rnn.RNNCell"><em>RNNCell</em></a><em> or </em><em>list of RNNCells</em>) – The RNN cells used by this symbol.</li> |
| <li><strong>prefix</strong> (<em>str</em>) – The file prefix to checkpoint to</li> |
| <li><strong>period</strong> (<em>int</em>) – How many epochs to wait before checkpointing. Default is 1.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>callback</strong> – The callback function that can be passed as iter_end_callback to fit.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">function</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <script>auto_index("api-reference");</script></div> |
| </div> |
| </div> |
| </div> |
| <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <h3><a href="../../index.html">Table Of Contents</a></h3> |
| <ul> |
| <li><a class="reference internal" href="#">RNN Cell API</a><ul> |
| <li><a class="reference internal" href="#overview">Overview</a></li> |
| <li><a class="reference internal" href="#the-rnn-module">The <code class="docutils literal"><span class="pre">rnn</span></code> module</a><ul> |
| <li><a class="reference internal" href="#cell-interfaces">Cell interfaces</a></li> |
| <li><a class="reference internal" href="#basic-rnn-cells">Basic RNN cells</a></li> |
| <li><a class="reference internal" href="#modifier-cells">Modifier cells</a></li> |
| <li><a class="reference internal" href="#multi-layer-cells">Multi-layer cells</a></li> |
| <li><a class="reference internal" href="#fused-rnn-cell">Fused RNN cell</a></li> |
| <li><a class="reference internal" href="#rnn-checkpoint-methods-and-parameters">RNN checkpoint methods and parameters</a></li> |
| <li><a class="reference internal" href="#i-o-utilities">I/O utilities</a></li> |
| </ul> |
| </li> |
| <li><a class="reference internal" href="#api-reference">API Reference</a></li> |
| </ul> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div><div class="footer"> |
| <div class="section-disclaimer"> |
| <div class="container"> |
| <div> |
| <img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/> |
| <p> |
| Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p> |
| <p> |
| "Copyright © 2017-2018, The Apache Software Foundation |
| Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation." |
| </p> |
| </div> |
| </div> |
| </div> |
| </div> <!-- pagename != index --> |
| </div> |
| <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |
| <script src="../../_static/js/sidebar.js" type="text/javascript"></script> |
| <script src="../../_static/js/search.js" type="text/javascript"></script> |
| <script src="../../_static/js/navbar.js" type="text/javascript"></script> |
| <script src="../../_static/js/clipboard.min.js" type="text/javascript"></script> |
| <script src="../../_static/js/copycode.js" type="text/javascript"></script> |
| <script src="../../_static/js/page.js" type="text/javascript"></script> |
| <script src="../../_static/js/docversion.js" type="text/javascript"></script> |
| <script type="text/javascript"> |
| $('body').ready(function () { |
| $('body').css('visibility', 'visible'); |
| }); |
| </script> |
| </body> |
| </html> |