| <!DOCTYPE html> |
| |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta content="IE=edge" http-equiv="X-UA-Compatible"/> |
| <meta content="width=device-width, initial-scale=1" name="viewport"/> |
| <title>Model API — mxnet documentation</title> |
| <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> |
| <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> |
| <link href="../../_static/basic.css" rel="stylesheet" type="text/css"> |
| <link href="../../_static/pygments.css" rel="stylesheet" type="text/css"> |
| <link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/> |
| <script type="text/javascript"> |
| var DOCUMENTATION_OPTIONS = { |
| URL_ROOT: '../../', |
| VERSION: '', |
| COLLAPSE_INDEX: false, |
| FILE_SUFFIX: '.html', |
| HAS_SOURCE: true, |
| SOURCELINK_SUFFIX: '' |
| }; |
| </script> |
| <script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script> |
| <script src="../../_static/underscore.js" type="text/javascript"></script> |
| <script src="../../_static/searchtools_custom.js" type="text/javascript"></script> |
| <script src="../../_static/doctools.js" type="text/javascript"></script> |
| <script src="../../_static/selectlang.js" type="text/javascript"></script> |
| <script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> |
| <script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new |
| Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| |
| </script> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/jquery.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/underscore.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/doctools.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> |
| <!-- --> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/> |
| </link></link></head> |
| <body role="document"><!-- Previous Navbar Layout |
| <div class="navbar navbar-default navbar-fixed-top"> |
| <div class="container"> |
| <div class="navbar-header"> |
| <button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar"> |
| <span class="sr-only">Toggle navigation</span> |
| <span class="icon-bar"></span> |
| <span class="icon-bar"></span> |
| <span class="icon-bar"></span> |
| </button> |
| <a href="../../" class="navbar-brand"> |
| <img src="http://data.mxnet.io/theme/mxnet.png"> |
| </a> |
| </div> |
| <div id="navbar" class="navbar-collapse collapse"> |
| <ul id="navbar" class="navbar navbar-left"> |
| |
| <li> <a href="../../get_started/index.html">Get Started</a> </li> |
| |
| <li> <a href="../../tutorials/index.html">Tutorials</a> </li> |
| |
| <li> <a href="../../how_to/index.html">How To</a> </li> |
| |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a> |
| <ul class="dropdown-menu"> |
| |
| <li><a href="../../packages/python/index.html"> |
| Python |
| </a></li> |
| |
| <li><a href="../../packages/r/index.html"> |
| R |
| </a></li> |
| |
| <li><a href="../../packages/julia/index.html"> |
| Julia |
| </a></li> |
| |
| <li><a href="../../packages/c++/index.html"> |
| C++ |
| </a></li> |
| |
| <li><a href="../../packages/scala/index.html"> |
| Scala |
| </a></li> |
| |
| <li><a href="../../packages/perl/index.html"> |
| Perl |
| </a></li> |
| |
| </ul> |
| </li> |
| |
| <li> <a href="../../system/index.html">System</a> </li> |
| <li> |
| <form class="" role="search" action="../../search.html" method="get" autocomplete="off"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input type="text" name="q" class="form-control" placeholder="Search"> |
| </div> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| |
| </form> </li> |
| </ul> |
| <ul id="navbar" class="navbar navbar-right"> |
| <li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li> |
| <li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| Previous Navbar Layout End --> |
| <div class="navbar navbar-fixed-top"> |
| <div class="container" id="navContainer"> |
| <div class="innder" id="header-inner"> |
| <h1 id="logo-wrap"> |
| <a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a> |
| </h1> |
| <nav class="nav-bar" id="main-nav"> |
| <a class="main-nav-link" href="../../get_started/install.html">Install</a> |
| <a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a> |
| <a class="main-nav-link" href="../../how_to/index.html">How To</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> |
| <ul class="dropdown-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li> |
| <li><a class="main-nav-link" href="../../api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li> |
| </ul> |
| </span> |
| <a class="main-nav-link" href="../../architecture/index.html">Architecture</a> |
| <!-- <a class="main-nav-link" href="../../community/index.html">Community</a> --> |
| <a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a> |
| <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav> |
| <script> function getRootPath(){ return "../../" } </script> |
| <div class="burgerIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> |
| <ul class="dropdown-menu dropdown-menu-right" id="burgerMenu"> |
| <li><a href="../../get_started/install.html">Install</a></li> |
| <li><a href="../../tutorials/index.html">Tutorials</a></li> |
| <li><a href="../../how_to/index.html">How To</a></li> |
| <li class="dropdown-submenu"> |
| <a href="#" tabindex="-1">API</a> |
| <ul class="dropdown-menu"> |
| <li><a href="../../api/python/index.html" tabindex="-1">Python</a> |
| </li> |
| <li><a href="../../api/scala/index.html" tabindex="-1">Scala</a> |
| </li> |
| <li><a href="../../api/r/index.html" tabindex="-1">R</a> |
| </li> |
| <li><a href="../../api/julia/index.html" tabindex="-1">Julia</a> |
| </li> |
| <li><a href="../../api/c++/index.html" tabindex="-1">C++</a> |
| </li> |
| <li><a href="../../api/perl/index.html" tabindex="-1">Perl</a> |
| </li> |
| </ul> |
| </li> |
| <li><a href="../../architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li> |
| <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul> |
| </div> |
| <div class="plusIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> |
| <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> |
| </div> |
| <div id="search-input-wrap"> |
| <form action="../../search.html" autocomplete="off" class="" method="get" role="search"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input class="form-control" name="q" placeholder="Search" type="text"/> |
| </div> |
| <input name="check_keywords" type="hidden" value="yes"> |
| <input name="area" type="hidden" value="default"/> |
| </input></form> |
| <div id="search-preview"></div> |
| </div> |
| <div id="searchIcon"> |
| <span aria-hidden="true" class="glyphicon glyphicon-search"></span> |
| </div> |
| <!-- <div id="lang-select-wrap"> --> |
| <!-- <label id="lang-select-label"> --> |
| <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> |
| <!-- <span></span> --> |
| <!-- </label> --> |
| <!-- <select id="lang-select"> --> |
| <!-- <option value="en">Eng</option> --> |
| <!-- <option value="zh">中文</option> --> |
| <!-- </select> --> |
| <!-- </div> --> |
| <!-- <a id="mobile-nav-toggle"> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| </a> --> |
| </div> |
| </div> |
| </div> |
| <div class="container"> |
| <div class="row"> |
| <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="index.html">Python Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../r/index.html">R Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../julia/index.html">Julia Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../c++/index.html">C++ Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../scala/index.html">Scala Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../perl/index.html">Perl Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li> |
| </ul> |
| </div> |
| </div> |
| <div class="content"> |
| <div class="section" id="model-api"> |
| <span id="model-api"></span><h1>Model API<a class="headerlink" href="#model-api" title="Permalink to this headline">¶</a></h1> |
| <p>The model API provides a simplified way to train neural networks using common best practices. |
| It’s a thin wrapper built on top of the <a class="reference internal" href="ndarray.html"><em>ndarray</em></a> and <a class="reference internal" href="symbol.html"><em>symbolic</em></a> |
| modules that make neural network training easy.</p> |
| <p>Topics:</p> |
| <ul class="simple"> |
| <li><a class="reference external" href="#train-a-model">Train a Model</a></li> |
| <li><a class="reference external" href="#save-the-model">Save the Model</a></li> |
| <li><a class="reference external" href="#periodic-checkpointing">Periodic Checkpoint</a></li> |
| <li><a class="reference external" href="#initializer-api-reference">Initializer API Reference</a></li> |
| <li><a class="reference external" href="#evaluation-metric-api-reference">Evaluation Metric API Reference</a></li> |
| <li><a class="reference external" href="#optimizer-api-reference">Optimizer API Reference</a></li> |
| <li><a class="reference external" href="#model-api-reference">Model API Reference</a></li> |
| </ul> |
| <div class="section" id="train-the-model"> |
| <span id="train-the-model"></span><h2>Train the Model<a class="headerlink" href="#train-the-model" title="Permalink to this headline">¶</a></h2> |
| <p>To train a model, perform two steps: configure the model using the symbol parameter, |
| then call <code class="docutils literal"><span class="pre">model.Feedforward.create</span></code> to create the model. |
| The following example creates a two-layer neural network.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="c1"># configure a two layer neuralnetwork</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span> |
| <span class="n">fc1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc1'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span> |
| <span class="n">act1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">fc1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'relu1'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span> |
| <span class="n">fc2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">act1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc2'</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span> |
| <span class="n">softmax</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">fc2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'sm'</span><span class="p">)</span> |
| <span class="c1"># create a model</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">X</span><span class="o">=</span><span class="n">data_set</span><span class="p">,</span> |
| <span class="n">num_epoch</span><span class="o">=</span><span class="n">num_epoch</span><span class="p">,</span> |
| <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>You can also use the <code class="docutils literal"><span class="pre">scikit-learn-style</span></code> construct and <code class="docutils literal"><span class="pre">fit</span></code> function to create a model.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="c1"># create a model using sklearn-style two-step way</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">num_epoch</span><span class="o">=</span><span class="n">num_epoch</span><span class="p">,</span> |
| <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span> |
| |
| <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="o">=</span><span class="n">data_set</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>For more information, see <a class="reference external" href="#model-api-reference">Model API Reference</a>.</p> |
| </div> |
| <div class="section" id="save-the-model"> |
| <span id="save-the-model"></span><h2>Save the Model<a class="headerlink" href="#save-the-model" title="Permalink to this headline">¶</a></h2> |
| <p>After the job is done, save your work. |
| To save the model, you can directly pickle it with Python. |
| We also provide <code class="docutils literal"><span class="pre">save</span></code> and <code class="docutils literal"><span class="pre">load</span></code> functions.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="c1"># save a model to mymodel-symbol.json and mymodel-0100.params</span> |
| <span class="n">prefix</span> <span class="o">=</span> <span class="s1">'mymodel'</span> |
| <span class="n">iteration</span> <span class="o">=</span> <span class="mi">100</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">iteration</span><span class="p">)</span> |
| |
| <span class="c1"># load model back</span> |
| <span class="n">model_loaded</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">iteration</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The advantage of these two <code class="docutils literal"><span class="pre">save</span></code> and <code class="docutils literal"><span class="pre">load</span></code> functions are that they are language agnostic. |
| You should be able to save and load directly into cloud storage, such as Amazon S3 and HDFS.</p> |
| </div> |
| <div class="section" id="periodic-checkpointing"> |
| <span id="periodic-checkpointing"></span><h2>Periodic Checkpointing<a class="headerlink" href="#periodic-checkpointing" title="Permalink to this headline">¶</a></h2> |
| <p>We recommend checkpointing your model after each iteration. |
| To do this, add a checkpoint callback <code class="docutils literal"><span class="pre">do_checkpoint(path)</span></code> to the function. |
| The training process automatically checkpoints the specified location after |
| each iteration.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="n">prefix</span><span class="o">=</span><span class="s1">'models/chkpt'</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">X</span><span class="o">=</span><span class="n">data_set</span><span class="p">,</span> |
| <span class="n">iter_end_callback</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">do_checkpoint</span><span class="p">(</span><span class="n">prefix</span><span class="p">),</span> |
| <span class="o">...</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>You can load the model checkpoint later using <code class="docutils literal"><span class="pre">Feedforward.load</span></code>.</p> |
| </div> |
| <div class="section" id="use-multiple-devices"> |
| <span id="use-multiple-devices"></span><h2>Use Multiple Devices<a class="headerlink" href="#use-multiple-devices" title="Permalink to this headline">¶</a></h2> |
| <p>Set <code class="docutils literal"><span class="pre">ctx</span></code> to the list of devices that you want to train on.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span> <span class="n">devices</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_device</span><span class="p">)]</span> |
| <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="n">softmax</span><span class="p">,</span> |
| <span class="n">X</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> |
| <span class="n">ctx</span><span class="o">=</span><span class="n">devices</span><span class="p">,</span> |
| <span class="o">...</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>Training occurs in parallel on the GPUs that you specify.</p> |
| <blockquote> |
| <div><script src="../../_static/js/auto_module_index.js" type="text/javascript"></script></div></blockquote> |
| </div> |
| <div class="section" id="initializer-api-reference"> |
| <span id="initializer-api-reference"></span><h2>Initializer API Reference<a class="headerlink" href="#initializer-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.initializer"></span><p>Weight initializer.</p> |
| <dl class="class"> |
| <dt id="mxnet.initializer.InitDesc"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">InitDesc</code><a class="headerlink" href="#mxnet.initializer.InitDesc" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Descriptor for the initialization pattern.</p> |
| <dl class="docutils"> |
| <dt>name <span class="classifier-delimiter">:</span> <span class="classifier">str</span></dt> |
| <dd>Name of variable.</dd> |
| <dt>attrs <span class="classifier-delimiter">:</span> <span class="classifier">dict of str to str</span></dt> |
| <dd>Attributes of this variable taken from <code class="docutils literal"><span class="pre">Symbol.attr_dict</span></code>.</dd> |
| <dt>global_init <span class="classifier-delimiter">:</span> <span class="classifier">Initializer</span></dt> |
| <dd>Global initializer to fallback to.</dd> |
| </dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Initializer"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Initializer</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The base class of an initializer.</p> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.set_verbosity"> |
| <code class="descname">set_verbosity</code><span class="sig-paren">(</span><em>verbose=False</em>, <em>print_func=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer.set_verbosity" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Switch on/off verbose mode</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>verbose</strong> (<em>bool</em>) – switch on/off verbose mode</li> |
| <li><strong>print_func</strong> (<em>function</em>) – A function that computes statistics of initialized arrays. |
| Takes an <cite>NDArray</cite> and returns an <cite>str</cite>. Defaults to mean |
| absolute value str((<a href="#id11"><span class="problematic" id="id12">|x|</span></a>/size(x)).asscalar()).</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.dumps"> |
| <code class="descname">dumps</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer.dumps" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Saves the initializer to string</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">JSON formatted string that describes the initializer.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">str</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Create initializer and retrieve its parameters</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">init</span><span class="o">.</span><span class="n">dumps</span><span class="p">()</span> |
| <span class="go">'["normal", {"sigma": 0.5}]'</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Xavier</span><span class="p">(</span><span class="n">factor_type</span><span class="o">=</span><span class="s2">"in"</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mf">2.34</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">init</span><span class="o">.</span><span class="n">dumps</span><span class="p">()</span> |
| <span class="go">'["xavier", {"rnd_type": "uniform", "magnitude": 2.34, "factor_type": "in"}]'</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.initializer.Initializer.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>desc</em>, <em>arr</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Initializer.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize an array</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>desc</strong> (<a class="reference internal" href="optimization.html#mxnet.initializer.InitDesc" title="mxnet.initializer.InitDesc"><em>InitDesc</em></a>) – Initialization pattern descriptor.</li> |
| <li><strong>arr</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The array to be initialized.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.initializer.register"> |
| <code class="descclassname">mxnet.initializer.</code><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a custom initializer.</p> |
| <p>Custom initializers can be created by extending <cite>mx.init.Initializer</cite> and implementing the |
| required functions like <cite>_init_weight</cite> and <cite>_init_bias</cite>. The created initializer must be |
| registered using <cite>mx.init.register</cite> before it can be called by name.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>klass</strong> (<em>class</em>) – A subclass of <cite>mx.init.Initializer</cite> that needs to be registered as a custom initializer.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Create and register a custom initializer that</span> |
| <span class="gp">... </span><span class="c1"># initializes weights to 0.1 and biases to 1.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="nd">@mx.init.register</span> |
| <span class="gp">... </span><span class="nd">@alias</span><span class="p">(</span><span class="s1">'myinit'</span><span class="p">)</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">CustomInit</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Initializer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="nb">super</span><span class="p">(</span><span class="n">CustomInit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="nf">_init_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="mf">0.1</span> |
| <span class="gp">... </span> <span class="k">def</span> <span class="nf">_init_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="mi">1</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="c1"># Module is an instance of 'mxnet.module.Module'</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="s2">"custominit"</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="c1"># module.init_params("myinit")</span> |
| <span class="gp">>>> </span><span class="c1"># module.init_params(CustomInit())</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Load"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Load</code><span class="sig-paren">(</span><em>param</em>, <em>default_init=None</em>, <em>verbose=False</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Load" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes variables by loading data from file or dict.</p> |
| <p><strong>Note</strong> Load will drop <code class="docutils literal"><span class="pre">arg:</span></code> or <code class="docutils literal"><span class="pre">aux:</span></code> from name and |
| initialize the variables that match with the prefix dropped.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>param</strong> (<em>str or dict of str->`NDArray`</em>) – Parameter file or dict mapping name to NDArray.</li> |
| <li><strong>default_init</strong> (<a class="reference internal" href="optimization.html#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><em>Initializer</em></a>) – Default initializer when name is not found in <cite>param</cite>.</li> |
| <li><strong>verbose</strong> (<em>bool</em>) – Flag for enabling logging of source when initializing.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Mixed"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Mixed</code><span class="sig-paren">(</span><em>patterns</em>, <em>initializers</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Mixed" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize parameters using multiple initializers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>patterns</strong> (<em>list of str</em>) – List of regular expressions matching parameter names.</li> |
| <li><strong>initializers</strong> (<em>list of Initializer</em>) – List of initializers corresponding to <cite>patterns</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize biases to zero</span> |
| <span class="gp">... </span><span class="c1"># and every other parameter to random values with uniform distribution.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Mixed</span><span class="p">([</span><span class="s1">'bias'</span><span class="p">,</span> <span class="s1">'.*'</span><span class="p">],</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Zero</span><span class="p">(),</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)])</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="go">>>></span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected1_weight</span> |
| <span class="go">[[ 0.0097627 0.01856892 0.04303787]]</span> |
| <span class="go">fullyconnected1_bias</span> |
| <span class="go">[ 0.]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Zero"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Zero</code><a class="headerlink" href="#mxnet.initializer.Zero" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights to zero.</p> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights to zero.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">Zero</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 0. 0. 0.]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.One"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">One</code><a class="headerlink" href="#mxnet.initializer.One" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights to one.</p> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights to one.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">initializer</span><span class="o">.</span><span class="n">One</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 1. 1. 1.]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Constant"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Constant</code><span class="sig-paren">(</span><em>value</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Constant" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes the weights to a scalar value.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>value</strong> (<em>float</em>) – Fill value.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Uniform"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Uniform</code><span class="sig-paren">(</span><em>scale=0.07</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Uniform" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights with random values uniformly sampled from a given range.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>scale</strong> (<em>float, optional</em>) – The bound on the range of the generated random values. |
| Values are generated from the range [-<cite>scale</cite>, <cite>scale</cite>]. |
| Default scale is 0.07.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights</span> |
| <span class="gp">>>> </span><span class="c1"># to random values uniformly sampled between -0.1 and 0.1.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[ 0.01360891 -0.02144304 0.08511933]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Normal"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Normal</code><span class="sig-paren">(</span><em>sigma=0.01</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Normal" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initializes weights with random values sampled from a normal distribution |
| with a mean of zero and standard deviation of <cite>sigma</cite>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>sigma</strong> (<em>float, optional</em>) – Standard deviation of the normal distribution. |
| Default standard deviation is 0.01.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Given 'module', an instance of 'mxnet.module.Module', initialize weights</span> |
| <span class="gp">>>> </span><span class="c1"># to random values sampled from a normal distribution.</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">init</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">module</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">init</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">get_params</span><span class="p">():</span> |
| <span class="gp">... </span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> |
| <span class="gp">... </span> <span class="k">print</span><span class="p">(</span><span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span> |
| <span class="gp">...</span> |
| <span class="go">fullyconnected0_weight</span> |
| <span class="go">[[-0.3214761 -0.12660924 0.53789419]]</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Orthogonal"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Orthogonal</code><span class="sig-paren">(</span><em>scale=1.414</em>, <em>rand_type='uniform'</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Orthogonal" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize weight as orthogonal matrix.</p> |
| <p>This initializer implements <em>Exact solutions to the nonlinear dynamics of |
| learning in deep linear neural networks</em>, available at |
| <a class="reference external" href="https://arxiv.org/abs/1312.6120">https://arxiv.org/abs/1312.6120</a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>scale</strong> (<em>float optional</em>) – Scaling factor of weight.</li> |
| <li><strong>rand_type</strong> (<em>string optional</em>) – Use “uniform” or “normal” random number to initialize weight.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Xavier"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Xavier</code><span class="sig-paren">(</span><em>rnd_type='uniform'</em>, <em>factor_type='avg'</em>, <em>magnitude=3</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.Xavier" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns an initializer performing “Xavier” initialization for weights.</p> |
| <p>This initializer is designed to keep the scale of gradients roughly the same |
| in all layers.</p> |
| <p>By default, <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'uniform'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'avg'</span></code>, |
| the initializer fills the weights with random numbers in the range |
| of <span class="math">\([-c, c]\)</span>, where <span class="math">\(c = \sqrt{\frac{3.}{0.5 * (n_{in} + n_{out})}}\)</span>. |
| <span class="math">\(n_{in}\)</span> is the number of neurons feeding into weights, and <span class="math">\(n_{out}\)</span> is |
| the number of neurons the result is fed to.</p> |
| <p>If <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'uniform'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'in'</span></code>, |
| the <span class="math">\(c = \sqrt{\frac{3.}{n_{in}}}\)</span>. |
| Similarly when <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'out'</span></code>, the <span class="math">\(c = \sqrt{\frac{3.}{n_{out}}}\)</span>.</p> |
| <p>If <cite>rnd_type</cite> is <code class="docutils literal"><span class="pre">'gaussian'</span></code> and <cite>factor_type</cite> is <code class="docutils literal"><span class="pre">'avg'</span></code>, |
| the initializer fills the weights with numbers from normal distribution with |
| a standard deviation of <span class="math">\(\sqrt{\frac{3.}{0.5 * (n_{in} + n_{out})}}\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rnd_type</strong> (<em>str, optional</em>) – Random generator type, can be <code class="docutils literal"><span class="pre">'gaussian'</span></code> or <code class="docutils literal"><span class="pre">'uniform'</span></code>.</li> |
| <li><strong>factor_type</strong> (<em>str, optional</em>) – Can be <code class="docutils literal"><span class="pre">'avg'</span></code>, <code class="docutils literal"><span class="pre">'in'</span></code>, or <code class="docutils literal"><span class="pre">'out'</span></code>.</li> |
| <li><strong>magnitude</strong> (<em>float, optional</em>) – Scale of random number.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.MSRAPrelu"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">MSRAPrelu</code><span class="sig-paren">(</span><em>factor_type='avg'</em>, <em>slope=0.25</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.MSRAPrelu" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize the weight according to a MSRA paper.</p> |
| <p>This initializer implements <em>Delving Deep into Rectifiers: Surpassing |
| Human-Level Performance on ImageNet Classification</em>, available at |
| <a class="reference external" href="https://arxiv.org/abs/1502.01852">https://arxiv.org/abs/1502.01852</a>.</p> |
| <p>This initializer is proposed for initialization related to ReLu activation, |
| it maked some changes on top of Xavier method.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>factor_type</strong> (<em>str, optional</em>) – Can be <code class="docutils literal"><span class="pre">'avg'</span></code>, <code class="docutils literal"><span class="pre">'in'</span></code>, or <code class="docutils literal"><span class="pre">'out'</span></code>.</li> |
| <li><strong>slope</strong> (<em>float, optional</em>) – initial slope of any PReLU (or similar) nonlinearities.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.Bilinear"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">Bilinear</code><a class="headerlink" href="#mxnet.initializer.Bilinear" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize weight for upsampling layers.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.LSTMBias"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">LSTMBias</code><span class="sig-paren">(</span><em>forget_bias</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.LSTMBias" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize all bias of an LSTMCell to 0.0 except for |
| the forget gate whose bias is set to custom value.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>forget_bias</strong> (<em>float, bias for the forget gate.</em>) – Jozefowicz et al. 2015 recommends setting this to 1.0.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.initializer.FusedRNN"> |
| <em class="property">class </em><code class="descclassname">mxnet.initializer.</code><code class="descname">FusedRNN</code><span class="sig-paren">(</span><em>init</em>, <em>num_hidden</em>, <em>num_layers</em>, <em>mode</em>, <em>bidirectional=False</em>, <em>forget_bias=1.0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.initializer.FusedRNN" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Initialize parameters for fused rnn layers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>init</strong> (<a class="reference internal" href="optimization.html#mxnet.initializer.Initializer" title="mxnet.initializer.Initializer"><em>Initializer</em></a>) – initializer applied to unpacked weights. Fall back to global |
| initializer if None.</li> |
| <li><strong>num_hidden</strong> (<em>int</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>num_layers</strong> (<em>int</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>mode</strong> (<em>str</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>bidirectional</strong> (<em>bool</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| <li><strong>forget_bias</strong> (<em>float</em>) – should be the same with arguments passed to FusedRNNCell.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <script>auto_index("initializer-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="evaluation-metric-api-reference"> |
| <span id="evaluation-metric-api-reference"></span><h2>Evaluation Metric API Reference<a class="headerlink" href="#evaluation-metric-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.metric"></span><p>Online evaluation metric module.</p> |
| <dl class="class"> |
| <dt id="mxnet.metric.EvalMetric"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">EvalMetric</code><span class="sig-paren">(</span><em>name</em>, <em>output_names=None</em>, <em>label_names=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Base class for all evaluation metrics.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">This is a base class that provides common metric interfaces. |
| One should not use this class directly, but instead create new metric |
| classes that extend it.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.get_config"> |
| <code class="descname">get_config</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric.get_config" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Save configurations of metric. Can be recreated |
| from configs with metric.create(<a href="#id1"><span class="problematic" id="id2">**</span></a>config)</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.update_dict"> |
| <code class="descname">update_dict</code><span class="sig-paren">(</span><em>label</em>, <em>pred</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric.update_dict" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Update the internal evaluation with named label and pred</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (<em>OrderedDict of str -> NDArray</em>) – name to array mapping for labels.</li> |
| <li><strong>preds</strong> (<em>list of NDArray</em>) – name to array mapping of predicted outputs.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Resets the internal evaluation result to initial state.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Gets the current evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body"><ul class="simple"> |
| <li><strong>names</strong> (<em>list of str</em>) – |
| Name of the metrics.</li> |
| <li><strong>values</strong> (<em>list of float</em>) – |
| Value of the evaluations.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.EvalMetric.get_name_value"> |
| <code class="descname">get_name_value</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.EvalMetric.get_name_value" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns zipped name and value pairs.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">A (name, value) tuple list.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">list of tuples</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.metric.create"> |
| <code class="descclassname">mxnet.metric.</code><code class="descname">create</code><span class="sig-paren">(</span><em>metric</em>, <em>*args</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates evaluation metric from metric names or instances of EvalMetric |
| or a custom metric function.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>metric</strong> (<em>str or callable</em>) – <p>Specifies the metric to create. |
| This argument must be one of the below:</p> |
| <ul> |
| <li>Name of a metric.</li> |
| <li>An instance of <cite>EvalMetric</cite>.</li> |
| <li>A list, each element of which is a metric or a metric name.</li> |
| <li>An evaluation function that computes custom metric for a given batch of |
| labels and predictions.</li> |
| </ul> |
| </li> |
| <li><strong>*args</strong> – <p>Additional arguments to metric constructor. |
| Only used when metric is str.</p> |
| </li> |
| <li><strong>**kwargs</strong> – <p>Additional arguments to metric constructor. |
| Only used when metric is str</p> |
| </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">custom_metric</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">pred</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">label</span> <span class="o">-</span> <span class="n">pred</span><span class="p">))</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">metric1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'acc'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">metric2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">custom_metric</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">metric3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">([</span><span class="n">metric1</span><span class="p">,</span> <span class="n">metric2</span><span class="p">,</span> <span class="s1">'rmse'</span><span class="p">])</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.CompositeEvalMetric"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">CompositeEvalMetric</code><span class="sig-paren">(</span><em>metrics=None</em>, <em>name='composite'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Manages multiple evaluation metrics.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>metrics</strong> (<em>list of EvalMetric</em>) – List of child metrics.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics_1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics_2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">F1</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">CompositeEvalMetric</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="k">for</span> <span class="n">child_metric</span> <span class="ow">in</span> <span class="p">[</span><span class="n">eval_metrics_1</span><span class="p">,</span> <span class="n">eval_metrics_2</span><span class="p">]:</span> |
| <span class="gp">>>> </span> <span class="n">eval_metrics</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">child_metric</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">eval_metrics</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">(['accuracy', 'f1'], [0.6666666666666666, 0.8])</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.add"> |
| <code class="descname">add</code><span class="sig-paren">(</span><em>metric</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.add" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Adds a child metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>metric</strong> – A metric instance.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.get_metric"> |
| <code class="descname">get_metric</code><span class="sig-paren">(</span><em>index</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.get_metric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns a child metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>index</strong> (<em>int</em>) – Index of child metric in the list of metrics.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.reset"> |
| <code class="descname">reset</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.reset" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Resets the internal evaluation result to initial state.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.CompositeEvalMetric.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CompositeEvalMetric.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns the current evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body"><ul class="simple"> |
| <li><strong>names</strong> (<em>list of str</em>) – |
| Name of the metrics.</li> |
| <li><strong>values</strong> (<em>list of float</em>) – |
| Value of the evaluations.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Accuracy"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Accuracy</code><span class="sig-paren">(</span><em>axis=1</em>, <em>name='accuracy'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Accuracy" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes accuracy classification score.</p> |
| <p>The accuracy score is defined as</p> |
| <div class="math"> |
| \[\text{accuracy}(y, \hat{y}) = \frac{1}{n} \sum_{i=0}^{n-1} |
| \text{1}(\hat{y_i} == y_i)\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>axis</strong> (<em>int, default=1</em>) – The axis that represents classes</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">acc</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('accuracy', 0.6666666666666666)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.Accuracy.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Accuracy.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.TopKAccuracy"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">TopKAccuracy</code><span class="sig-paren">(</span><em>top_k=1</em>, <em>name='top_k_accuracy'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.TopKAccuracy" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes top k predictions accuracy.</p> |
| <p><cite>TopKAccuracy</cite> differs from Accuracy in that it considers the prediction |
| to be <code class="docutils literal"><span class="pre">True</span></code> as long as the ground truth label is in the top K |
| predicated labels.</p> |
| <p>If <cite>top_k</cite> = <code class="docutils literal"><span class="pre">1</span></code>, then <cite>TopKAccuracy</cite> is identical to <cite>Accuracy</cite>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>top_k</strong> (<em>int</em>) – Whether targets are in top k predictions.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">999</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">top_k</span> <span class="o">=</span> <span class="mi">3</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">6</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">TopKAccuracy</span><span class="p">(</span><span class="n">top_k</span><span class="o">=</span><span class="n">top_k</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">acc</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('top_k_accuracy', 0.3)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.TopKAccuracy.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.TopKAccuracy.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.F1"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">F1</code><span class="sig-paren">(</span><em>name='f1'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.F1" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes the F1 score of a binary classification problem.</p> |
| <p>The F1 score is equivalent to weighted average of the precision and recall, |
| where the best value is 1.0 and the worst value is 0.0. The formula for F1 score is:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">F1</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">precision</span> <span class="o">*</span> <span class="n">recall</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">precision</span> <span class="o">+</span> <span class="n">recall</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>The formula for precision and recall is:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">precision</span> <span class="o">=</span> <span class="n">true_positives</span> <span class="o">/</span> <span class="p">(</span><span class="n">true_positives</span> <span class="o">+</span> <span class="n">false_positives</span><span class="p">)</span> |
| <span class="n">recall</span> <span class="o">=</span> <span class="n">true_positives</span> <span class="o">/</span> <span class="p">(</span><span class="n">true_positives</span> <span class="o">+</span> <span class="n">false_negatives</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">This F1 score only supports binary classification.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">acc</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">F1</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">acc</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">acc</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('f1', 0.8)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.F1.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.F1.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Perplexity"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Perplexity</code><span class="sig-paren">(</span><em>ignore_label</em>, <em>axis=-1</em>, <em>name='perplexity'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Perplexity" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes perplexity.</p> |
| <p>Perplexity is a measurement of how well a probability distribution |
| or model predicts a sample. A low perplexity indicates the model |
| is good at predicting the sample.</p> |
| <p>The perplexity of a model q is defined as</p> |
| <div class="math"> |
| \[b^{\big(-\frac{1}{N} \sum_{i=1}^N \log_b q(x_i) \big)} |
| = \exp \big(-\frac{1}{N} \sum_{i=1}^N \log q(x_i)\big)\]</div> |
| <p>where we let <cite>b = e</cite>.</p> |
| <p><span class="math">\(q(x_i)\)</span> is the predicted value of its ground truth |
| label on sample <span class="math">\(x_i\)</span>.</p> |
| <p>For example, we have three samples <span class="math">\(x_1, x_2, x_3\)</span> and their labels |
| are <span class="math">\([0, 1, 1]\)</span>. |
| Suppose our model predicts <span class="math">\(q(x_1) = p(y_1 = 0 | x_1) = 0.3\)</span> |
| and <span class="math">\(q(x_2) = 1.0\)</span>, |
| <span class="math">\(q(x_3) = 0.6\)</span>. The perplexity of model q is |
| <span class="math">\(exp\big(-(\log 0.3 + \log 1.0 + \log 0.6) / 3\big) = 1.77109762852\)</span>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>ignore_label</strong> (<em>int or None</em>) – Index of invalid label to ignore when |
| counting. By default, sets to -1. |
| If set to <cite>None</cite>, it will include all entries.</li> |
| <li><strong>axis</strong> (<em>int (default -1)</em>) – The axis from prediction that was used to |
| compute softmax. By default use the last |
| axis.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">perp</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Perplexity</span><span class="p">(</span><span class="n">ignore_label</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">perp</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">perp</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('Perplexity', 1.7710976285155853)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.Perplexity.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Perplexity.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.metric.Perplexity.get"> |
| <code class="descname">get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Perplexity.get" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns the current evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Returns:</th><td class="field-body">Representing name of the metric and evaluation result.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Return type:</th><td class="field-body">Tuple of (str, float)</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.MAE"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">MAE</code><span class="sig-paren">(</span><em>name='mae'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.MAE" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Mean Absolute Error (MAE) loss.</p> |
| <p>The mean absolute error is given by</p> |
| <div class="math"> |
| \[\frac{\sum_i^n |y_i - \hat{y}_i|}{n}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">mean_absolute_error</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">MAE</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">mean_absolute_error</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">mean_absolute_error</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('mae', 0.5)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.MAE.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.MAE.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.MSE"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">MSE</code><span class="sig-paren">(</span><em>name='mse'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.MSE" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Mean Squared Error (MSE) loss.</p> |
| <p>The mean squared error is given by</p> |
| <div class="math"> |
| \[\frac{\sum_i^n (y_i - \hat{y}_i)^2}{n}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">mean_squared_error</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">MSE</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">mean_squared_error</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">mean_squared_error</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('mse', 0.375)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.MSE.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.MSE.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.RMSE"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">RMSE</code><span class="sig-paren">(</span><em>name='rmse'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.RMSE" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Root Mean Squred Error (RMSE) loss.</p> |
| <p>The root mean squared error is given by</p> |
| <div class="math"> |
| \[\sqrt{\frac{\sum_i^n (y_i - \hat{y}_i)^2}{n}}\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">root_mean_squared_error</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">RMSE</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">root_mean_squared_error</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">root_mean_squared_error</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('rmse', 0.612372457981)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.RMSE.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.RMSE.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.CrossEntropy"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">CrossEntropy</code><span class="sig-paren">(</span><em>eps=1e-08</em>, <em>name='cross-entropy'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CrossEntropy" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes Cross Entropy loss.</p> |
| <p>The cross entropy is given by</p> |
| <div class="math"> |
| \[-y\log \hat{y} + (1-y)\log (1-\hat{y})\]</div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>eps</strong> (<em>float</em>) – Cross Entropy loss is undefined for predicted value is 0 or 1, |
| so predicted values are added with the small constant.</li> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]])]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])]</span> |
| <span class="gp">>>> </span><span class="n">ce</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">CrossEntropy</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">ce</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">ce</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('cross-entropy', 0.57159948348999023)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.CrossEntropy.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CrossEntropy.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Loss"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Loss</code><span class="sig-paren">(</span><em>name='loss'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Loss" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Dummy metric for directly printing loss.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Torch"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Torch</code><span class="sig-paren">(</span><em>name='torch'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Torch" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Dummy metric for torch criterions.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.Caffe"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">Caffe</code><span class="sig-paren">(</span><em>name='caffe'</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.Caffe" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Dummy metric for caffe criterions.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.metric.CustomMetric"> |
| <em class="property">class </em><code class="descclassname">mxnet.metric.</code><code class="descname">CustomMetric</code><span class="sig-paren">(</span><em>feval</em>, <em>name=None</em>, <em>allow_extra_outputs=False</em>, <em>output_names=None</em>, <em>label_names=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CustomMetric" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Computes a customized evaluation metric.</p> |
| <p>The <cite>feval</cite> function can return a <cite>tuple</cite> of (sum_metric, num_inst) or return |
| an <cite>int</cite> sum_metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>feval</strong> (<em>callable(label, pred)</em>) – Customized evaluation function.</li> |
| <li><strong>name</strong> (<em>str</em>) – The name of the metric. (the default is None).</li> |
| <li><strong>allow_extra_outputs</strong> (<em>bool, optional</em>) – If true, the prediction outputs can have extra outputs. |
| This is useful in RNN, where the states are also produced |
| in outputs for forwarding. (the default is False).</li> |
| <li><strong>name</strong> – Name of this metric instance for display.</li> |
| <li><strong>output_names</strong> (<em>list of str, or None</em>) – Name of predictions that should be used when updating with update_dict. |
| By default include all predictions.</li> |
| <li><strong>label_names</strong> (<em>list of str, or None</em>) – Name of labels that should be used when updating with update_dict. |
| By default include all labels.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">predicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">))]</span> |
| <span class="gp">>>> </span><span class="n">feval</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="p">:</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">CustomMetric</span><span class="p">(</span><span class="n">feval</span><span class="o">=</span><span class="n">feval</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="n">eval_metrics</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predicts</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span> <span class="n">eval_metrics</span><span class="o">.</span><span class="n">get</span><span class="p">()</span> |
| <span class="go">('custom(<lambda>)', 6.0)</span> |
| </pre></div> |
| </div> |
| <dl class="method"> |
| <dt id="mxnet.metric.CustomMetric.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>labels</em>, <em>preds</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.CustomMetric.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the internal evaluation result.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>labels</strong> (list of <cite>NDArray</cite>) – The labels of the data.</li> |
| <li><strong>preds</strong> (list of <cite>NDArray</cite>) – Predicted values.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.metric.np"> |
| <code class="descclassname">mxnet.metric.</code><code class="descname">np</code><span class="sig-paren">(</span><em>numpy_feval</em>, <em>name=None</em>, <em>allow_extra_outputs=False</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.metric.np" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates a custom evaluation metric that receives its inputs as numpy arrays.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>numpy_feval</strong> (<em>callable(label, pred)</em>) – Custom evaluation function that receives labels and predictions for a minibatch |
| as numpy arrays and returns the corresponding custom metric as a floating point number.</li> |
| <li><strong>name</strong> (<em>str, optional</em>) – Name of the custom metric.</li> |
| <li><strong>allow_extra_outputs</strong> (<em>bool, optional</em>) – Whether prediction output is allowed to have extra outputs. This is useful in cases |
| like RNN where states are also part of output which can then be fed back to the RNN |
| in the next step. By default, extra outputs are not allowed.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">Custom metric corresponding to the provided labels and predictions.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">float</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Example</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">custom_metric</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">pred</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">label</span><span class="o">-</span><span class="n">pred</span><span class="p">))</span> |
| <span class="gp">...</span> |
| <span class="gp">>>> </span><span class="n">metric</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">np</span><span class="p">(</span><span class="n">custom_metric</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <script>auto_index("evaluation-metric-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="optimizer-api-reference"> |
| <span id="optimizer-api-reference"></span><h2>Optimizer API Reference<a class="headerlink" href="#optimizer-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.optimizer"></span><p>Weight updating functions.</p> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Optimizer"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Optimizer</code><span class="sig-paren">(</span><em>rescale_grad=1.0</em>, <em>param_idx2name=None</em>, <em>wd=0.0</em>, <em>clip_gradient=None</em>, <em>learning_rate=0.01</em>, <em>lr_scheduler=None</em>, <em>sym=None</em>, <em>begin_num_update=0</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The base class inherited by all optimizers.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rescale_grad</strong> (<em>float, optional</em>) – Multiply the gradient with <cite>rescale_grad</cite> before updating. Often |
| choose to be <code class="docutils literal"><span class="pre">1.0/batch_size</span></code>.</li> |
| <li><strong>param_idx2name</strong> (<em>dict from int to string, optional</em>) – A dictionary that maps int index to string name.</li> |
| <li><strong>clip_gradient</strong> (<em>float, optional</em>) – Clip the gradient by projecting onto the box <code class="docutils literal"><span class="pre">[-clip_gradient,</span> <span class="pre">clip_gradient]</span></code>.</li> |
| <li><strong>learning_rate</strong> (<em>float, optional</em>) – The initial learning rate.</li> |
| <li><strong>lr_scheduler</strong> (<em>LRScheduler, optional</em>) – The learning rate scheduler.</li> |
| <li><strong>wd</strong> (<em>float, optional</em>) – The weight decay (or L2 regularization) coefficient. Modifies objective |
| by adding a penalty for having large weights.</li> |
| <li><strong>sym</strong> (<em>Symbol, optional</em>) – The Symbol this optimizer is applying to.</li> |
| <li><strong>begin_num_update</strong> (<em>int, optional</em>) – The initial number of updates.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="staticmethod"> |
| <dt id="mxnet.optimizer.Optimizer.register"> |
| <em class="property">static </em><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a new optimizer.</p> |
| <p>Once an optimizer is registered, we can create an instance of this |
| optimizer with <cite>create_optimizer</cite> later.</p> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="nd">@mx.optimizer.Optimizer.register</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">MyOptimizer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">pass</span> |
| <span class="gp">>>> </span><span class="n">optim</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'MyOptimizer'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">optim</span><span class="p">))</span> |
| <span class="go"><class '__main__.MyOptimizer'></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.optimizer.Optimizer.create_optimizer"> |
| <em class="property">static </em><code class="descname">create_optimizer</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_optimizer" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Instantiates an optimizer with a given name and kwargs.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">We can use the alias <cite>create</cite> for <code class="docutils literal"><span class="pre">Optimizer.create_optimizer</span></code>.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of the optimizer. Should be the name |
| of a subclass of Optimizer. Case insensitive.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Parameters for the optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">An instantiated optimizer.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer">Optimizer</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">sgd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'sgd'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">sgd</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.SGD'></span> |
| <span class="gp">>>> </span><span class="n">adam</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mi">1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">adam</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.Adam'></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.create_state"> |
| <code class="descname">create_state</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.create_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates auxiliary state for a given weight.</p> |
| <p>Some optimizers require additional states, e.g. as momentum, in addition |
| to gradients in order to update weights. This function creates state |
| for a given weight which will be used in <cite>update</cite>. This function is |
| called only once for each weight.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>index</strong> (<em>int</em>) – An unique index to identify the weight.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The weight.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>state</strong> – |
| The state associated with the weight.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">any obj</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates the given parameter using the corresponding gradient and state.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>index</strong> (<em>int</em>) – The unique index of the parameter into the individual learning |
| rates and weight decays. Learning rates and weight decay |
| may be set via <cite>set_lr_mult()</cite> and <cite>set_wd_mult()</cite>, respectively.</li> |
| <li><strong>weight</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The parameter to be updated.</li> |
| <li><strong>grad</strong> (<a class="reference internal" href="ndarray.html#mxnet.ndarray.NDArray" title="mxnet.ndarray.NDArray"><em>NDArray</em></a>) – The gradient of the objective with respect to this parameter.</li> |
| <li><strong>state</strong> (<em>any obj</em>) – The state returned by <cite>create_state()</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_lr_scale"> |
| <code class="descname">set_lr_scale</code><span class="sig-paren">(</span><em>args_lrscale</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_lr_scale" title="Permalink to this definition">¶</a></dt> |
| <dd><p>[DEPRECATED] Sets lr scale. Use set_lr_mult instead.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_lr_mult"> |
| <code class="descname">set_lr_mult</code><span class="sig-paren">(</span><em>args_lr_mult</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_lr_mult" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets an individual learning rate multiplier for each parameter.</p> |
| <p>If you specify a learning rate multiplier for a parameter, then |
| the learning rate for the parameter will be set as the product of |
| the global learning rate <cite>self.lr</cite> and its multiplier.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The default learning rate multiplier of a <cite>Variable</cite> |
| can be set with <cite>lr_mult</cite> argument in the constructor.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args_lr_mult</strong> (<em>dict of str/int to float</em>) – <p>For each of its key-value entries, the learning rate multipler for the |
| parameter specified in the key will be set as the given value.</p> |
| <p>You can specify the parameter with either its name or its index. |
| If you use the name, you should pass <cite>sym</cite> in the constructor, |
| and the name you specified in the key of <cite>args_lr_mult</cite> should match |
| the name of the parameter in <cite>sym</cite>. If you use the index, it should |
| correspond to the index of the parameter used in the <cite>update</cite> method.</p> |
| <p>Specifying a parameter by its index is only supported for backward |
| compatibility, and we recommend to use the name instead.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Optimizer.set_wd_mult"> |
| <code class="descname">set_wd_mult</code><span class="sig-paren">(</span><em>args_wd_mult</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Optimizer.set_wd_mult" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets an individual weight decay multiplier for each parameter.</p> |
| <p>By default, if <cite>param_idx2name</cite> was provided in the |
| constructor, the weight decay multipler is set as 0 for all |
| parameters whose name don’t end with <code class="docutils literal"><span class="pre">_weight</span></code> or |
| <code class="docutils literal"><span class="pre">_gamma</span></code>.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">The default weight decay multiplier for a <cite>Variable</cite> |
| can be set with its <cite>wd_mult</cite> argument in the constructor.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>args_wd_mult</strong> (<em>dict of string/int to float</em>) – <p>For each of its key-value entries, the weight decay multipler for the |
| parameter specified in the key will be set as the given value.</p> |
| <p>You can specify the parameter with either its name or its index. |
| If you use the name, you should pass <cite>sym</cite> in the constructor, |
| and the name you specified in the key of <cite>args_lr_mult</cite> should match |
| the name of the parameter in <cite>sym</cite>. If you use the index, it should |
| correspond to the index of the parameter used in the <cite>update</cite> method.</p> |
| <p>Specifying a parameter by its index is only supported for backward |
| compatibility, and we recommend to use the name instead.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.register"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">register</code><span class="sig-paren">(</span><em>klass</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.register" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Registers a new optimizer.</p> |
| <p>Once an optimizer is registered, we can create an instance of this |
| optimizer with <cite>create_optimizer</cite> later.</p> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="nd">@mx.optimizer.Optimizer.register</span> |
| <span class="gp">... </span><span class="k">class</span> <span class="nc">MyOptimizer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span> |
| <span class="gp">... </span> <span class="k">pass</span> |
| <span class="gp">>>> </span><span class="n">optim</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'MyOptimizer'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">optim</span><span class="p">))</span> |
| <span class="go"><class '__main__.MyOptimizer'></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.SGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">SGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>multi_precision=False</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.SGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The SGD optimizer with momentum and weight decay.</p> |
| <p>The optimizer updates the weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">rescale_grad</span> <span class="o">*</span> <span class="n">clip</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">clip_gradient</span><span class="p">)</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="n">state</span> |
| </pre></div> |
| </div> |
| <p>For details of the update algorithm see <a class="reference internal" href="ndarray.html#mxnet.ndarray.sgd_update" title="mxnet.ndarray.sgd_update"><code class="xref py py-class docutils literal"><span class="pre">sgd_update</span></code></a> and |
| <a class="reference internal" href="ndarray.html#mxnet.ndarray.sgd_mom_update" title="mxnet.ndarray.sgd_mom_update"><code class="xref py py-class docutils literal"><span class="pre">sgd_mom_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>multi_precision</strong> (<em>bool, optional</em>) – <p>Flag to control the internal precision of the optimizer. |
| <code class="docutils literal"><span class="pre">False</span></code> results in using the same precision as the weights (default), |
| <code class="docutils literal"><span class="pre">True</span></code> makes internal 32-bit copy of the weights and applies gradients</p> |
| <blockquote> |
| <div>in 32-bit precision even if actual weights used in the model have lower precision. |
| Turning this on can improve convergence and accuracy when training with float16.</div></blockquote> |
| </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.DCASGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">DCASGD</code><span class="sig-paren">(</span><em>momentum=0.0</em>, <em>lamda=0.04</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.DCASGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The DCASGD optimizer.</p> |
| <p>This class implements the optimizer described in <em>Asynchronous Stochastic Gradient Descent |
| with Delay Compensation for Distributed Deep Learning</em>, |
| available at <a class="reference external" href="https://arxiv.org/abs/1609.08326">https://arxiv.org/abs/1609.08326</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>momentum</strong> (<em>float, optional</em>) – The momentum value.</li> |
| <li><strong>lamda</strong> (<em>float, optional</em>) – Scale DC value.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.NAG"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">NAG</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.NAG" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Nesterov accelerated SGD.</p> |
| <p>This optimizer updates each weight by:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">state</span> <span class="o">=</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span> <span class="o">+</span> <span class="n">grad</span> <span class="o">+</span> <span class="n">wd</span> <span class="o">*</span> <span class="n">weight</span> |
| <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">-</span> <span class="p">(</span><span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="n">grad</span> <span class="o">+</span> <span class="n">momentum</span> <span class="o">*</span> <span class="n">state</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <p>This optimizer accepts the same arguments as <a class="reference internal" href="optimization.html#mxnet.optimizer.SGD" title="mxnet.optimizer.SGD"><code class="xref py py-class docutils literal"><span class="pre">SGD</span></code></a>.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.SGLD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">SGLD</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.SGLD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Stochastic Gradient Riemannian Langevin Dynamics.</p> |
| <p>This class implements the optimizer described in the paper <em>Stochastic Gradient |
| Riemannian Langevin Dynamics on the Probability Simplex</em>, available at |
| <a class="reference external" href="https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf">https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf</a>.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.ccSGD"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">ccSGD</code><span class="sig-paren">(</span><em>*args</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.ccSGD" title="Permalink to this definition">¶</a></dt> |
| <dd><p>[DEPRECATED] Same as <cite>SGD</cite>. Left here for backward compatibility.</p> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Adam"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Adam</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Adam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Adam optimizer.</p> |
| <p>This class implements the optimizer described in <em>Adam: A Method for |
| Stochastic Optimization</em>, available at <a class="reference external" href="http://arxiv.org/abs/1412.6980">http://arxiv.org/abs/1412.6980</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <p>For details of the update algorithm, see <code class="xref py py-class docutils literal"><span class="pre">ndarray.adam_update</span></code>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.AdaGrad"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">AdaGrad</code><span class="sig-paren">(</span><em>eps=1e-07</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.AdaGrad" title="Permalink to this definition">¶</a></dt> |
| <dd><p>AdaGrad optimizer.</p> |
| <p>This class implements the AdaGrad optimizer described in <em>Adaptive Subgradient |
| Methods for Online Learning and Stochastic Optimization</em>, and available at |
| <a class="reference external" href="http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf">http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>eps</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.RMSProp"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">RMSProp</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>gamma1=0.9</em>, <em>gamma2=0.9</em>, <em>epsilon=1e-08</em>, <em>centered=False</em>, <em>clip_weights=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.RMSProp" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The RMSProp optimizer.</p> |
| <p>Two versions of RMSProp are implemented:</p> |
| <p>If <code class="docutils literal"><span class="pre">centered=False</span></code>, we follow |
| <a class="reference external" href="http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf</a> by |
| Tieleman & Hinton, 2012. |
| For details of the update algorithm see <a class="reference internal" href="ndarray.html#mxnet.ndarray.rmsprop_update" title="mxnet.ndarray.rmsprop_update"><code class="xref py py-class docutils literal"><span class="pre">rmsprop_update</span></code></a>.</p> |
| <p>If <code class="docutils literal"><span class="pre">centered=True</span></code>, we follow <a class="reference external" href="http://arxiv.org/pdf/1308.0850v5.pdf">http://arxiv.org/pdf/1308.0850v5.pdf</a> (38)-(45) |
| by Alex Graves, 2013. |
| For details of the update algorithm see <a class="reference internal" href="ndarray.html#mxnet.ndarray.rmspropalex_update" title="mxnet.ndarray.rmspropalex_update"><code class="xref py py-class docutils literal"><span class="pre">rmspropalex_update</span></code></a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>gamma1</strong> (<em>float, optional</em>) – A decay factor of moving average over past squared gradient.</li> |
| <li><strong>gamma2</strong> (<em>float, optional</em>) – A “momentum” factor. Only used if <cite>centered`=``True`</cite>.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>centered</strong> (<em>bool, optional</em>) – Flag to control which version of RMSProp to use. |
| <code class="docutils literal"><span class="pre">True</span></code> will use Graves’s version of <cite>RMSProp</cite>, |
| <code class="docutils literal"><span class="pre">False</span></code> will use Tieleman & Hinton’s version of <cite>RMSProp</cite>.</li> |
| <li><strong>clip_weights</strong> (<em>float, optional</em>) – Clips weights into range <code class="docutils literal"><span class="pre">[-clip_weights,</span> <span class="pre">clip_weights]</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.AdaDelta"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">AdaDelta</code><span class="sig-paren">(</span><em>rho=0.9</em>, <em>epsilon=1e-05</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.AdaDelta" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The AdaDelta optimizer.</p> |
| <p>This class implements AdaDelta, an optimizer described in <em>ADADELTA: An adaptive |
| learning rate method</em>, available at <a class="reference external" href="https://arxiv.org/abs/1212.5701">https://arxiv.org/abs/1212.5701</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>rho</strong> (<em>float</em>) – Decay rate for both squared gradients and delta.</li> |
| <li><strong>epsilon</strong> (<em>float</em>) – Small value to avoid division by 0.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Ftrl"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Ftrl</code><span class="sig-paren">(</span><em>lamda1=0.01</em>, <em>learning_rate=0.1</em>, <em>beta=1</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Ftrl" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Ftrl optimizer.</p> |
| <p>Referenced from <em>Ad Click Prediction: a View from the Trenches</em>, available at |
| <a class="reference external" href="http://dl.acm.org/citation.cfm?id=2488200">http://dl.acm.org/citation.cfm?id=2488200</a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>lamda1</strong> (<em>float, optional</em>) – L1 regularization coefficient.</li> |
| <li><strong>learning_rate</strong> (<em>float, optional</em>) – The initial learning rate.</li> |
| <li><strong>beta</strong> (<em>float, optional</em>) – Per-coordinate learning rate correlation parameter.</li> |
| <li><strong>eta</strong> – <div class="math"> |
| </div> |
| <p>eta_{t,i} = frac{learningrate}{beta+sqrt{sum_{s=1}^tg_{s,i}^t}}</p> |
| </li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Adamax"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Adamax</code><span class="sig-paren">(</span><em>learning_rate=0.002</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Adamax" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The AdaMax optimizer.</p> |
| <p>It is a variant of Adam based on the infinity norm |
| available at <a class="reference external" href="http://arxiv.org/abs/1412.6980">http://arxiv.org/abs/1412.6980</a> Section 7.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Nadam"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Nadam</code><span class="sig-paren">(</span><em>learning_rate=0.001</em>, <em>beta1=0.9</em>, <em>beta2=0.999</em>, <em>epsilon=1e-08</em>, <em>schedule_decay=0.004</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Nadam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Nesterov Adam optimizer.</p> |
| <p>Much like Adam is essentially RMSprop with momentum, |
| Nadam is Adam RMSprop with Nesterov momentum available |
| at <a class="reference external" href="http://cs229.stanford.edu/proj2015/054_report.pdf">http://cs229.stanford.edu/proj2015/054_report.pdf</a>.</p> |
| <p>This optimizer accepts the following parameters in addition to those accepted |
| by <a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><code class="xref py py-class docutils literal"><span class="pre">Optimizer</span></code></a>.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>beta1</strong> (<em>float, optional</em>) – Exponential decay rate for the first moment estimates.</li> |
| <li><strong>beta2</strong> (<em>float, optional</em>) – Exponential decay rate for the second moment estimates.</li> |
| <li><strong>epsilon</strong> (<em>float, optional</em>) – Small value to avoid division by 0.</li> |
| <li><strong>schedule_decay</strong> (<em>float, optional</em>) – Exponential decay rate for the momentum schedule</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Test"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Test</code><span class="sig-paren">(</span><em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Test" title="Permalink to this definition">¶</a></dt> |
| <dd><p>The Test optimizer</p> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Test.create_state"> |
| <code class="descname">create_state</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Test.create_state" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Creates a state to duplicate weight.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Test.update"> |
| <code class="descname">update</code><span class="sig-paren">(</span><em>index</em>, <em>weight</em>, <em>grad</em>, <em>state</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Test.update" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Performs w += rescale_grad * grad.</p> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.create"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">create</code><span class="sig-paren">(</span><em>name</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Instantiates an optimizer with a given name and kwargs.</p> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">We can use the alias <cite>create</cite> for <code class="docutils literal"><span class="pre">Optimizer.create_optimizer</span></code>.</p> |
| </div> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>name</strong> (<em>str</em>) – Name of the optimizer. Should be the name |
| of a subclass of Optimizer. Case insensitive.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Parameters for the optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">An instantiated optimizer.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer">Optimizer</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Examples</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">sgd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">Optimizer</span><span class="o">.</span><span class="n">create_optimizer</span><span class="p">(</span><span class="s1">'sgd'</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">sgd</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.SGD'></span> |
| <span class="gp">>>> </span><span class="n">adam</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mi">1</span><span class="p">)</span> |
| <span class="gp">>>> </span><span class="nb">type</span><span class="p">(</span><span class="n">adam</span><span class="p">)</span> |
| <span class="go"><class 'mxnet.optimizer.Adam'></span> |
| </pre></div> |
| </div> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.optimizer.Updater"> |
| <em class="property">class </em><code class="descclassname">mxnet.optimizer.</code><code class="descname">Updater</code><span class="sig-paren">(</span><em>optimizer</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updater for kvstore.</p> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.__call__"> |
| <code class="descname">__call__</code><span class="sig-paren">(</span><em>index</em>, <em>grad</em>, <em>weight</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater.__call__" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Updates weight given gradient and index.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.set_states"> |
| <code class="descname">set_states</code><span class="sig-paren">(</span><em>states</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater.set_states" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Sets updater states.</p> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.optimizer.Updater.get_states"> |
| <code class="descname">get_states</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.Updater.get_states" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Gets updater states.</p> |
| </dd></dl> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.optimizer.get_updater"> |
| <code class="descclassname">mxnet.optimizer.</code><code class="descname">get_updater</code><span class="sig-paren">(</span><em>optimizer</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.optimizer.get_updater" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Returns a closure of the updater needed for kvstore.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>optimizer</strong> (<a class="reference internal" href="optimization.html#mxnet.optimizer.Optimizer" title="mxnet.optimizer.Optimizer"><em>Optimizer</em></a>) – The optimizer.</td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><strong>updater</strong> – |
| The closure of the updater.</td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body">function</td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <script>auto_index("optimizer-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="model-api-reference"> |
| <span id="model-api-reference"></span><h2>Model API Reference<a class="headerlink" href="#model-api-reference" title="Permalink to this headline">¶</a></h2> |
| <blockquote> |
| <div><span class="target" id="module-mxnet.model"></span><p>MXNet model module</p> |
| <dl class="attribute"> |
| <dt id="mxnet.model.BatchEndParam"> |
| <code class="descclassname">mxnet.model.</code><code class="descname">BatchEndParam</code><a class="headerlink" href="#mxnet.model.BatchEndParam" title="Permalink to this definition">¶</a></dt> |
| <dd><p>alias of <code class="xref py py-class docutils literal"><span class="pre">BatchEndParams</span></code></p> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.model.save_checkpoint"> |
| <code class="descclassname">mxnet.model.</code><code class="descname">save_checkpoint</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch</em>, <em>symbol</em>, <em>arg_params</em>, <em>aux_params</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.save_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Checkpoint the model data into file.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – The epoch number of the model.</li> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The input Symbol.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="function"> |
| <dt id="mxnet.model.load_checkpoint"> |
| <code class="descclassname">mxnet.model.</code><code class="descname">load_checkpoint</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.load_checkpoint" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Load model checkpoint from file.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – Epoch number of model we would like to load.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first last"><ul class="simple"> |
| <li><strong>symbol</strong> (<em>Symbol</em>) – |
| The symbol configuration of computation network.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray</em>) – |
| Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray</em>) – |
| Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| </ul> |
| </p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li>Symbol will be loaded from <code class="docutils literal"><span class="pre">prefix-symbol.json</span></code>.</li> |
| <li>Parameters will be loaded from <code class="docutils literal"><span class="pre">prefix-epoch.params</span></code>.</li> |
| </ul> |
| </dd></dl> |
| <dl class="class"> |
| <dt id="mxnet.model.FeedForward"> |
| <em class="property">class </em><code class="descclassname">mxnet.model.</code><code class="descname">FeedForward</code><span class="sig-paren">(</span><em>symbol</em>, <em>ctx=None</em>, <em>num_epoch=None</em>, <em>epoch_size=None</em>, <em>optimizer='sgd'</em>, <em>initializer=<mxnet.initializer.Uniform object></em>, <em>numpy_batch_size=128</em>, <em>arg_params=None</em>, <em>aux_params=None</em>, <em>allow_extra_params=False</em>, <em>begin_epoch=0</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Model class of MXNet for training and predicting feedforward nets. |
| This class is designed for a single-data single output supervised network.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The symbol configuration of computation network.</li> |
| <li><strong>ctx</strong> (<em>Context or list of Context, optional</em>) – The device context of training and prediction. |
| To use multi GPU training, pass in a list of gpu contexts.</li> |
| <li><strong>num_epoch</strong> (<em>int, optional</em>) – Training parameter, number of training epochs(epochs).</li> |
| <li><strong>epoch_size</strong> (<em>int, optional</em>) – Number of batches in a epoch. In default, it is set to |
| <code class="docutils literal"><span class="pre">ceil(num_train_examples</span> <span class="pre">/</span> <span class="pre">batch_size)</span></code>.</li> |
| <li><strong>optimizer</strong> (<em>str or Optimizer, optional</em>) – Training parameter, name or optimizer object for training.</li> |
| <li><strong>initializer</strong> (<em>initializer function, optional</em>) – Training parameter, the initialization scheme used.</li> |
| <li><strong>numpy_batch_size</strong> (<em>int, optional</em>) – The batch size of training data. |
| Only needed when input array is numpy.</li> |
| <li><strong>arg_params</strong> (<em>dict of str to NDArray, optional</em>) – Model parameter, dict of name to NDArray of net’s weights.</li> |
| <li><strong>aux_params</strong> (<em>dict of str to NDArray, optional</em>) – Model parameter, dict of name to NDArray of net’s auxiliary states.</li> |
| <li><strong>allow_extra_params</strong> (<em>boolean, optional</em>) – Whether allow extra parameters that are not needed by symbol |
| to be passed by aux_params and <code class="docutils literal"><span class="pre">arg_params</span></code>. |
| If this is True, no error will be thrown when <code class="docutils literal"><span class="pre">aux_params</span></code> and <code class="docutils literal"><span class="pre">arg_params</span></code> |
| contain more parameters than needed.</li> |
| <li><strong>begin_epoch</strong> (<em>int, optional</em>) – The begining training epoch.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – The additional keyword arguments passed to optimizer.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.predict"> |
| <code class="descname">predict</code><span class="sig-paren">(</span><em>X</em>, <em>num_batch=None</em>, <em>return_data=False</em>, <em>reset=True</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward.predict" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Run the prediction, always only use one device.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>X</strong> (<em>mxnet.DataIter</em>) – </li> |
| <li><strong>num_batch</strong> (<em>int or None</em>) – The number of batch to run. Go though all batches if <code class="docutils literal"><span class="pre">None</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>y</strong> – |
| The predicted value of the output.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs.</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.score"> |
| <code class="descname">score</code><span class="sig-paren">(</span><em>X</em>, <em>eval_metric='acc'</em>, <em>num_batch=None</em>, <em>batch_end_callback=None</em>, <em>reset=True</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward.score" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Run the model given an input and calculate the score |
| as assessed by an evaluation metric.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>X</strong> (<em>mxnet.DataIter</em>) – </li> |
| <li><strong>eval_metric</strong> (<em>metric.metric</em>) – The metric for calculating score.</li> |
| <li><strong>num_batch</strong> (<em>int or None</em>) – The number of batches to run. Go though all batches if <code class="docutils literal"><span class="pre">None</span></code>.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>s</strong> – |
| The final score.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">float</p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.fit"> |
| <code class="descname">fit</code><span class="sig-paren">(</span><em>X</em>, <em>y=None</em>, <em>eval_data=None</em>, <em>eval_metric='acc'</em>, <em>epoch_end_callback=None</em>, <em>batch_end_callback=None</em>, <em>kvstore='local'</em>, <em>logger=None</em>, <em>work_load_list=None</em>, <em>monitor=None</em>, <em>eval_end_callback=<mxnet.callback.LogValidationMetricsCallback object></em>, <em>eval_batch_end_callback=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward.fit" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Fit the model.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>X</strong> (<em>DataIter, or numpy.ndarray/NDArray</em>) – Training data. If <cite>X</cite> is a <cite>DataIter</cite>, the name or (if name not available) |
| the position of its outputs should match the corresponding variable |
| names defined in the symbolic graph.</li> |
| <li><strong>y</strong> (<em>numpy.ndarray/NDArray, optional</em>) – Training set label. |
| If X is <code class="docutils literal"><span class="pre">numpy.ndarray</span></code> or <cite>NDArray</cite>, <cite>y</cite> is required to be set. |
| While y can be 1D or 2D (with 2nd dimension as 1), its first dimension must be |
| the same as <cite>X</cite>, i.e. the number of data points and labels should be equal.</li> |
| <li><strong>eval_data</strong> (<em>DataIter or numpy.ndarray/list/NDArray pair</em>) – If eval_data is numpy.ndarray/list/NDArray pair, |
| it should be <code class="docutils literal"><span class="pre">(valid_data,</span> <span class="pre">valid_label)</span></code>.</li> |
| <li><strong>eval_metric</strong> (<em>metric.EvalMetric or str or callable</em>) – The evaluation metric. This could be the name of evaluation metric |
| or a custom evaluation function that returns statistics |
| based on a minibatch.</li> |
| <li><strong>epoch_end_callback</strong> (<em>callable(epoch, symbol, arg_params, aux_states)</em>) – A callback that is invoked at end of each epoch. |
| This can be used to checkpoint model each epoch.</li> |
| <li><strong>batch_end_callback</strong> (<em>callable(epoch)</em>) – A callback that is invoked at end of each batch for purposes of printing.</li> |
| <li><strong>kvstore</strong> (<em>KVStore or str, optional</em>) – The KVStore or a string kvstore type: ‘local’, ‘dist_sync’, ‘dist_async’ |
| In default uses ‘local’, often no need to change for single machiine.</li> |
| <li><strong>logger</strong> (<em>logging logger, optional</em>) – When not specified, default logger will be used.</li> |
| <li><strong>work_load_list</strong> (<em>float or int, optional</em>) – The list of work load for different devices, |
| in the same order as <cite>ctx</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <div class="admonition note"> |
| <p class="first admonition-title">Note</p> |
| <p class="last">KVStore behavior |
| - ‘local’, multi-devices on a single machine, will automatically choose best type. |
| - ‘dist_sync’, multiple machines communicating via BSP. |
| - ‘dist_async’, multiple machines with asynchronous communication.</p> |
| </div> |
| </dd></dl> |
| <dl class="method"> |
| <dt id="mxnet.model.FeedForward.save"> |
| <code class="descname">save</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch=None</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward.save" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Checkpoint the model checkpoint into file. |
| You can also use <cite>pickle</cite> to do the job if you only work on Python. |
| The advantage of <cite>load</cite> and <cite>save</cite> (as compared to <cite>pickle</cite>) is that |
| the resulting file can be loaded from other MXNet language bindings. |
| One can also directly <cite>load</cite>/<cite>save</cite> from/to cloud storage(S3, HDFS)</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.model.FeedForward.load"> |
| <em class="property">static </em><code class="descname">load</code><span class="sig-paren">(</span><em>prefix</em>, <em>epoch</em>, <em>ctx=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward.load" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Load model checkpoint from file.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple"> |
| <li><strong>prefix</strong> (<em>str</em>) – Prefix of model name.</li> |
| <li><strong>epoch</strong> (<em>int</em>) – epoch number of model we would like to load.</li> |
| <li><strong>ctx</strong> (<em>Context or list of Context, optional</em>) – The device context of training and prediction.</li> |
| <li><strong>kwargs</strong> (<em>dict</em>) – Other parameters for model, including <cite>num_epoch</cite>, optimizer and <cite>numpy_batch_size</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| <tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first"><strong>model</strong> – |
| The loaded model that can be used for prediction.</p> |
| </td> |
| </tr> |
| <tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last"><a class="reference internal" href="#mxnet.model.FeedForward" title="mxnet.model.FeedForward">FeedForward</a></p> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| <p class="rubric">Notes</p> |
| <ul class="simple"> |
| <li><code class="docutils literal"><span class="pre">prefix-symbol.json</span></code> will be saved for symbol.</li> |
| <li><code class="docutils literal"><span class="pre">prefix-epoch.params</span></code> will be saved for parameters.</li> |
| </ul> |
| </dd></dl> |
| <dl class="staticmethod"> |
| <dt id="mxnet.model.FeedForward.create"> |
| <em class="property">static </em><code class="descname">create</code><span class="sig-paren">(</span><em>symbol</em>, <em>X</em>, <em>y=None</em>, <em>ctx=None</em>, <em>num_epoch=None</em>, <em>epoch_size=None</em>, <em>optimizer='sgd'</em>, <em>initializer=<mxnet.initializer.Uniform object></em>, <em>eval_data=None</em>, <em>eval_metric='acc'</em>, <em>epoch_end_callback=None</em>, <em>batch_end_callback=None</em>, <em>kvstore='local'</em>, <em>logger=None</em>, <em>work_load_list=None</em>, <em>eval_end_callback=<mxnet.callback.LogValidationMetricsCallback object></em>, <em>eval_batch_end_callback=None</em>, <em>**kwargs</em><span class="sig-paren">)</span><a class="headerlink" href="#mxnet.model.FeedForward.create" title="Permalink to this definition">¶</a></dt> |
| <dd><p>Functional style to create a model. |
| This function is more consistent with functional |
| languages such as R, where mutation is not allowed.</p> |
| <table class="docutils field-list" frame="void" rules="none"> |
| <col class="field-name"/> |
| <col class="field-body"/> |
| <tbody valign="top"> |
| <tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple"> |
| <li><strong>symbol</strong> (<a class="reference internal" href="symbol.html#mxnet.symbol.Symbol" title="mxnet.symbol.Symbol"><em>Symbol</em></a>) – The symbol configuration of a computation network.</li> |
| <li><strong>X</strong> (<a class="reference internal" href="io.html#mxnet.io.DataIter" title="mxnet.io.DataIter"><em>DataIter</em></a>) – Training data.</li> |
| <li><strong>y</strong> (<em>numpy.ndarray, optional</em>) – If <cite>X</cite> is a <code class="docutils literal"><span class="pre">numpy.ndarray</span></code>, <cite>y</cite> must be set.</li> |
| <li><strong>ctx</strong> (<em>Context or list of Context, optional</em>) – The device context of training and prediction. |
| To use multi-GPU training, pass in a list of GPU contexts.</li> |
| <li><strong>num_epoch</strong> (<em>int, optional</em>) – The number of training epochs(epochs).</li> |
| <li><strong>epoch_size</strong> (<em>int, optional</em>) – Number of batches in a epoch. In default, it is set to |
| <code class="docutils literal"><span class="pre">ceil(num_train_examples</span> <span class="pre">/</span> <span class="pre">batch_size)</span></code>.</li> |
| <li><strong>optimizer</strong> (<em>str or Optimizer, optional</em>) – The name of the chosen optimizer, or an optimizer object, used for training.</li> |
| <li><strong>initializier</strong> (<em>initializer function, optional</em>) – The initialization scheme used.</li> |
| <li><strong>eval_data</strong> (<em>DataIter or numpy.ndarray pair</em>) – If <cite>eval_set</cite> is <code class="docutils literal"><span class="pre">numpy.ndarray</span></code> pair, it should |
| be (<cite>valid_data</cite>, <cite>valid_label</cite>).</li> |
| <li><strong>eval_metric</strong> (<em>metric.EvalMetric or str or callable</em>) – The evaluation metric. Can be the name of an evaluation metric |
| or a custom evaluation function that returns statistics |
| based on a minibatch.</li> |
| <li><strong>epoch_end_callback</strong> (<em>callable(epoch, symbol, arg_params, aux_states)</em>) – A callback that is invoked at end of each epoch. |
| This can be used to checkpoint model each epoch.</li> |
| <li><strong>batch_end_callback</strong> (<em>callable(epoch)</em>) – A callback that is invoked at end of each batch for print purposes.</li> |
| <li><strong>kvstore</strong> (<em>KVStore or str, optional</em>) – The KVStore or a string kvstore type: ‘local’, ‘dist_sync’, ‘dis_async’. |
| Defaults to ‘local’, often no need to change for single machiine.</li> |
| <li><strong>logger</strong> (<em>logging logger, optional</em>) – When not specified, default logger will be used.</li> |
| <li><strong>work_load_list</strong> (<em>list of float or int, optional</em>) – The list of work load for different devices, |
| in the same order as <cite>ctx</cite>.</li> |
| </ul> |
| </td> |
| </tr> |
| </tbody> |
| </table> |
| </dd></dl> |
| </dd></dl> |
| <script>auto_index("model-api-reference");</script></div></blockquote> |
| </div> |
| <div class="section" id="next-steps"> |
| <span id="next-steps"></span><h2>Next Steps<a class="headerlink" href="#next-steps" title="Permalink to this headline">¶</a></h2> |
| <ul class="simple"> |
| <li>See <a class="reference internal" href="symbol.html"><em>Symbolic API</em></a> for operations on NDArrays that assemble neural networks from layers.</li> |
| <li>See <a class="reference internal" href="io.html"><em>IO Data Loading API</em></a> for parsing and loading data.</li> |
| <li>See <a class="reference internal" href="ndarray.html"><em>NDArray API</em></a> for vector/matrix/tensor operations.</li> |
| <li>See <a class="reference internal" href="kvstore.html"><em>KVStore API</em></a> for multi-GPU and multi-host distributed training.</li> |
| </ul> |
| </div> |
| </div> |
| <div class="container"> |
| <div class="footer"> |
| <p> © 2015-2017 DMLC. All rights reserved. </p> |
| </div> |
| </div> |
| </div> |
| <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <h3><a href="../../index.html">Table Of Contents</a></h3> |
| <ul> |
| <li><a class="reference internal" href="#">Model API</a><ul> |
| <li><a class="reference internal" href="#train-the-model">Train the Model</a></li> |
| <li><a class="reference internal" href="#save-the-model">Save the Model</a></li> |
| <li><a class="reference internal" href="#periodic-checkpointing">Periodic Checkpointing</a></li> |
| <li><a class="reference internal" href="#use-multiple-devices">Use Multiple Devices</a></li> |
| <li><a class="reference internal" href="#initializer-api-reference">Initializer API Reference</a></li> |
| <li><a class="reference internal" href="#evaluation-metric-api-reference">Evaluation Metric API Reference</a></li> |
| <li><a class="reference internal" href="#optimizer-api-reference">Optimizer API Reference</a></li> |
| <li><a class="reference internal" href="#model-api-reference">Model API Reference</a></li> |
| <li><a class="reference internal" href="#next-steps">Next Steps</a></li> |
| </ul> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> <!-- pagename != index --> |
| <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |
| <script src="../../_static/js/sidebar.js" type="text/javascript"></script> |
| <script src="../../_static/js/search.js" type="text/javascript"></script> |
| <script src="../../_static/js/navbar.js" type="text/javascript"></script> |
| <script src="../../_static/js/clipboard.min.js" type="text/javascript"></script> |
| <script src="../../_static/js/copycode.js" type="text/javascript"></script> |
| <script type="text/javascript"> |
| $('body').ready(function () { |
| $('body').css('visibility', 'visible'); |
| }); |
| </script> |
| </div></body> |
| </html> |