| <!DOCTYPE html> |
| |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta content="IE=edge" http-equiv="X-UA-Compatible"/> |
| <meta content="width=device-width, initial-scale=1" name="viewport"/> |
| <title>Model API — mxnet documentation</title> |
| <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> |
| <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> |
| <link href="../../_static/basic.css" rel="stylesheet" type="text/css"/> |
| <link href="../../_static/pygments.css" rel="stylesheet" type="text/css"/> |
| <link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"> |
| <script type="text/javascript"> |
| var DOCUMENTATION_OPTIONS = { |
| URL_ROOT: '../../', |
| VERSION: '', |
| COLLAPSE_INDEX: false, |
| FILE_SUFFIX: '.html', |
| HAS_SOURCE: true, |
| SOURCELINK_SUFFIX: '' |
| }; |
| </script> |
| <script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script> |
| <script src="../../_static/underscore.js" type="text/javascript"></script> |
| <script src="../../_static/searchtools_custom.js" type="text/javascript"></script> |
| <script src="../../_static/doctools.js" type="text/javascript"></script> |
| <script src="../../_static/selectlang.js" type="text/javascript"></script> |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> |
| <script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new |
| Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| |
| </script> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/jquery.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/underscore.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/doctools.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> |
| <!-- --> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"> |
| </link></link></head> |
| <body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document"> |
| <div class="content-block"><div class="navbar navbar-fixed-top"> |
| <div class="container" id="navContainer"> |
| <div class="innder" id="header-inner"> |
| <h1 id="logo-wrap"> |
| <a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a> |
| </h1> |
| <nav class="nav-bar" id="main-nav"> |
| <a class="main-nav-link" href="../../install/index.html">Install</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="../../gluon/index.html">About</a></li> |
| <li><a class="main-nav-link" href="http://gluon.mxnet.io">Tutorials</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li> |
| <li><a class="main-nav-link" href="../../api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-docs"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs"> |
| <li><a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a> |
| <li><a class="main-nav-link" href="../../faq/index.html">FAQ</a></li> |
| <li><a class="main-nav-link" href="../../architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.1.0/example">Examples</a></li> |
| <li><a class="main-nav-link" href="../../api/python/gluon/model_zoo.html">Gluon Model Zoo</a></li> |
| </li></ul> |
| </span> |
| <a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a> |
| <span id="dropdown-menu-position-anchor-community"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community"> |
| <li><a class="main-nav-link" href="../../community/index.html">Community</a></li> |
| <li><a class="main-nav-link" href="../../community/contribute.html">Contribute</a></li> |
| <li><a class="main-nav-link" href="../../community/powered_by.html">Powered By</a></li> |
| <li><a class="main-nav-link" href="http://discuss.mxnet.io">Discuss</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/>1.1.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></span></nav> |
| <script> function getRootPath(){ return "../../" } </script> |
| <div class="burgerIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> |
| <ul class="dropdown-menu" id="burgerMenu"> |
| <li><a href="../../install/index.html">Install</a></li> |
| <li><a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a></li> |
| <li class="dropdown-submenu"> |
| <a href="#" tabindex="-1">Community</a> |
| <ul class="dropdown-menu"> |
| <li><a href="../../community/index.html" tabindex="-1">Community</a></li> |
| <li><a href="../../community/contribute.html" tabindex="-1">Contribute</a></li> |
| <li><a href="../../community/powered_by.html" tabindex="-1">Powered By</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu"> |
| <a href="#" tabindex="-1">API</a> |
| <ul class="dropdown-menu"> |
| <li><a href="../../api/python/index.html" tabindex="-1">Python</a> |
| </li> |
| <li><a href="../../api/scala/index.html" tabindex="-1">Scala</a> |
| </li> |
| <li><a href="../../api/r/index.html" tabindex="-1">R</a> |
| </li> |
| <li><a href="../../api/julia/index.html" tabindex="-1">Julia</a> |
| </li> |
| <li><a href="../../api/c++/index.html" tabindex="-1">C++</a> |
| </li> |
| <li><a href="../../api/perl/index.html" tabindex="-1">Perl</a> |
| </li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu"> |
| <a href="#" tabindex="-1">Docs</a> |
| <ul class="dropdown-menu"> |
| <li><a href="../../tutorials/index.html" tabindex="-1">Tutorials</a></li> |
| <li><a href="../../faq/index.html" tabindex="-1">FAQ</a></li> |
| <li><a href="../../architecture/index.html" tabindex="-1">Architecture</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/tree/1.1.0/example" tabindex="-1">Examples</a></li> |
| <li><a href="../../api/python/gluon/model_zoo.html" tabindex="-1">Gluon Model Zoo</a></li> |
| </ul> |
| </li> |
| <li><a href="../../architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li> |
| <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/>1.1.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></li></ul> |
| </div> |
| <div class="plusIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> |
| <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> |
| </div> |
| <div id="search-input-wrap"> |
| <form action="../../search.html" autocomplete="off" class="" method="get" role="search"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input class="form-control" name="q" placeholder="Search" type="text"/> |
| </div> |
| <input name="check_keywords" type="hidden" value="yes"/> |
| <input name="area" type="hidden" value="default"> |
| </input></form> |
| <div id="search-preview"></div> |
| </div> |
| <div id="searchIcon"> |
| <span aria-hidden="true" class="glyphicon glyphicon-search"></span> |
| </div> |
| <!-- <div id="lang-select-wrap"> --> |
| <!-- <label id="lang-select-label"> --> |
| <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> |
| <!-- <span></span> --> |
| <!-- </label> --> |
| <!-- <select id="lang-select"> --> |
| <!-- <option value="en">Eng</option> --> |
| <!-- <option value="zh">中文</option> --> |
| <!-- </select> --> |
| <!-- </div> --> |
| <!-- <a id="mobile-nav-toggle"> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| </a> --> |
| </div> |
| </div> |
| </div> |
| <script type="text/javascript"> |
| $('body').css('background', 'white'); |
| </script> |
| <div class="container"> |
| <div class="row"> |
| <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="index.html">Python Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../r/index.html">R Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../julia/index.html">Julia Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../c++/index.html">C++ Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../scala/index.html">Scala Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../perl/index.html">Perl Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">HowTo Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../community/index.html">Community</a></li> |
| </ul> |
| </div> |
| </div> |
| <div class="content"> |
| <div class="page-tracker"></div> |
| <div class="section" id="model-api"> |
| <span id="model-api"></span><h1>Model API<a class="headerlink" href="#model-api" title="Permalink to this headline">¶</a></h1> |
| <p>The model API provides a simplified way to train neural networks using common best practices. |
| It’s a thin wrapper built on top of the <a class="reference internal" href="ndarray/ndarray.html"><em>ndarray</em></a> and <a class="reference internal" href="symbol/symbol.html"><em>symbolic</em></a> |
| modules that make neural network training easy.</p> |
| <p>Topics:</p> |
| <ul class="simple"> |
| <li><a class="reference external" href="#train-a-model">Train a Model</a></li> |
| <li><a class="reference external" href="#save-the-model">Save the Model</a></li> |
| <li><a class="reference external" href="#periodic-checkpointing">Periodic Checkpoint</a></li> |
| <li><a class="reference external" href="#initializer-api-reference">Initializer API Reference</a></li> |
| <li><a class="reference external" href="#evaluation-metric-api-reference">Evaluation Metric API Reference</a></li> |
| <li><a class="reference external" href="#optimizer-api-reference">Optimizer API Reference</a></li> |
| <li><a class="reference external" href="#model-api-reference">Model API Reference</a></li> |
| </ul> |
| <div class="section" id="train-the-model"> |
| <span id="train-the-model"></span><h2>Train the Model<a class="headerlink" href="#train-the-model" title="Permalink to this headline">¶</a></h2> |
| <p>To train a model, perform two steps: configure the model using the symbol parameter, |
| then call <code class="docutils literal"><span class="pre">model.Feedforward.create</span></code> to create the model. |
| The following example creates a two-layer neural network.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="c1"># configure a two layer neuralnetwork</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span> |
| <span class="n">fc1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc1'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span> |
| <span class="n">act1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">fc1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'relu1'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span> |
| <span class="n">fc2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">act1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc2'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span> |
| <span class="n">softmax</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">fc2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'sm'</span><span class="p">)</span> |
| <span class="c1"># create a model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">X</span><span class="o">=</span><span class="n">data_set</span><span class="p">,</span> |
| <span class="n">num_epoch</span><span class="o">=</span><span class="n">num_epoch</span><span class="p">,</span> |
| <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>You can also use the <code class="docutils literal"><span class="pre">scikit-learn-style</span></code> construct and <code class="docutils literal"><span class="pre">fit</span></code> function to create a model.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="c1"># create a model using sklearn-style two-step way</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">num_epoch</span><span class="o">=</span><span class="n">num_epoch</span><span class="p">,</span> |
| <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span> |
| |
| <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="o">=</span><span class="n">data_set</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>For more information, see <a class="reference external" href="#model-api-reference">Model API Reference</a>.</p> |
| </div> |
| <div class="section" id="save-the-model"> |
| <span id="save-the-model"></span><h2>Save the Model<a class="headerlink" href="#save-the-model" title="Permalink to this headline">¶</a></h2> |
| <p>After the job is done, save your work. |
| To save the model, you can directly pickle it with Python. |
| We also provide <code class="docutils literal"><span class="pre">save</span></code> and <code class="docutils literal"><span class="pre">load</span></code> functions.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="c1"># save a model to mymodel-symbol.json and mymodel-0100.params</span> |
| <span class="n">prefix</span> <span class="o">=</span> <span class="s1">'mymodel'</span> |
| <span class="n">iteration</span> <span class="o">=</span> <span class="mi">100</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">iteration</span><span class="p">)</span> |
| |
| <span class="c1"># load model back</span> |
| <span class="n">model_loaded</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">iteration</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The advantage of these two <code class="docutils literal"><span class="pre">save</span></code> and <code class="docutils literal"><span class="pre">load</span></code> functions are that they are language agnostic. |
| You should be able to save and load directly into cloud storage, such as Amazon S3 and HDFS.</p> |
| </div> |
| <div class="section" id="periodic-checkpointing"> |
| <span id="periodic-checkpointing"></span><h2>Periodic Checkpointing<a class="headerlink" href="#periodic-checkpointing" title="Permalink to this headline">¶</a></h2> |
| <p>We recommend checkpointing your model after each iteration. |
| To do this, add a checkpoint callback <code class="docutils literal"><span class="pre">do_checkpoint(path)</span></code> to the function. |
| The training process automatically checkpoints the specified location after |
| each iteration.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="n">prefix</span><span class="o">=</span><span class="s1">'models/chkpt'</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">X</span><span class="o">=</span><span class="n">data_set</span><span class="p">,</span> |
| <span class="n">iter_end_callback</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">do_checkpoint</span><span class="p">(</span><span class="n">prefix</span><span class="p">),</span> |
| <span class="o">...</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>You can load the model checkpoint later using <code class="docutils literal"><span class="pre">Feedforward.load</span></code>.</p> |
| </div> |
| <div class="section" id="use-multiple-devices"> |
| <span id="use-multiple-devices"></span><h2>Use Multiple Devices<a class="headerlink" href="#use-multiple-devices" title="Permalink to this headline">¶</a></h2> |
| <p>Set <code class="docutils literal"><span class="pre">ctx</span></code> to the list of devices that you want to train on.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="n">devices</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_device</span><span class="p">)]</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">X</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> |
| <span class="n">ctx</span><span class="o">=</span><span class="n">devices</span><span class="p">,</span> |
| <span class="o">...</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>Training occurs in parallel on the GPUs that you specify.</p> |
| <blockquote> |
| <div><script src="../../_static/js/auto_module_index.js" type="text/javascript"></script></div></blockquote> |
| </div> |
| <div class="section" id="initializer-api-reference"> |
| <span id="initializer-api-reference"></span><h2>Initializer API Reference<a class="headerlink" href="#initializer-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.initializer"></span><p>Weight initializer.</p> |
| <dl class="class"> |
| <dt id="mxnet.initializer.InitDesc"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">InitDesc</code><a class="reference internal" href="../../_modules/mxnet/initializer.html#InitDesc"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.InitDesc" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Descriptor for the initialization pattern.</p> |
| <dl class="docutils"> |
| <dt>name <span class="classifier-delimiter">:</span> <span class="classifier">str</span></dt> |
| <dd>Name of variable.</dd> |
| <dt>attrs <span class="classifier-delimiter">:</span> <span class="classifier">dict of str to str</span></dt> |
| <dd>Attributes of this variable taken from <code class="docutils literal"><span class="pre">Symbol.attr_dict</span></code>.</dd> |
| <dt>global_init <span class="classifier-delimiter">:</span> <span class="classifier">Initializer</span></dt> |
| <dd>Global initializer to fallback to.</dd> |
| </dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Initializer"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Initializer</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Initializer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Initializer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The base class of an initializer.</p> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.set_verbosity"> |
| <code class="descname">set_verbosity</code><span class="sig-paren">(</span><em>verbose=False</em>, <em>print_func=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Initializer.set_verbosity"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Initializer.set_verbosity" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Switch on/off verbose mode</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>verbose</strong> (<em>bool</em>) – switch on/off verbose mode</li> |
| <li><strong>print_func</strong> (<em>function</em>) – A function that computes statistics of initialized arrays. |
| Takes an <cite>NDArray</cite> and returns an <cite>str</cite>. Defaults to mean |
| absolute value str((<a href="#id11"><span class="problematic" id="id12">|x|</span></a>/size(x)).asscalar()).</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.dumps"> |
| <code class="descname">dumps</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Initializer.dumps"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Initializer.dumps" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Saves the initializer to string</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">JSON formatted string that describes the initializer.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">str</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Create initializer and retrieve its parameters</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">init</span><span class="o">.</span><span class="n">dumps</span><span class="p">()</span> |
| <span class="go">'["normal", {"sigma": 0.5}]'</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Xavier</span><span class="p">(</span><span class="n">factor_type</span><span class="o">=</span><span class="s2">"in"</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mf">2.34</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">init</span><span class="o">.</span><span class="n">dumps</span><span class="p">()</span> |
| <span class="go">'["xavier", {"rnd_type": "uniform", "magnitude": 2.34, "factor_type": "in"}]'</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>desc</em>, <em>arr</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Initializer.__call__"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Initializer.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize an array</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>desc</strong> (<a class="reference internal" href="optimization/optimization.html#mxnet.initializer.InitDesc" title="mxnet.initializer.InitDesc"><em>InitDesc</em></a>) – Initialization pattern descriptor.</li> |
| <li><strong>arr</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The array to be initialized.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.initializer.register"> |
| <code class="descclassname">mxnet.initializer.</code><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#register"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a custom initializer.</p> |
| <p>Custom initializers can be created by extending <cite>mx.init.Initializer</cite> and implementing the |
| required functions like <cite>_init_weight</cite> and <cite>_init_bias</cite>. The created initializer must be |
| registered using <cite>mx.init.register</cite> before it can be called by name.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>klass</strong> (<em>class</em>) – A subclass of <cite>mx.init.Initializer</cite> that needs to be registered as a custom initializer.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Create and register a custom initializer that</span> |
| <span class="gp">... </span><span class="c1"># initializes weights to 0.1 and biases to 1.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="nd">@mx.init.register</span> |
| <span class="gp">... </span><span class="nd">@alias</span><span class="p">(</span><span class="s1">'myinit'</span><span class="p">)</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">CustomInit</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Initializer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="nb">super</span><span class="p">(</span><span class="n">CustomInit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="nf">_init_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="mf">0.1</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="nf">_init_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="mi">1</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="c1"># Module is an instance of 'mxnet.module.Module'</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="s2">"custominit"</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="c1"># module.init_params("myinit")</span> |
| <span class="gp">>>> </span><span class="c1"># module.init_params(CustomInit())</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Load"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Load</code><span class="sig-paren">(</span><em>param</em>, <em>default_init=None</em>, <em>verbose=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Load"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Load" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes variables by loading data from file or dict.</p> |
| <p><strong>Note</strong> Load will drop <code class="docutils literal"><span class="pre">arg:</span></code> or <code class="docutils literal"><span class="pre">aux:</span></code> from name and |
| initialize the variables that match with the prefix dropped.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>param</strong> (<em>str or dict of str->`NDArray`</em>) – Parameter file or dict mapping name to NDArray.</li> |
| <li><strong>default_init</strong> (<a class="reference internal" href="optimization/optimization.html#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><em>Initializer</em></a>) – Default initializer when name is not found in <cite>param</cite>.</li> |
| <li><strong>verbose</strong> (<em>bool</em>) – Flag for enabling logging of source when initializing.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Mixed"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Mixed</code><span class="sig-paren">(</span><em>patterns</em>, <em>initializers</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Mixed"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Mixed" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize parameters using multiple initializers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>patterns</strong> (<em>list of str</em>) – List of regular expressions matching parameter names.</li> |
| <li><strong>initializers</strong> (<em>list of Initializer</em>) – List of initializers corresponding to <cite>patterns</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize biases to zero</span> |
| <span class="gp">... </span><span class="c1"># and every other parameter to random values with uniform distribution.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Mixed</span><span class="p">([</span><span class="s1">'bias'</span><span class="p">,</span> <span class="s1">'.*'</span><span class="p">],</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Zero</span><span class="p">(),</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)])</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="go">>>></span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected1_weight</span> |
| <span class="go">[[ 0.0097627 0.01856892 0.04303787]]</span> |
| <span class="go">fullyconnected1_bias</span> |
| <span class="go">[ 0.]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Zero"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Zero</code><a class="reference internal" href="../../_modules/mxnet/initializer.html#Zero"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Zero" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights to zero.</p> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights to zero.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Zero</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 0. 0. 0.]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.One"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">One</code><a class="reference internal" href="../../_modules/mxnet/initializer.html#One"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.One" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights to one.</p> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights to one.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">One</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 1. 1. 1.]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Constant"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Constant</code><span class="sig-paren">(</span><em>value</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Constant"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Constant" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes the weights to a scalar value.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>value</strong> (<em>float</em>) – Fill value.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Uniform"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Uniform</code><span class="sig-paren">(</span><em>scale=0.07</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Uniform"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Uniform" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights with random values uniformly sampled from a given range.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>scale</strong> (<em>float, optional</em>) – The bound on the range of the generated random values. |
| Values are generated from the range [-<cite>scale</cite>, <cite>scale</cite>]. |
| Default scale is 0.07.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights</span> |
| <span class="gp">>>> </span><span class="c1"># to random values uniformly sampled between -0.1 and 0.1.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 0.01360891 -0.02144304 0.08511933]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Normal"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Normal</code><span class="sig-paren">(</span><em>sigma=0.01</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Normal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Normal" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights with random values sampled from a normal distribution |
| with a mean of zero and standard deviation of <cite>sigma</cite>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>sigma</strong> (<em>float, optional</em>) – Standard deviation of the normal distribution. |
| Default standard deviation is 0.01.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights</span> |
| <span class="gp">>>> </span><span class="c1"># to random values sampled from a normal distribution.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[-0.3214761 -0.12660924 0.53789419]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Orthogonal"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Orthogonal</code><span class="sig-paren">(</span><em>scale=1.414</em>, <em>rand_type='uniform'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Orthogonal"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Orthogonal" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize weight as orthogonal matrix.</p> |
| <p>This initializer implements <em>Exact solutions to the nonlinear dynamics of |
| learning in deep linear neural networks</em>, available at |
| <a class="reference external" href="https://arxiv.org/abs/1312.6120">https://arxiv.org/abs/1312.6120</a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>scale</strong> (<em>float optional</em>) – Scaling factor of weight.</li> |
| <li><strong>rand_type</strong> (<em>string optional</em>) – Use “uniform” or “normal” random number to initialize weight.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Xavier"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Xavier</code><span class="sig-paren">(</span><em>rnd_type='uniform'</em>, <em>factor_type='avg'</em>, <em>magnitude=3</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#Xavier"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Xavier" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns an initializer performing “Xavier” initialization for weights.</p> |
| <p>This initializer is designed to keep the scale of gradients roughly the same |
| in all layers.</p> |
| <p>By default, <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'uniform'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'avg'</span></code>, |
| the initializer fills the weights with random numbers in the range |
| of <span class="math">\([-c, c]\)</span>, where <span class="math">\(c = \sqrt{\frac{3.}{0.5 * (n_{in} + n_{out})}}\)</span>. |
| <span class="math">\(n_{in}\)</span> is the number of neurons feeding into weights, and <span class="math">\(n_{out}\)</span> is |
| the number of neurons the result is fed to.</p> |
| <p>If <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'uniform'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'in'</span></code>, |
| the <span class="math">\(c = \sqrt{\frac{3.}{n_{in}}}\)</span>. |
| Similarly when <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'out'</span></code>, the <span class="math">\(c = \sqrt{\frac{3.}{n_{out}}}\)</span>.</p> |
| <p>If <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'gaussian'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'avg'</span></code>, |
| the initializer fills the weights with numbers from normal distribution with |
| a standard deviation of <span class="math">\(\sqrt{\frac{3.}{0.5 * (n_{in} + n_{out})}}\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rnd_type</strong> (<em>str, optional</em>) – Random generator type, can be <code class="docutils literal"><span class="pre">'gaussian'</span></code> or <code class="docutils literal"><span class="pre">'uniform'</span></code>.</li> |
| <li><strong>factor_type</strong> (<em>str, optional</em>) – Can be <code class="docutils literal"><span class="pre">'avg'</span></code>, <code class="docutils literal"><span class="pre">'in'</span></code>, or <code class="docutils literal"><span class="pre">'out'</span></code>.</li> |
| <li><strong>magnitude</strong> (<em>float, optional</em>) – Scale of random number.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.MSRAPrelu"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">MSRAPrelu</code><span class="sig-paren">(</span><em>factor_type='avg'</em>, <em>slope=0.25</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#MSRAPrelu"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.MSRAPrelu" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize the weight according to a MSRA paper.</p> |
| <p>This initializer implements <em>Delving Deep into Rectifiers: Surpassing |
| Human-Level Performance on ImageNet Classification</em>, available at |
| <a class="reference external" href="https://arxiv.org/abs/1502.01852">https://arxiv.org/abs/1502.01852</a>.</p> |
| <p>This initializer is proposed for initialization related to ReLu activation, |
| it maked some changes on top of Xavier method.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>factor_type</strong> (<em>str, optional</em>) – Can be <code class="docutils literal"><span class="pre">'avg'</span></code>, <code class="docutils literal"><span class="pre">'in'</span></code>, or <code class="docutils literal"><span class="pre">'out'</span></code>.</li> |
| <li><strong>slope</strong> (<em>float, optional</em>) – initial slope of any PReLU (or similar) nonlinearities.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Bilinear"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Bilinear</code><a class="reference internal" href="../../_modules/mxnet/initializer.html#Bilinear"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.Bilinear" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize weight for upsampling layers.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.LSTMBias"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">LSTMBias</code><span class="sig-paren">(</span><em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#LSTMBias"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.LSTMBias" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize all bias of an LSTMCell to 0.0 except for |
| the forget gate whose bias is set to custom value.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>forget_bias</strong> (<em>float, default 1.0</em>) – bias for the forget gate. Jozefowicz et al. 2015 recommends |
| setting this to 1.0.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.FusedRNN"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">FusedRNN</code><span class="sig-paren">(</span><em>init</em>, <em>num_hidden</em>, <em>num_layers</em>, <em>mode</em>, <em>bidirectional=False</em>, <em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/initializer.html#FusedRNN"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.initializer.FusedRNN" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize parameters for fused rnn layers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>init</strong> (<a class="reference internal" href="optimization/optimization.html#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><em>Initializer</em></a>) – initializer applied to unpacked weights. Fall back to global |
| initializer if None.</li> |
| <li><strong>num_hidden</strong> (<em>int</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>num_layers</strong> (<em>int</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>mode</strong> (<em>str</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>bidirectional</strong> (<em>bool</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>forget_bias</strong> (<em>float</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <script>auto_index("initializer-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="evaluation-metric-api-reference"> |
| <span id="evaluation-metric-api-reference"></span><h2>Evaluation Metric API Reference<a class="headerlink" href="#evaluation-metric-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.metric"></span><p>Online evaluation metric module.</p> |
| <dl class="class"> |
| <dt id="mxnet.metric.EvalMetric"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">EvalMetric</code><span class="sig-paren">(</span><em>name</em>, <em>output_names=None</em>, <em>label_names=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Base class for all evaluation metrics.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">This is a base class that provides common metric interfaces. |
| One should not use this class directly, but instead create new metric |
| classes that extend it.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.get_config"> |
| <code class="descname">get_config</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric.get_config"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric.get_config" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Save configurations of metric. Can be recreated |
| from configs with metric.create(<a href="#id1"><span class="problematic" id="id2">**</span></a>config)</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.update_dict"> |
| <code class="descname">update_dict</code><span class="sig-paren">(</span><em>label</em>, <em>pred</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric.update_dict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric.update_dict" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Update the internal evaluation with named label and pred</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (<em>OrderedDict of str -> NDArray</em>) – name to array mapping for labels.</li> |
| <li><strong>preds</strong> (<em>list of NDArray</em>) – name to array mapping of predicted outputs.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Resets the internal evaluation result to initial state.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric.get"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Gets the current evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body"><ul class="simple"> |
| <li><strong>names</strong> (<em>list of str</em>) – |
| Name of the metrics.</li> |
| <li><strong>values</strong> (<em>list of float</em>) – |
| Value of the evaluations.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.get_name_value"> |
| <code class="descname">get_name_value</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#EvalMetric.get_name_value"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.EvalMetric.get_name_value" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns zipped name and value pairs.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">A (name, value) tuple list.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">list of tuples</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.metric.create"> |
| <code class="descclassname">mxnet.metric.</code><code class="descname">create</code><span class="sig-paren">(</span><em>metric</em>, <em>*args</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#create"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates evaluation metric from metric names or instances of EvalMetric |
| or a custom metric function.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>metric</strong> (<em>str or callable</em>) – <p>Specifies the metric to create. |
| This argument must be one of the below:</p> |
| <ul> |
| <li>Name of a metric.</li> |
| <li>An instance of <cite>EvalMetric</cite>.</li> |
| <li>A list, each element of which is a metric or a metric name.</li> |
| <li>An evaluation function that computes custom metric for a given batch of |
| labels and predictions.</li> |
| </ul> |
| </li> |
| <li><strong>*args</strong> – <p>Additional arguments to metric constructor. |
| Only used when metric is str.</p> |
| </li> |
| <li><strong>**kwargs</strong> – <p>Additional arguments to metric constructor. |
| Only used when metric is str</p> |
| </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">custom_metric</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">pred</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">label</span> <span class="o">-</span> <span class="n">pred</span><span class="p">))</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">metric1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'acc'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">metric2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">custom_metric</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">metric3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">([</span><span class="n">metric1</span><span class="p">,</span> <span class="n">metric2</span><span class="p">,</span> <span class="s1">'rmse'</span><span class="p">])</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.CompositeEvalMetric"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">CompositeEvalMetric</code><span class="sig-paren">(</span><em>metrics=None</em>, <em>name='composite'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CompositeEvalMetric"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Manages multiple evaluation metrics.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>metrics</strong> (<em>list of EvalMetric</em>) – List of child metrics.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics_1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics_2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">F1</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">CompositeEvalMetric</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">child_metric</span> <span class="ow">in</span> <span class="p">[</span><span class="n">eval_metrics_1</span><span class="p">,</span> <span class="n">eval_metrics_2</span><span class="p">]:</span> |
| <span class="gp">>>> </span> <span class="n">eval_metrics</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">child_metric</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">eval_metrics</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">(['accuracy', 'f1'], [0.6666666666666666, 0.8])</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.add"> |
| <code class="descname">add</code><span class="sig-paren">(</span><em>metric</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CompositeEvalMetric.add"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.add" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Adds a child metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>metric</strong> – A metric instance.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.get_metric"> |
| <code class="descname">get_metric</code><span class="sig-paren">(</span><em>index</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CompositeEvalMetric.get_metric"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.get_metric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns a child metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>index</strong> (<em>int</em>) – Index of child metric in the list of metrics.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CompositeEvalMetric.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CompositeEvalMetric.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Resets the internal evaluation result to initial state.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CompositeEvalMetric.get"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns the current evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body"><ul class="simple"> |
| <li><strong>names</strong> (<em>list of str</em>) – |
| Name of the metrics.</li> |
| <li><strong>values</strong> (<em>list of float</em>) – |
| Value of the evaluations.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Accuracy"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Accuracy</code><span class="sig-paren">(</span><em>axis=1</em>, <em>name='accuracy'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Accuracy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Accuracy" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes accuracy classification score.</p> |
| <p>The accuracy score is defined as</p> |
| <div class="math"> |
| \[\text{accuracy}(y, \hat{y}) = \frac{1}{n} \sum_{i=0}^{n-1} |
| \text{1}(\hat{y_i} == y_i)\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>axis</strong> (<em>int, default=1</em>) – The axis that represents classes</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">acc</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('accuracy', 0.6666666666666666)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.Accuracy.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Accuracy.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Accuracy.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data with class indices as values, one per sample.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Prediction values for samples. Each prediction value can either be the class index, |
| or a vector of likelihoods for all classes.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.TopKAccuracy"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">TopKAccuracy</code><span class="sig-paren">(</span><em>top_k=1</em>, <em>name='top_k_accuracy'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#TopKAccuracy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.TopKAccuracy" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes top k predictions accuracy.</p> |
| <p><cite>TopKAccuracy</cite> differs from Accuracy in that it considers the prediction |
| to be <code class="docutils literal"><span class="pre">True</span></code> as long as the ground truth label is in the top K |
| predicated labels.</p> |
| <p>If <cite>top_k</cite> = <code class="docutils literal"><span class="pre">1</span></code>, then <cite>TopKAccuracy</cite> is identical to <cite>Accuracy</cite>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>top_k</strong> (<em>int</em>) – Whether targets are in top k predictions.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">999</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">top_k</span> <span class="o">=</span> <span class="mi">3</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">6</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">TopKAccuracy</span><span class="p">(</span><span class="n">top_k</span><span class="o">=</span><span class="n">top_k</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">acc</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('top_k_accuracy', 0.3)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.TopKAccuracy.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#TopKAccuracy.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.TopKAccuracy.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.F1"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">F1</code><span class="sig-paren">(</span><em>name='f1'</em>, <em>output_names=None</em>, <em>label_names=None</em>, <em>average='macro'</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#F1"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.F1" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes the F1 score of a binary classification problem.</p> |
| <p>The F1 score is equivalent to weighted average of the precision and recall, |
| where the best value is 1.0 and the worst value is 0.0. The formula for F1 score is:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">F1</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">precision</span> <span class="o">*</span> <span class="n">recall</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">precision</span> <span class="o">+</span> <span class="n">recall</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The formula for precision and recall is:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">precision</span> <span class="o">=</span> <span class="n">true_positives</span> <span class="o">/</span> <span class="p">(</span><span class="n">true_positives</span> <span class="o">+</span> <span class="n">false_positives</span><span class="p">)</span> |
| <span class="n">recall</span> <span class="o">=</span> <span class="n">true_positives</span> <span class="o">/</span> <span class="p">(</span><span class="n">true_positives</span> <span class="o">+</span> <span class="n">false_negatives</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">This F1 score only supports binary classification.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| <li><strong>average</strong> (<em>str, default 'macro'</em>) – Strategy to be used for aggregating across mini-batches. |
| “macro”: average the F1 scores for each batch. |
| “micro”: compute a single F1 score across all batches.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">f1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">F1</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">f1</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">f1</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('f1', 0.8)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.F1.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#F1.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.F1.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.F1.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#F1.reset"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.F1.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Resets the internal evaluation result to initial state.</p> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Perplexity"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Perplexity</code><span class="sig-paren">(</span><em>ignore_label</em>, <em>axis=-1</em>, <em>name='perplexity'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Perplexity"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Perplexity" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes perplexity.</p> |
| <p>Perplexity is a measurement of how well a probability distribution |
| or model predicts a sample. A low perplexity indicates the model |
| is good at predicting the sample.</p> |
| <p>The perplexity of a model q is defined as</p> |
| <div class="math"> |
| \[b^{\big(-\frac{1}{N} \sum_{i=1}^N \log_b q(x_i) \big)} |
| = \exp \big(-\frac{1}{N} \sum_{i=1}^N \log q(x_i)\big)\]</div> |
| <p>where we let <cite>b = e</cite>.</p> |
| <p><span class="math">\(q(x_i)\)</span> is the predicted value of its ground truth |
| label on sample <span class="math">\(x_i\)</span>.</p> |
| <p>For example, we have three samples <span class="math">\(x_1, x_2, x_3\)</span> and their labels |
| are <span class="math">\([0, 1, 1]\)</span>. |
| Suppose our model predicts <span class="math">\(q(x_1) = p(y_1 = 0 | x_1) = 0.3\)</span> |
| and <span class="math">\(q(x_2) = 1.0\)</span>, |
| <span class="math">\(q(x_3) = 0.6\)</span>. The perplexity of model q is |
| <span class="math">\(exp\big(-(\log 0.3 + \log 1.0 + \log 0.6) / 3\big) = 1.77109762852\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>ignore_label</strong> (<em>int or None</em>) – Index of invalid label to ignore when |
| counting. By default, sets to -1. |
| If set to <cite>None</cite>, it will include all entries.</li> |
| <li><strong>axis</strong> (<em>int (default -1)</em>) – The axis from prediction that was used to |
| compute softmax. By default use the last |
| axis.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">perp</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Perplexity</span><span class="p">(</span><span class="n">ignore_label</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">perp</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">perp</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('Perplexity', 1.7710976285155853)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.Perplexity.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Perplexity.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Perplexity.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.Perplexity.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Perplexity.get"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Perplexity.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns the current evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">Representing name of the metric and evaluation result.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">Tuple of (str, float)</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.MAE"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">MAE</code><span class="sig-paren">(</span><em>name='mae'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#MAE"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.MAE" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Mean Absolute Error (MAE) loss.</p> |
| <p>The mean absolute error is given by</p> |
| <div class="math"> |
| \[\frac{\sum_i^n |y_i - \hat{y}_i|}{n}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">mean_absolute_error</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">MAE</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">mean_absolute_error</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">mean_absolute_error</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('mae', 0.5)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.MAE.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#MAE.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.MAE.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.MSE"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">MSE</code><span class="sig-paren">(</span><em>name='mse'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#MSE"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.MSE" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Mean Squared Error (MSE) loss.</p> |
| <p>The mean squared error is given by</p> |
| <div class="math"> |
| \[\frac{\sum_i^n (y_i - \hat{y}_i)^2}{n}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">mean_squared_error</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">MSE</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">mean_squared_error</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">mean_squared_error</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('mse', 0.375)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.MSE.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#MSE.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.MSE.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.RMSE"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">RMSE</code><span class="sig-paren">(</span><em>name='rmse'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#RMSE"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.RMSE" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Root Mean Squred Error (RMSE) loss.</p> |
| <p>The root mean squared error is given by</p> |
| <div class="math"> |
| \[\sqrt{\frac{\sum_i^n (y_i - \hat{y}_i)^2}{n}}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">root_mean_squared_error</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">RMSE</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">root_mean_squared_error</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">root_mean_squared_error</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('rmse', 0.612372457981)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.RMSE.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#RMSE.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.RMSE.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.CrossEntropy"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">CrossEntropy</code><span class="sig-paren">(</span><em>eps=1e-12</em>, <em>name='cross-entropy'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CrossEntropy"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CrossEntropy" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Cross Entropy loss.</p> |
| <p>The cross entropy over a batch of sample size <span class="math">\(N\)</span> is given by</p> |
| <div class="math"> |
| \[-\sum_{n=1}^{N}\sum_{k=1}^{K}t_{nk}\log (y_{nk}),\]</div> |
| <p>where <span class="math">\(t_{nk}=1\)</span> if and only if sample <span class="math">\(n\)</span> belongs to class <span class="math">\(k\)</span>. |
| <span class="math">\(y_{nk}\)</span> denotes the probability of sample <span class="math">\(n\)</span> belonging to |
| class <span class="math">\(k\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>eps</strong> (<em>float</em>) – Cross Entropy loss is undefined for predicted value is 0 or 1, |
| so predicted values are added with the small constant.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">ce</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">CrossEntropy</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">ce</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">ce</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('cross-entropy', 0.57159948348999023)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.CrossEntropy.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CrossEntropy.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CrossEntropy.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.NegativeLogLikelihood"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">NegativeLogLikelihood</code><span class="sig-paren">(</span><em>eps=1e-12</em>, <em>name='nll-loss'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#NegativeLogLikelihood"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.NegativeLogLikelihood" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes the negative log-likelihood loss.</p> |
| <p>The negative log-likelihoodd loss over a batch of sample size <span class="math">\(N\)</span> is given by</p> |
| <div class="math"> |
| \[-\sum_{n=1}^{N}\sum_{k=1}^{K}t_{nk}\log (y_{nk}),\]</div> |
| <p>where <span class="math">\(K\)</span> is the number of classes, <span class="math">\(y_{nk}\)</span> is the prediceted probability for |
| <span class="math">\(k\)</span>-th class for <span class="math">\(n\)</span>-th sample. <span class="math">\(t_{nk}=1\)</span> if and only if sample |
| <span class="math">\(n\)</span> belongs to class <span class="math">\(k\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>eps</strong> (<em>float</em>) – Negative log-likelihood loss is undefined for predicted value is 0, |
| so predicted values are added with the small constant.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">nll_loss</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">NegativeLogLikelihood</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">nll_loss</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">nll_loss</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('nll-loss', 0.57159948348999023)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.NegativeLogLikelihood.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#NegativeLogLikelihood.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.NegativeLogLikelihood.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.PearsonCorrelation"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">PearsonCorrelation</code><span class="sig-paren">(</span><em>name='pearsonr'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#PearsonCorrelation"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.PearsonCorrelation" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Pearson correlation.</p> |
| <p>The pearson correlation is given by</p> |
| <div class="math"> |
| \[\frac{cov(y, \hat{y})}{\sigma{y}\sigma{\hat{y}}}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">pr</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">PearsonCorrelation</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">pr</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">pr</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('pearson-correlation', 0.42163704544016178)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.PearsonCorrelation.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#PearsonCorrelation.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.PearsonCorrelation.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Loss"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Loss</code><span class="sig-paren">(</span><em>name='loss'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Loss"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Loss" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Dummy metric for directly printing loss.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Torch"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Torch</code><span class="sig-paren">(</span><em>name='torch'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Torch"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Torch" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Dummy metric for torch criterions.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Caffe"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Caffe</code><span class="sig-paren">(</span><em>name='caffe'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#Caffe"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.Caffe" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Dummy metric for caffe criterions.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.CustomMetric"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">CustomMetric</code><span class="sig-paren">(</span><em>feval</em>, <em>name=None</em>, <em>allow_extra_outputs=False</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CustomMetric"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CustomMetric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes a customized evaluation metric.</p> |
| <p>The <cite>feval</cite> function can return a <cite>tuple</cite> of (sum_metric, num_inst) or return |
| an <cite>int</cite> sum_metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>feval</strong> (<em>callable(label, pred)</em>) – Customized evaluation function.</li> |
| <li><strong>name</strong> (<em>str</em>) – The name of the metric. (the default is None).</li> |
| <li><strong>allow_extra_outputs</strong> (<em>bool, optional</em>) – If true, the prediction outputs can have extra outputs. |
| This is useful in RNN, where the states are also produced |
| in outputs for forwarding. (the default is False).</li> |
| <li><strong>name</strong> – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">feval</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="p">:</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">CustomMetric</span><span class="p">(</span><span class="n">feval</span><span class="o">=</span><span class="n">feval</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">eval_metrics</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('custom(<lambda>)', 6.0)</lambda></span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.CustomMetric.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#CustomMetric.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.CustomMetric.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.metric.np"> |
| <code class="descclassname">mxnet.metric.</code><code class="descname">np</code><span class="sig-paren">(</span><em>numpy_feval</em>, <em>name=None</em>, <em>allow_extra_outputs=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/metric.html#np"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.metric.np" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates a custom evaluation metric that receives its inputs as numpy arrays.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>numpy_feval</strong> (<em>callable(label, pred)</em>) – Custom evaluation function that receives labels and predictions for a minibatch |
| as numpy arrays and returns the corresponding custom metric as a floating point number.</li> |
| <li><strong>name</strong> (<em>str, optional</em>) – Name of the custom metric.</li> |
| <li><strong>allow_extra_outputs</strong> (<em>bool, optional</em>) – Whether prediction output is allowed to have extra outputs. This is useful in cases |
| like RNN where states are also part of output which can then be fed back to the RNN |
| in the next step. By default, extra outputs are not allowed.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">Custom metric corresponding to the provided labels and predictions.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">float</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">custom_metric</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">pred</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">label</span><span class="o">-</span><span class="n">pred</span><span class="p">))</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">metric</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">np</span><span class="p">(</span><span class="n">custom_metric</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <script>auto_index("evaluation-metric-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="optimizer-api-reference"> |
| <span id="optimizer-api-reference"></span><h2>Optimizer API Reference<a class="headerlink" href="#optimizer-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.optimizer"></span><p>Weight updating functions.</p> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Optimizer"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Optimizer</code><span class="sig-paren">(</span><em>rescale_grad=1.0</em>, <em>param_idx2name=None</em>, <em>wd=0.0</em>, <em>clip_gradient=None</em>, <em>learning_rate=0.01</em>, <em>lr_scheduler=None</em>, <em>sym=None</em>, <em>begin_num_update=0</em>, <em>multi_precision=False</em>, <em>param_dict=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The base class inherited by all optimizers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rescale_grad</strong> (<em>float, optional</em>) – Multiply the gradient with <cite>rescale_grad</cite> before updating. Often |
| choose to be <code class="docutils literal"><span class="pre">1.0/batch_size</span></code>.</li> |
| <li><strong>param_idx2name</strong> (<em>dict from int to string, optional</em>) – A dictionary that maps int index to string name.</li> |
| <li><strong>clip_gradient</strong> (<em>float, optional</em>) – Clip the gradient by projecting onto the box <code class="docutils literal"><span class="pre">[-clip_gradient,</span> <span class="pre">clip_gradient]</span></code>.</li> |
| <li><strong>learning_rate</strong> (<em>float</em>) – The initial learning rate.</li> |
| <li><strong>lr_scheduler</strong> (<em>LRScheduler, optional</em>) – The learning rate scheduler.</li> |
| <li><strong>wd</strong> (<em>float, optional</em>) – The weight decay (or L2 regularization) coefficient. Modifies objective |
| by adding a penalty for having large weights.</li> |
| <li><strong>sym</strong> (<em>Symbol, optional</em>) – The Symbol this optimizer is applying to.</li> |
| <li><strong>begin_num_update</strong> (<em>int, optional</em>) – The initial number of updates.</li> |
| <li><strong>multi_precision</strong> (<em>bool, optional</em>) – Flag to control the internal precision of the optimizer. |
| <code class="docutils literal"><span class="pre">False</span></code> results in using the same precision as the weights (default), |
| <code class="docutils literal"><span class="pre">True</span></code> makes internal 32-bit copy of the weights and applies gradients |
| in 32-bit precision even if actual weights used in the model have lower precision. |
| Turning this on can improve convergence and accuracy when training with float16.</li> |
| <li><strong>Properties</strong> – </li> |
| <li><strong>----------</strong> – </li> |
| <li><strong>learning_rate</strong> – The current learning rate of the optimizer. Given an Optimizer object |
| optimizer, its learning rate can be accessed as optimizer.learning_rate.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="staticmethod"> |
| <dt id="mxnet.optimizer.Optimizer.register"> |
| <em class="property">static </em><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.register"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a new optimizer.</p> |
| <p>Once an optimizer is registered, we can create an instance of this |
| optimizer with <cite>create_optimizer</cite> later.</p> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="nd">@mx.optimizer.Optimizer.register</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">MyOptimizer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">pass</span> |
| <span class="gp">>>> </span><span class="n">optim</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'MyOptimizer'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">optim</span><span class="p">))</span> |
| <span class="go"><class '__main__.myoptimizer'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.optimizer.Optimizer.create_optimizer"> |
| <em class="property">static </em><code class="descname">create_optimizer</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.create_optimizer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_optimizer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Instantiates an optimizer with a given name and kwargs.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">We can use the alias <cite>create</cite> for <code class="docutils literal"><span class="pre">Optimizer.create_optimizer</span></code>.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of the optimizer. Should be the name |
| of a subclass of Optimizer. Case insensitive.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Parameters for the optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">An instantiated optimizer.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer">Optimizer</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">sgd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'sgd'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">sgd</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.sgd'=""></class></span> |
| <span class="gp">>>> </span><span class="n">adam</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mi">1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">adam</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.adam'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.create_state"> |
| <code class="descname">create_state</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.create_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates auxiliary state for a given weight.</p> |
| <p>Some optimizers require additional states, e.g. as momentum, in addition |
| to gradients in order to update weights. This function creates state |
| for a given weight which will be used in <cite>update</cite>. This function is |
| called only once for each weight.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>index</strong> (<em>int</em>) – An unique index to identify the weight.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The weight.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>state</strong> – |
| The state associated with the weight.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">any obj</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.create_state_multi_precision"> |
| <code class="descname">create_state_multi_precision</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.create_state_multi_precision"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_state_multi_precision" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates auxiliary state for a given weight, including FP32 high |
| precision copy if original weight is FP16.</p> |
| <p>This method is provided to perform automatic mixed precision training |
| for optimizers that do not support it themselves.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>index</strong> (<em>int</em>) – An unique index to identify the weight.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The weight.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>state</strong> – |
| The state associated with the weight.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">any obj</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the given parameter using the corresponding gradient and state.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>index</strong> (<em>int</em>) – The unique index of the parameter into the individual learning |
| rates and weight decays. Learning rates and weight decay |
| may be set via <cite>set_lr_mult()</cite> and <cite>set_wd_mult()</cite>, respectively.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The parameter to be updated.</li> |
| <li><strong>grad</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The gradient of the objective with respect to this parameter.</li> |
| <li><strong>state</strong> (<em>any obj</em>) – The state returned by <cite>create_state()</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.update_multi_precision"> |
| <code class="descname">update_multi_precision</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.update_multi_precision"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.update_multi_precision" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the given parameter using the corresponding gradient and state. |
| Mixed precision version.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>index</strong> (<em>int</em>) – The unique index of the parameter into the individual learning |
| rates and weight decays. Learning rates and weight decay |
| may be set via <cite>set_lr_mult()</cite> and <cite>set_wd_mult()</cite>, respectively.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The parameter to be updated.</li> |
| <li><strong>grad</strong> (<a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The gradient of the objective with respect to this parameter.</li> |
| <li><strong>state</strong> (<em>any obj</em>) – The state returned by <cite>create_state()</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_learning_rate"> |
| <code class="descname">set_learning_rate</code><span class="sig-paren">(</span><em>lr</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.set_learning_rate"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_learning_rate" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets a new learning rate of the optimizer.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>lr</strong> (<em>float</em>) – The new learning rate of the optimizer.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_lr_scale"> |
| <code class="descname">set_lr_scale</code><span class="sig-paren">(</span><em>args_lrscale</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.set_lr_scale"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_lr_scale" title="Permalink to this definition">¶</a></dt> |
| <dd><p>[DEPRECATED] Sets lr scale. Use set_lr_mult instead.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_lr_mult"> |
| <code class="descname">set_lr_mult</code><span class="sig-paren">(</span><em>args_lr_mult</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.set_lr_mult"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_lr_mult" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets an individual learning rate multiplier for each parameter.</p> |
| <p>If you specify a learning rate multiplier for a parameter, then |
| the learning rate for the parameter will be set as the product of |
| the global learning rate <cite>self.lr</cite> and its multiplier.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The default learning rate multiplier of a <cite>Variable</cite> |
| can be set with <cite>lr_mult</cite> argument in the constructor.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args_lr_mult</strong> (<em>dict of str/int to float</em>) – <p>For each of its key-value entries, the learning rate multipler for the |
| parameter specified in the key will be set as the given value.</p> |
| <p>You can specify the parameter with either its name or its index. |
| If you use the name, you should pass <cite>sym</cite> in the constructor, |
| and the name you specified in the key of <cite>args_lr_mult</cite> should match |
| the name of the parameter in <cite>sym</cite>. If you use the index, it should |
| correspond to the index of the parameter used in the <cite>update</cite> method.</p> |
| <p>Specifying a parameter by its index is only supported for backward |
| compatibility, and we recommend to use the name instead.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_wd_mult"> |
| <code class="descname">set_wd_mult</code><span class="sig-paren">(</span><em>args_wd_mult</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Optimizer.set_wd_mult"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_wd_mult" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets an individual weight decay multiplier for each parameter.</p> |
| <p>By default, if <cite>param_idx2name</cite> was provided in the |
| constructor, the weight decay multipler is set as 0 for all |
| parameters whose name don’t end with <code class="docutils literal"><span class="pre">_weight</span></code> or |
| <code class="docutils literal"><span class="pre">_gamma</span></code>.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The default weight decay multiplier for a <cite>Variable</cite> |
| can be set with its <cite>wd_mult</cite> argument in the constructor.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args_wd_mult</strong> (<em>dict of string/int to float</em>) – <p>For each of its key-value entries, the weight decay multipler for the |
| parameter specified in the key will be set as the given value.</p> |
| <p>You can specify the parameter with either its name or its index. |
| If you use the name, you should pass <cite>sym</cite> in the constructor, |
| and the name you specified in the key of <cite>args_lr_mult</cite> should match |
| the name of the parameter in <cite>sym</cite>. If you use the index, it should |
| correspond to the index of the parameter used in the <cite>update</cite> method.</p> |
| <p>Specifying a parameter by its index is only supported for backward |
| compatibility, and we recommend to use the name instead.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.register"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a new optimizer.</p> |
| <p>Once an optimizer is registered, we can create an instance of this |
| optimizer with <cite>create_optimizer</cite> later.</p> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="nd">@mx.optimizer.Optimizer.register</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">MyOptimizer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">pass</span> |
| <span class="gp">>>> </span><span class="n">optim</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'MyOptimizer'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">optim</span><span class="p">))</span> |
| <span class="go"><class '__main__.myoptimizer'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.SGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">SGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>lazy_update=True</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#SGD"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.SGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The SGD optimizer with momentum and weight decay.</p> |
| <p>If the storage types of weight and grad are both <code class="docutils literal"><span class="pre">row_sparse</span></code>, and <code class="docutils literal"><span class="pre">lazy_update</span></code> is True, <strong>lazy updates</strong> are applied by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">grad</span><span class="o">.</span><span class="n">indices</span><span class="p">:</span> |
| <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">*</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">[</span><span class="n">row</span><span class="p">],</span> <span class="n">clip_gradient</span><span class="p">)</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> |
| <span class="n">state</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">momentum</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">*</span> <span class="n">state</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> |
| <span class="n">weight</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">weight</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">-</span> <span class="n">state</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| <p>The sparse update only updates the momentum for the weights whose row_sparse |
| gradient indices appear in the current batch, rather than updating it for all |
| indices. Compared with the original update, it can provide large |
| improvements in model training throughput for some applications. However, it |
| provides slightly different semantics than the original update, and |
| may lead to different empirical results.</p> |
| <p>Otherwise, <strong>standard updates</strong> are applied by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">rescaled_grad</span> <span class="o">=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">*</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">rescaled_grad</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="n">state</span> |
| </pre></div> |
| </div> |
| <p>For details of the update algorithm see |
| <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.sgd_update" title="mxnet.ndarray.sgd_update"><code class="xref py py-class docutils literal"><span class="pre">sgd_update</span></code></a> and <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.sgd_mom_update" title="mxnet.ndarray.sgd_mom_update"><code class="xref py py-class docutils literal"><span class="pre">sgd_mom_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>lazy_update</strong> (<em>bool, optional</em>) – Default is True. If True, lazy updates are applied if the storage types of weight and grad are both <code class="docutils literal"><span class="pre">row_sparse</span></code>.</li> |
| <li><strong>multi_precision</strong> (<em>bool, optional</em>) – Flag to control the internal precision of the optimizer. |
| <code class="docutils literal"><span class="pre">False</span></code> results in using the same precision as the weights (default), |
| <code class="docutils literal"><span class="pre">True</span></code> makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Signum"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Signum</code><span class="sig-paren">(</span><em>learning_rate=0.01</em>, <em>momentum=0.9</em>, <em>wd_lh=0.0</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Signum"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Signum" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Signum optimizer that takes the sign of gradient or momentum.</p> |
| <p>The optimizer updates the weight by:</p> |
| <blockquote> |
| <div>rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight |
| state = momentum * state + (1-momentum)*rescaled_grad |
| weight = (1 - lr * wd_lh) * weight - lr * sign(state)</div></blockquote> |
| <p>See the original paper at: <a class="reference external" href="https://jeremybernste.in/projects/amazon/signum.pdf">https://jeremybernste.in/projects/amazon/signum.pdf</a></p> |
| <p>For details of the update algorithm see |
| <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.signsgd_update" title="mxnet.ndarray.signsgd_update"><code class="xref py py-class docutils literal"><span class="pre">signsgd_update</span></code></a> and <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.signum_update" title="mxnet.ndarray.signum_update"><code class="xref py py-class docutils literal"><span class="pre">signum_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>wd_lh</strong> (<em>float, optional</em>) – The amount of decoupled weight decay regularization, see details in the original paper at: <a class="reference external" href="https://arxiv.org/abs/1711.05101">https://arxiv.org/abs/1711.05101</a></li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.FTML"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">FTML</code><span class="sig-paren">(</span><em>beta1=0.6</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#FTML"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.FTML" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The FTML optimizer.</p> |
| <p>This class implements the optimizer described in |
| <em>FTML - Follow the Moving Leader in Deep Learning</em>, |
| available at <a class="reference external" href="http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf">http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – 0 < beta1 < 1. Generally close to 0.5.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – 0 < beta2 < 1. Generally close to 1.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.LBSGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">LBSGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>multi_precision=False</em>, <em>warmup_strategy='linear'</em>, <em>warmup_epochs=5</em>, <em>batch_scale=1</em>, <em>updates_per_epoch=32</em>, <em>begin_epoch=0</em>, <em>num_epochs=60</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#LBSGD"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.LBSGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Large Batch SGD optimizer with momentum and weight decay.</p> |
| <p>The optimizer updates the weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">*</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="n">state</span> |
| </pre></div> |
| </div> |
| <p>For details of the update algorithm see <code class="xref py py-class docutils literal"><span class="pre">lbsgd_update</span></code> and |
| <code class="xref py py-class docutils literal"><span class="pre">lbsgd_mom_update</span></code>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>multi_precision</strong> (<em>bool, optional</em>) – <p>Flag to control the internal precision of the optimizer. |
| <code class="docutils literal"><span class="pre">False</span></code> results in using the same precision as the weights (default), |
| <code class="docutils literal"><span class="pre">True</span></code> makes internal 32-bit copy of the weights and applies gradients</p> |
| <blockquote> |
| <div>in 32-bit precision even if actual weights used in the model have lower precision.`< |
| Turning this on can improve convergence and accuracy when training with float16.</div></blockquote> |
| </li> |
| <li><strong>warmup_strategy</strong> (<em>string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear')</em>) – </li> |
| <li><strong>warmup_epochs</strong> (<em>unsigned, default: 5</em>) – </li> |
| <li><strong>batch_scale</strong> (<em>unsigned, default: 1 (same as batch size*numworkers)</em>) – </li> |
| <li><strong>updates_per_epoch</strong> (<em>updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)</em>) – </li> |
| <li><strong>begin_epoch</strong> (<em>unsigned, default 0, starting epoch.</em>) – </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.DCASGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">DCASGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>lamda=0.04</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#DCASGD"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.DCASGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The DCASGD optimizer.</p> |
| <p>This class implements the optimizer described in <em>Asynchronous Stochastic Gradient Descent |
| with Delay Compensation for Distributed Deep Learning</em>, |
| available at <a class="reference external" href="https://arxiv.org/abs/1609.08326">https://arxiv.org/abs/1609.08326</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>lamda</strong> (<em>float, optional</em>) – Scale DC value.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.NAG"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">NAG</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#NAG"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.NAG" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Nesterov accelerated SGD.</p> |
| <p>This optimizer updates each weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">grad</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="p">(</span><span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="n">grad</span> <span class="o">+</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>multi_precision</strong> (<em>bool, optional</em>) – Flag to control the internal precision of the optimizer. |
| <code class="docutils literal"><span class="pre">False</span></code> results in using the same precision as the weights (default), |
| <code class="docutils literal"><span class="pre">True</span></code> makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.SGLD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">SGLD</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#SGLD"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.SGLD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Stochastic Gradient Riemannian Langevin Dynamics.</p> |
| <p>This class implements the optimizer described in the paper <em>Stochastic Gradient |
| Riemannian Langevin Dynamics on the Probability Simplex</em>, available at |
| <a class="reference external" href="https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf">https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf</a>.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.ccSGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">ccSGD</code><span class="sig-paren">(</span><em>*args</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#ccSGD"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.ccSGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>[DEPRECATED] Same as <cite>SGD</cite>. Left here for backward compatibility.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Adam"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Adam</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>lazy_update=True</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Adam"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Adam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Adam optimizer.</p> |
| <p>This class implements the optimizer described in <em>Adam: A Method for |
| Stochastic Optimization</em>, available at <a class="reference external" href="http://arxiv.org/abs/1412.6980">http://arxiv.org/abs/1412.6980</a>.</p> |
| <p>If the storage types of weight and grad are both <code class="docutils literal"><span class="pre">row_sparse</span></code>, and <code class="docutils literal"><span class="pre">lazy_update</span></code> is True, <strong>lazy updates</strong> are applied by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">grad</span><span class="o">.</span><span class="n">indices</span><span class="p">:</span> |
| <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span><span class="p">[</span><span class="n">row</span><span class="p">],</span> <span class="n">clip_gradient</span><span class="p">)</span> |
| <span class="n">m</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">beta1</span> <span class="o">*</span> <span class="n">m</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta1</span><span class="p">)</span> <span class="o">*</span> <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> |
| <span class="n">v</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">beta2</span> <span class="o">*</span> <span class="n">v</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta2</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> |
| <span class="n">w</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">w</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">m</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">/</span> <span class="p">(</span><span class="n">sqrt</span><span class="p">(</span><span class="n">v</span><span class="p">[</span><span class="n">row</span><span class="p">])</span> <span class="o">+</span> <span class="n">epsilon</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The lazy update only updates the mean and var for the weights whose row_sparse |
| gradient indices appear in the current batch, rather than updating it for all indices. |
| Compared with the original update, it can provide large improvements in model training |
| throughput for some applications. However, it provides slightly different semantics than |
| the original update, and may lead to different empirical results.</p> |
| <p>Otherwise, <strong>standard updates</strong> are applied by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">rescaled_grad</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> |
| <span class="n">m</span> <span class="o">=</span> <span class="n">beta1</span> <span class="o">*</span> <span class="n">m</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta1</span><span class="p">)</span> <span class="o">*</span> <span class="n">rescaled_grad</span> |
| <span class="n">v</span> <span class="o">=</span> <span class="n">beta2</span> <span class="o">*</span> <span class="n">v</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta2</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">rescaled_grad</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> |
| <span class="n">w</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">m</span> <span class="o">/</span> <span class="p">(</span><span class="n">sqrt</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="o">+</span> <span class="n">epsilon</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <p>For details of the update algorithm, see <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.adam_update" title="mxnet.ndarray.adam_update"><code class="xref py py-class docutils literal"><span class="pre">adam_update</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>lazy_update</strong> (<em>bool, optional</em>) – Default is True. If True, lazy updates are applied if the storage types of weight and grad are both <code class="docutils literal"><span class="pre">row_sparse</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.AdaGrad"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">AdaGrad</code><span class="sig-paren">(</span><em>eps=1e-07</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#AdaGrad"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.AdaGrad" title="Permalink to this definition">¶</a></dt> |
| <dd><p>AdaGrad optimizer.</p> |
| <p>This class implements the AdaGrad optimizer described in <em>Adaptive Subgradient |
| Methods for Online Learning and Stochastic Optimization</em>, and available at |
| <a class="reference external" href="http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf">http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>eps</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.RMSProp"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">RMSProp</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>gamma1=0.9</em>, <em>gamma2=0.9</em>, <em>epsilon=1e-08</em>, <em>centered=False</em>, <em>clip_weights=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#RMSProp"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.RMSProp" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The RMSProp optimizer.</p> |
| <p>Two versions of RMSProp are implemented:</p> |
| <p>If <code class="docutils literal"><span class="pre">centered=False</span></code>, we follow |
| <a class="reference external" href="http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf</a> by |
| Tieleman & Hinton, 2012. |
| For details of the update algorithm see <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.rmsprop_update" title="mxnet.ndarray.rmsprop_update"><code class="xref py py-class docutils literal"><span class="pre">rmsprop_update</span></code></a>.</p> |
| <p>If <code class="docutils literal"><span class="pre">centered=True</span></code>, we follow <a class="reference external" href="http://arxiv.org/pdf/1308.0850v5.pdf">http://arxiv.org/pdf/1308.0850v5.pdf</a> (38)-(45) |
| by Alex Graves, 2013. |
| For details of the update algorithm see <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.rmspropalex_update" title="mxnet.ndarray.rmspropalex_update"><code class="xref py py-class docutils literal"><span class="pre">rmspropalex_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>gamma1</strong> (<em>float, optional</em>) – A decay factor of moving average over past squared gradient.</li> |
| <li><strong>gamma2</strong> (<em>float, optional</em>) – A “momentum” factor. Only used if <cite>centered`=``True`</cite>.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>centered</strong> (<em>bool, optional</em>) – Flag to control which version of RMSProp to use. |
| <code class="docutils literal"><span class="pre">True</span></code> will use Graves’s version of <cite>RMSProp</cite>, |
| <code class="docutils literal"><span class="pre">False</span></code> will use Tieleman & Hinton’s version of <cite>RMSProp</cite>.</li> |
| <li><strong>clip_weights</strong> (<em>float, optional</em>) – Clips weights into range <code class="docutils literal"><span class="pre">[-clip_weights,</span> <span class="pre">clip_weights]</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.AdaDelta"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">AdaDelta</code><span class="sig-paren">(</span><em>rho=0.9</em>, <em>epsilon=1e-05</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#AdaDelta"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.AdaDelta" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The AdaDelta optimizer.</p> |
| <p>This class implements AdaDelta, an optimizer described in <em>ADADELTA: An adaptive |
| learning rate method</em>, available at <a class="reference external" href="https://arxiv.org/abs/1212.5701">https://arxiv.org/abs/1212.5701</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rho</strong> (<em>float</em>) – Decay rate for both squared gradients and delta.</li> |
| <li><strong>epsilon</strong> (<em>float</em>) – Small value to avoid division by 0.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Ftrl"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Ftrl</code><span class="sig-paren">(</span><em>lamda1=0.01</em>, <em>learning_rate=0.1</em>, <em>beta=1</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Ftrl"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Ftrl" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Ftrl optimizer.</p> |
| <p>Referenced from <em>Ad Click Prediction: a View from the Trenches</em>, available at |
| <a class="reference external" href="http://dl.acm.org/citation.cfm?id=2488200">http://dl.acm.org/citation.cfm?id=2488200</a>.</p> |
| <dl class="docutils"> |
| <dt>eta :</dt> |
| <dd><div class="first last math"> |
| \[\eta_{t,i} = \frac{learningrate}{\beta+\sqrt{\sum_{s=1}^tg_{s,i}^2}}\]</div> |
| </dd> |
| </dl> |
| <p>The optimizer updates the weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">rescaled_grad</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span> <span class="o">*</span> <span class="n">rescale_grad</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> |
| <span class="n">z</span> <span class="o">+=</span> <span class="n">rescaled_grad</span> <span class="o">-</span> <span class="p">(</span><span class="n">sqrt</span><span class="p">(</span><span class="n">n</span> <span class="o">+</span> <span class="n">rescaled_grad</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">n</span><span class="p">))</span> <span class="o">*</span> <span class="n">weight</span> <span class="o">/</span> <span class="n">learning_rate</span> |
| <span class="n">n</span> <span class="o">+=</span> <span class="n">rescaled_grad</span><span class="o">**</span><span class="mi">2</span> |
| <span class="n">w</span> <span class="o">=</span> <span class="p">(</span><span class="n">sign</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="o">*</span> <span class="n">lamda1</span> <span class="o">-</span> <span class="n">z</span><span class="p">)</span> <span class="o">/</span> <span class="p">((</span><span class="n">beta</span> <span class="o">+</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">n</span><span class="p">))</span> <span class="o">/</span> <span class="n">learning_rate</span> <span class="o">+</span> <span class="n">wd</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="nb">abs</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="o">></span> <span class="n">lamda1</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>If the storage types of weight, state and grad are all <code class="docutils literal"><span class="pre">row_sparse</span></code>, <strong>sparse updates</strong> are applied by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">grad</span><span class="o">.</span><span class="n">indices</span><span class="p">:</span> |
| <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">*</span> <span class="n">rescale_grad</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> |
| <span class="n">z</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+=</span> <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">-</span> <span class="p">(</span><span class="n">sqrt</span><span class="p">(</span><span class="n">n</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">n</span><span class="p">[</span><span class="n">row</span><span class="p">]))</span> <span class="o">*</span> <span class="n">weight</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">/</span> <span class="n">learning_rate</span> |
| <span class="n">n</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+=</span> <span class="n">rescaled_grad</span><span class="p">[</span><span class="n">row</span><span class="p">]</span><span class="o">**</span><span class="mi">2</span> |
| <span class="n">w</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">sign</span><span class="p">(</span><span class="n">z</span><span class="p">[</span><span class="n">row</span><span class="p">])</span> <span class="o">*</span> <span class="n">lamda1</span> <span class="o">-</span> <span class="n">z</span><span class="p">[</span><span class="n">row</span><span class="p">])</span> <span class="o">/</span> <span class="p">((</span><span class="n">beta</span> <span class="o">+</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">n</span><span class="p">[</span><span class="n">row</span><span class="p">]))</span> <span class="o">/</span> <span class="n">learning_rate</span> <span class="o">+</span> <span class="n">wd</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="nb">abs</span><span class="p">(</span><span class="n">z</span><span class="p">[</span><span class="n">row</span><span class="p">])</span> <span class="o">></span> <span class="n">lamda1</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The sparse update only updates the z and n for the weights whose row_sparse |
| gradient indices appear in the current batch, rather than updating it for all |
| indices. Compared with the original update, it can provide large |
| improvements in model training throughput for some applications. However, it |
| provides slightly different semantics than the original update, and |
| may lead to different empirical results.</p> |
| <p>For details of the update algorithm, see <a class="reference internal" href="ndarray/ndarray.html#mxnet.ndarray.ftrl_update" title="mxnet.ndarray.ftrl_update"><code class="xref py py-class docutils literal"><span class="pre">ftrl_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>lamda1</strong> (<em>float, optional</em>) – L1 regularization coefficient.</li> |
| <li><strong>learning_rate</strong> (<em>float, optional</em>) – The initial learning rate.</li> |
| <li><strong>beta</strong> (<em>float, optional</em>) – Per-coordinate learning rate correlation parameter.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Adamax"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Adamax</code><span class="sig-paren">(</span><em>learning_rate=0.002</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Adamax"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Adamax" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The AdaMax optimizer.</p> |
| <p>It is a variant of Adam based on the infinity norm |
| available at <a class="reference external" href="http://arxiv.org/abs/1412.6980">http://arxiv.org/abs/1412.6980</a> Section 7.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Nadam"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Nadam</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>schedule_decay=0.004</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Nadam"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Nadam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Nesterov Adam optimizer.</p> |
| <p>Much like Adam is essentially RMSprop with momentum, |
| Nadam is Adam RMSprop with Nesterov momentum available |
| at <a class="reference external" href="http://cs229.stanford.edu/proj2015/054_report.pdf">http://cs229.stanford.edu/proj2015/054_report.pdf</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>schedule_decay</strong> (<em>float, optional</em>) – Exponential decay rate for the momentum schedule</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Test"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Test</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Test"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Test" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Test optimizer</p> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Test.create_state"> |
| <code class="descname">create_state</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Test.create_state"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Test.create_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates a state to duplicate weight.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Test.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Test.update"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Test.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Performs w += rescale_grad * grad.</p> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.create"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">create</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Instantiates an optimizer with a given name and kwargs.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">We can use the alias <cite>create</cite> for <code class="docutils literal"><span class="pre">Optimizer.create_optimizer</span></code>.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of the optimizer. Should be the name |
| of a subclass of Optimizer. Case insensitive.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Parameters for the optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">An instantiated optimizer.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer">Optimizer</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">sgd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'sgd'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">sgd</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.sgd'=""></class></span> |
| <span class="gp">>>> </span><span class="n">adam</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mi">1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">adam</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.adam'=""></class></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Updater"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Updater</code><span class="sig-paren">(</span><em>optimizer</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Updater"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Updater" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updater for kvstore.</p> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>index</em>, <em>grad</em>, <em>weight</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Updater.__call__"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Updater.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates weight given gradient and index.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.sync_state_context"> |
| <code class="descname">sync_state_context</code><span class="sig-paren">(</span><em>state</em>, <em>context</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Updater.sync_state_context"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Updater.sync_state_context" title="Permalink to this definition">¶</a></dt> |
| <dd><p>sync state context.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.set_states"> |
| <code class="descname">set_states</code><span class="sig-paren">(</span><em>states</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Updater.set_states"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Updater.set_states" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets updater states.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.get_states"> |
| <code class="descname">get_states</code><span class="sig-paren">(</span><em>dump_optimizer=False</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#Updater.get_states"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.Updater.get_states" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Gets updater states.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>dump_optimizer</strong> (<em>bool, default False</em>) – Whether to also save the optimizer itself. This would also save optimizer |
| information such as learning rate and weight decay schedules.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.get_updater"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">get_updater</code><span class="sig-paren">(</span><em>optimizer</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/optimizer.html#get_updater"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.optimizer.get_updater" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns a closure of the updater needed for kvstore.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>optimizer</strong> (<a class="reference internal" href="optimization/optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><em>Optimizer</em></a>) – The optimizer.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>updater</strong> – |
| The closure of the updater.</td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">function</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <script>auto_index("optimizer-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="model-api-reference"> |
| <span id="model-api-reference"></span><h2>Model API Reference<a class="headerlink" href="#model-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.model"></span><p>MXNet model module</p> |
| <dl class="attribute"> |
| <dt id="mxnet.model.BatchEndParam"> |
| <code class="descclassname">mxnet.model.</code><code class="descname">BatchEndParam</code><a class="headerlink" href="#mxnet.model.BatchEndParam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>alias of <code class="xref py py-class docutils literal"><span class="pre">BatchEndParams</span></code></p> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.model.save_checkpoint"> |
| <code class="descclassname">mxnet.model.</code><code class="descname">save_checkpoint</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch</em>, <em>symbol</em>, <em>arg_params</em>, <em>aux_params</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#save_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.save_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Checkpoint the model data into file.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – The epoch number of the model.</li> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol/symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The input Symbol.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.model.load_checkpoint"> |
| <code class="descclassname">mxnet.model.</code><code class="descname">load_checkpoint</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#load_checkpoint"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.load_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Load model checkpoint from file.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – Epoch number of model we would like to load.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple"> |
| <li><strong>symbol</strong> (<em>Symbol</em>) – |
| The symbol configuration of computation network.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – |
| Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – |
| Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| </ul> |
| </p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li>Symbol will be loaded from <code class="docutils literal"><span class="pre">prefix-symbol.json</span></code>.</li> |
| <li>Parameters will be loaded from <code class="docutils literal"><span class="pre">prefix-epoch.params</span></code>.</li> |
| </ul> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.model.FeedForward"> |
| <em class="property">class </em><code class="descclassname">mxnet.model.</code><code class="descname">FeedForward</code><span class="sig-paren">(</span><em>symbol</em>, <em>ctx=None</em>, <em>num_epoch=None</em>, <em>epoch_size=None</em>, <em>optimizer='sgd'</em>, <em>initializer=<mxnet.initializer.uniform object=""></mxnet.initializer.uniform></em>, <em>numpy_batch_size=128</em>, <em>arg_params=None</em>, <em>aux_params=None</em>, <em>allow_extra_params=False</em>, <em>begin_epoch=0</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Model class of MXNet for training and predicting feedforward nets. |
| This class is designed for a single-data single output supervised network.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol/symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The symbol configuration of computation network.</li> |
| <li><strong>ctx</strong> (<em>Context or list of Context, optional</em>) – The device context of training and prediction. |
| To use multi GPU training, pass in a list of gpu contexts.</li> |
| <li><strong>num_epoch</strong> (<em>int, optional</em>) – Training parameter, number of training epochs(epochs).</li> |
| <li><strong>epoch_size</strong> (<em>int, optional</em>) – Number of batches in a epoch. In default, it is set to |
| <code class="docutils literal"><span class="pre">ceil(num_train_examples</span> <span class="pre">/</span> <span class="pre">batch_size)</span></code>.</li> |
| <li><strong>optimizer</strong> (<em>str or Optimizer, optional</em>) – Training parameter, name or optimizer object for training.</li> |
| <li><strong>initializer</strong> (<em>initializer function, optional</em>) – Training parameter, the initialization scheme used.</li> |
| <li><strong>numpy_batch_size</strong> (<em>int, optional</em>) – The batch size of training data. |
| Only needed when input array is numpy.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray, optional</em>) – Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray, optional</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| <li><strong>allow_extra_params</strong> (<em>boolean, optional</em>) – Whether allow extra parameters that are not needed by symbol |
| to be passed by aux_params and <code class="docutils literal"><span class="pre">arg_params</span></code>. |
| If this is True, no error will be thrown when <code class="docutils literal"><span class="pre">aux_params</span></code> and <code class="docutils literal"><span class="pre">arg_params</span></code> |
| contain more parameters than needed.</li> |
| <li><strong>begin_epoch</strong> (<em>int, optional</em>) – The begining training epoch.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – The additional keyword arguments passed to optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.predict"> |
| <code class="descname">predict</code><span class="sig-paren">(</span><em>X</em>, <em>num_batch=None</em>, <em>return_data=False</em>, <em>reset=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward.predict"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward.predict" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Run the prediction, always only use one device.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>X</strong> (<em>mxnet.DataIter</em>) – </li> |
| <li><strong>num_batch</strong> (<em>int or None</em>) – The number of batch to run. Go though all batches if <code class="docutils literal"><span class="pre">None</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>y</strong> – |
| The predicted value of the output.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.score"> |
| <code class="descname">score</code><span class="sig-paren">(</span><em>X</em>, <em>eval_metric='acc'</em>, <em>num_batch=None</em>, <em>batch_end_callback=None</em>, <em>reset=True</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward.score"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward.score" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Run the model given an input and calculate the score |
| as assessed by an evaluation metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>X</strong> (<em>mxnet.DataIter</em>) – </li> |
| <li><strong>eval_metric</strong> (<em>metric.metric</em>) – The metric for calculating score.</li> |
| <li><strong>num_batch</strong> (<em>int or None</em>) – The number of batches to run. Go though all batches if <code class="docutils literal"><span class="pre">None</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>s</strong> – |
| The final score.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">float</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.fit"> |
| <code class="descname">fit</code><span class="sig-paren">(</span><em>X</em>, <em>y=None</em>, <em>eval_data=None</em>, <em>eval_metric='acc'</em>, <em>epoch_end_callback=None</em>, <em>batch_end_callback=None</em>, <em>kvstore='local'</em>, <em>logger=None</em>, <em>work_load_list=None</em>, <em>monitor=None</em>, <em>eval_end_callback=<mxnet.callback.logvalidationmetricscallback object=""></mxnet.callback.logvalidationmetricscallback></em>, <em>eval_batch_end_callback=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward.fit"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward.fit" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Fit the model.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>X</strong> (<em>DataIter, or numpy.ndarray/NDArray</em>) – Training data. If <cite>X</cite> is a <cite>DataIter</cite>, the name or (if name not available) |
| the position of its outputs should match the corresponding variable |
| names defined in the symbolic graph.</li> |
| <li><strong>y</strong> (<em>numpy.ndarray/NDArray, optional</em>) – Training set label. |
| If X is <code class="docutils literal"><span class="pre">numpy.ndarray</span></code> or <cite>NDArray</cite>, <cite>y</cite> is required to be set. |
| While y can be 1D or 2D (with 2nd dimension as 1), its first dimension must be |
| the same as <cite>X</cite>, i.e. the number of data points and labels should be equal.</li> |
| <li><strong>eval_data</strong> (<em>DataIter or numpy.ndarray/list/NDArray pair</em>) – If eval_data is numpy.ndarray/list/NDArray pair, |
| it should be <code class="docutils literal"><span class="pre">(valid_data,</span> <span class="pre">valid_label)</span></code>.</li> |
| <li><strong>eval_metric</strong> (<em>metric.EvalMetric or str or callable</em>) – The evaluation metric. This could be the name of evaluation metric |
| or a custom evaluation function that returns statistics |
| based on a minibatch.</li> |
| <li><strong>epoch_end_callback</strong> (<em>callable(epoch, symbol, arg_params, aux_states)</em>) – A callback that is invoked at end of each epoch. |
| This can be used to checkpoint model each epoch.</li> |
| <li><strong>batch_end_callback</strong> (<em>callable(epoch)</em>) – A callback that is invoked at end of each batch for purposes of printing.</li> |
| <li><strong>kvstore</strong> (<em>KVStore or str, optional</em>) – The KVStore or a string kvstore type: ‘local’, ‘dist_sync’, ‘dist_async’ |
| In default uses ‘local’, often no need to change for single machiine.</li> |
| <li><strong>logger</strong> (<em>logging logger, optional</em>) – When not specified, default logger will be used.</li> |
| <li><strong>work_load_list</strong> (<em>float or int, optional</em>) – The list of work load for different devices, |
| in the same order as <cite>ctx</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">KVStore behavior |
| - ‘local’, multi-devices on a single machine, will automatically choose best type. |
| - ‘dist_sync’, multiple machines communicating via BSP. |
| - ‘dist_async’, multiple machines with asynchronous communication.</p> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.save"> |
| <code class="descname">save</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch=None</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward.save"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward.save" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Checkpoint the model checkpoint into file. |
| You can also use <cite>pickle</cite> to do the job if you only work on Python. |
| The advantage of <cite>load</cite> and <cite>save</cite> (as compared to <cite>pickle</cite>) is that |
| the resulting file can be loaded from other MXNet language bindings. |
| One can also directly <cite>load</cite>/<cite>save</cite> from/to cloud storage(S3, HDFS)</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.model.FeedForward.load"> |
| <em class="property">static </em><code class="descname">load</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch</em>, <em>ctx=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward.load"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward.load" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Load model checkpoint from file.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – epoch number of model we would like to load.</li> |
| <li><strong>ctx</strong> (<em>Context or list of Context, optional</em>) – The device context of training and prediction.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Other parameters for model, including <cite>num_epoch</cite>, optimizer and <cite>numpy_batch_size</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>model</strong> – |
| The loaded model that can be used for prediction.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="#mxnet.model.FeedForward" title="mxnet.model.FeedForward">FeedForward</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.model.FeedForward.create"> |
| <em class="property">static </em><code class="descname">create</code><span class="sig-paren">(</span><em>symbol</em>, <em>X</em>, <em>y=None</em>, <em>ctx=None</em>, <em>num_epoch=None</em>, <em>epoch_size=None</em>, <em>optimizer='sgd'</em>, <em>initializer=<mxnet.initializer.uniform object=""></mxnet.initializer.uniform></em>, <em>eval_data=None</em>, <em>eval_metric='acc'</em>, <em>epoch_end_callback=None</em>, <em>batch_end_callback=None</em>, <em>kvstore='local'</em>, <em>logger=None</em>, <em>work_load_list=None</em>, <em>eval_end_callback=<mxnet.callback.logvalidationmetricscallback object=""></mxnet.callback.logvalidationmetricscallback></em>, <em>eval_batch_end_callback=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/mxnet/model.html#FeedForward.create"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#mxnet.model.FeedForward.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Functional style to create a model. |
| This function is more consistent with functional |
| languages such as R, where mutation is not allowed.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol/symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The symbol configuration of a computation network.</li> |
| <li><strong>X</strong> (<a class="reference internal" href="io/io.html#mxnet.io.DataIter" title="mxnet.io.DataIter"><em>DataIter</em></a>) – Training data.</li> |
| <li><strong>y</strong> (<em>numpy.ndarray, optional</em>) – If <cite>X</cite> is a <code class="docutils literal"><span class="pre">numpy.ndarray</span></code>, <cite>y</cite> must be set.</li> |
| <li><strong>ctx</strong> (<em>Context or list of Context, optional</em>) – The device context of training and prediction. |
| To use multi-GPU training, pass in a list of GPU contexts.</li> |
| <li><strong>num_epoch</strong> (<em>int, optional</em>) – The number of training epochs(epochs).</li> |
| <li><strong>epoch_size</strong> (<em>int, optional</em>) – Number of batches in a epoch. In default, it is set to |
| <code class="docutils literal"><span class="pre">ceil(num_train_examples</span> <span class="pre">/</span> <span class="pre">batch_size)</span></code>.</li> |
| <li><strong>optimizer</strong> (<em>str or Optimizer, optional</em>) – The name of the chosen optimizer, or an optimizer object, used for training.</li> |
| <li><strong>initializer</strong> (<em>initializer function, optional</em>) – The initialization scheme used.</li> |
| <li><strong>eval_data</strong> (<em>DataIter or numpy.ndarray pair</em>) – If <cite>eval_set</cite> is <code class="docutils literal"><span class="pre">numpy.ndarray</span></code> pair, it should |
| be (<cite>valid_data</cite>, <cite>valid_label</cite>).</li> |
| <li><strong>eval_metric</strong> (<em>metric.EvalMetric or str or callable</em>) – The evaluation metric. Can be the name of an evaluation metric |
| or a custom evaluation function that returns statistics |
| based on a minibatch.</li> |
| <li><strong>epoch_end_callback</strong> (<em>callable(epoch, symbol, arg_params, aux_states)</em>) – A callback that is invoked at end of each epoch. |
| This can be used to checkpoint model each epoch.</li> |
| <li><strong>batch_end_callback</strong> (<em>callable(epoch)</em>) – A callback that is invoked at end of each batch for print purposes.</li> |
| <li><strong>kvstore</strong> (<em>KVStore or str, optional</em>) – The KVStore or a string kvstore type: ‘local’, ‘dist_sync’, ‘dis_async’. |
| Defaults to ‘local’, often no need to change for single machine.</li> |
| <li><strong>logger</strong> (<em>logging logger, optional</em>) – When not specified, default logger will be used.</li> |
| <li><strong>work_load_list</strong> (<em>list of float or int, optional</em>) – The list of work load for different devices, |
| in the same order as <cite>ctx</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <script>auto_index("model-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="next-steps"> |
| <span id="next-steps"></span><h2>Next Steps<a class="headerlink" href="#next-steps" title="Permalink to this headline">¶</a></h2> |
| <ul class="simple"> |
| <li>See <a class="reference internal" href="symbol/symbol.html"><em>Symbolic API</em></a> for operations on NDArrays that assemble neural networks from layers.</li> |
| <li>See <a class="reference internal" href="io/io.html"><em>IO Data Loading API</em></a> for parsing and loading data.</li> |
| <li>See <a class="reference internal" href="ndarray/ndarray.html"><em>NDArray API</em></a> for vector/matrix/tensor operations.</li> |
| <li>See <a class="reference internal" href="kvstore/kvstore.html"><em>KVStore API</em></a> for multi-GPU and multi-host distributed training.</li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <h3><a href="../../index.html">Table Of Contents</a></h3> |
| <ul> |
| <li><a class="reference internal" href="#">Model API</a><ul> |
| <li><a class="reference internal" href="#train-the-model">Train the Model</a></li> |
| <li><a class="reference internal" href="#save-the-model">Save the Model</a></li> |
| <li><a class="reference internal" href="#periodic-checkpointing">Periodic Checkpointing</a></li> |
| <li><a class="reference internal" href="#use-multiple-devices">Use Multiple Devices</a></li> |
| <li><a class="reference internal" href="#initializer-api-reference">Initializer API Reference</a></li> |
| <li><a class="reference internal" href="#evaluation-metric-api-reference">Evaluation Metric API Reference</a></li> |
| <li><a class="reference internal" href="#optimizer-api-reference">Optimizer API Reference</a></li> |
| <li><a class="reference internal" href="#model-api-reference">Model API Reference</a></li> |
| <li><a class="reference internal" href="#next-steps">Next Steps</a></li> |
| </ul> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div><div class="footer"> |
| <div class="section-disclaimer"> |
| <div class="container"> |
| <div> |
| <img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/> |
| <p> |
| Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p> |
| <p> |
| "Copyright © 2017-2018, The Apache Software Foundation |
| Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation." |
| </p> |
| </div> |
| </div> |
| </div> |
| </div> <!-- pagename != index --> |
| </div> |
| <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |
| <script src="../../_static/js/sidebar.js" type="text/javascript"></script> |
| <script src="../../_static/js/search.js" type="text/javascript"></script> |
| <script src="../../_static/js/navbar.js" type="text/javascript"></script> |
| <script src="../../_static/js/clipboard.min.js" type="text/javascript"></script> |
| <script src="../../_static/js/copycode.js" type="text/javascript"></script> |
| <script src="../../_static/js/page.js" type="text/javascript"></script> |
| <script type="text/javascript"> |
| $('body').ready(function () { |
| $('body').css('visibility', 'visible'); |
| }); |
| </script> |
| </body> |
| </html> |