blob: b123ec365eced13b01ff25d6ff5ba0c1e64d4671 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="mxnet.gluon.parameter" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="mxnet.gluon.parameter" property="og:description"/>
<title>mxnet.gluon.parameter — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../../_static/underscore.js" type="text/javascript"></script>
<script src="../../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../../_static/doctools.js" type="text/javascript"></script>
<script src="../../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/1.2.1/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../../genindex.html" rel="index" title="Index">
<link href="../../../search.html" rel="search" title="Search"/>
<link href="../../index.html" rel="up" title="Module code"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/1.2.1/install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.2.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.2.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/1.2.1/faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.2.1/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.2.1">Github</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">1.2.1<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/1.2.1/install/index.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.2.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.2.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.2.1/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/1.2.1/faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/1.2.1/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.2.1/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/1.2.1/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/1.2.1/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.2.1" tabindex="-1">Github</a></li>
<li><a href="/versions/1.2.1/community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="/versions/1.2.1/community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">1.2.1</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../tutorials/index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../community/index.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<h1>Source code for mxnet.gluon.parameter</h1><div class="highlight"><pre>
<span></span><span class="c1"># Licensed to the Apache Software Foundation (ASF) under one</span>
<span class="c1"># or more contributor license agreements. See the NOTICE file</span>
<span class="c1"># distributed with this work for additional information</span>
<span class="c1"># regarding copyright ownership. The ASF licenses this file</span>
<span class="c1"># to you under the Apache License, Version 2.0 (the</span>
<span class="c1"># "License"); you may not use this file except in compliance</span>
<span class="c1"># with the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing,</span>
<span class="c1"># software distributed under the License is distributed on an</span>
<span class="c1"># "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY</span>
<span class="c1"># KIND, either express or implied. See the License for the</span>
<span class="c1"># specific language governing permissions and limitations</span>
<span class="c1"># under the License.</span>
<span class="c1"># coding: utf-8</span>
<span class="c1"># pylint: disable=</span>
<span class="sd">"""Neural network parameter."""</span>
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'DeferredInitializationError'</span><span class="p">,</span> <span class="s1">'Parameter'</span><span class="p">,</span> <span class="s1">'Constant'</span><span class="p">,</span>
<span class="s1">'ParameterDict'</span><span class="p">,</span> <span class="s1">'tensor_types'</span><span class="p">]</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="k">import</span> <span class="n">OrderedDict</span>
<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">..base</span> <span class="k">import</span> <span class="n">mx_real_t</span><span class="p">,</span> <span class="n">MXNetError</span>
<span class="kn">from</span> <span class="nn">..</span> <span class="k">import</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">ndarray</span><span class="p">,</span> <span class="n">initializer</span><span class="p">,</span> <span class="n">context</span>
<span class="kn">from</span> <span class="nn">..context</span> <span class="k">import</span> <span class="n">Context</span><span class="p">,</span> <span class="n">cpu</span>
<span class="kn">from</span> <span class="nn">..</span> <span class="k">import</span> <span class="n">autograd</span>
<span class="kn">from</span> <span class="nn">.utils</span> <span class="k">import</span> <span class="n">_indent</span><span class="p">,</span> <span class="n">_brief_print_list</span>
<span class="c1"># pylint: disable= invalid-name</span>
<span class="n">tensor_types</span> <span class="o">=</span> <span class="p">(</span><span class="n">symbol</span><span class="o">.</span><span class="n">Symbol</span><span class="p">,</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">NDArray</span><span class="p">)</span>
<span class="c1"># pylint: enable= invalid-name</span>
<div class="viewcode-block" id="DeferredInitializationError"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.DeferredInitializationError">[docs]</a><span class="k">class</span> <span class="nc">DeferredInitializationError</span><span class="p">(</span><span class="n">MXNetError</span><span class="p">):</span>
<span class="sd">"""Error for unfinished deferred initialization."""</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="Parameter"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter">[docs]</a><span class="k">class</span> <span class="nc">Parameter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="sd">"""A Container holding parameters (weights) of Blocks.</span>
<span class="sd"> :py:class:`Parameter` holds a copy of the parameter on each :py:class:`Context` after</span>
<span class="sd"> it is initialized with ``Parameter.initialize(...)``. If :py:attr:`grad_req` is</span>
<span class="sd"> not ``'null'``, it will also hold a gradient array on each :py:class:`Context`::</span>
<span class="sd"> ctx = mx.gpu(0)</span>
<span class="sd"> x = mx.nd.zeros((16, 100), ctx=ctx)</span>
<span class="sd"> w = mx.gluon.Parameter('fc_weight', shape=(64, 100), init=mx.init.Xavier())</span>
<span class="sd"> b = mx.gluon.Parameter('fc_bias', shape=(64,), init=mx.init.Zero())</span>
<span class="sd"> w.initialize(ctx=ctx)</span>
<span class="sd"> b.initialize(ctx=ctx)</span>
<span class="sd"> out = mx.nd.FullyConnected(x, w.data(ctx), b.data(ctx), num_hidden=64)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> name : str</span>
<span class="sd"> Name of this parameter.</span>
<span class="sd"> grad_req : {'write', 'add', 'null'}, default 'write'</span>
<span class="sd"> Specifies how to update gradient to grad arrays.</span>
<span class="sd"> - ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.</span>
<span class="sd"> - ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need</span>
<span class="sd"> to manually call ``zero_grad()`` to clear the gradient buffer before each</span>
<span class="sd"> iteration when using this option.</span>
<span class="sd"> - 'null' means gradient is not requested for this parameter. gradient arrays</span>
<span class="sd"> will not be allocated.</span>
<span class="sd"> shape : tuple of int, default None</span>
<span class="sd"> Shape of this parameter. By default shape is not specified. Parameter with</span>
<span class="sd"> unknown shape can be used for :py:class:`Symbol` API, but ``init`` will throw an error</span>
<span class="sd"> when using :py:class:`NDArray` API.</span>
<span class="sd"> dtype : numpy.dtype or str, default 'float32'</span>
<span class="sd"> Data type of this parameter. For example, ``numpy.float32`` or ``'float32'``.</span>
<span class="sd"> lr_mult : float, default 1.0</span>
<span class="sd"> Learning rate multiplier. Learning rate will be multiplied by lr_mult</span>
<span class="sd"> when updating this parameter with optimizer.</span>
<span class="sd"> wd_mult : float, default 1.0</span>
<span class="sd"> Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.</span>
<span class="sd"> init : Initializer, default None</span>
<span class="sd"> Initializer of this parameter. Will use the global initializer by default.</span>
<span class="sd"> Attributes</span>
<span class="sd"> ----------</span>
<span class="sd"> grad_req : {'write', 'add', 'null'}</span>
<span class="sd"> This can be set before or after initialization. Setting ``grad_req`` to ``'null'``</span>
<span class="sd"> with ``x.grad_req = 'null'`` saves memory and computation when you don't</span>
<span class="sd"> need gradient w.r.t x.</span>
<span class="sd"> lr_mult : float</span>
<span class="sd"> Local learning rate multiplier for this Parameter. The actual learning rate</span>
<span class="sd"> is calculated with ``learning_rate * lr_mult``. You can set it with</span>
<span class="sd"> ``param.lr_mult = 2.0``</span>
<span class="sd"> wd_mult : float</span>
<span class="sd"> Local weight decay multiplier for this Parameter.</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="s1">'write'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">mx_real_t</span><span class="p">,</span>
<span class="n">lr_mult</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">wd_mult</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">allow_deferred_init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">differentiable</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_var</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_ctx_list</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_ctx_map</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_differentiable</span> <span class="o">=</span> <span class="n">differentiable</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_allow_deferred_init</span> <span class="o">=</span> <span class="n">allow_deferred_init</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad_req</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">shape</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lr_mult</span> <span class="o">=</span> <span class="n">lr_mult</span>
<span class="bp">self</span><span class="o">.</span><span class="n">wd_mult</span> <span class="o">=</span> <span class="n">wd_mult</span>
<span class="bp">self</span><span class="o">.</span><span class="n">grad_req</span> <span class="o">=</span> <span class="n">grad_req</span>
<span class="bp">self</span><span class="o">.</span><span class="n">init</span> <span class="o">=</span> <span class="n">init</span>
<span class="k">def</span> <span class="nf">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">s</span> <span class="o">=</span> <span class="s1">'Parameter </span><span class="si">{name}</span><span class="s1"> (shape=</span><span class="si">{shape}</span><span class="s1">, dtype=</span><span class="si">{dtype}</span><span class="s1">)'</span>
<span class="k">return</span> <span class="n">s</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">grad_req</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad_req</span>
<span class="nd">@grad_req</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">grad_req</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">req</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">req</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'write'</span><span class="p">,</span> <span class="s1">'add'</span><span class="p">,</span> <span class="s1">'null'</span><span class="p">],</span> \
<span class="s2">"grad_req must be one of 'write', 'add', or 'null', but got '</span><span class="si">%s</span><span class="s2">'"</span><span class="o">%</span><span class="n">req</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_differentiable</span><span class="p">:</span>
<span class="n">req</span> <span class="o">=</span> <span class="s1">'null'</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad_req</span> <span class="o">==</span> <span class="n">req</span><span class="p">:</span>
<span class="k">return</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad_req</span> <span class="o">=</span> <span class="n">req</span>
<span class="k">if</span> <span class="n">req</span> <span class="o">==</span> <span class="s1">'null'</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">]</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_grad</span><span class="p">()</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span>
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span> <span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">new_shape</span>
<span class="k">return</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="ow">and</span> \
<span class="nb">all</span><span class="p">(</span><span class="n">j</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">i</span> <span class="o">==</span> <span class="n">j</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">new_shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">)),</span> \
<span class="s2">"Expected shape </span><span class="si">%s</span><span class="s2"> is incompatible with given shape </span><span class="si">%s</span><span class="s2">."</span><span class="o">%</span><span class="p">(</span>
<span class="nb">str</span><span class="p">(</span><span class="n">new_shape</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_shape</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">new_shape</span>
<span class="k">def</span> <span class="nf">_check_and_get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arr_list</span><span class="p">,</span> <span class="n">ctx</span><span class="p">):</span>
<span class="k">if</span> <span class="n">arr_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="nb">list</span><span class="p">:</span>
<span class="k">return</span> <span class="n">arr_list</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">arr_list</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">arr_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">current_context</span><span class="p">()</span>
<span class="n">ctx_list</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ctx_map</span><span class="p">[</span><span class="n">ctx</span><span class="o">.</span><span class="n">device_typeid</span><span class="o">&amp;</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">ctx</span><span class="o">.</span><span class="n">device_id</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="n">ctx_list</span><span class="p">):</span>
<span class="n">idx</span> <span class="o">=</span> <span class="n">ctx_list</span><span class="p">[</span><span class="n">ctx</span><span class="o">.</span><span class="n">device_id</span><span class="p">]</span>
<span class="k">if</span> <span class="n">idx</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">arr_list</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' was not initialized on context </span><span class="si">%s</span><span class="s2">. "</span>
<span class="s2">"It was only initialized on </span><span class="si">%s</span><span class="s2">."</span><span class="o">%</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">ctx</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_ctx_list</span><span class="p">)))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">:</span>
<span class="k">raise</span> <span class="n">DeferredInitializationError</span><span class="p">(</span>
<span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' has not been initialized yet because initialization was "</span> \
<span class="s2">"deferred. Actual initialization happens during the first forward pass. "</span> \
<span class="s2">"Please pass one batch of data through the network before accessing Parameters. "</span> \
<span class="s2">"You can also avoid deferred initialization by specifying in_units, "</span> \
<span class="s2">"num_features, etc., for network layers."</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' has not been initialized. Note that "</span> \
<span class="s2">"you should initialize parameters and create Trainer "</span> \
<span class="s2">"with Block.collect_params() instead of Block.params "</span> \
<span class="s2">"because the later does not include Parameters of "</span> \
<span class="s2">"nested child Blocks"</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">_load_init</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">ctx</span><span class="p">):</span>
<span class="sd">"""(Re)initializes by loading from data."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
<span class="k">for</span> <span class="n">self_dim</span><span class="p">,</span> <span class="n">data_dim</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">self_dim</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">self_dim</span> <span class="o">==</span> <span class="n">data_dim</span><span class="p">,</span> \
<span class="s2">"Failed loading Parameter '</span><span class="si">%s</span><span class="s2">' from saved params: "</span> \
<span class="s2">"shape incompatible expected </span><span class="si">%s</span><span class="s2"> vs saved </span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">i</span> <span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">j</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">data</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> \
<span class="s2">"Failed loading Parameter '</span><span class="si">%s</span><span class="s2">' from saved params: "</span> \
<span class="s2">"dtype incompatible expected </span><span class="si">%s</span><span class="s2"> vs saved </span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">Context</span><span class="p">):</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">ctx</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">set</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span> <span class="o">==</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> \
<span class="s2">"Failed to load Parameter '</span><span class="si">%s</span><span class="s2">' on </span><span class="si">%s</span><span class="s2"> because it was "</span> \
<span class="s2">"previous initialized on </span><span class="si">%s</span><span class="s2">."</span><span class="o">%</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">ctx</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">list_ctx</span><span class="p">()))</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">elif</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">cpu</span><span class="p">()]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_impl</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">set</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span> <span class="o">==</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">list_ctx</span><span class="p">()),</span> \
<span class="s2">"Failed to load Parameter '</span><span class="si">%s</span><span class="s2">' on </span><span class="si">%s</span><span class="s2"> because it was "</span> \
<span class="s2">"previous initialized on </span><span class="si">%s</span><span class="s2">."</span><span class="o">%</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">ctx</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">list_ctx</span><span class="p">()))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">set_data</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="p">()</span>
<span class="k">def</span> <span class="nf">_finish_deferred_init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Finishes deferred initialization."""</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">:</span>
<span class="k">return</span>
<span class="n">init</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">default_init</span><span class="p">,</span> <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="p">()</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> \
<span class="s2">"Cannot initialize Parameter '</span><span class="si">%s</span><span class="s2">' because it has "</span> \
<span class="s2">"invalid shape: </span><span class="si">%s</span><span class="s2">. Please specify in_units, "</span> \
<span class="s2">"in_channels, etc for `Block`s."</span><span class="o">%</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">with</span> <span class="n">autograd</span><span class="o">.</span><span class="n">pause</span><span class="p">():</span>
<span class="k">if</span> <span class="n">data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">ctx</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span>
<span class="n">initializer</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">default_init</span><span class="p">)(</span>
<span class="n">initializer</span><span class="o">.</span><span class="n">InitDesc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="p">{</span><span class="s1">'__init__'</span><span class="p">:</span> <span class="n">init</span><span class="p">}),</span> <span class="n">data</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_impl</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_init_impl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">ctx_list</span><span class="p">):</span>
<span class="sd">"""Sets data and grad."""</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_ctx_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ctx_list</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_ctx_map</span> <span class="o">=</span> <span class="p">[[],</span> <span class="p">[]]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ctx</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_ctx_list</span><span class="p">):</span>
<span class="n">dev_list</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ctx_map</span><span class="p">[</span><span class="n">ctx</span><span class="o">.</span><span class="n">device_typeid</span><span class="o">&amp;</span><span class="mi">1</span><span class="p">]</span>
<span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">dev_list</span><span class="p">)</span> <span class="o"><=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">device_id</span><span class="p">:</span>
<span class="n">dev_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
<span class="n">dev_list</span><span class="p">[</span><span class="n">ctx</span><span class="o">.</span><span class="n">device_id</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="o">=</span> <span class="p">[</span><span class="n">data</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span> <span class="k">for</span> <span class="n">ctx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ctx_list</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_grad</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">_init_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Initialize grad buffers."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_req</span> <span class="o">==</span> <span class="s1">'null'</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">return</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="o">=</span> <span class="p">[</span><span class="n">ndarray</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">]</span>
<span class="n">autograd</span><span class="o">.</span><span class="n">mark_variables</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">list_data</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_grad</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_req</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_reduce</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Reduce data from multiple context."""</span>
<span class="n">block</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_data</span><span class="p">()</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">add_n</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">block</span><span class="p">))</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">block</span><span class="p">)</span>
<span class="k">return</span> <span class="n">data</span>
<div class="viewcode-block" id="Parameter.initialize"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.initialize">[docs]</a> <span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">default_init</span><span class="o">=</span><span class="n">initializer</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(),</span>
<span class="n">force_reinit</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Initializes parameter and gradient arrays. Only used for :py:class:`NDArray` API.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> init : Initializer</span>
<span class="sd"> The initializer to use. Overrides :py:meth:`Parameter.init` and default_init.</span>
<span class="sd"> ctx : Context or list of Context, defaults to :py:meth:`context.current_context()`.</span>
<span class="sd"> Initialize Parameter on given context. If ctx is a list of Context, a</span>
<span class="sd"> copy will be made for each context.</span>
<span class="sd"> .. note::</span>
<span class="sd"> Copies are independent arrays. User is responsible for keeping</span>
<span class="sd"> their values consistent when updating.</span>
<span class="sd"> Normally :py:class:`gluon.Trainer` does this for you.</span>
<span class="sd"> default_init : Initializer</span>
<span class="sd"> Default initializer is used when both :py:func:`init`</span>
<span class="sd"> and :py:meth:`Parameter.init` are ``None``.</span>
<span class="sd"> force_reinit : bool, default False</span>
<span class="sd"> Whether to force re-initialization if parameter is already initialized.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> >>> weight = mx.gluon.Parameter('weight', shape=(2, 2))</span>
<span class="sd"> >>> weight.initialize(ctx=mx.cpu(0))</span>
<span class="sd"> >>> weight.data()</span>
<span class="sd"> [[-0.01068833 0.01729892]</span>
<span class="sd"> [ 0.02042518 -0.01618656]]</span>
<span class="sd"> <NDArray 2x2 @cpu(0)></span>
<span class="sd"> >>> weight.grad()</span>
<span class="sd"> [[ 0. 0.]</span>
<span class="sd"> [ 0. 0.]]</span>
<span class="sd"> <NDArray 2x2 @cpu(0)></span>
<span class="sd"> >>> weight.initialize(ctx=[mx.gpu(0), mx.gpu(1)])</span>
<span class="sd"> >>> weight.data(mx.gpu(0))</span>
<span class="sd"> [[-0.00873779 -0.02834515]</span>
<span class="sd"> [ 0.05484822 -0.06206018]]</span>
<span class="sd"> <NDArray 2x2 @gpu(0)></span>
<span class="sd"> >>> weight.data(mx.gpu(1))</span>
<span class="sd"> [[-0.00873779 -0.02834515]</span>
<span class="sd"> [ 0.05484822 -0.06206018]]</span>
<span class="sd"> <NDArray 2x2 @gpu(1)></span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">force_reinit</span><span class="p">:</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' is already initialized, ignoring. "</span> \
<span class="s2">"Set force_reinit=True to re-initialize."</span><span class="o">%</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
<span class="n">stacklevel</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="k">return</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">context</span><span class="o">.</span><span class="n">current_context</span><span class="p">()]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">Context</span><span class="p">):</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">ctx</span><span class="p">]</span>
<span class="k">if</span> <span class="n">init</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">init</span> <span class="o">=</span> <span class="n">default_init</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">init</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">init</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_allow_deferred_init</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="p">(</span><span class="n">init</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">default_init</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">return</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Cannot initialize Parameter '</span><span class="si">%s</span><span class="s2">' because it has "</span> \
<span class="s2">"invalid shape: </span><span class="si">%s</span><span class="s2">."</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="p">(</span><span class="n">init</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">default_init</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_finish_deferred_init</span><span class="p">()</span></div>
<div class="viewcode-block" id="Parameter.reset_ctx"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.reset_ctx">[docs]</a> <span class="k">def</span> <span class="nf">reset_ctx</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="p">):</span>
<span class="sd">"""Re-assign Parameter to other contexts.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> ctx : Context or list of Context, default ``context.current_context()``.</span>
<span class="sd"> Assign Parameter to given context. If ctx is a list of Context, a</span>
<span class="sd"> copy will be made for each context.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">context</span><span class="o">.</span><span class="n">current_context</span><span class="p">()]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">Context</span><span class="p">):</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="p">[</span><span class="n">ctx</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce</span><span class="p">()</span>
<span class="k">with</span> <span class="n">autograd</span><span class="o">.</span><span class="n">pause</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_impl</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">:</span>
<span class="n">init</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">default_init</span><span class="p">,</span> <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="p">(</span><span class="n">init</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">default_init</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Cannot reset context for Parameter '</span><span class="si">%s</span><span class="s2">' because it "</span>
<span class="s2">"has not been initialized."</span><span class="o">%</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span></div>
<div class="viewcode-block" id="Parameter.set_data"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.set_data">[docs]</a> <span class="k">def</span> <span class="nf">set_data</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="sd">"""Sets this parameter's value on all contexts."""</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> \
<span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' has not been initialized"</span><span class="o">%</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="n">data</span><span class="p">,)</span>
<span class="k">return</span>
<span class="k">for</span> <span class="n">arr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_data</span><span class="p">():</span>
<span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">data</span></div>
<div class="viewcode-block" id="Parameter.data"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.data">[docs]</a> <span class="k">def</span> <span class="nf">data</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Returns a copy of this parameter on one context. Must have been</span>
<span class="sd"> initialized on this context before.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> ctx : Context</span>
<span class="sd"> Desired context.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> NDArray on ctx</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_and_get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span></div>
<div class="viewcode-block" id="Parameter.list_data"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.list_data">[docs]</a> <span class="k">def</span> <span class="nf">list_data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Returns copies of this parameter on all contexts, in the same order</span>
<span class="sd"> as creation."""</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_and_get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span></div>
<div class="viewcode-block" id="Parameter.grad"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.grad">[docs]</a> <span class="k">def</span> <span class="nf">grad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Returns a gradient buffer for this parameter on one context.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> ctx : Context</span>
<span class="sd"> Desired context.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">"Cannot get gradient array for Parameter '</span><span class="si">%s</span><span class="s2">' "</span> \
<span class="s2">"because grad_req='null'"</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_and_get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_grad</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span></div>
<div class="viewcode-block" id="Parameter.list_grad"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.list_grad">[docs]</a> <span class="k">def</span> <span class="nf">list_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Returns gradient buffers on all contexts, in the same order</span>
<span class="sd"> as :py:meth:`values`."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">"Cannot get gradient array for Parameter '</span><span class="si">%s</span><span class="s2">' "</span> \
<span class="s2">"because grad_req='null'"</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">))</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_and_get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_grad</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span></div>
<div class="viewcode-block" id="Parameter.list_ctx"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.list_ctx">[docs]</a> <span class="k">def</span> <span class="nf">list_ctx</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Returns a list of contexts this parameter is initialized on."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_deferred_init</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' has not been initialized"</span><span class="o">%</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ctx_list</span></div>
<div class="viewcode-block" id="Parameter.zero_grad"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.zero_grad">[docs]</a> <span class="k">def</span> <span class="nf">zero_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Sets gradient buffer on all contexts to 0. No action is taken if</span>
<span class="sd"> parameter is uninitialized or doesn't require gradient."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span><span class="p">:</span>
<span class="n">i</span><span class="p">[:]</span> <span class="o">=</span> <span class="mi">0</span></div>
<div class="viewcode-block" id="Parameter.var"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.var">[docs]</a> <span class="k">def</span> <span class="nf">var</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Returns a symbol representing this parameter."""</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_var</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_var</span> <span class="o">=</span> <span class="n">symbol</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">lr_mult</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lr_mult</span><span class="p">,</span> <span class="n">wd_mult</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">wd_mult</span><span class="p">,</span>
<span class="n">init</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">init</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_var</span></div>
<div class="viewcode-block" id="Parameter.cast"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Parameter.cast">[docs]</a> <span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="sd">"""Cast data and gradient of this Parameter to a new data type.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dtype : str or numpy.dtype</span>
<span class="sd"> The new data type.</span>
<span class="sd"> """</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">with</span> <span class="n">autograd</span><span class="o">.</span><span class="n">pause</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_data</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_grad</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span><span class="p">]</span>
<span class="n">autograd</span><span class="o">.</span><span class="n">mark_variables</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_req</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="Constant"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.Constant">[docs]</a><span class="k">class</span> <span class="nc">Constant</span><span class="p">(</span><span class="n">Parameter</span><span class="p">):</span>
<span class="sd">"""A constant parameter for holding immutable tensors.</span>
<span class="sd"> `Constant`s are ignored by `autograd` and `Trainer`, thus their values</span>
<span class="sd"> will not change during training. But you can still update their values</span>
<span class="sd"> manually with the `set_data` method.</span>
<span class="sd"> `Constant`s can be created with either::</span>
<span class="sd"> const = mx.gluon.Constant('const', [[1,2],[3,4]])</span>
<span class="sd"> or::</span>
<span class="sd"> class Block(gluon.Block):</span>
<span class="sd"> def __init__(self, **kwargs):</span>
<span class="sd"> super(Block, self).__init__(**kwargs)</span>
<span class="sd"> self.const = self.params.get_constant('const', [[1,2],[3,4]])</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> name : str</span>
<span class="sd"> Name of the parameter.</span>
<span class="sd"> value : array-like</span>
<span class="sd"> Initial value for the constant.</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">NDArray</span><span class="p">):</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">class</span> <span class="nc">Init</span><span class="p">(</span><span class="n">initializer</span><span class="o">.</span><span class="n">Initializer</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_init_weight</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">arr</span><span class="p">):</span>
<span class="n">value</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">arr</span><span class="p">)</span>
<span class="n">init_name</span> <span class="o">=</span> <span class="s1">'Constant_</span><span class="si">{}</span><span class="s1">_</span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="nb">id</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span>
<span class="n">initializer</span><span class="o">.</span><span class="n">alias</span><span class="p">(</span><span class="n">init_name</span><span class="p">)(</span><span class="n">Init</span><span class="p">)</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Constant</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="s1">'null'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">init</span><span class="o">=</span><span class="n">init_name</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">s</span> <span class="o">=</span> <span class="s1">'Constant </span><span class="si">{name}</span><span class="s1"> (shape=</span><span class="si">{shape}</span><span class="s1">, dtype=</span><span class="si">{dtype}</span><span class="s1">)'</span>
<span class="k">return</span> <span class="n">s</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="ParameterDict"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict">[docs]</a><span class="k">class</span> <span class="nc">ParameterDict</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="sd">"""A dictionary managing a set of parameters.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> prefix : str, default ``''``</span>
<span class="sd"> The prefix to be prepended to all Parameters' names created by this dict.</span>
<span class="sd"> shared : ParameterDict or None</span>
<span class="sd"> If not ``None``, when this dict's :py:meth:`get` method creates a new parameter, will</span>
<span class="sd"> first try to retrieve it from "shared" dict. Usually used for sharing</span>
<span class="sd"> parameters with another Block.</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">''</span><span class="p">,</span> <span class="n">shared</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_prefix</span> <span class="o">=</span> <span class="n">prefix</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_params</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_shared</span> <span class="o">=</span> <span class="n">shared</span>
<span class="k">def</span> <span class="nf">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">s</span> <span class="o">=</span> <span class="s1">'</span><span class="si">{name}</span><span class="s1">(</span><span class="se">\n</span><span class="si">{content}</span><span class="se">\n</span><span class="s1">)'</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prefix</span><span class="o">+</span><span class="s1">' '</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prefix</span> <span class="k">else</span> <span class="s1">''</span>
<span class="k">return</span> <span class="n">s</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span>
<span class="n">content</span><span class="o">=</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">_indent</span><span class="p">(</span><span class="s1">' </span><span class="si">{0}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">v</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">()]))</span>
<span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">iter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">items</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">keys</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">values</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="o">.</span><span class="n">values</span><span class="p">()</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">prefix</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Prefix of this dict. It will be prepended to :py:class:`Parameter`s' name created</span>
<span class="sd"> with :py:func:`get`."""</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prefix</span>
<span class="k">def</span> <span class="nf">_get_impl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shared</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shared</span><span class="o">.</span><span class="n">_params</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shared</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shared</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">return</span> <span class="kc">None</span>
<div class="viewcode-block" id="ParameterDict.get"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.get">[docs]</a> <span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Retrieves a :py:class:`Parameter` with name ``self.prefix+name``. If not found,</span>
<span class="sd"> :py:func:`get` will first try to retrieve it from "shared" dict. If still not</span>
<span class="sd"> found, :py:func:`get` will create a new :py:class:`Parameter` with key-word arguments and</span>
<span class="sd"> insert it to self.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> name : str</span>
<span class="sd"> Name of the desired Parameter. It will be prepended with this dictionary's</span>
<span class="sd"> prefix.</span>
<span class="sd"> **kwargs : dict</span>
<span class="sd"> The rest of key-word arguments for the created :py:class:`Parameter`.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> Parameter</span>
<span class="sd"> The created or retrieved :py:class:`Parameter`.</span>
<span class="sd"> """</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prefix</span> <span class="o">+</span> <span class="n">name</span>
<span class="n">param</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_impl</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="k">if</span> <span class="n">param</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># pylint: disable=too-many-nested-blocks</span>
<span class="n">param</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">existing</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span>
<span class="k">if</span> <span class="n">k</span> <span class="o">==</span> <span class="s1">'shape'</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">existing</span><span class="p">):</span>
<span class="n">inferred_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">matched</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">for</span> <span class="n">dim1</span><span class="p">,</span> <span class="n">dim2</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">existing</span><span class="p">):</span>
<span class="k">if</span> <span class="n">dim1</span> <span class="o">!=</span> <span class="n">dim2</span> <span class="ow">and</span> <span class="n">dim1</span> <span class="o">*</span> <span class="n">dim2</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">matched</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">break</span>
<span class="k">elif</span> <span class="n">dim1</span> <span class="o">==</span> <span class="n">dim2</span><span class="p">:</span>
<span class="n">inferred_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim1</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">dim1</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">inferred_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim2</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">inferred_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">matched</span><span class="p">:</span>
<span class="n">param</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">inferred_shape</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">assert</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">v</span> <span class="o">==</span> <span class="n">existing</span><span class="p">,</span> \
<span class="s2">"Cannot retrieve Parameter '</span><span class="si">%s</span><span class="s2">' because desired attribute "</span> \
<span class="s2">"does not match with stored for attribute '</span><span class="si">%s</span><span class="s2">': "</span> \
<span class="s2">"desired '</span><span class="si">%s</span><span class="s2">' vs stored '</span><span class="si">%s</span><span class="s2">'."</span><span class="o">%</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">k</span><span class="p">)))</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">setattr</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="k">return</span> <span class="n">param</span></div>
<div class="viewcode-block" id="ParameterDict.get_constant"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.get_constant">[docs]</a> <span class="k">def</span> <span class="nf">get_constant</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Retrieves a :py:class:`Constant` with name ``self.prefix+name``. If not found,</span>
<span class="sd"> :py:func:`get` will first try to retrieve it from "shared" dict. If still not</span>
<span class="sd"> found, :py:func:`get` will create a new :py:class:`Constant` with key-word</span>
<span class="sd"> arguments and insert it to self.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> name : str</span>
<span class="sd"> Name of the desired Constant. It will be prepended with this dictionary's</span>
<span class="sd"> prefix.</span>
<span class="sd"> value : array-like</span>
<span class="sd"> Initial value of constant.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> Constant</span>
<span class="sd"> The created or retrieved :py:class:`Constant`.</span>
<span class="sd"> """</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prefix</span> <span class="o">+</span> <span class="n">name</span>
<span class="n">param</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_impl</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="k">if</span> <span class="n">param</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s2">"No constant named '</span><span class="si">{}</span><span class="s2">'. Please specify value "</span> \
<span class="s2">"if you want to create a new constant."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="n">name</span><span class="p">))</span>
<span class="n">param</span> <span class="o">=</span> <span class="n">Constant</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span>
<span class="k">elif</span> <span class="n">value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">Constant</span><span class="p">),</span> \
<span class="s2">"Parameter '</span><span class="si">{}</span><span class="s2">' already exists but it is not a constant."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="n">name</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">nd</span><span class="o">.</span><span class="n">NDArray</span><span class="p">):</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">param</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">value</span><span class="o">.</span><span class="n">shape</span> <span class="ow">and</span> \
<span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="o">==</span> <span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">(),</span> \
<span class="s2">"Constant '</span><span class="si">{}</span><span class="s2">' already exists but it's value doesn't match new "</span> \
<span class="s2">"value"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="k">return</span> <span class="n">param</span></div>
<div class="viewcode-block" id="ParameterDict.update"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.update">[docs]</a> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="sd">"""Copies all Parameters in ``other`` to self."""</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">other</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">:</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="ow">is</span> <span class="n">v</span><span class="p">,</span> \
<span class="s2">"Cannot update self with other because they have different "</span> \
<span class="s2">"Parameters with the same name '</span><span class="si">%s</span><span class="s2">'"</span><span class="o">%</span><span class="n">k</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">other</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span></div>
<div class="viewcode-block" id="ParameterDict.initialize"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.initialize">[docs]</a> <span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="n">initializer</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(),</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">force_reinit</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Initializes all Parameters managed by this dictionary to be used for :py:class:`NDArray`</span>
<span class="sd"> API. It has no effect when using :py:class:`Symbol` API.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> init : Initializer</span>
<span class="sd"> Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``.</span>
<span class="sd"> Otherwise, :py:meth:`Parameter.init` takes precedence.</span>
<span class="sd"> ctx : Context or list of Context</span>
<span class="sd"> Keeps a copy of Parameters on one or many context(s).</span>
<span class="sd"> verbose : bool, default False</span>
<span class="sd"> Whether to verbosely print out details on initialization.</span>
<span class="sd"> force_reinit : bool, default False</span>
<span class="sd"> Whether to force re-initialization if parameter is already initialized.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
<span class="n">init</span><span class="o">.</span><span class="n">set_verbosity</span><span class="p">(</span><span class="n">verbose</span><span class="o">=</span><span class="n">verbose</span><span class="p">)</span>
<span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">v</span><span class="o">.</span><span class="n">initialize</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">init</span><span class="p">,</span> <span class="n">force_reinit</span><span class="o">=</span><span class="n">force_reinit</span><span class="p">)</span></div>
<div class="viewcode-block" id="ParameterDict.zero_grad"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.zero_grad">[docs]</a> <span class="k">def</span> <span class="nf">zero_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Sets all Parameters' gradient buffer to 0."""</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">i</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></div>
<div class="viewcode-block" id="ParameterDict.reset_ctx"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.reset_ctx">[docs]</a> <span class="k">def</span> <span class="nf">reset_ctx</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="p">):</span>
<span class="sd">"""Re-assign all Parameters to other contexts.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> ctx : Context or list of Context, default :py:meth:`context.current_context()`.</span>
<span class="sd"> Assign Parameter to given context. If ctx is a list of Context, a</span>
<span class="sd"> copy will be made for each context.</span>
<span class="sd"> """</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">i</span><span class="o">.</span><span class="n">reset_ctx</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span></div>
<div class="viewcode-block" id="ParameterDict.setattr"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.setattr">[docs]</a> <span class="k">def</span> <span class="nf">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
<span class="sd">"""Set an attribute to a new value for all Parameters.</span>
<span class="sd"> For example, set grad_req to null if you don't need gradient w.r.t a</span>
<span class="sd"> model's Parameters::</span>
<span class="sd"> model.collect_params().setattr('grad_req', 'null')</span>
<span class="sd"> or change the learning rate multiplier::</span>
<span class="sd"> model.collect_params().setattr('lr_mult', 0.5)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> name : str</span>
<span class="sd"> Name of the attribute.</span>
<span class="sd"> value : valid type for attribute name</span>
<span class="sd"> The new value for the attribute.</span>
<span class="sd"> """</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="nb">setattr</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="ParameterDict.save"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.save">[docs]</a> <span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">strip_prefix</span><span class="o">=</span><span class="s1">''</span><span class="p">):</span>
<span class="sd">"""Save parameters to file.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> filename : str</span>
<span class="sd"> Path to parameter file.</span>
<span class="sd"> strip_prefix : str, default ''</span>
<span class="sd"> Strip prefix from parameter names before saving.</span>
<span class="sd"> """</span>
<span class="n">arg_dict</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">_reduce</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">param</span><span class="o">.</span><span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">strip_prefix</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">"Prefix '</span><span class="si">%s</span><span class="s2">' is to be striped before saving, but Parameter's "</span>
<span class="s2">"name '</span><span class="si">%s</span><span class="s2">' does not start with '</span><span class="si">%s</span><span class="s2">'. "</span>
<span class="s2">"this may be due to your Block shares parameters from other "</span>
<span class="s2">"Blocks or you forgot to use 'with name_scope()' when creating "</span>
<span class="s2">"child blocks. For more info on naming, please see "</span>
<span class="s2">"/versions/1.2.1/tutorials/basic/naming.html"</span><span class="o">%</span><span class="p">(</span>
<span class="n">strip_prefix</span><span class="p">,</span> <span class="n">param</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">strip_prefix</span><span class="p">))</span>
<span class="n">arg_dict</span><span class="p">[</span><span class="n">param</span><span class="o">.</span><span class="n">name</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">strip_prefix</span><span class="p">):]]</span> <span class="o">=</span> <span class="n">weight</span>
<span class="n">ndarray</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">arg_dict</span><span class="p">)</span></div>
<div class="viewcode-block" id="ParameterDict.load"><a class="viewcode-back" href="../../../api/python/gluon/gluon.html#mxnet.gluon.ParameterDict.load">[docs]</a> <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">allow_missing</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">ignore_extra</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">restore_prefix</span><span class="o">=</span><span class="s1">''</span><span class="p">):</span>
<span class="sd">"""Load parameters from file.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> filename : str</span>
<span class="sd"> Path to parameter file.</span>
<span class="sd"> ctx : Context or list of Context</span>
<span class="sd"> Context(s) initialize loaded parameters on.</span>
<span class="sd"> allow_missing : bool, default False</span>
<span class="sd"> Whether to silently skip loading parameters not represents in the file.</span>
<span class="sd"> ignore_extra : bool, default False</span>
<span class="sd"> Whether to silently ignore parameters from the file that are not</span>
<span class="sd"> present in this ParameterDict.</span>
<span class="sd"> restore_prefix : str, default ''</span>
<span class="sd"> prepend prefix to names of stored parameters before loading.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">restore_prefix</span><span class="p">:</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">assert</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">restore_prefix</span><span class="p">),</span> \
<span class="s2">"restore_prefix is '</span><span class="si">%s</span><span class="s2">' but Parameters name '</span><span class="si">%s</span><span class="s2">' does not start "</span> \
<span class="s2">"with '</span><span class="si">%s</span><span class="s2">'"</span><span class="o">%</span><span class="p">(</span><span class="n">restore_prefix</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">restore_prefix</span><span class="p">)</span>
<span class="n">lprefix</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">restore_prefix</span><span class="p">)</span>
<span class="n">loaded</span> <span class="o">=</span> <span class="p">[(</span><span class="n">k</span><span class="p">[</span><span class="mi">4</span><span class="p">:]</span> <span class="k">if</span> <span class="n">k</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'arg:'</span><span class="p">)</span> <span class="ow">or</span> <span class="n">k</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'aux:'</span><span class="p">)</span> <span class="k">else</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> \
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">()]</span>
<span class="n">arg_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">restore_prefix</span><span class="o">+</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">loaded</span><span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">allow_missing</span><span class="p">:</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">assert</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">arg_dict</span><span class="p">,</span> \
<span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' is missing in file '</span><span class="si">%s</span><span class="s2">', which contains parameters: </span><span class="si">%s</span><span class="s2">. "</span> \
<span class="s2">"Please make sure source and target networks have the same prefix."</span><span class="o">%</span><span class="p">(</span>
<span class="n">name</span><span class="p">[</span><span class="n">lprefix</span><span class="p">:],</span> <span class="n">filename</span><span class="p">,</span> <span class="n">_brief_print_list</span><span class="p">(</span><span class="n">arg_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">arg_dict</span><span class="p">:</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">ignore_extra</span><span class="p">,</span> \
<span class="s2">"Parameter '</span><span class="si">%s</span><span class="s2">' loaded from file '</span><span class="si">%s</span><span class="s2">' is not present in ParameterDict, "</span> \
<span class="s2">"choices are: </span><span class="si">%s</span><span class="s2">. Set ignore_extra to True to ignore. "</span> \
<span class="s2">"Please make sure source and target networks have the same prefix."</span><span class="o">%</span><span class="p">(</span>
<span class="n">name</span><span class="p">[</span><span class="n">lprefix</span><span class="p">:],</span> <span class="n">filename</span><span class="p">,</span> <span class="n">_brief_print_list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="k">continue</span>
<span class="bp">self</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">_load_init</span><span class="p">(</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">ctx</span><span class="p">)</span></div></div>
</pre></div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../../_static/js/search.js" type="text/javascript"></script>
<script src="../../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../../_static/js/page.js" type="text/javascript"></script>
<script src="../../../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>