blob: 31badb4b1fec21de5ec7dec55384bea997a352bb [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>mxnet.model — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../index.html" rel="up" title="Module code">
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../install/index.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../gluon/index.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">Tutorials</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="../../faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.0.0/example">Examples</a></li>
<li><a class="main-nav-link" href="../../model_zoo/index.html">Model Zoo</a></li>
</ul>
</span>
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="../../community/index.html">Community</a></li>
<li><a class="main-nav-link" href="../../community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="../../community/powered_by.html">Powered By</a></li>
</ul>
</span>
<a class="main-nav-link" href="http://discuss.mxnet.io">Discuss</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(1.0.0)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/>1.0.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a class="main-nav-link" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="../../install/index.html">Install</a></li>
<li><a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="../../community/index.html" tabindex="-1">Community</a></li>
<li><a href="../../community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="../../community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="../../tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="../../faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="../../architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.0.0/example" tabindex="-1">Examples</a></li>
<li><a href="../../model_zoo/index.html" tabindex="-1">Model Zoo</a></li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(1.0.0)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/>1.0.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a tabindex="-1" href=https://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../community/index.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<h1>Source code for mxnet.model</h1><div class="highlight"><pre>
<span></span><span class="c1"># Licensed to the Apache Software Foundation (ASF) under one</span>
<span class="c1"># or more contributor license agreements. See the NOTICE file</span>
<span class="c1"># distributed with this work for additional information</span>
<span class="c1"># regarding copyright ownership. The ASF licenses this file</span>
<span class="c1"># to you under the Apache License, Version 2.0 (the</span>
<span class="c1"># "License"); you may not use this file except in compliance</span>
<span class="c1"># with the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing,</span>
<span class="c1"># software distributed under the License is distributed on an</span>
<span class="c1"># "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY</span>
<span class="c1"># KIND, either express or implied. See the License for the</span>
<span class="c1"># specific language governing permissions and limitations</span>
<span class="c1"># under the License.</span>
<span class="c1"># pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, too-many-lines</span>
<span class="c1"># pylint: disable=too-many-branches, too-many-statements</span>
<span class="sd">"""MXNet model module"""</span>
<span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">absolute_import</span><span class="p">,</span> <span class="n">print_function</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">logging</span>
<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">namedtuple</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">io</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">nd</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">symbol</span> <span class="k">as</span> <span class="n">sym</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">optimizer</span> <span class="k">as</span> <span class="n">opt</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">metric</span>
<span class="kn">from</span> <span class="nn">.</span> <span class="kn">import</span> <span class="n">kvstore</span> <span class="k">as</span> <span class="n">kvs</span>
<span class="kn">from</span> <span class="nn">.context</span> <span class="kn">import</span> <span class="n">Context</span><span class="p">,</span> <span class="n">cpu</span>
<span class="kn">from</span> <span class="nn">.initializer</span> <span class="kn">import</span> <span class="n">Uniform</span>
<span class="kn">from</span> <span class="nn">.optimizer</span> <span class="kn">import</span> <span class="n">get_updater</span>
<span class="kn">from</span> <span class="nn">.executor_manager</span> <span class="kn">import</span> <span class="n">DataParallelExecutorManager</span><span class="p">,</span> <span class="n">_check_arguments</span><span class="p">,</span> <span class="n">_load_data</span>
<span class="kn">from</span> <span class="nn">.io</span> <span class="kn">import</span> <span class="n">DataDesc</span>
<span class="kn">from</span> <span class="nn">.base</span> <span class="kn">import</span> <span class="n">mx_real_t</span>
<span class="n">BASE_ESTIMATOR</span> <span class="o">=</span> <span class="nb">object</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">sklearn.base</span> <span class="kn">import</span> <span class="n">BaseEstimator</span>
<span class="n">BASE_ESTIMATOR</span> <span class="o">=</span> <span class="n">BaseEstimator</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="n">SKLEARN_INSTALLED</span> <span class="o">=</span> <span class="bp">False</span>
<span class="c1"># Parameter to pass to batch_end_callback</span>
<span class="n">BatchEndParam</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span><span class="s1">'BatchEndParams'</span><span class="p">,</span>
<span class="p">[</span><span class="s1">'epoch'</span><span class="p">,</span>
<span class="s1">'nbatch'</span><span class="p">,</span>
<span class="s1">'eval_metric'</span><span class="p">,</span>
<span class="s1">'locals'</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">_create_kvstore</span><span class="p">(</span><span class="n">kvstore</span><span class="p">,</span> <span class="n">num_device</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">):</span>
<span class="sd">"""Create kvstore</span>
<span class="sd"> This function select and create a proper kvstore if given the kvstore type.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> kvstore : KVStore or str</span>
<span class="sd"> The kvstore.</span>
<span class="sd"> num_device : int</span>
<span class="sd"> The number of devices</span>
<span class="sd"> arg_params : dict of str to `NDArray`.</span>
<span class="sd"> Model parameter, dict of name to `NDArray` of net's weights.</span>
<span class="sd"> """</span>
<span class="n">update_on_kvstore</span> <span class="o">=</span> <span class="bp">True</span>
<span class="k">if</span> <span class="n">kvstore</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">kv</span> <span class="o">=</span> <span class="bp">None</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">kvstore</span><span class="p">,</span> <span class="n">kvs</span><span class="o">.</span><span class="n">KVStore</span><span class="p">):</span>
<span class="n">kv</span> <span class="o">=</span> <span class="n">kvstore</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">kvstore</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="c1"># create kvstore using the string type</span>
<span class="k">if</span> <span class="n">num_device</span> <span class="ow">is</span> <span class="mi">1</span> <span class="ow">and</span> <span class="s1">'dist'</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">kvstore</span><span class="p">:</span>
<span class="c1"># no need to use kv for single device and single machine</span>
<span class="n">kv</span> <span class="o">=</span> <span class="bp">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv</span> <span class="o">=</span> <span class="n">kvs</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">kvstore</span><span class="p">)</span>
<span class="k">if</span> <span class="n">kvstore</span> <span class="o">==</span> <span class="s1">'local'</span><span class="p">:</span>
<span class="c1"># automatically select a proper local</span>
<span class="n">max_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span>
<span class="n">arg_params</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
<span class="k">if</span> <span class="n">max_size</span> <span class="o">></span> <span class="mi">1024</span> <span class="o">*</span> <span class="mi">1024</span> <span class="o">*</span> <span class="mi">16</span><span class="p">:</span>
<span class="n">update_on_kvstore</span> <span class="o">=</span> <span class="bp">False</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">'kvstore must be KVStore, str or None'</span><span class="p">)</span>
<span class="k">if</span> <span class="n">kv</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">update_on_kvstore</span> <span class="o">=</span> <span class="bp">False</span>
<span class="k">return</span> <span class="p">(</span><span class="n">kv</span><span class="p">,</span> <span class="n">update_on_kvstore</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_initialize_kvstore</span><span class="p">(</span><span class="n">kvstore</span><span class="p">,</span> <span class="n">param_arrays</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">param_names</span><span class="p">,</span> <span class="n">update_on_kvstore</span><span class="p">):</span>
<span class="sd">"""Initialize kvstore"""</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">param_on_devs</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">param_arrays</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="n">param_names</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="k">if</span> <span class="n">update_on_kvstore</span><span class="p">:</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">pull</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">param_on_devs</span><span class="p">,</span> <span class="n">priority</span><span class="o">=-</span><span class="n">idx</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_update_params_on_kvstore_nccl</span><span class="p">(</span><span class="n">param_arrays</span><span class="p">,</span> <span class="n">grad_arrays</span><span class="p">,</span> <span class="n">kvstore</span><span class="p">,</span> <span class="n">param_names</span><span class="p">):</span>
<span class="sd">"""Perform update of param_arrays from grad_arrays on NCCL kvstore."""</span>
<span class="n">valid_indices</span> <span class="o">=</span> <span class="p">[</span><span class="n">index</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">grad_list</span> <span class="ow">in</span>
<span class="nb">enumerate</span><span class="p">(</span><span class="n">grad_arrays</span><span class="p">)</span> <span class="k">if</span> <span class="n">grad_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">]</span>
<span class="n">valid_grad_arrays</span> <span class="o">=</span> <span class="p">[</span><span class="n">grad_arrays</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">valid_indices</span><span class="p">]</span>
<span class="n">valid_param_arrays</span> <span class="o">=</span> <span class="p">[</span><span class="n">param_arrays</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">valid_indices</span><span class="p">]</span>
<span class="n">valid_param_names</span> <span class="o">=</span> <span class="p">[</span><span class="n">param_names</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">valid_indices</span><span class="p">]</span>
<span class="n">size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">valid_grad_arrays</span><span class="p">)</span>
<span class="n">start</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Use aggregation by default only with NCCL</span>
<span class="n">default_batch</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">batch</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="s1">'MXNET_UPDATE_AGGREGATION_SIZE'</span><span class="p">,</span> <span class="n">default_batch</span><span class="p">))</span>
<span class="k">while</span> <span class="n">start</span> <span class="o"><</span> <span class="n">size</span><span class="p">:</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">start</span> <span class="o">+</span> <span class="n">batch</span> <span class="k">if</span> <span class="n">start</span> <span class="o">+</span> <span class="n">batch</span> <span class="o"><</span> <span class="n">size</span> <span class="k">else</span> <span class="n">size</span>
<span class="c1"># push gradient, priority is negative index</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">valid_param_names</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">],</span> <span class="n">valid_grad_arrays</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">],</span> <span class="n">priority</span><span class="o">=-</span><span class="n">start</span><span class="p">)</span>
<span class="c1"># pull back the weights</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">pull</span><span class="p">(</span><span class="n">valid_param_names</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">],</span> <span class="n">valid_param_arrays</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">],</span> <span class="n">priority</span><span class="o">=-</span><span class="n">start</span><span class="p">)</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">end</span>
<span class="k">def</span> <span class="nf">_update_params_on_kvstore</span><span class="p">(</span><span class="n">param_arrays</span><span class="p">,</span> <span class="n">grad_arrays</span><span class="p">,</span> <span class="n">kvstore</span><span class="p">,</span> <span class="n">param_names</span><span class="p">):</span>
<span class="sd">"""Perform update of param_arrays from grad_arrays on kvstore."""</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">pair</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">param_arrays</span><span class="p">,</span> <span class="n">grad_arrays</span><span class="p">)):</span>
<span class="n">arg_list</span><span class="p">,</span> <span class="n">grad_list</span> <span class="o">=</span> <span class="n">pair</span>
<span class="k">if</span> <span class="n">grad_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">name</span> <span class="o">=</span> <span class="n">param_names</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="c1"># push gradient, priority is negative index</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">grad_list</span><span class="p">,</span> <span class="n">priority</span><span class="o">=-</span><span class="n">index</span><span class="p">)</span>
<span class="c1"># pull back the weights</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">pull</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">arg_list</span><span class="p">,</span> <span class="n">priority</span><span class="o">=-</span><span class="n">index</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_update_params</span><span class="p">(</span><span class="n">param_arrays</span><span class="p">,</span> <span class="n">grad_arrays</span><span class="p">,</span> <span class="n">updater</span><span class="p">,</span> <span class="n">num_device</span><span class="p">,</span>
<span class="n">kvstore</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">param_names</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="sd">"""Perform update of param_arrays from grad_arrays not on kvstore."""</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">pair</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">param_arrays</span><span class="p">,</span> <span class="n">grad_arrays</span><span class="p">)):</span>
<span class="n">arg_list</span><span class="p">,</span> <span class="n">grad_list</span> <span class="o">=</span> <span class="n">pair</span>
<span class="k">if</span> <span class="n">grad_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">index</span> <span class="o">=</span> <span class="n">i</span>
<span class="k">if</span> <span class="n">kvstore</span><span class="p">:</span>
<span class="n">name</span> <span class="o">=</span> <span class="n">param_names</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="c1"># push gradient, priority is negative index</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">grad_list</span><span class="p">,</span> <span class="n">priority</span><span class="o">=-</span><span class="n">index</span><span class="p">)</span>
<span class="c1"># pull back the sum gradients, to the same locations.</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">pull</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">grad_list</span><span class="p">,</span> <span class="n">priority</span><span class="o">=-</span><span class="n">index</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">arg_list</span><span class="p">,</span> <span class="n">grad_list</span><span class="p">)):</span>
<span class="c1"># faked an index here, to make optimizer create diff</span>
<span class="c1"># state for the same index but on diff devs, TODO(mli)</span>
<span class="c1"># use a better solution later</span>
<span class="n">w</span><span class="p">,</span> <span class="n">g</span> <span class="o">=</span> <span class="n">p</span>
<span class="n">updater</span><span class="p">(</span><span class="n">index</span><span class="o">*</span><span class="n">num_device</span><span class="o">+</span><span class="n">k</span><span class="p">,</span> <span class="n">g</span><span class="p">,</span> <span class="n">w</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_multiple_callbacks</span><span class="p">(</span><span class="n">callbacks</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Sends args and kwargs to any configured callbacks.</span>
<span class="sd"> This handles the cases where the 'callbacks' variable</span>
<span class="sd"> is ``None``, a single function, or a list.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">callbacks</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="k">for</span> <span class="n">cb</span> <span class="ow">in</span> <span class="n">callbacks</span><span class="p">:</span>
<span class="n">cb</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">return</span>
<span class="k">if</span> <span class="n">callbacks</span><span class="p">:</span>
<span class="n">callbacks</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_train_multi_device</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">arg_names</span><span class="p">,</span> <span class="n">param_names</span><span class="p">,</span> <span class="n">aux_names</span><span class="p">,</span>
<span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">,</span>
<span class="n">begin_epoch</span><span class="p">,</span> <span class="n">end_epoch</span><span class="p">,</span> <span class="n">epoch_size</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span>
<span class="n">kvstore</span><span class="p">,</span> <span class="n">update_on_kvstore</span><span class="p">,</span>
<span class="n">train_data</span><span class="p">,</span> <span class="n">eval_data</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">eval_metric</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">epoch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">work_load_list</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">monitor</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">eval_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">eval_batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">sym_gen</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="sd">"""Internal training function on multiple devices.</span>
<span class="sd"> This function will also work for single device as well.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> symbol : Symbol</span>
<span class="sd"> The network configuration.</span>
<span class="sd"> ctx : list of Context</span>
<span class="sd"> The training devices.</span>
<span class="sd"> arg_names: list of str</span>
<span class="sd"> Name of all arguments of the network.</span>
<span class="sd"> param_names: list of str</span>
<span class="sd"> Name of all trainable parameters of the network.</span>
<span class="sd"> aux_names: list of str</span>
<span class="sd"> Name of all auxiliary states of the network.</span>
<span class="sd"> arg_params : dict of str to NDArray</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's weights.</span>
<span class="sd"> aux_params : dict of str to NDArray</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's auxiliary states.</span>
<span class="sd"> begin_epoch : int</span>
<span class="sd"> The begining training epoch.</span>
<span class="sd"> end_epoch : int</span>
<span class="sd"> The end training epoch.</span>
<span class="sd"> epoch_size : int, optional</span>
<span class="sd"> Number of batches in a epoch. In default, it is set to</span>
<span class="sd"> ``ceil(num_train_examples / batch_size)``.</span>
<span class="sd"> optimizer : Optimizer</span>
<span class="sd"> The optimization algorithm</span>
<span class="sd"> train_data : DataIter</span>
<span class="sd"> Training data iterator.</span>
<span class="sd"> eval_data : DataIter</span>
<span class="sd"> Validation data iterator.</span>
<span class="sd"> eval_metric : EvalMetric</span>
<span class="sd"> An evaluation function or a list of evaluation functions.</span>
<span class="sd"> epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)</span>
<span class="sd"> A callback that is invoked at end of each epoch.</span>
<span class="sd"> This can be used to checkpoint model each epoch.</span>
<span class="sd"> batch_end_callback : callable(BatchEndParams)</span>
<span class="sd"> A callback that is invoked at end of each batch.</span>
<span class="sd"> This can be used to measure speed, get result from evaluation metric. etc.</span>
<span class="sd"> kvstore : KVStore</span>
<span class="sd"> The KVStore.</span>
<span class="sd"> update_on_kvstore : bool</span>
<span class="sd"> Whether or not perform weight updating on kvstore.</span>
<span class="sd"> logger : logging logger</span>
<span class="sd"> When not specified, default logger will be used.</span>
<span class="sd"> work_load_list : list of float or int, optional</span>
<span class="sd"> The list of work load for different devices,</span>
<span class="sd"> in the same order as ``ctx``.</span>
<span class="sd"> monitor : Monitor, optional</span>
<span class="sd"> Monitor installed to executor,</span>
<span class="sd"> for monitoring outputs, weights, and gradients for debugging.</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> - This function will inplace update the NDArrays in `arg_params` and `aux_states`.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">logger</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span>
<span class="n">executor_manager</span> <span class="o">=</span> <span class="n">DataParallelExecutorManager</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">symbol</span><span class="p">,</span>
<span class="n">sym_gen</span><span class="o">=</span><span class="n">sym_gen</span><span class="p">,</span>
<span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span>
<span class="n">train_data</span><span class="o">=</span><span class="n">train_data</span><span class="p">,</span>
<span class="n">param_names</span><span class="o">=</span><span class="n">param_names</span><span class="p">,</span>
<span class="n">arg_names</span><span class="o">=</span><span class="n">arg_names</span><span class="p">,</span>
<span class="n">aux_names</span><span class="o">=</span><span class="n">aux_names</span><span class="p">,</span>
<span class="n">work_load_list</span><span class="o">=</span><span class="n">work_load_list</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">)</span>
<span class="k">if</span> <span class="n">monitor</span><span class="p">:</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">install_monitor</span><span class="p">(</span><span class="n">monitor</span><span class="p">)</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">update_on_kvstore</span><span class="p">:</span>
<span class="n">updater</span> <span class="o">=</span> <span class="n">get_updater</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">kvstore</span><span class="p">:</span>
<span class="n">_initialize_kvstore</span><span class="p">(</span><span class="n">kvstore</span><span class="o">=</span><span class="n">kvstore</span><span class="p">,</span>
<span class="n">param_arrays</span><span class="o">=</span><span class="n">executor_manager</span><span class="o">.</span><span class="n">param_arrays</span><span class="p">,</span>
<span class="n">arg_params</span><span class="o">=</span><span class="n">arg_params</span><span class="p">,</span>
<span class="n">param_names</span><span class="o">=</span><span class="n">executor_manager</span><span class="o">.</span><span class="n">param_names</span><span class="p">,</span>
<span class="n">update_on_kvstore</span><span class="o">=</span><span class="n">update_on_kvstore</span><span class="p">)</span>
<span class="k">if</span> <span class="n">update_on_kvstore</span><span class="p">:</span>
<span class="n">kvstore</span><span class="o">.</span><span class="n">set_optimizer</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
<span class="c1"># Now start training</span>
<span class="n">train_data</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">begin_epoch</span><span class="p">,</span> <span class="n">end_epoch</span><span class="p">):</span>
<span class="c1"># Training phase</span>
<span class="n">tic</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">eval_metric</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">nbatch</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Iterate over training data.</span>
<span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
<span class="n">do_reset</span> <span class="o">=</span> <span class="bp">True</span>
<span class="k">for</span> <span class="n">data_batch</span> <span class="ow">in</span> <span class="n">train_data</span><span class="p">:</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">load_data_batch</span><span class="p">(</span><span class="n">data_batch</span><span class="p">)</span>
<span class="k">if</span> <span class="n">monitor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">monitor</span><span class="o">.</span><span class="n">tic</span><span class="p">()</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="k">if</span> <span class="n">update_on_kvstore</span><span class="p">:</span>
<span class="k">if</span> <span class="s1">'nccl'</span> <span class="ow">in</span> <span class="n">kvstore</span><span class="o">.</span><span class="n">type</span><span class="p">:</span>
<span class="n">_update_params_on_kvstore_nccl</span><span class="p">(</span><span class="n">executor_manager</span><span class="o">.</span><span class="n">param_arrays</span><span class="p">,</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">,</span>
<span class="n">kvstore</span><span class="p">,</span> <span class="n">executor_manager</span><span class="o">.</span><span class="n">param_names</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">_update_params_on_kvstore</span><span class="p">(</span><span class="n">executor_manager</span><span class="o">.</span><span class="n">param_arrays</span><span class="p">,</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">,</span>
<span class="n">kvstore</span><span class="p">,</span> <span class="n">executor_manager</span><span class="o">.</span><span class="n">param_names</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">_update_params</span><span class="p">(</span><span class="n">executor_manager</span><span class="o">.</span><span class="n">param_arrays</span><span class="p">,</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">,</span>
<span class="n">updater</span><span class="o">=</span><span class="n">updater</span><span class="p">,</span>
<span class="n">num_device</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">ctx</span><span class="p">),</span>
<span class="n">kvstore</span><span class="o">=</span><span class="n">kvstore</span><span class="p">,</span>
<span class="n">param_names</span><span class="o">=</span><span class="n">executor_manager</span><span class="o">.</span><span class="n">param_names</span><span class="p">)</span>
<span class="k">if</span> <span class="n">monitor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">monitor</span><span class="o">.</span><span class="n">toc_print</span><span class="p">()</span>
<span class="c1"># evaluate at end, so we can lazy copy</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">update_metric</span><span class="p">(</span><span class="n">eval_metric</span><span class="p">,</span> <span class="n">data_batch</span><span class="o">.</span><span class="n">label</span><span class="p">)</span>
<span class="n">nbatch</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="c1"># batch callback (for print purpose)</span>
<span class="k">if</span> <span class="n">batch_end_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">batch_end_params</span> <span class="o">=</span> <span class="n">BatchEndParam</span><span class="p">(</span><span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
<span class="n">nbatch</span><span class="o">=</span><span class="n">nbatch</span><span class="p">,</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="n">eval_metric</span><span class="p">,</span>
<span class="nb">locals</span><span class="o">=</span><span class="nb">locals</span><span class="p">())</span>
<span class="n">_multiple_callbacks</span><span class="p">(</span><span class="n">batch_end_callback</span><span class="p">,</span> <span class="n">batch_end_params</span><span class="p">)</span>
<span class="c1"># this epoch is done possibly earlier</span>
<span class="k">if</span> <span class="n">epoch_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">and</span> <span class="n">nbatch</span> <span class="o">>=</span> <span class="n">epoch_size</span><span class="p">:</span>
<span class="n">do_reset</span> <span class="o">=</span> <span class="bp">False</span>
<span class="k">break</span>
<span class="k">if</span> <span class="n">do_reset</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'Epoch[</span><span class="si">%d</span><span class="s1">] Resetting Data Iterator'</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">train_data</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="c1"># this epoch is done</span>
<span class="k">if</span> <span class="n">epoch_size</span> <span class="ow">is</span> <span class="bp">None</span> <span class="ow">or</span> <span class="n">nbatch</span> <span class="o">>=</span> <span class="n">epoch_size</span><span class="p">:</span>
<span class="k">break</span>
<span class="n">toc</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'Epoch[</span><span class="si">%d</span><span class="s1">] Time cost=</span><span class="si">%.3f</span><span class="s1">'</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="p">(</span><span class="n">toc</span> <span class="o">-</span> <span class="n">tic</span><span class="p">))</span>
<span class="k">if</span> <span class="n">epoch_end_callback</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">==</span> <span class="n">end_epoch</span><span class="p">:</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">copy_to</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span>
<span class="n">_multiple_callbacks</span><span class="p">(</span><span class="n">epoch_end_callback</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span>
<span class="c1"># evaluation</span>
<span class="k">if</span> <span class="n">eval_data</span><span class="p">:</span>
<span class="n">eval_metric</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">eval_data</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">total_num_batch</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">eval_batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">eval_data</span><span class="p">):</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">load_data_batch</span><span class="p">(</span><span class="n">eval_batch</span><span class="p">)</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">executor_manager</span><span class="o">.</span><span class="n">update_metric</span><span class="p">(</span><span class="n">eval_metric</span><span class="p">,</span> <span class="n">eval_batch</span><span class="o">.</span><span class="n">label</span><span class="p">)</span>
<span class="k">if</span> <span class="n">eval_batch_end_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">batch_end_params</span> <span class="o">=</span> <span class="n">BatchEndParam</span><span class="p">(</span><span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
<span class="n">nbatch</span><span class="o">=</span><span class="n">i</span><span class="p">,</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="n">eval_metric</span><span class="p">,</span>
<span class="nb">locals</span><span class="o">=</span><span class="nb">locals</span><span class="p">())</span>
<span class="n">_multiple_callbacks</span><span class="p">(</span><span class="n">eval_batch_end_callback</span><span class="p">,</span> <span class="n">batch_end_params</span><span class="p">)</span>
<span class="n">total_num_batch</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">eval_end_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">eval_end_params</span> <span class="o">=</span> <span class="n">BatchEndParam</span><span class="p">(</span><span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
<span class="n">nbatch</span><span class="o">=</span><span class="n">total_num_batch</span><span class="p">,</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="n">eval_metric</span><span class="p">,</span>
<span class="nb">locals</span><span class="o">=</span><span class="nb">locals</span><span class="p">())</span>
<span class="n">_multiple_callbacks</span><span class="p">(</span><span class="n">eval_end_callback</span><span class="p">,</span> <span class="n">eval_end_params</span><span class="p">)</span>
<span class="n">eval_data</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="c1"># end of all epochs</span>
<span class="k">return</span>
<div class="viewcode-block" id="save_checkpoint"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.save_checkpoint">[docs]</a><span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">):</span>
<span class="sd">"""Checkpoint the model data into file.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> prefix : str</span>
<span class="sd"> Prefix of model name.</span>
<span class="sd"> epoch : int</span>
<span class="sd"> The epoch number of the model.</span>
<span class="sd"> symbol : Symbol</span>
<span class="sd"> The input Symbol.</span>
<span class="sd"> arg_params : dict of str to NDArray</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's weights.</span>
<span class="sd"> aux_params : dict of str to NDArray</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's auxiliary states.</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> - ``prefix-symbol.json`` will be saved for symbol.</span>
<span class="sd"> - ``prefix-epoch.params`` will be saved for parameters.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">symbol</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">symbol</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'</span><span class="si">%s</span><span class="s1">-symbol.json'</span> <span class="o">%</span> <span class="n">prefix</span><span class="p">)</span>
<span class="n">save_dict</span> <span class="o">=</span> <span class="p">{(</span><span class="s1">'arg:</span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">k</span><span class="p">)</span> <span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">cpu</span><span class="p">())</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">arg_params</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">save_dict</span><span class="o">.</span><span class="n">update</span><span class="p">({(</span><span class="s1">'aux:</span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span> <span class="n">k</span><span class="p">)</span> <span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">cpu</span><span class="p">())</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">aux_params</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
<span class="n">param_name</span> <span class="o">=</span> <span class="s1">'</span><span class="si">%s</span><span class="s1">-</span><span class="si">%04d</span><span class="s1">.params'</span> <span class="o">%</span> <span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">nd</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">param_name</span><span class="p">,</span> <span class="n">save_dict</span><span class="p">)</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'Saved checkpoint to </span><span class="se">\"</span><span class="si">%s</span><span class="se">\"</span><span class="s1">'</span><span class="p">,</span> <span class="n">param_name</span><span class="p">)</span></div>
<div class="viewcode-block" id="load_checkpoint"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.load_checkpoint">[docs]</a><span class="k">def</span> <span class="nf">load_checkpoint</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
<span class="sd">"""Load model checkpoint from file.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> prefix : str</span>
<span class="sd"> Prefix of model name.</span>
<span class="sd"> epoch : int</span>
<span class="sd"> Epoch number of model we would like to load.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> symbol : Symbol</span>
<span class="sd"> The symbol configuration of computation network.</span>
<span class="sd"> arg_params : dict of str to NDArray</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's weights.</span>
<span class="sd"> aux_params : dict of str to NDArray</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's auxiliary states.</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> - Symbol will be loaded from ``prefix-symbol.json``.</span>
<span class="sd"> - Parameters will be loaded from ``prefix-epoch.params``.</span>
<span class="sd"> """</span>
<span class="n">symbol</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s1">'</span><span class="si">%s</span><span class="s1">-symbol.json'</span> <span class="o">%</span> <span class="n">prefix</span><span class="p">)</span>
<span class="n">save_dict</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s1">'</span><span class="si">%s</span><span class="s1">-</span><span class="si">%04d</span><span class="s1">.params'</span> <span class="o">%</span> <span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">))</span>
<span class="n">arg_params</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">aux_params</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">save_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">tp</span><span class="p">,</span> <span class="n">name</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">':'</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tp</span> <span class="o">==</span> <span class="s1">'arg'</span><span class="p">:</span>
<span class="n">arg_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
<span class="k">if</span> <span class="n">tp</span> <span class="o">==</span> <span class="s1">'aux'</span><span class="p">:</span>
<span class="n">aux_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
<span class="k">return</span> <span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span></div>
<span class="kn">from</span> <span class="nn">.callback</span> <span class="kn">import</span> <span class="n">LogValidationMetricsCallback</span> <span class="c1"># pylint: disable=wrong-import-position</span>
<div class="viewcode-block" id="FeedForward"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward">[docs]</a><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">BASE_ESTIMATOR</span><span class="p">):</span>
<span class="sd">"""Model class of MXNet for training and predicting feedforward nets.</span>
<span class="sd"> This class is designed for a single-data single output supervised network.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> symbol : Symbol</span>
<span class="sd"> The symbol configuration of computation network.</span>
<span class="sd"> ctx : Context or list of Context, optional</span>
<span class="sd"> The device context of training and prediction.</span>
<span class="sd"> To use multi GPU training, pass in a list of gpu contexts.</span>
<span class="sd"> num_epoch : int, optional</span>
<span class="sd"> Training parameter, number of training epochs(epochs).</span>
<span class="sd"> epoch_size : int, optional</span>
<span class="sd"> Number of batches in a epoch. In default, it is set to</span>
<span class="sd"> ``ceil(num_train_examples / batch_size)``.</span>
<span class="sd"> optimizer : str or Optimizer, optional</span>
<span class="sd"> Training parameter, name or optimizer object for training.</span>
<span class="sd"> initializer : initializer function, optional</span>
<span class="sd"> Training parameter, the initialization scheme used.</span>
<span class="sd"> numpy_batch_size : int, optional</span>
<span class="sd"> The batch size of training data.</span>
<span class="sd"> Only needed when input array is numpy.</span>
<span class="sd"> arg_params : dict of str to NDArray, optional</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's weights.</span>
<span class="sd"> aux_params : dict of str to NDArray, optional</span>
<span class="sd"> Model parameter, dict of name to NDArray of net's auxiliary states.</span>
<span class="sd"> allow_extra_params : boolean, optional</span>
<span class="sd"> Whether allow extra parameters that are not needed by symbol</span>
<span class="sd"> to be passed by aux_params and ``arg_params``.</span>
<span class="sd"> If this is True, no error will be thrown when ``aux_params`` and ``arg_params``</span>
<span class="sd"> contain more parameters than needed.</span>
<span class="sd"> begin_epoch : int, optional</span>
<span class="sd"> The begining training epoch.</span>
<span class="sd"> kwargs : dict</span>
<span class="sd"> The additional keyword arguments passed to optimizer.</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">epoch_size</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span>
<span class="n">initializer</span><span class="o">=</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.01</span><span class="p">),</span>
<span class="n">numpy_batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
<span class="n">arg_params</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">aux_params</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">allow_extra_params</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
<span class="n">begin_epoch</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
<span class="s1">'</span><span class="se">\033</span><span class="s1">[91mmxnet.model.FeedForward has been deprecated. '</span> <span class="o">+</span> \
<span class="s1">'Please use mxnet.mod.Module instead.</span><span class="se">\033</span><span class="s1">[0m'</span><span class="p">,</span>
<span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="n">stacklevel</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">sym</span><span class="o">.</span><span class="n">Symbol</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">symbol</span> <span class="o">=</span> <span class="n">symbol</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sym_gen</span> <span class="o">=</span> <span class="bp">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span><span class="p">(</span><span class="nb">callable</span><span class="p">(</span><span class="n">symbol</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">symbol</span> <span class="o">=</span> <span class="bp">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sym_gen</span> <span class="o">=</span> <span class="n">symbol</span>
<span class="c1"># model parameters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span> <span class="o">=</span> <span class="n">arg_params</span>
<span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span> <span class="o">=</span> <span class="n">aux_params</span>
<span class="bp">self</span><span class="o">.</span><span class="n">allow_extra_params</span> <span class="o">=</span> <span class="n">allow_extra_params</span>
<span class="bp">self</span><span class="o">.</span><span class="n">argument_checked</span> <span class="o">=</span> <span class="bp">False</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sym_gen</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_check_arguments</span><span class="p">()</span>
<span class="c1"># basic configuration</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">cpu</span><span class="p">()]</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">Context</span><span class="p">):</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">ctx</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx</span> <span class="o">=</span> <span class="n">ctx</span>
<span class="c1"># training parameters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_epoch</span> <span class="o">=</span> <span class="n">num_epoch</span>
<span class="bp">self</span><span class="o">.</span><span class="n">epoch_size</span> <span class="o">=</span> <span class="n">epoch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
<span class="bp">self</span><span class="o">.</span><span class="n">initializer</span> <span class="o">=</span> <span class="n">initializer</span>
<span class="bp">self</span><span class="o">.</span><span class="n">numpy_batch_size</span> <span class="o">=</span> <span class="n">numpy_batch_size</span>
<span class="c1"># internal helper state</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span> <span class="o">=</span> <span class="bp">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">begin_epoch</span> <span class="o">=</span> <span class="n">begin_epoch</span>
<span class="k">def</span> <span class="nf">_check_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""verify the argument of the default symbol and user provided parameters"""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">argument_checked</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">assert</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">symbol</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">argument_checked</span> <span class="o">=</span> <span class="bp">True</span>
<span class="c1"># check if symbol contain duplicated names.</span>
<span class="n">_check_arguments</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="p">)</span>
<span class="c1"># rematch parameters to delete useless ones</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">allow_extra_params</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="p">:</span>
<span class="n">arg_names</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span> <span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">arg_names</span><span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span><span class="p">:</span>
<span class="n">aux_names</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">list_auxiliary_states</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span> <span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">aux_names</span><span class="p">}</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">_is_data_arg</span><span class="p">(</span><span class="n">name</span><span class="p">):</span>
<span class="sd">"""Check if name is a data argument."""</span>
<span class="k">return</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span> <span class="ow">or</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'label'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_init_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="sd">"""Initialize weight parameters and auxiliary states."""</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">DataDesc</span><span class="p">)</span> <span class="k">else</span> <span class="n">DataDesc</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">]</span>
<span class="n">input_shapes</span> <span class="o">=</span> <span class="p">{</span><span class="n">item</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">item</span><span class="o">.</span><span class="n">shape</span> <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">}</span>
<span class="n">arg_shapes</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">aux_shapes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">infer_shape</span><span class="p">(</span><span class="o">**</span><span class="n">input_shapes</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">arg_shapes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span>
<span class="n">input_dtypes</span> <span class="o">=</span> <span class="p">{</span><span class="n">item</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">item</span><span class="o">.</span><span class="n">dtype</span> <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">}</span>
<span class="n">arg_dtypes</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">aux_dtypes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">infer_type</span><span class="p">(</span><span class="o">**</span><span class="n">input_dtypes</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">arg_dtypes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span>
<span class="n">arg_names</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()</span>
<span class="n">input_names</span> <span class="o">=</span> <span class="n">input_shapes</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="n">param_names</span> <span class="o">=</span> <span class="p">[</span><span class="n">key</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">arg_names</span> <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">input_names</span><span class="p">]</span>
<span class="n">aux_names</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">list_auxiliary_states</span><span class="p">()</span>
<span class="n">param_name_attrs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">arg_names</span><span class="p">,</span> <span class="n">arg_shapes</span><span class="p">,</span> <span class="n">arg_dtypes</span><span class="p">)</span>
<span class="k">if</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">in</span> <span class="n">param_names</span><span class="p">]</span>
<span class="n">arg_params</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span> <span class="p">:</span> <span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">s</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">t</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">param_name_attrs</span><span class="p">}</span>
<span class="n">aux_name_attrs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">aux_names</span><span class="p">,</span> <span class="n">aux_shapes</span><span class="p">,</span> <span class="n">aux_dtypes</span><span class="p">)</span>
<span class="k">if</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">in</span> <span class="n">aux_names</span><span class="p">]</span>
<span class="n">aux_params</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span> <span class="p">:</span> <span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">s</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">t</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">aux_name_attrs</span><span class="p">}</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">arg_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span> <span class="ow">and</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span> <span class="ow">and</span> <span class="p">(</span><span class="ow">not</span> <span class="n">overwrite</span><span class="p">):</span>
<span class="n">arg_params</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">initializer</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">aux_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span> <span class="ow">and</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span> <span class="ow">and</span> <span class="p">(</span><span class="ow">not</span> <span class="n">overwrite</span><span class="p">):</span>
<span class="n">aux_params</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">initializer</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span> <span class="o">=</span> <span class="n">arg_params</span>
<span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span> <span class="o">=</span> <span class="n">aux_params</span>
<span class="k">return</span> <span class="p">(</span><span class="n">arg_names</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">param_names</span><span class="p">),</span> <span class="n">aux_names</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__getstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">this</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="n">this</span><span class="p">[</span><span class="s1">'_pred_exec'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span>
<span class="k">return</span> <span class="n">this</span>
<span class="k">def</span> <span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_init_predictor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shapes</span><span class="p">,</span> <span class="n">type_dict</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="sd">"""Initialize the predictor module for running prediction."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">arg_shapes</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">infer_shape</span><span class="p">(</span><span class="o">**</span><span class="nb">dict</span><span class="p">(</span><span class="n">input_shapes</span><span class="p">))</span>
<span class="k">assert</span> <span class="n">arg_shapes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">,</span> <span class="s2">"Incomplete input shapes"</span>
<span class="n">pred_shapes</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">arg_arrays</span><span class="p">]</span>
<span class="k">if</span> <span class="n">arg_shapes</span> <span class="o">==</span> <span class="n">pred_shapes</span><span class="p">:</span>
<span class="k">return</span>
<span class="c1"># for now only use the first device</span>
<span class="n">pred_exec</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">simple_bind</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">grad_req</span><span class="o">=</span><span class="s1">'null'</span><span class="p">,</span> <span class="n">type_dict</span><span class="o">=</span><span class="n">type_dict</span><span class="p">,</span> <span class="o">**</span><span class="nb">dict</span><span class="p">(</span><span class="n">input_shapes</span><span class="p">))</span>
<span class="n">pred_exec</span><span class="o">.</span><span class="n">copy_params_from</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span><span class="p">)</span>
<span class="n">_check_arguments</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span> <span class="o">=</span> <span class="n">pred_exec</span>
<span class="k">def</span> <span class="nf">_init_iter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">is_train</span><span class="p">):</span>
<span class="sd">"""Initialize the iterator given input."""</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">nd</span><span class="o">.</span><span class="n">NDArray</span><span class="p">)):</span>
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">is_train</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'y must be specified when X is numpy.ndarray'</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">nd</span><span class="o">.</span><span class="n">NDArray</span><span class="p">)):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">'y must be ndarray when X is numpy.ndarray'</span><span class="p">)</span>
<span class="k">if</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"The numbers of data points and labels not equal"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">y</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">if</span> <span class="n">y</span><span class="o">.</span><span class="n">ndim</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Label must be 1D or 2D (with 2nd dimension being 1)"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">is_train</span><span class="p">:</span>
<span class="k">return</span> <span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">numpy_batch_size</span><span class="p">),</span>
<span class="n">shuffle</span><span class="o">=</span><span class="n">is_train</span><span class="p">,</span> <span class="n">last_batch_handle</span><span class="o">=</span><span class="s1">'roll_over'</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">numpy_batch_size</span><span class="p">),</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">DataIter</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">'X must be DataIter, NDArray or numpy.ndarray'</span><span class="p">)</span>
<span class="k">return</span> <span class="n">X</span>
<span class="k">def</span> <span class="nf">_init_eval_iter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">eval_data</span><span class="p">):</span>
<span class="sd">"""Initialize the iterator given eval_data."""</span>
<span class="k">if</span> <span class="n">eval_data</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">eval_data</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_data</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">))</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">eval_data</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">if</span> <span class="n">eval_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">eval_data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">is</span> <span class="bp">None</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">io</span><span class="o">.</span><span class="n">DataIter</span><span class="p">):</span>
<span class="k">return</span> <span class="n">eval_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">input_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">eval_data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">list</span><span class="p">)</span>
<span class="k">else</span> <span class="n">eval_data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">input_label</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">eval_data</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_data</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="nb">list</span><span class="p">)</span>
<span class="k">else</span> <span class="n">eval_data</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_iter</span><span class="p">(</span><span class="n">input_data</span><span class="p">,</span> <span class="n">input_label</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Eval data is NONE"</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_data</span><span class="p">,</span> <span class="n">io</span><span class="o">.</span><span class="n">DataIter</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">'Eval data must be DataIter, or '</span> \
<span class="s1">'NDArray/numpy.ndarray/list pair (i.e. tuple/list of length 2)'</span><span class="p">)</span>
<span class="k">return</span> <span class="n">eval_data</span>
<div class="viewcode-block" id="FeedForward.predict"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward.predict">[docs]</a> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">num_batch</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">return_data</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">reset</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
<span class="sd">"""Run the prediction, always only use one device.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> X : mxnet.DataIter</span>
<span class="sd"> num_batch : int or None</span>
<span class="sd"> The number of batch to run. Go though all batches if ``None``.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs.</span>
<span class="sd"> The predicted value of the output.</span>
<span class="sd"> """</span>
<span class="n">X</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_iter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">reset</span><span class="p">:</span>
<span class="n">X</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">data_shapes</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">provide_data</span>
<span class="n">data_names</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data_shapes</span><span class="p">]</span>
<span class="n">type_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">((</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="o">.</span><span class="n">items</span><span class="p">())</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">X</span><span class="o">.</span><span class="n">provide_data</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">DataDesc</span><span class="p">):</span>
<span class="n">type_dict</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">type_dict</span><span class="p">[</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">mx_real_t</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_predictor</span><span class="p">(</span><span class="n">data_shapes</span><span class="p">,</span> <span class="n">type_dict</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">batch_size</span>
<span class="n">data_arrays</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">data_names</span><span class="p">]</span>
<span class="n">output_list</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">outputs</span><span class="p">))]</span>
<span class="k">if</span> <span class="n">return_data</span><span class="p">:</span>
<span class="n">data_list</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">X</span><span class="o">.</span><span class="n">provide_data</span><span class="p">]</span>
<span class="n">label_list</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">X</span><span class="o">.</span><span class="n">provide_label</span><span class="p">]</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">X</span><span class="p">:</span>
<span class="n">_load_data</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">data_arrays</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">padded</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">pad</span>
<span class="n">real_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">padded</span>
<span class="k">for</span> <span class="n">o_list</span><span class="p">,</span> <span class="n">o_nd</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">output_list</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">outputs</span><span class="p">):</span>
<span class="n">o_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">o_nd</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">real_size</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span>
<span class="k">if</span> <span class="n">return_data</span><span class="p">:</span>
<span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">data</span><span class="p">):</span>
<span class="n">data_list</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">real_size</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span>
<span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">label</span><span class="p">):</span>
<span class="n">label_list</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">real_size</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span>
<span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">num_batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">==</span> <span class="n">num_batch</span><span class="p">:</span>
<span class="k">break</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">output_list</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">return_data</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data_list</span><span class="p">]</span>
<span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">label_list</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">label</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">label</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">label</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">outputs</span></div>
<div class="viewcode-block" id="FeedForward.score"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward.score">[docs]</a> <span class="k">def</span> <span class="nf">score</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span> <span class="n">num_batch</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">reset</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
<span class="sd">"""Run the model given an input and calculate the score</span>
<span class="sd"> as assessed by an evaluation metric.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> X : mxnet.DataIter</span>
<span class="sd"> eval_metric : metric.metric</span>
<span class="sd"> The metric for calculating score.</span>
<span class="sd"> num_batch : int or None</span>
<span class="sd"> The number of batches to run. Go though all batches if ``None``.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> s : float</span>
<span class="sd"> The final score.</span>
<span class="sd"> """</span>
<span class="c1"># setup metric</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_metric</span><span class="p">,</span> <span class="n">metric</span><span class="o">.</span><span class="n">EvalMetric</span><span class="p">):</span>
<span class="n">eval_metric</span> <span class="o">=</span> <span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">eval_metric</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_iter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">None</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">reset</span><span class="p">:</span>
<span class="n">X</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">data_shapes</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">provide_data</span>
<span class="n">data_names</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">data_shapes</span><span class="p">]</span>
<span class="n">type_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">((</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="o">.</span><span class="n">items</span><span class="p">())</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">X</span><span class="o">.</span><span class="n">provide_data</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">DataDesc</span><span class="p">):</span>
<span class="n">type_dict</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">type_dict</span><span class="p">[</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">mx_real_t</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_predictor</span><span class="p">(</span><span class="n">data_shapes</span><span class="p">,</span> <span class="n">type_dict</span><span class="p">)</span>
<span class="n">data_arrays</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">data_names</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">X</span><span class="p">):</span>
<span class="k">if</span> <span class="n">num_batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">==</span> <span class="n">num_batch</span><span class="p">:</span>
<span class="k">break</span>
<span class="n">_load_data</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">data_arrays</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">eval_metric</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">label</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pred_exec</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">batch_end_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">batch_end_params</span> <span class="o">=</span> <span class="n">BatchEndParam</span><span class="p">(</span><span class="n">epoch</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">nbatch</span><span class="o">=</span><span class="n">i</span><span class="p">,</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="n">eval_metric</span><span class="p">,</span>
<span class="nb">locals</span><span class="o">=</span><span class="nb">locals</span><span class="p">())</span>
<span class="n">_multiple_callbacks</span><span class="p">(</span><span class="n">batch_end_callback</span><span class="p">,</span> <span class="n">batch_end_params</span><span class="p">)</span>
<span class="k">return</span> <span class="n">eval_metric</span><span class="o">.</span><span class="n">get</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span></div>
<div class="viewcode-block" id="FeedForward.fit"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward.fit">[docs]</a> <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">eval_data</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span>
<span class="n">epoch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">kvstore</span><span class="o">=</span><span class="s1">'local'</span><span class="p">,</span> <span class="n">logger</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">work_load_list</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">monitor</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">eval_end_callback</span><span class="o">=</span><span class="n">LogValidationMetricsCallback</span><span class="p">(),</span>
<span class="n">eval_batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="sd">"""Fit the model.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> X : DataIter, or numpy.ndarray/NDArray</span>
<span class="sd"> Training data. If `X` is a `DataIter`, the name or (if name not available)</span>
<span class="sd"> the position of its outputs should match the corresponding variable</span>
<span class="sd"> names defined in the symbolic graph.</span>
<span class="sd"> y : numpy.ndarray/NDArray, optional</span>
<span class="sd"> Training set label.</span>
<span class="sd"> If X is ``numpy.ndarray`` or `NDArray`, `y` is required to be set.</span>
<span class="sd"> While y can be 1D or 2D (with 2nd dimension as 1), its first dimension must be</span>
<span class="sd"> the same as `X`, i.e. the number of data points and labels should be equal.</span>
<span class="sd"> eval_data : DataIter or numpy.ndarray/list/NDArray pair</span>
<span class="sd"> If eval_data is numpy.ndarray/list/NDArray pair,</span>
<span class="sd"> it should be ``(valid_data, valid_label)``.</span>
<span class="sd"> eval_metric : metric.EvalMetric or str or callable</span>
<span class="sd"> The evaluation metric. This could be the name of evaluation metric</span>
<span class="sd"> or a custom evaluation function that returns statistics</span>
<span class="sd"> based on a minibatch.</span>
<span class="sd"> epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)</span>
<span class="sd"> A callback that is invoked at end of each epoch.</span>
<span class="sd"> This can be used to checkpoint model each epoch.</span>
<span class="sd"> batch_end_callback: callable(epoch)</span>
<span class="sd"> A callback that is invoked at end of each batch for purposes of printing.</span>
<span class="sd"> kvstore: KVStore or str, optional</span>
<span class="sd"> The KVStore or a string kvstore type: 'local', 'dist_sync', 'dist_async'</span>
<span class="sd"> In default uses 'local', often no need to change for single machiine.</span>
<span class="sd"> logger : logging logger, optional</span>
<span class="sd"> When not specified, default logger will be used.</span>
<span class="sd"> work_load_list : float or int, optional</span>
<span class="sd"> The list of work load for different devices,</span>
<span class="sd"> in the same order as `ctx`.</span>
<span class="sd"> Note</span>
<span class="sd"> ----</span>
<span class="sd"> KVStore behavior</span>
<span class="sd"> - 'local', multi-devices on a single machine, will automatically choose best type.</span>
<span class="sd"> - 'dist_sync', multiple machines communicating via BSP.</span>
<span class="sd"> - 'dist_async', multiple machines with asynchronous communication.</span>
<span class="sd"> """</span>
<span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_iter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">eval_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_eval_iter</span><span class="p">(</span><span class="n">eval_data</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sym_gen</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">symbol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sym_gen</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">default_bucket_key</span><span class="p">)</span> <span class="c1"># pylint: disable=no-member</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_check_arguments</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span><span class="p">[</span><span class="s2">"sym"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span>
<span class="n">arg_names</span><span class="p">,</span> <span class="n">param_names</span><span class="p">,</span> <span class="n">aux_names</span> <span class="o">=</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">_init_params</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">provide_data</span><span class="o">+</span><span class="n">data</span><span class="o">.</span><span class="n">provide_label</span><span class="p">)</span>
<span class="c1"># setup metric</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">eval_metric</span><span class="p">,</span> <span class="n">metric</span><span class="o">.</span><span class="n">EvalMetric</span><span class="p">):</span>
<span class="n">eval_metric</span> <span class="o">=</span> <span class="n">metric</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">eval_metric</span><span class="p">)</span>
<span class="c1"># create kvstore</span>
<span class="p">(</span><span class="n">kvstore</span><span class="p">,</span> <span class="n">update_on_kvstore</span><span class="p">)</span> <span class="o">=</span> <span class="n">_create_kvstore</span><span class="p">(</span>
<span class="n">kvstore</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ctx</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="p">)</span>
<span class="n">param_idx2name</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="n">update_on_kvstore</span><span class="p">:</span>
<span class="n">param_idx2name</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">param_names</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">param_names</span><span class="p">):</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ctx</span><span class="p">)):</span>
<span class="n">param_idx2name</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ctx</span><span class="p">)</span><span class="o">+</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span><span class="p">[</span><span class="s2">"param_idx2name"</span><span class="p">]</span> <span class="o">=</span> <span class="n">param_idx2name</span>
<span class="c1"># init optmizer</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">batch_size</span>
<span class="k">if</span> <span class="n">kvstore</span> <span class="ow">and</span> <span class="s1">'dist'</span> <span class="ow">in</span> <span class="n">kvstore</span><span class="o">.</span><span class="n">type</span> <span class="ow">and</span> <span class="ow">not</span> <span class="s1">'_async'</span> <span class="ow">in</span> <span class="n">kvstore</span><span class="o">.</span><span class="n">type</span><span class="p">:</span>
<span class="n">batch_size</span> <span class="o">*=</span> <span class="n">kvstore</span><span class="o">.</span><span class="n">num_workers</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">opt</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span>
<span class="n">rescale_grad</span><span class="o">=</span><span class="p">(</span><span class="mf">1.0</span><span class="o">/</span><span class="n">batch_size</span><span class="p">),</span>
<span class="o">**</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span><span class="p">))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span>
<span class="c1"># do training</span>
<span class="n">_train_multi_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ctx</span><span class="p">,</span> <span class="n">arg_names</span><span class="p">,</span> <span class="n">param_names</span><span class="p">,</span> <span class="n">aux_names</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span><span class="p">,</span>
<span class="n">begin_epoch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">begin_epoch</span><span class="p">,</span> <span class="n">end_epoch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_epoch</span><span class="p">,</span>
<span class="n">epoch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">epoch_size</span><span class="p">,</span>
<span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span>
<span class="n">train_data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">eval_data</span><span class="o">=</span><span class="n">eval_data</span><span class="p">,</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="n">eval_metric</span><span class="p">,</span>
<span class="n">epoch_end_callback</span><span class="o">=</span><span class="n">epoch_end_callback</span><span class="p">,</span>
<span class="n">batch_end_callback</span><span class="o">=</span><span class="n">batch_end_callback</span><span class="p">,</span>
<span class="n">kvstore</span><span class="o">=</span><span class="n">kvstore</span><span class="p">,</span> <span class="n">update_on_kvstore</span><span class="o">=</span><span class="n">update_on_kvstore</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">,</span> <span class="n">work_load_list</span><span class="o">=</span><span class="n">work_load_list</span><span class="p">,</span> <span class="n">monitor</span><span class="o">=</span><span class="n">monitor</span><span class="p">,</span>
<span class="n">eval_end_callback</span><span class="o">=</span><span class="n">eval_end_callback</span><span class="p">,</span>
<span class="n">eval_batch_end_callback</span><span class="o">=</span><span class="n">eval_batch_end_callback</span><span class="p">,</span>
<span class="n">sym_gen</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sym_gen</span><span class="p">)</span></div>
<div class="viewcode-block" id="FeedForward.save"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward.save">[docs]</a> <span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="sd">"""Checkpoint the model checkpoint into file.</span>
<span class="sd"> You can also use `pickle` to do the job if you only work on Python.</span>
<span class="sd"> The advantage of `load` and `save` (as compared to `pickle`) is that</span>
<span class="sd"> the resulting file can be loaded from other MXNet language bindings.</span>
<span class="sd"> One can also directly `load`/`save` from/to cloud storage(S3, HDFS)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> prefix : str</span>
<span class="sd"> Prefix of model name.</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> - ``prefix-symbol.json`` will be saved for symbol.</span>
<span class="sd"> - ``prefix-epoch.params`` will be saved for parameters.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">epoch</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">epoch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_epoch</span>
<span class="k">assert</span> <span class="n">epoch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span>
<span class="n">save_checkpoint</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">symbol</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">arg_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_params</span><span class="p">)</span></div>
<span class="nd">@staticmethod</span>
<div class="viewcode-block" id="FeedForward.load"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward.load">[docs]</a> <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Load model checkpoint from file.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> prefix : str</span>
<span class="sd"> Prefix of model name.</span>
<span class="sd"> epoch : int</span>
<span class="sd"> epoch number of model we would like to load.</span>
<span class="sd"> ctx : Context or list of Context, optional</span>
<span class="sd"> The device context of training and prediction.</span>
<span class="sd"> kwargs : dict</span>
<span class="sd"> Other parameters for model, including `num_epoch`, optimizer and `numpy_batch_size`.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> model : FeedForward</span>
<span class="sd"> The loaded model that can be used for prediction.</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> - ``prefix-symbol.json`` will be saved for symbol.</span>
<span class="sd"> - ``prefix-epoch.params`` will be saved for parameters.</span>
<span class="sd"> """</span>
<span class="n">symbol</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span> <span class="o">=</span> <span class="n">load_checkpoint</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="k">return</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span>
<span class="n">arg_params</span><span class="o">=</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="o">=</span><span class="n">aux_params</span><span class="p">,</span>
<span class="n">begin_epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="nd">@staticmethod</span>
<div class="viewcode-block" id="FeedForward.create"><a class="viewcode-back" href="../../api/python/model.html#mxnet.model.FeedForward.create">[docs]</a> <span class="k">def</span> <span class="nf">create</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">epoch_size</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.01</span><span class="p">),</span>
<span class="n">eval_data</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span>
<span class="n">epoch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">kvstore</span><span class="o">=</span><span class="s1">'local'</span><span class="p">,</span> <span class="n">logger</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">work_load_list</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span>
<span class="n">eval_end_callback</span><span class="o">=</span><span class="n">LogValidationMetricsCallback</span><span class="p">(),</span>
<span class="n">eval_batch_end_callback</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Functional style to create a model.</span>
<span class="sd"> This function is more consistent with functional</span>
<span class="sd"> languages such as R, where mutation is not allowed.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> symbol : Symbol</span>
<span class="sd"> The symbol configuration of a computation network.</span>
<span class="sd"> X : DataIter</span>
<span class="sd"> Training data.</span>
<span class="sd"> y : numpy.ndarray, optional</span>
<span class="sd"> If `X` is a ``numpy.ndarray``, `y` must be set.</span>
<span class="sd"> ctx : Context or list of Context, optional</span>
<span class="sd"> The device context of training and prediction.</span>
<span class="sd"> To use multi-GPU training, pass in a list of GPU contexts.</span>
<span class="sd"> num_epoch : int, optional</span>
<span class="sd"> The number of training epochs(epochs).</span>
<span class="sd"> epoch_size : int, optional</span>
<span class="sd"> Number of batches in a epoch. In default, it is set to</span>
<span class="sd"> ``ceil(num_train_examples / batch_size)``.</span>
<span class="sd"> optimizer : str or Optimizer, optional</span>
<span class="sd"> The name of the chosen optimizer, or an optimizer object, used for training.</span>
<span class="sd"> initializer : initializer function, optional</span>
<span class="sd"> The initialization scheme used.</span>
<span class="sd"> eval_data : DataIter or numpy.ndarray pair</span>
<span class="sd"> If `eval_set` is ``numpy.ndarray`` pair, it should</span>
<span class="sd"> be (`valid_data`, `valid_label`).</span>
<span class="sd"> eval_metric : metric.EvalMetric or str or callable</span>
<span class="sd"> The evaluation metric. Can be the name of an evaluation metric</span>
<span class="sd"> or a custom evaluation function that returns statistics</span>
<span class="sd"> based on a minibatch.</span>
<span class="sd"> epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)</span>
<span class="sd"> A callback that is invoked at end of each epoch.</span>
<span class="sd"> This can be used to checkpoint model each epoch.</span>
<span class="sd"> batch_end_callback: callable(epoch)</span>
<span class="sd"> A callback that is invoked at end of each batch for print purposes.</span>
<span class="sd"> kvstore: KVStore or str, optional</span>
<span class="sd"> The KVStore or a string kvstore type: 'local', 'dist_sync', 'dis_async'.</span>
<span class="sd"> Defaults to 'local', often no need to change for single machine.</span>
<span class="sd"> logger : logging logger, optional</span>
<span class="sd"> When not specified, default logger will be used.</span>
<span class="sd"> work_load_list : list of float or int, optional</span>
<span class="sd"> The list of work load for different devices,</span>
<span class="sd"> in the same order as `ctx`.</span>
<span class="sd"> """</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">num_epoch</span><span class="o">=</span><span class="n">num_epoch</span><span class="p">,</span>
<span class="n">epoch_size</span><span class="o">=</span><span class="n">epoch_size</span><span class="p">,</span>
<span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="n">initializer</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">eval_data</span><span class="o">=</span><span class="n">eval_data</span><span class="p">,</span> <span class="n">eval_metric</span><span class="o">=</span><span class="n">eval_metric</span><span class="p">,</span>
<span class="n">epoch_end_callback</span><span class="o">=</span><span class="n">epoch_end_callback</span><span class="p">,</span>
<span class="n">batch_end_callback</span><span class="o">=</span><span class="n">batch_end_callback</span><span class="p">,</span>
<span class="n">kvstore</span><span class="o">=</span><span class="n">kvstore</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">,</span>
<span class="n">work_load_list</span><span class="o">=</span><span class="n">work_load_list</span><span class="p">,</span>
<span class="n">eval_end_callback</span><span class="o">=</span><span class="n">eval_end_callback</span><span class="p">,</span>
<span class="n">eval_batch_end_callback</span><span class="o">=</span><span class="n">eval_batch_end_callback</span><span class="p">)</span>
<span class="k">return</span> <span class="n">model</span></div></div>
</pre></div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>