| <!DOCTYPE html> |
| |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta content="IE=edge" http-equiv="X-UA-Compatible"/> |
| <meta content="width=device-width, initial-scale=1" name="viewport"/> |
| <title>Optimization: initialize and update weights — mxnet documentation</title> |
| <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> |
| <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> |
| <link href="../../_static/basic.css" rel="stylesheet" type="text/css"/> |
| <link href="../../_static/pygments.css" rel="stylesheet" type="text/css"/> |
| <link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"> |
| <script type="text/javascript"> |
| var DOCUMENTATION_OPTIONS = { |
| URL_ROOT: '../../', |
| VERSION: '', |
| COLLAPSE_INDEX: false, |
| FILE_SUFFIX: '.html', |
| HAS_SOURCE: true, |
| SOURCELINK_SUFFIX: '' |
| }; |
| </script> |
| <script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script> |
| <script src="../../_static/underscore.js" type="text/javascript"></script> |
| <script src="../../_static/searchtools_custom.js" type="text/javascript"></script> |
| <script src="../../_static/doctools.js" type="text/javascript"></script> |
| <script src="../../_static/selectlang.js" type="text/javascript"></script> |
| <script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> |
| <script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new |
| Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| |
| </script> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/jquery.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/underscore.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/doctools.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> |
| <!-- --> |
| <link href="index.html" rel="up" title="MXNet - Python API"/> |
| <link href="callback.html" rel="next" title="Callback API"> |
| <link href="image.html" rel="prev" title="Image API"> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/> |
| </link></link></link></head> |
| <body role="document"><div class="navbar navbar-fixed-top"> |
| <div class="container" id="navContainer"> |
| <div class="innder" id="header-inner"> |
| <h1 id="logo-wrap"> |
| <a href="../../" id="logo"><img src="../../_static/mxnet.png"/></a> |
| </h1> |
| <nav class="nav-bar" id="main-nav"> |
| <a class="main-nav-link" href="../../get_started/install.html">Install</a> |
| <a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a> |
| <ul class="dropdown-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="../../gluon/index.html">About</a></li> |
| <li><a class="main-nav-link" href="http://gluon.mxnet.io/">Tutorials</a></li> |
| </ul> |
| </span> |
| <a class="main-nav-link" href="../../how_to/index.html">How To</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> |
| <ul class="dropdown-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li> |
| <li><a class="main-nav-link" href="../../api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li> |
| </ul> |
| </span> |
| <a class="main-nav-link" href="../../architecture/index.html">Architecture</a> |
| <!-- <a class="main-nav-link" href="../../community/index.html">Community</a> --> |
| <a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a> |
| <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(1.0.0)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/>1.1.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></span></nav> |
| <script> function getRootPath(){ return "../../" } </script> |
| <div class="burgerIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> |
| <ul class="dropdown-menu dropdown-menu-right" id="burgerMenu"> |
| <li><a href="../../get_started/install.html">Install</a></li> |
| <li><a href="../../tutorials/index.html">Tutorials</a></li> |
| <li><a href="../../how_to/index.html">How To</a></li> |
| <li class="dropdown-submenu"> |
| <a href="#" tabindex="-1">API</a> |
| <ul class="dropdown-menu"> |
| <li><a href="../../api/python/index.html" tabindex="-1">Python</a> |
| </li> |
| <li><a href="../../api/scala/index.html" tabindex="-1">Scala</a> |
| </li> |
| <li><a href="../../api/r/index.html" tabindex="-1">R</a> |
| </li> |
| <li><a href="../../api/julia/index.html" tabindex="-1">Julia</a> |
| </li> |
| <li><a href="../../api/c++/index.html" tabindex="-1">C++</a> |
| </li> |
| <li><a href="../../api/perl/index.html" tabindex="-1">Perl</a> |
| </li> |
| </ul> |
| </li> |
| <li><a href="../../architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li> |
| <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(1.0.0)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/>1.1.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></li></ul> |
| </div> |
| <div class="plusIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> |
| <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> |
| </div> |
| <div id="search-input-wrap"> |
| <form action="../../search.html" autocomplete="off" class="" method="get" role="search"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input class="form-control" name="q" placeholder="Search" type="text"/> |
| </div> |
| <input name="check_keywords" type="hidden" value="yes"/> |
| <input name="area" type="hidden" value="default"> |
| </input></form> |
| <div id="search-preview"></div> |
| </div> |
| <div id="searchIcon"> |
| <span aria-hidden="true" class="glyphicon glyphicon-search"></span> |
| </div> |
| <!-- <div id="lang-select-wrap"> --> |
| <!-- <label id="lang-select-label"> --> |
| <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> |
| <!-- <span></span> --> |
| <!-- </label> --> |
| <!-- <select id="lang-select"> --> |
| <!-- <option value="en">Eng</option> --> |
| <!-- <option value="zh">中文</option> --> |
| <!-- </select> --> |
| <!-- </div> --> |
| <!-- <a id="mobile-nav-toggle"> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| </a> --> |
| </div> |
| </div> |
| </div> |
| <div class="container"> |
| <div class="row"> |
| <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <ul class="current"> |
| <li class="toctree-l1 current"><a class="reference internal" href="index.html">Python Documents</a><ul class="current"> |
| <li class="toctree-l2 current"><a class="reference internal" href="index.html#table-of-contents">Table of contents</a><ul class="current"> |
| <li class="toctree-l3"><a class="reference internal" href="ndarray.html">NDArray API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="symbol.html">Symbol API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="module.html">Module API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="autograd.html">Autograd Package</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="gluon.html">Gluon Package</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="rnn.html">RNN Cell API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="kvstore.html">KVStore API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="io.html">Data Loading API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="image.html">Image API</a></li> |
| <li class="toctree-l3 current"><a class="current reference internal" href="">Optimization: initialize and update weights</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#overview">Overview</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#the-mxnet-initializer-package">The <code class="docutils literal"><span class="pre">mxnet.initializer</span></code> package</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#the-mxnet-optimizer-package">The <code class="docutils literal"><span class="pre">mxnet.optimizer</span></code> package</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#the-mxnet-lr-scheduler-package">The <code class="docutils literal"><span class="pre">mxnet.lr_scheduler</span></code> package</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#implement-a-new-algorithm">Implement a new algorithm</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#api-reference">API Reference</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="callback.html">Callback API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="metric.html">Evaluation Metric API</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| <li class="toctree-l1"><a class="reference internal" href="../r/index.html">R Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../julia/index.html">Julia Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../c++/index.html">C++ Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../scala/index.html">Scala Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../perl/index.html">Perl Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li> |
| </ul> |
| </div> |
| </div> |
| <div class="content"> |
| <div class="section" id="optimization-initialize-and-update-weights"> |
| <span id="optimization-initialize-and-update-weights"></span><h1>Optimization: initialize and update weights<a class="headerlink" href="#optimization-initialize-and-update-weights" title="Permalink to this headline">¶</a></h1> |
| <div class="section" id="overview"> |
| <span id="overview"></span><h2>Overview<a class="headerlink" href="#overview" title="Permalink to this headline">¶</a></h2> |
| <p>This document summaries the APIs used to initialize and update the model weights |
| during training</p> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#module-mxnet.initializer" title="mxnet.initializer"><code class="xref py py-obj docutils literal"><span class="pre">mxnet.initializer</span></code></a></td> |
| <td>Weight initializer.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#module-mxnet.optimizer" title="mxnet.optimizer"><code class="xref py py-obj docutils literal"><span class="pre">mxnet.optimizer</span></code></a></td> |
| <td>Weight updating functions.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#module-mxnet.lr_scheduler" title="mxnet.lr_scheduler"><code class="xref py py-obj docutils literal"><span class="pre">mxnet.lr_scheduler</span></code></a></td> |
| <td>Scheduling learning rate.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p>and how to develop a new optimization algorithm in MXNet.</p> |
| <p>Assume there there is a pre-defined <code class="docutils literal"><span class="pre">Symbol</span></code> and a <code class="docutils literal"><span class="pre">Module</span></code> is created for |
| it</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">label</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'softmax_label'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">fc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">loss</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">fc</span><span class="p">,</span> <span class="n">label</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">mod</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">data_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'data'</span><span class="p">,</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span><span class="mi">20</span><span class="p">))],</span> <span class="n">label_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'softmax_label'</span><span class="p">,</span> <span class="p">(</span><span class="mi">128</span><span class="p">,))])</span> |
| </pre></div> |
| </div> |
| <p>Next we can initialize the weights with values sampled uniformly from |
| <code class="docutils literal"><span class="pre">[-1,1]</span></code>:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">mod</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <p>Then we will train a model with standard SGD which decreases the learning rate |
| by multiplying 0.9 for each 100 batches.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">lr_sch</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">FactorScheduler</span><span class="p">(</span><span class="n">step</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">mod</span><span class="o">.</span><span class="n">init_optimizer</span><span class="p">(</span> |
| <span class="gp">... </span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span> <span class="n">optimizer_params</span><span class="o">=</span><span class="p">((</span><span class="s1">'learning_rate'</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">),</span> <span class="p">(</span><span class="s1">'lr_scheduler'</span><span class="p">,</span> <span class="n">lr_sch</span><span class="p">)))</span> |
| </pre></div> |
| </div> |
| <p>Finally run <code class="docutils literal"><span class="pre">mod.fit(...)</span></code> to start training.</p> |
| </div> |
| <div class="section" id="the-mxnet-initializer-package"> |
| <span id="the-mxnet-initializer-package"></span><h2>The <code class="docutils literal"><span class="pre">mxnet.initializer</span></code> package<a class="headerlink" href="#the-mxnet-initializer-package" title="Permalink to this headline">¶</a></h2> |
| <p>The base class <code class="docutils literal"><span class="pre">Initializer</span></code> defines the default behaviors to initialize |
| various parameters, such as set bias to 1, except for the weight. Other classes |
| then defines how to initialize the weight.</p> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><code class="xref py py-obj docutils literal"><span class="pre">Initializer</span></code></a></td> |
| <td>The base class of an initializer.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.initializer.Uniform" title="mxnet.initializer.Uniform"><code class="xref py py-obj docutils literal"><span class="pre">Uniform</span></code></a></td> |
| <td>Initializes weights with random values uniformly sampled from a given range.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.Normal" title="mxnet.initializer.Normal"><code class="xref py py-obj docutils literal"><span class="pre">Normal</span></code></a></td> |
| <td>Initializes weights with random values sampled from a normal distribution with a mean of zero and standard deviation of <cite>sigma</cite>.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.initializer.Load" title="mxnet.initializer.Load"><code class="xref py py-obj docutils literal"><span class="pre">Load</span></code></a></td> |
| <td>Initializes variables by loading data from file or dict.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.Mixed" title="mxnet.initializer.Mixed"><code class="xref py py-obj docutils literal"><span class="pre">Mixed</span></code></a></td> |
| <td>Initialize parameters using multiple initializers.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.initializer.Zero" title="mxnet.initializer.Zero"><code class="xref py py-obj docutils literal"><span class="pre">Zero</span></code></a></td> |
| <td>Initializes weights to zero.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.One" title="mxnet.initializer.One"><code class="xref py py-obj docutils literal"><span class="pre">One</span></code></a></td> |
| <td>Initializes weights to one.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.initializer.Constant" title="mxnet.initializer.Constant"><code class="xref py py-obj docutils literal"><span class="pre">Constant</span></code></a></td> |
| <td>Initializes the weights to a scalar value.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.Orthogonal" title="mxnet.initializer.Orthogonal"><code class="xref py py-obj docutils literal"><span class="pre">Orthogonal</span></code></a></td> |
| <td>Initialize weight as orthogonal matrix.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.initializer.Xavier" title="mxnet.initializer.Xavier"><code class="xref py py-obj docutils literal"><span class="pre">Xavier</span></code></a></td> |
| <td>Returns an initializer performing “Xavier” initialization for weights.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.MSRAPrelu" title="mxnet.initializer.MSRAPrelu"><code class="xref py py-obj docutils literal"><span class="pre">MSRAPrelu</span></code></a></td> |
| <td>Initialize the weight according to a MSRA paper.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.initializer.Bilinear" title="mxnet.initializer.Bilinear"><code class="xref py py-obj docutils literal"><span class="pre">Bilinear</span></code></a></td> |
| <td>Initialize weight for upsampling layers.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.initializer.FusedRNN" title="mxnet.initializer.FusedRNN"><code class="xref py py-obj docutils literal"><span class="pre">FusedRNN</span></code></a></td> |
| <td>Initialize parameters for fused rnn layers.</td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| <div class="section" id="the-mxnet-optimizer-package"> |
| <span id="the-mxnet-optimizer-package"></span><h2>The <code class="docutils literal"><span class="pre">mxnet.optimizer</span></code> package<a class="headerlink" href="#the-mxnet-optimizer-package" title="Permalink to this headline">¶</a></h2> |
| <p>The base class <code class="docutils literal"><span class="pre">Optimizer</span></code> accepts commonly shared arguments such as |
| <code class="docutils literal"><span class="pre">learning_rate</span></code> and defines the interface. Each other class in this package |
| implements one weight updating function.</p> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-obj docutils literal"><span class="pre">Optimizer</span></code></a></td> |
| <td>The base class inherited by all optimizers.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.optimizer.SGD" title="mxnet.optimizer.SGD"><code class="xref py py-obj docutils literal"><span class="pre">SGD</span></code></a></td> |
| <td>The SGD optimizer with momentum and weight decay.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.optimizer.NAG" title="mxnet.optimizer.NAG"><code class="xref py py-obj docutils literal"><span class="pre">NAG</span></code></a></td> |
| <td>Nesterov accelerated SGD.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.optimizer.RMSProp" title="mxnet.optimizer.RMSProp"><code class="xref py py-obj docutils literal"><span class="pre">RMSProp</span></code></a></td> |
| <td>The RMSProp optimizer.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.optimizer.Adam" title="mxnet.optimizer.Adam"><code class="xref py py-obj docutils literal"><span class="pre">Adam</span></code></a></td> |
| <td>The Adam optimizer.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.optimizer.AdaGrad" title="mxnet.optimizer.AdaGrad"><code class="xref py py-obj docutils literal"><span class="pre">AdaGrad</span></code></a></td> |
| <td>AdaGrad optimizer.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.optimizer.AdaDelta" title="mxnet.optimizer.AdaDelta"><code class="xref py py-obj docutils literal"><span class="pre">AdaDelta</span></code></a></td> |
| <td>The AdaDelta optimizer.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.optimizer.DCASGD" title="mxnet.optimizer.DCASGD"><code class="xref py py-obj docutils literal"><span class="pre">DCASGD</span></code></a></td> |
| <td>The DCASGD optimizer.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.optimizer.SGLD" title="mxnet.optimizer.SGLD"><code class="xref py py-obj docutils literal"><span class="pre">SGLD</span></code></a></td> |
| <td>Stochastic Gradient Riemannian Langevin Dynamics.</td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| <div class="section" id="the-mxnet-lr-scheduler-package"> |
| <span id="the-mxnet-lr-scheduler-package"></span><h2>The <code class="docutils literal"><span class="pre">mxnet.lr_scheduler</span></code> package<a class="headerlink" href="#the-mxnet-lr-scheduler-package" title="Permalink to this headline">¶</a></h2> |
| <p>The base class <code class="docutils literal"><span class="pre">LRScheduler</span></code> defines the interface, while other classes |
| implement various schemes to change the learning rate during training.</p> |
| <table border="1" class="longtable docutils"> |
| <colgroup> |
| <col width="10%"/> |
| <col width="90%"/> |
| </colgroup> |
| <tbody valign="top"> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.lr_scheduler.LRScheduler" title="mxnet.lr_scheduler.LRScheduler"><code class="xref py py-obj docutils literal"><span class="pre">LRScheduler</span></code></a></td> |
| <td>Base class of a learning rate scheduler.</td> |
| </tr> |
| <tr class="row-even"><td><a class="reference internal" href="#mxnet.lr_scheduler.FactorScheduler" title="mxnet.lr_scheduler.FactorScheduler"><code class="xref py py-obj docutils literal"><span class="pre">FactorScheduler</span></code></a></td> |
| <td>Reduce the learning rate by a factor for every <em>n</em> steps.</td> |
| </tr> |
| <tr class="row-odd"><td><a class="reference internal" href="#mxnet.lr_scheduler.MultiFactorScheduler" title="mxnet.lr_scheduler.MultiFactorScheduler"><code class="xref py py-obj docutils literal"><span class="pre">MultiFactorScheduler</span></code></a></td> |
| <td>Reduce the learning rate by given a list of steps.</td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| <div class="section" id="implement-a-new-algorithm"> |
| <span id="implement-a-new-algorithm"></span><h2>Implement a new algorithm<a class="headerlink" href="#implement-a-new-algorithm" title="Permalink to this headline">¶</a></h2> |
| <p>Most classes listed in this document are implemented in Python by using <code class="docutils literal"><span class="pre">NDArray</span></code>. |
| So implementing new weight updating or initialization functions is |
| straightforward.</p> |
| <p>For <code class="docutils literal"><span class="pre">initializer</span></code>, create a subclass of <code class="docutils literal"><span class="pre">Initializer</span></code> and define the |
| <code class="docutils literal"><span class="pre">_init_weight</span></code> method. We can also change the default behaviors to initialize |
| other parameters such as <code class="docutils literal"><span class="pre">_init_bias</span></code>. See |
| <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/python/mxnet/initializer.py"><code class="docutils literal"><span class="pre">initializer.py</span></code></a> |
| for examples.</p> |
| <p>For <code class="docutils literal"><span class="pre">optimizer</span></code>, create a subclass of <code class="docutils literal"><span class="pre">Optimizer</span></code> |
| and implement two methods <code class="docutils literal"><span class="pre">create_state</span></code> and <code class="docutils literal"><span class="pre">update</span></code>. Also add |
| <code class="docutils literal"><span class="pre">@mx.optimizer.Optimizer.register</span></code> before this class. See |
| <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/python/mxnet/optimizer.py"><code class="docutils literal"><span class="pre">optimizer.py</span></code></a> |
| for examples.</p> |
| <p>For <code class="docutils literal"><span class="pre">lr_scheduler</span></code>, create a subclass of <code class="docutils literal"><span class="pre">LRScheduler</span></code> and then implement the |
| <code class="docutils literal"><span class="pre">__call__</span></code> method. See |
| <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/python/mxnet/lr_scheduler.py"><code class="docutils literal"><span class="pre">lr_scheduler.py</span></code></a> |
| for examples.</p> |
| </div> |
| <div class="section" id="api-reference"> |
| <span id="api-reference"></span><h2>API Reference<a class="headerlink" href="#api-reference" title="Permalink to this headline">¶</a></h2> |
| <script src="../../_static/js/auto_module_index.js" type="text/javascript"></script><span class="target" id="module-mxnet.optimizer"></span><p>Weight updating functions.</p> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Optimizer"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Optimizer</code><span class="sig-paren">(</span><em>rescale_grad=1.0</em>, <em>param_idx2name=None</em>, <em>wd=0.0</em>, <em>clip_gradient=None</em>, <em>learning_rate=0.01</em>, <em>lr_scheduler=None</em>, <em>sym=None</em>, <em>begin_num_update=0</em>, <em>param_dict=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The base class inherited by all optimizers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rescale_grad</strong> (<em>float, optional</em>) – Multiply the gradient with <cite>rescale_grad</cite> before updating. Often |
| choose to be <code class="docutils literal"><span class="pre">1.0/batch_size</span></code>.</li> |
| <li><strong>param_idx2name</strong> (<em>dict from int to string, optional</em>) – A dictionary that maps int index to string name.</li> |
| <li><strong>clip_gradient</strong> (<em>float, optional</em>) – Clip the gradient by projecting onto the box <code class="docutils literal"><span class="pre">[-clip_gradient,</span> <span class="pre">clip_gradient]</span></code>.</li> |
| <li><strong>learning_rate</strong> (<em>float, optional</em>) – The initial learning rate.</li> |
| <li><strong>lr_scheduler</strong> (<em>LRScheduler, optional</em>) – The learning rate scheduler.</li> |
| <li><strong>wd</strong> (<em>float, optional</em>) – The weight decay (or L2 regularization) coefficient. Modifies objective |
| by adding a penalty for having large weights.</li> |
| <li><strong>sym</strong> (<em>Symbol, optional</em>) – The Symbol this optimizer is applying to.</li> |
| <li><strong>begin_num_update</strong> (<em>int, optional</em>) – The initial number of updates.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="staticmethod"> |
| <dt id="mxnet.optimizer.Optimizer.register"> |
| <em class="property">static </em><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a new optimizer.</p> |
| <p>Once an optimizer is registered, we can create an instance of this |
| optimizer with <cite>create_optimizer</cite> later.</p> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="nd">@mx.optimizer.Optimizer.register</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">MyOptimizer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">pass</span> |
| <span class="gp">>>> </span><span class="n">optim</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'MyOptimizer'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">optim</span><span class="p">))</span> |
| <span class="go"><class '__main__.myoptimizer'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.optimizer.Optimizer.create_optimizer"> |
| <em class="property">static </em><code class="descname">create_optimizer</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_optimizer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Instantiates an optimizer with a given name and kwargs.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">We can use the alias <cite>create</cite> for <code class="docutils literal"><span class="pre">Optimizer.create_optimizer</span></code>.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of the optimizer. Should be the name |
| of a subclass of Optimizer. Case insensitive.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Parameters for the optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">An instantiated optimizer.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer">Optimizer</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">sgd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'sgd'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">sgd</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.sgd'=""></class></span> |
| <span class="gp">>>> </span><span class="n">adam</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mi">1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">adam</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.adam'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.create_state"> |
| <code class="descname">create_state</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates auxiliary state for a given weight.</p> |
| <p>Some optimizers require additional states, e.g. as momentum, in addition |
| to gradients in order to update weights. This function creates state |
| for a given weight which will be used in <cite>update</cite>. This function is |
| called only once for each weight.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>index</strong> (<em>int</em>) – An unique index to identify the weight.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The weight.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>state</strong> – |
| The state associated with the weight.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">any obj</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the given parameter using the corresponding gradient and state.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>index</strong> (<em>int</em>) – The unique index of the parameter into the individual learning |
| rates and weight decays. Learning rates and weight decay |
| may be set via <cite>set_lr_mult()</cite> and <cite>set_wd_mult()</cite>, respectively.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The parameter to be updated.</li> |
| <li><strong>grad</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The gradient of the objective with respect to this parameter.</li> |
| <li><strong>state</strong> (<em>any obj</em>) – The state returned by <cite>create_state()</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_lr_scale"> |
| <code class="descname">set_lr_scale</code><span class="sig-paren">(</span><em>args_lrscale</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_lr_scale" title="Permalink to this definition">¶</a></dt> |
| <dd><p>[DEPRECATED] Sets lr scale. Use set_lr_mult instead.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_lr_mult"> |
| <code class="descname">set_lr_mult</code><span class="sig-paren">(</span><em>args_lr_mult</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_lr_mult" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets an individual learning rate multiplier for each parameter.</p> |
| <p>If you specify a learning rate multiplier for a parameter, then |
| the learning rate for the parameter will be set as the product of |
| the global learning rate <cite>self.lr</cite> and its multiplier.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The default learning rate multiplier of a <cite>Variable</cite> |
| can be set with <cite>lr_mult</cite> argument in the constructor.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args_lr_mult</strong> (<em>dict of str/int to float</em>) – <p>For each of its key-value entries, the learning rate multipler for the |
| parameter specified in the key will be set as the given value.</p> |
| <p>You can specify the parameter with either its name or its index. |
| If you use the name, you should pass <cite>sym</cite> in the constructor, |
| and the name you specified in the key of <cite>args_lr_mult</cite> should match |
| the name of the parameter in <cite>sym</cite>. If you use the index, it should |
| correspond to the index of the parameter used in the <cite>update</cite> method.</p> |
| <p>Specifying a parameter by its index is only supported for backward |
| compatibility, and we recommend to use the name instead.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_wd_mult"> |
| <code class="descname">set_wd_mult</code><span class="sig-paren">(</span><em>args_wd_mult</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_wd_mult" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets an individual weight decay multiplier for each parameter.</p> |
| <p>By default, if <cite>param_idx2name</cite> was provided in the |
| constructor, the weight decay multipler is set as 0 for all |
| parameters whose name don’t end with <code class="docutils literal"><span class="pre">_weight</span></code> or |
| <code class="docutils literal"><span class="pre">_gamma</span></code>.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The default weight decay multiplier for a <cite>Variable</cite> |
| can be set with its <cite>wd_mult</cite> argument in the constructor.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args_wd_mult</strong> (<em>dict of string/int to float</em>) – <p>For each of its key-value entries, the weight decay multipler for the |
| parameter specified in the key will be set as the given value.</p> |
| <p>You can specify the parameter with either its name or its index. |
| If you use the name, you should pass <cite>sym</cite> in the constructor, |
| and the name you specified in the key of <cite>args_lr_mult</cite> should match |
| the name of the parameter in <cite>sym</cite>. If you use the index, it should |
| correspond to the index of the parameter used in the <cite>update</cite> method.</p> |
| <p>Specifying a parameter by its index is only supported for backward |
| compatibility, and we recommend to use the name instead.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.register"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a new optimizer.</p> |
| <p>Once an optimizer is registered, we can create an instance of this |
| optimizer with <cite>create_optimizer</cite> later.</p> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="nd">@mx.optimizer.Optimizer.register</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">MyOptimizer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">pass</span> |
| <span class="gp">>>> </span><span class="n">optim</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'MyOptimizer'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">optim</span><span class="p">))</span> |
| <span class="go"><class '__main__.myoptimizer'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.SGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">SGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>multi_precision=False</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.SGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The SGD optimizer with momentum and weight decay.</p> |
| <p>The optimizer updates the weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">*</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="n">state</span> |
| </pre></div> |
| </div> |
| <p>Sparse updating is supported. For details of the update algorithm see |
| <a class="reference internal" href="ndarray.html#mxnet.ndarray.sgd_update" title="mxnet.ndarray.sgd_update"><code class="xref py py-class docutils literal"><span class="pre">sgd_update</span></code></a> and <a class="reference internal" href="ndarray.html#mxnet.ndarray.sgd_mom_update" title="mxnet.ndarray.sgd_mom_update"><code class="xref py py-class docutils literal"><span class="pre">sgd_mom_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>multi_precision</strong> (<em>bool, optional</em>) – <p>Flag to control the internal precision of the optimizer. |
| <code class="docutils literal"><span class="pre">False</span></code> results in using the same precision as the weights (default), |
| <code class="docutils literal"><span class="pre">True</span></code> makes internal 32-bit copy of the weights and applies gradients</p> |
| <blockquote> |
| <div>in 32-bit precision even if actual weights used in the model have lower precision. |
| Turning this on can improve convergence and accuracy when training with float16.</div></blockquote> |
| </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.DCASGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">DCASGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>lamda=0.04</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.DCASGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The DCASGD optimizer.</p> |
| <p>This class implements the optimizer described in <em>Asynchronous Stochastic Gradient Descent |
| with Delay Compensation for Distributed Deep Learning</em>, |
| available at <a class="reference external" href="https://arxiv.org/abs/1609.08326">https://arxiv.org/abs/1609.08326</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>lamda</strong> (<em>float, optional</em>) – Scale DC value.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.NAG"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">NAG</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.NAG" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Nesterov accelerated SGD.</p> |
| <p>This optimizer updates each weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">grad</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="p">(</span><span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="n">grad</span> <span class="o">+</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <p>This optimizer accepts the same arguments as <a class="reference internal" href="#mxnet.optimizer.SGD" title="mxnet.optimizer.SGD"><code class="xref py py-class docutils literal"><span class="pre">SGD</span></code></a>.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.SGLD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">SGLD</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.SGLD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Stochastic Gradient Riemannian Langevin Dynamics.</p> |
| <p>This class implements the optimizer described in the paper <em>Stochastic Gradient |
| Riemannian Langevin Dynamics on the Probability Simplex</em>, available at |
| <a class="reference external" href="https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf">https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf</a>.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.ccSGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">ccSGD</code><span class="sig-paren">(</span><em>*args</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.ccSGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>[DEPRECATED] Same as <cite>SGD</cite>. Left here for backward compatibility.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Adam"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Adam</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Adam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Adam optimizer.</p> |
| <p>This class implements the optimizer described in <em>Adam: A Method for |
| Stochastic Optimization</em>, available at <a class="reference external" href="http://arxiv.org/abs/1412.6980">http://arxiv.org/abs/1412.6980</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <p>For details of the update algorithm, see <code class="xref py py-class docutils literal"><span class="pre">ndarray.adam_update</span></code>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.AdaGrad"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">AdaGrad</code><span class="sig-paren">(</span><em>eps=1e-07</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.AdaGrad" title="Permalink to this definition">¶</a></dt> |
| <dd><p>AdaGrad optimizer.</p> |
| <p>This class implements the AdaGrad optimizer described in <em>Adaptive Subgradient |
| Methods for Online Learning and Stochastic Optimization</em>, and available at |
| <a class="reference external" href="http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf">http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>eps</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.RMSProp"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">RMSProp</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>gamma1=0.9</em>, <em>gamma2=0.9</em>, <em>epsilon=1e-08</em>, <em>centered=False</em>, <em>clip_weights=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.RMSProp" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The RMSProp optimizer.</p> |
| <p>Two versions of RMSProp are implemented:</p> |
| <p>If <code class="docutils literal"><span class="pre">centered=False</span></code>, we follow |
| <a class="reference external" href="http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf</a> by |
| Tieleman & Hinton, 2012. |
| For details of the update algorithm see <a class="reference internal" href="ndarray.html#mxnet.ndarray.rmsprop_update" title="mxnet.ndarray.rmsprop_update"><code class="xref py py-class docutils literal"><span class="pre">rmsprop_update</span></code></a>.</p> |
| <p>If <code class="docutils literal"><span class="pre">centered=True</span></code>, we follow <a class="reference external" href="http://arxiv.org/pdf/1308.0850v5.pdf">http://arxiv.org/pdf/1308.0850v5.pdf</a> (38)-(45) |
| by Alex Graves, 2013. |
| For details of the update algorithm see <a class="reference internal" href="ndarray.html#mxnet.ndarray.rmspropalex_update" title="mxnet.ndarray.rmspropalex_update"><code class="xref py py-class docutils literal"><span class="pre">rmspropalex_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>gamma1</strong> (<em>float, optional</em>) – A decay factor of moving average over past squared gradient.</li> |
| <li><strong>gamma2</strong> (<em>float, optional</em>) – A “momentum” factor. Only used if <cite>centered`=``True`</cite>.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>centered</strong> (<em>bool, optional</em>) – Flag to control which version of RMSProp to use. |
| <code class="docutils literal"><span class="pre">True</span></code> will use Graves’s version of <cite>RMSProp</cite>, |
| <code class="docutils literal"><span class="pre">False</span></code> will use Tieleman & Hinton’s version of <cite>RMSProp</cite>.</li> |
| <li><strong>clip_weights</strong> (<em>float, optional</em>) – Clips weights into range <code class="docutils literal"><span class="pre">[-clip_weights,</span> <span class="pre">clip_weights]</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.AdaDelta"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">AdaDelta</code><span class="sig-paren">(</span><em>rho=0.9</em>, <em>epsilon=1e-05</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.AdaDelta" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The AdaDelta optimizer.</p> |
| <p>This class implements AdaDelta, an optimizer described in <em>ADADELTA: An adaptive |
| learning rate method</em>, available at <a class="reference external" href="https://arxiv.org/abs/1212.5701">https://arxiv.org/abs/1212.5701</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rho</strong> (<em>float</em>) – Decay rate for both squared gradients and delta.</li> |
| <li><strong>epsilon</strong> (<em>float</em>) – Small value to avoid division by 0.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Ftrl"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Ftrl</code><span class="sig-paren">(</span><em>lamda1=0.01</em>, <em>learning_rate=0.1</em>, <em>beta=1</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Ftrl" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Ftrl optimizer.</p> |
| <p>Referenced from <em>Ad Click Prediction: a View from the Trenches</em>, available at |
| <a class="reference external" href="http://dl.acm.org/citation.cfm?id=2488200">http://dl.acm.org/citation.cfm?id=2488200</a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>lamda1</strong> (<em>float, optional</em>) – L1 regularization coefficient.</li> |
| <li><strong>learning_rate</strong> (<em>float, optional</em>) – The initial learning rate.</li> |
| <li><strong>beta</strong> (<em>float, optional</em>) – Per-coordinate learning rate correlation parameter.</li> |
| <li><strong>eta</strong> – <div class="math"> |
| </div> |
| <p>eta_{t,i} = frac{learningrate}{beta+sqrt{sum_{s=1}^tg_{s,i}^t}}</p> |
| </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Adamax"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Adamax</code><span class="sig-paren">(</span><em>learning_rate=0.002</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Adamax" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The AdaMax optimizer.</p> |
| <p>It is a variant of Adam based on the infinity norm |
| available at <a class="reference external" href="http://arxiv.org/abs/1412.6980">http://arxiv.org/abs/1412.6980</a> Section 7.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Nadam"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Nadam</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>schedule_decay=0.004</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Nadam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Nesterov Adam optimizer.</p> |
| <p>Much like Adam is essentially RMSprop with momentum, |
| Nadam is Adam RMSprop with Nesterov momentum available |
| at <a class="reference external" href="https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ">https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>schedule_decay</strong> (<em>float, optional</em>) – Exponential decay rate for the momentum schedule</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Test"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Test</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Test" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Test optimizer</p> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Test.create_state"> |
| <code class="descname">create_state</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Test.create_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates a state to duplicate weight.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Test.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Test.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Performs w += rescale_grad * grad.</p> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.create"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">create</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Instantiates an optimizer with a given name and kwargs.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">We can use the alias <cite>create</cite> for <code class="docutils literal"><span class="pre">Optimizer.create_optimizer</span></code>.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of the optimizer. Should be the name |
| of a subclass of Optimizer. Case insensitive.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Parameters for the optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">An instantiated optimizer.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer">Optimizer</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">sgd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'sgd'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">sgd</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.sgd'=""></class></span> |
| <span class="gp">>>> </span><span class="n">adam</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mi">1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">adam</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.adam'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Updater"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Updater</code><span class="sig-paren">(</span><em>optimizer</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updater for kvstore.</p> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>index</em>, <em>grad</em>, <em>weight</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates weight given gradient and index.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.set_states"> |
| <code class="descname">set_states</code><span class="sig-paren">(</span><em>states</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater.set_states" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets updater states.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.get_states"> |
| <code class="descname">get_states</code><span class="sig-paren">(</span><em>dump_optimizer=False</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater.get_states" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Gets updater states.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>dump_optimizer</strong> (<em>bool, default False</em>) – Whether to also save the optimizer itself. This would also save optimizer |
| information such as learning rate and weight decay schedules.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.get_updater"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">get_updater</code><span class="sig-paren">(</span><em>optimizer</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.get_updater" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns a closure of the updater needed for kvstore.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>optimizer</strong> (<a class="reference internal" href="#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><em>Optimizer</em></a>) – The optimizer.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>updater</strong> – |
| The closure of the updater.</td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">function</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <span class="target" id="module-mxnet.lr_scheduler"></span><p>Scheduling learning rate.</p> |
| <dl class="class"> |
| <dt id="mxnet.lr_scheduler.LRScheduler"> |
| <em class="property">class </em><code class="descclassname">mxnet.lr_scheduler.</code><code class="descname">LRScheduler</code><span class="sig-paren">(</span><em>base_lr=0.01</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.lr_scheduler.LRScheduler" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Base class of a learning rate scheduler.</p> |
| <p>A scheduler returns a new learning rate based on the number of updates that have |
| been performed.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>base_lr</strong> (<em>float, optional</em>) – The initial learning rate.</td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.lr_scheduler.LRScheduler.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>num_update</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.lr_scheduler.LRScheduler.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Return a new learning rate.</p> |
| <p>The <code class="docutils literal"><span class="pre">num_update</span></code> is the upper bound of the number of updates applied to |
| every weight.</p> |
| <p>Assume the optimizer has updated <em>i</em>-th weight by <em>k_i</em> times, namely |
| <code class="docutils literal"><span class="pre">optimizer.update(i,</span> <span class="pre">weight_i)</span></code> is called by <em>k_i</em> times. Then:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span>num_update = max([k_i for all i]) |
| </pre></div> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>num_update</strong> (<em>int</em>) – the maximal number of updates applied to a weight.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.lr_scheduler.FactorScheduler"> |
| <em class="property">class </em><code class="descclassname">mxnet.lr_scheduler.</code><code class="descname">FactorScheduler</code><span class="sig-paren">(</span><em>step</em>, <em>factor=1</em>, <em>stop_factor_lr=1e-08</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.lr_scheduler.FactorScheduler" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Reduce the learning rate by a factor for every <em>n</em> steps.</p> |
| <p>It returns a new learning rate by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">base_lr</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">factor</span><span class="p">,</span> <span class="n">floor</span><span class="p">(</span><span class="n">num_update</span><span class="o">/</span><span class="n">step</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>step</strong> (<em>int</em>) – Changes the learning rate for every n updates.</li> |
| <li><strong>factor</strong> (<em>float, optional</em>) – The factor to change the learning rate.</li> |
| <li><strong>stop_factor_lr</strong> (<em>float, optional</em>) – Stop updating the learning rate if it is less than this value.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.lr_scheduler.MultiFactorScheduler"> |
| <em class="property">class </em><code class="descclassname">mxnet.lr_scheduler.</code><code class="descname">MultiFactorScheduler</code><span class="sig-paren">(</span><em>step</em>, <em>factor=1</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.lr_scheduler.MultiFactorScheduler" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Reduce the learning rate by given a list of steps.</p> |
| <p>Assume there exists <em>k</em> such that:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">step</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o"><=</span> <span class="n">num_update</span> <span class="ow">and</span> <span class="n">num_update</span> <span class="o"><</span> <span class="n">step</span><span class="p">[</span><span class="n">k</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| <p>Then calculate the new learning rate by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">base_lr</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">factor</span><span class="p">,</span> <span class="n">k</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>step</strong> (<em>list of int</em>) – The list of steps to schedule a change</li> |
| <li><strong>factor</strong> (<em>float</em>) – The factor to change the learning rate.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <span class="target" id="module-mxnet.initializer"></span><p>Weight initializer.</p> |
| <dl class="class"> |
| <dt id="mxnet.initializer.InitDesc"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">InitDesc</code><a class="headerlink" href="#mxnet.initializer.InitDesc" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Descriptor for the initialization pattern.</p> |
| <dl class="docutils"> |
| <dt>name <span class="classifier-delimiter">:</span> <span class="classifier">str</span></dt> |
| <dd>Name of variable.</dd> |
| <dt>attrs <span class="classifier-delimiter">:</span> <span class="classifier">dict of str to str</span></dt> |
| <dd>Attributes of this variable taken from <code class="docutils literal"><span class="pre">Symbol.attr_dict</span></code>.</dd> |
| <dt>global_init <span class="classifier-delimiter">:</span> <span class="classifier">Initializer</span></dt> |
| <dd>Global initializer to fallback to.</dd> |
| </dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Initializer"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Initializer</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The base class of an initializer.</p> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.set_verbosity"> |
| <code class="descname">set_verbosity</code><span class="sig-paren">(</span><em>verbose=False</em>, <em>print_func=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer.set_verbosity" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Switch on/off verbose mode</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>verbose</strong> (<em>bool</em>) – switch on/off verbose mode</li> |
| <li><strong>print_func</strong> (<em>function</em>) – A function that computes statistics of initialized arrays. |
| Takes an <cite>NDArray</cite> and returns an <cite>str</cite>. Defaults to mean |
| absolute value str((<a href="#id1"><span class="problematic" id="id2">|x|</span></a>/size(x)).asscalar()).</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.dumps"> |
| <code class="descname">dumps</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer.dumps" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Saves the initializer to string</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">JSON formatted string that describes the initializer.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">str</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Create initializer and retrieve its parameters</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">init</span><span class="o">.</span><span class="n">dumps</span><span class="p">()</span> |
| <span class="go">'["normal", {"sigma": 0.5}]'</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Xavier</span><span class="p">(</span><span class="n">factor_type</span><span class="o">=</span><span class="s2">"in"</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mf">2.34</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">init</span><span class="o">.</span><span class="n">dumps</span><span class="p">()</span> |
| <span class="go">'["xavier", {"rnd_type": "uniform", "magnitude": 2.34, "factor_type": "in"}]'</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>desc</em>, <em>arr</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize an array</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>desc</strong> (<a class="reference internal" href="#mxnet.initializer.InitDesc" title="mxnet.initializer.InitDesc"><em>InitDesc</em></a>) – Initialization pattern descriptor.</li> |
| <li><strong>arr</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The array to be initialized.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.initializer.register"> |
| <code class="descclassname">mxnet.initializer.</code><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a custom initializer.</p> |
| <p>Custom initializers can be created by extending <cite>mx.init.Initializer</cite> and implementing the |
| required functions like <cite>_init_weight</cite> and <cite>_init_bias</cite>. The created initializer must be |
| registered using <cite>mx.init.register</cite> before it can be called by name.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>klass</strong> (<em>class</em>) – A subclass of <cite>mx.init.Initializer</cite> that needs to be registered as a custom initializer.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Create and register a custom initializer that</span> |
| <span class="gp">... </span><span class="c1"># initializes weights to 0.1 and biases to 1.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="nd">@mx.init.register</span> |
| <span class="gp">... </span><span class="nd">@alias</span><span class="p">(</span><span class="s1">'myinit'</span><span class="p">)</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">CustomInit</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Initializer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="nb">super</span><span class="p">(</span><span class="n">CustomInit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="nf">_init_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="mf">0.1</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="nf">_init_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="mi">1</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="c1"># Module is an instance of 'mxnet.module.Module'</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="s2">"custominit"</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="c1"># module.init_params("myinit")</span> |
| <span class="gp">>>> </span><span class="c1"># module.init_params(CustomInit())</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Load"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Load</code><span class="sig-paren">(</span><em>param</em>, <em>default_init=None</em>, <em>verbose=False</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Load" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes variables by loading data from file or dict.</p> |
| <p><strong>Note</strong> Load will drop <code class="docutils literal"><span class="pre">arg:</span></code> or <code class="docutils literal"><span class="pre">aux:</span></code> from name and |
| initialize the variables that match with the prefix dropped.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>param</strong> (<em>str or dict of str->`NDArray`</em>) – Parameter file or dict mapping name to NDArray.</li> |
| <li><strong>default_init</strong> (<a class="reference internal" href="#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><em>Initializer</em></a>) – Default initializer when name is not found in <cite>param</cite>.</li> |
| <li><strong>verbose</strong> (<em>bool</em>) – Flag for enabling logging of source when initializing.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Mixed"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Mixed</code><span class="sig-paren">(</span><em>patterns</em>, <em>initializers</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Mixed" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize parameters using multiple initializers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>patterns</strong> (<em>list of str</em>) – List of regular expressions matching parameter names.</li> |
| <li><strong>initializers</strong> (<em>list of Initializer</em>) – List of initializers corresponding to <cite>patterns</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize biases to zero</span> |
| <span class="gp">... </span><span class="c1"># and every other parameter to random values with uniform distribution.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Mixed</span><span class="p">([</span><span class="s1">'bias'</span><span class="p">,</span> <span class="s1">'.*'</span><span class="p">],</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Zero</span><span class="p">(),</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)])</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="go">>>></span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected1_weight</span> |
| <span class="go">[[ 0.0097627 0.01856892 0.04303787]]</span> |
| <span class="go">fullyconnected1_bias</span> |
| <span class="go">[ 0.]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Zero"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Zero</code><a class="headerlink" href="#mxnet.initializer.Zero" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights to zero.</p> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights to zero.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Zero</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 0. 0. 0.]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.One"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">One</code><a class="headerlink" href="#mxnet.initializer.One" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights to one.</p> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights to one.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">One</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 1. 1. 1.]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Constant"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Constant</code><span class="sig-paren">(</span><em>value</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Constant" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes the weights to a scalar value.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>value</strong> (<em>float</em>) – Fill value.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Uniform"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Uniform</code><span class="sig-paren">(</span><em>scale=0.07</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Uniform" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights with random values uniformly sampled from a given range.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>scale</strong> (<em>float, optional</em>) – The bound on the range of the generated random values. |
| Values are generated from the range [-<cite>scale</cite>, <cite>scale</cite>]. |
| Default scale is 0.07.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights</span> |
| <span class="gp">>>> </span><span class="c1"># to random values uniformly sampled between -0.1 and 0.1.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 0.01360891 -0.02144304 0.08511933]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Normal"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Normal</code><span class="sig-paren">(</span><em>sigma=0.01</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Normal" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights with random values sampled from a normal distribution |
| with a mean of zero and standard deviation of <cite>sigma</cite>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>sigma</strong> (<em>float, optional</em>) – Standard deviation of the normal distribution. |
| Default standard deviation is 0.01.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights</span> |
| <span class="gp">>>> </span><span class="c1"># to random values sampled from a normal distribution.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[-0.3214761 -0.12660924 0.53789419]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Orthogonal"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Orthogonal</code><span class="sig-paren">(</span><em>scale=1.414</em>, <em>rand_type='uniform'</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Orthogonal" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize weight as orthogonal matrix.</p> |
| <p>This initializer implements <em>Exact solutions to the nonlinear dynamics of |
| learning in deep linear neural networks</em>, available at |
| <a class="reference external" href="https://arxiv.org/abs/1312.6120">https://arxiv.org/abs/1312.6120</a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>scale</strong> (<em>float optional</em>) – Scaling factor of weight.</li> |
| <li><strong>rand_type</strong> (<em>string optional</em>) – Use “uniform” or “normal” random number to initialize weight.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Xavier"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Xavier</code><span class="sig-paren">(</span><em>rnd_type='uniform'</em>, <em>factor_type='avg'</em>, <em>magnitude=3</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Xavier" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns an initializer performing “Xavier” initialization for weights.</p> |
| <p>This initializer is designed to keep the scale of gradients roughly the same |
| in all layers.</p> |
| <p>By default, <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'uniform'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'avg'</span></code>, |
| the initializer fills the weights with random numbers in the range |
| of <span class="math">\([-c, c]\)</span>, where <span class="math">\(c = \sqrt{\frac{3.}{0.5 * (n_{in} + n_{out})}}\)</span>. |
| <span class="math">\(n_{in}\)</span> is the number of neurons feeding into weights, and <span class="math">\(n_{out}\)</span> is |
| the number of neurons the result is fed to.</p> |
| <p>If <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'uniform'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'in'</span></code>, |
| the <span class="math">\(c = \sqrt{\frac{3.}{n_{in}}}\)</span>. |
| Similarly when <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'out'</span></code>, the <span class="math">\(c = \sqrt{\frac{3.}{n_{out}}}\)</span>.</p> |
| <p>If <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'gaussian'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'avg'</span></code>, |
| the initializer fills the weights with numbers from normal distribution with |
| a standard deviation of <span class="math">\(\sqrt{\frac{3.}{0.5 * (n_{in} + n_{out})}}\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rnd_type</strong> (<em>str, optional</em>) – Random generator type, can be <code class="docutils literal"><span class="pre">'gaussian'</span></code> or <code class="docutils literal"><span class="pre">'uniform'</span></code>.</li> |
| <li><strong>factor_type</strong> (<em>str, optional</em>) – Can be <code class="docutils literal"><span class="pre">'avg'</span></code>, <code class="docutils literal"><span class="pre">'in'</span></code>, or <code class="docutils literal"><span class="pre">'out'</span></code>.</li> |
| <li><strong>magnitude</strong> (<em>float, optional</em>) – Scale of random number.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.MSRAPrelu"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">MSRAPrelu</code><span class="sig-paren">(</span><em>factor_type='avg'</em>, <em>slope=0.25</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.MSRAPrelu" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize the weight according to a MSRA paper.</p> |
| <p>This initializer implements <em>Delving Deep into Rectifiers: Surpassing |
| Human-Level Performance on ImageNet Classification</em>, available at |
| <a class="reference external" href="https://arxiv.org/abs/1502.01852">https://arxiv.org/abs/1502.01852</a>.</p> |
| <p>This initializer is proposed for initialization related to ReLu activation, |
| it maked some changes on top of Xavier method.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>factor_type</strong> (<em>str, optional</em>) – Can be <code class="docutils literal"><span class="pre">'avg'</span></code>, <code class="docutils literal"><span class="pre">'in'</span></code>, or <code class="docutils literal"><span class="pre">'out'</span></code>.</li> |
| <li><strong>slope</strong> (<em>float, optional</em>) – initial slope of any PReLU (or similar) nonlinearities.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Bilinear"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Bilinear</code><a class="headerlink" href="#mxnet.initializer.Bilinear" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize weight for upsampling layers.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.LSTMBias"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">LSTMBias</code><span class="sig-paren">(</span><em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.LSTMBias" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize all bias of an LSTMCell to 0.0 except for |
| the forget gate whose bias is set to custom value.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>forget_bias</strong> (<em>float, default 1.0</em>) – bias for the forget gate. Jozefowicz et al. 2015 recommends |
| setting this to 1.0.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.FusedRNN"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">FusedRNN</code><span class="sig-paren">(</span><em>init</em>, <em>num_hidden</em>, <em>num_layers</em>, <em>mode</em>, <em>bidirectional=False</em>, <em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.FusedRNN" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize parameters for fused rnn layers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>init</strong> (<a class="reference internal" href="#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><em>Initializer</em></a>) – initializer applied to unpacked weights. Fall back to global |
| initializer if None.</li> |
| <li><strong>num_hidden</strong> (<em>int</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>num_layers</strong> (<em>int</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>mode</strong> (<em>str</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>bidirectional</strong> (<em>bool</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>forget_bias</strong> (<em>float</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <script>auto_index("api-reference");</script></div> |
| </div> |
| <div class="container"> |
| <div class="footer"> |
| <p> </p> |
| </div> |
| </div> |
| </div> |
| <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <h3><a href="../../index.html">Table Of Contents</a></h3> |
| <ul> |
| <li><a class="reference internal" href="#">Optimization: initialize and update weights</a><ul> |
| <li><a class="reference internal" href="#overview">Overview</a></li> |
| <li><a class="reference internal" href="#the-mxnet-initializer-package">The <code class="docutils literal"><span class="pre">mxnet.initializer</span></code> package</a></li> |
| <li><a class="reference internal" href="#the-mxnet-optimizer-package">The <code class="docutils literal"><span class="pre">mxnet.optimizer</span></code> package</a></li> |
| <li><a class="reference internal" href="#the-mxnet-lr-scheduler-package">The <code class="docutils literal"><span class="pre">mxnet.lr_scheduler</span></code> package</a></li> |
| <li><a class="reference internal" href="#implement-a-new-algorithm">Implement a new algorithm</a></li> |
| <li><a class="reference internal" href="#api-reference">API Reference</a></li> |
| </ul> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> <!-- pagename != index --> |
| <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |
| <script src="../../_static/js/sidebar.js" type="text/javascript"></script> |
| <script src="../../_static/js/search.js" type="text/javascript"></script> |
| <script src="../../_static/js/navbar.js" type="text/javascript"></script> |
| <script src="../../_static/js/clipboard.min.js" type="text/javascript"></script> |
| <script src="../../_static/js/copycode.js" type="text/javascript"></script> |
| <script type="text/javascript"> |
| $('body').ready(function () { |
| $('body').css('visibility', 'visible'); |
| }); |
| </script> |
| </div></body> |
| </html> |