blob: 2181790feac0b6717b064c6211d34221934b251d [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="mxnet.test_utils" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="mxnet.test_utils" property="og:description"/>
<title>mxnet.test_utils — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/1.5.0/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../genindex.html" rel="index" title="Index">
<link href="../../search.html" rel="search" title="Search"/>
<link href="../index.html" rel="up" title="Module code"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/1.5.0/install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.5.0/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.5.0/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/java/index.html">Java</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/1.5.0/faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.5.0/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.5.0">Github</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/community/ecosystem.html">Ecosystem</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">1.5.0<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7/">1.7</a></li><li><a href=/versions/1.6/>1.6</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/1.5.0/install/index.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.5.0/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.5.0/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/java/index.html">Java</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.5.0/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/1.5.0/faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/1.5.0/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.5.0/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/1.5.0/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/1.5.0/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.5.0" tabindex="-1">Github</a></li>
<li><a href="/versions/1.5.0/community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="/versions/1.5.0/community/ecosystem.html" tabindex="-1">Ecosystem</a></li>
<li><a href="/versions/1.5.0/community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">1.5.0</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6/>1.6</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/index.html">MXNet APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">MXNet Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../community/index.html">MXNet Community</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">MXNet FAQ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../gluon/index.html">About Gluon</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing MXNet</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html#nvidia-jetson-tx-family">Nvidia Jetson TX family</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html#source-download">Source Download</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../model_zoo/index.html">MXNet Model Zoo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../tutorials/index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<h1>Source code for mxnet.test_utils</h1><div class="highlight"><pre>
<span></span><span class="c1"># Licensed to the Apache Software Foundation (ASF) under one</span>
<span class="c1"># or more contributor license agreements. See the NOTICE file</span>
<span class="c1"># distributed with this work for additional information</span>
<span class="c1"># regarding copyright ownership. The ASF licenses this file</span>
<span class="c1"># to you under the Apache License, Version 2.0 (the</span>
<span class="c1"># "License"); you may not use this file except in compliance</span>
<span class="c1"># with the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing,</span>
<span class="c1"># software distributed under the License is distributed on an</span>
<span class="c1"># "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY</span>
<span class="c1"># KIND, either express or implied. See the License for the</span>
<span class="c1"># specific language governing permissions and limitations</span>
<span class="c1"># under the License.</span>
<span class="sd">"""Tools for testing."""</span>
<span class="c1"># pylint: disable=too-many-lines</span>
<span class="kn">from</span> <span class="nn">__future__</span> <span class="k">import</span> <span class="n">absolute_import</span><span class="p">,</span> <span class="n">print_function</span><span class="p">,</span> <span class="n">division</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">gzip</span>
<span class="kn">import</span> <span class="nn">struct</span>
<span class="kn">import</span> <span class="nn">traceback</span>
<span class="kn">import</span> <span class="nn">numbers</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">errno</span>
<span class="kn">import</span> <span class="nn">logging</span>
<span class="kn">import</span> <span class="nn">bz2</span>
<span class="kn">import</span> <span class="nn">zipfile</span>
<span class="kn">from</span> <span class="nn">contextlib</span> <span class="k">import</span> <span class="n">contextmanager</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">numpy.testing</span> <span class="k">as</span> <span class="nn">npt</span>
<span class="kn">import</span> <span class="nn">numpy.random</span> <span class="k">as</span> <span class="nn">rnd</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">scipy.stats</span> <span class="k">as</span> <span class="nn">ss</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="n">ss</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">requests</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="c1"># in rare cases requests may be not installed</span>
<span class="k">pass</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">.context</span> <span class="k">import</span> <span class="n">Context</span><span class="p">,</span> <span class="n">current_context</span>
<span class="kn">from</span> <span class="nn">.ndarray.ndarray</span> <span class="k">import</span> <span class="n">_STORAGE_TYPE_STR_TO_ID</span>
<span class="kn">from</span> <span class="nn">.ndarray</span> <span class="k">import</span> <span class="n">array</span>
<span class="kn">from</span> <span class="nn">.symbol</span> <span class="k">import</span> <span class="n">Symbol</span>
<div class="viewcode-block" id="default_context"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.default_context">[docs]</a><span class="k">def</span> <span class="nf">default_context</span><span class="p">():</span>
<span class="sd">"""Get default context for regression test."""</span>
<span class="c1"># _TODO: get context from environment variable to support</span>
<span class="c1"># testing with GPUs</span>
<span class="k">return</span> <span class="n">current_context</span><span class="p">()</span></div>
<div class="viewcode-block" id="set_default_context"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.set_default_context">[docs]</a><span class="k">def</span> <span class="nf">set_default_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">):</span>
<span class="sd">"""Set default context."""</span>
<span class="n">Context</span><span class="o">.</span><span class="n">_default_ctx</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">ctx</span></div>
<div class="viewcode-block" id="default_dtype"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.default_dtype">[docs]</a><span class="k">def</span> <span class="nf">default_dtype</span><span class="p">():</span>
<span class="sd">"""Get default data type for regression test."""</span>
<span class="c1"># _TODO: get default dtype from environment variable</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span></div>
<div class="viewcode-block" id="get_atol"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_atol">[docs]</a><span class="k">def</span> <span class="nf">get_atol</span><span class="p">(</span><span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Get default numerical threshold for regression test."""</span>
<span class="c1"># _TODO: get from env variable, different threshold might</span>
<span class="c1"># be needed for different device and dtype</span>
<span class="k">return</span> <span class="mf">1e-20</span> <span class="k">if</span> <span class="n">atol</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">atol</span></div>
<div class="viewcode-block" id="get_rtol"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_rtol">[docs]</a><span class="k">def</span> <span class="nf">get_rtol</span><span class="p">(</span><span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Get default numerical threshold for regression test."""</span>
<span class="c1"># _TODO: get from env variable, different threshold might</span>
<span class="c1"># be needed for different device and dtype</span>
<span class="k">return</span> <span class="mf">1e-5</span> <span class="k">if</span> <span class="n">rtol</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">rtol</span></div>
<div class="viewcode-block" id="get_etol"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_etol">[docs]</a><span class="k">def</span> <span class="nf">get_etol</span><span class="p">(</span><span class="n">etol</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Get default numerical threshold for regression test."""</span>
<span class="c1"># _TODO: get from env variable, different threshold might</span>
<span class="c1"># be needed for different device and dtype</span>
<span class="k">return</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">etol</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">etol</span></div>
<div class="viewcode-block" id="random_arrays"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.random_arrays">[docs]</a><span class="k">def</span> <span class="nf">random_arrays</span><span class="p">(</span><span class="o">*</span><span class="n">shapes</span><span class="p">):</span>
<span class="sd">"""Generate some random numpy arrays."""</span>
<span class="n">arrays</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="o">*</span><span class="n">s</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">default_dtype</span><span class="p">())</span>
<span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">shapes</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">arrays</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">arrays</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">arrays</span></div>
<div class="viewcode-block" id="random_sample"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.random_sample">[docs]</a><span class="k">def</span> <span class="nf">random_sample</span><span class="p">(</span><span class="n">population</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span>
<span class="sd">"""Return a k length list of the elements chosen from the population sequence."""</span>
<span class="k">assert</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">k</span> <span class="o"><=</span> <span class="nb">len</span><span class="p">(</span><span class="n">population</span><span class="p">)</span>
<span class="n">population_copy</span> <span class="o">=</span> <span class="n">population</span><span class="p">[:]</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">population_copy</span><span class="p">)</span>
<span class="k">return</span> <span class="n">population_copy</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">k</span><span class="p">]</span></div>
<span class="k">def</span> <span class="nf">_validate_csr_generation_inputs</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">density</span><span class="p">,</span>
<span class="n">distribution</span><span class="o">=</span><span class="s2">"uniform"</span><span class="p">):</span>
<span class="sd">"""Validates inputs for csr generation helper functions</span>
<span class="sd"> """</span>
<span class="n">total_nnz</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_rows</span> <span class="o">*</span> <span class="n">num_cols</span> <span class="o">*</span> <span class="n">density</span><span class="p">)</span>
<span class="k">if</span> <span class="n">density</span> <span class="o"><</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">density</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"density has to be between 0 and 1"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">num_rows</span> <span class="o"><=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">num_cols</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"num_rows or num_cols should be greater than 0"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">distribution</span> <span class="o">==</span> <span class="s2">"powerlaw"</span><span class="p">:</span>
<span class="k">if</span> <span class="n">total_nnz</span> <span class="o"><</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_rows</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"not supported for this density: </span><span class="si">%s</span><span class="s2">"</span>
<span class="s2">" for this shape (</span><span class="si">%s</span><span class="s2">, </span><span class="si">%s</span><span class="s2">)"</span>
<span class="s2">" Please keep :"</span>
<span class="s2">" num_rows * num_cols * density >= 2 * num_rows"</span>
<span class="o">%</span> <span class="p">(</span><span class="n">density</span><span class="p">,</span> <span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">))</span>
<div class="viewcode-block" id="shuffle_csr_column_indices"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.shuffle_csr_column_indices">[docs]</a><span class="k">def</span> <span class="nf">shuffle_csr_column_indices</span><span class="p">(</span><span class="n">csr</span><span class="p">):</span>
<span class="sd">"""Shuffle CSR column indices per row</span>
<span class="sd"> This allows validation of unordered column indices, which is not a requirement</span>
<span class="sd"> for a valid CSR matrix</span>
<span class="sd"> """</span>
<span class="n">row_count</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">csr</span><span class="o">.</span><span class="n">indptr</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">row_count</span><span class="p">):</span>
<span class="n">start_index</span> <span class="o">=</span> <span class="n">csr</span><span class="o">.</span><span class="n">indptr</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">end_index</span> <span class="o">=</span> <span class="n">csr</span><span class="o">.</span><span class="n">indptr</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">sublist</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">csr</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">start_index</span> <span class="p">:</span> <span class="n">end_index</span><span class="p">])</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">sublist</span><span class="p">)</span>
<span class="n">csr</span><span class="o">.</span><span class="n">indices</span><span class="p">[</span><span class="n">start_index</span> <span class="p">:</span> <span class="n">end_index</span><span class="p">]</span> <span class="o">=</span> <span class="n">sublist</span></div>
<span class="k">def</span> <span class="nf">_get_uniform_dataset_csr</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">data_init</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Returns CSRNDArray with uniform distribution</span>
<span class="sd"> This generates a csr matrix with totalnnz unique randomly chosen numbers</span>
<span class="sd"> from num_rows*num_cols and arranges them in the 2d array in the</span>
<span class="sd"> following way:</span>
<span class="sd"> row_index = (random_number_generated / num_rows)</span>
<span class="sd"> col_index = random_number_generated - row_index * num_cols</span>
<span class="sd"> """</span>
<span class="n">_validate_csr_generation_inputs</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">density</span><span class="p">,</span>
<span class="n">distribution</span><span class="o">=</span><span class="s2">"uniform"</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">scipy</span> <span class="k">import</span> <span class="n">sparse</span> <span class="k">as</span> <span class="n">spsp</span>
<span class="n">csr</span> <span class="o">=</span> <span class="n">spsp</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">density</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">format</span><span class="o">=</span><span class="s2">"csr"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">data_init</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">csr</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">fill</span><span class="p">(</span><span class="n">data_init</span><span class="p">)</span>
<span class="k">if</span> <span class="n">shuffle_csr_indices</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
<span class="n">shuffle_csr_column_indices</span><span class="p">(</span><span class="n">csr</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">sparse</span><span class="o">.</span><span class="n">csr_matrix</span><span class="p">((</span><span class="n">csr</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">csr</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">csr</span><span class="o">.</span><span class="n">indptr</span><span class="p">),</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="k">assert</span><span class="p">(</span><span class="n">data_init</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">),</span> \
<span class="s2">"data_init option is not supported when scipy is absent"</span>
<span class="k">assert</span><span class="p">(</span><span class="ow">not</span> <span class="n">shuffle_csr_indices</span><span class="p">),</span> \
<span class="s2">"shuffle_csr_indices option is not supported when scipy is absent"</span>
<span class="c1"># scipy not available. try to generate one from a dense array</span>
<span class="n">dns</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">masked_dns</span> <span class="o">=</span> <span class="n">dns</span> <span class="o">*</span> <span class="p">(</span><span class="n">dns</span> <span class="o"><</span> <span class="n">density</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">masked_dns</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s1">'csr'</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span>
<span class="k">def</span> <span class="nf">_get_powerlaw_dataset_csr</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Returns CSRNDArray with powerlaw distribution</span>
<span class="sd"> with exponentially increasing number of non zeros in each row.</span>
<span class="sd"> Not supported for cases where total_nnz < 2*num_rows. This is because</span>
<span class="sd"> the algorithm first tries to ensure that there are rows with no zeros by</span>
<span class="sd"> putting non zeros at beginning of each row.</span>
<span class="sd"> """</span>
<span class="n">_validate_csr_generation_inputs</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">density</span><span class="p">,</span>
<span class="n">distribution</span><span class="o">=</span><span class="s2">"powerlaw"</span><span class="p">)</span>
<span class="n">total_nnz</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_rows</span> <span class="o">*</span> <span class="n">num_cols</span> <span class="o">*</span> <span class="n">density</span><span class="p">)</span>
<span class="n">unused_nnz</span> <span class="o">=</span> <span class="n">total_nnz</span>
<span class="n">output_arr</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># Start with ones on each row so that no row is empty</span>
<span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rows</span><span class="p">):</span>
<span class="n">output_arr</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">rnd</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.001</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">unused_nnz</span> <span class="o">=</span> <span class="n">unused_nnz</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">unused_nnz</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">output_arr</span><span class="p">)</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s2">"csr"</span><span class="p">)</span>
<span class="c1"># Populate rest of matrix with 2^i items in ith row.</span>
<span class="c1"># if we have used all total nnz return the sparse matrix</span>
<span class="c1"># else if we reached max column size then fill up full columns until we use all nnz</span>
<span class="n">col_max</span> <span class="o">=</span> <span class="mi">2</span>
<span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rows</span><span class="p">):</span>
<span class="n">col_limit</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_cols</span><span class="p">,</span> <span class="n">col_max</span><span class="p">)</span>
<span class="c1"># In case col_limit reached assign same value to all elements, which is much faster</span>
<span class="k">if</span> <span class="n">col_limit</span> <span class="o">==</span> <span class="n">num_cols</span> <span class="ow">and</span> <span class="n">unused_nnz</span> <span class="o">></span> <span class="n">col_limit</span><span class="p">:</span>
<span class="n">output_arr</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">rnd</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.001</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">unused_nnz</span> <span class="o">=</span> <span class="n">unused_nnz</span> <span class="o">-</span> <span class="n">col_limit</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">unused_nnz</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">output_arr</span><span class="p">)</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s2">"csr"</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">continue</span>
<span class="k">for</span> <span class="n">col_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">col_limit</span><span class="p">):</span>
<span class="n">output_arr</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="n">col_index</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">rnd</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.001</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">unused_nnz</span> <span class="o">=</span> <span class="n">unused_nnz</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">unused_nnz</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">output_arr</span><span class="p">)</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s2">"csr"</span><span class="p">)</span>
<span class="n">col_max</span> <span class="o">=</span> <span class="n">col_max</span> <span class="o">*</span> <span class="mi">2</span>
<span class="k">if</span> <span class="n">unused_nnz</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># pylint: disable=no-else-raise</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"not supported for this density: </span><span class="si">%s</span><span class="s2">"</span>
<span class="s2">" for this shape (</span><span class="si">%s</span><span class="s2">,</span><span class="si">%s</span><span class="s2">)"</span> <span class="o">%</span> <span class="p">(</span><span class="n">density</span><span class="p">,</span> <span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">output_arr</span><span class="p">)</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s2">"csr"</span><span class="p">)</span>
<div class="viewcode-block" id="assign_each"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.assign_each">[docs]</a><span class="k">def</span> <span class="nf">assign_each</span><span class="p">(</span><span class="n">the_input</span><span class="p">,</span> <span class="n">function</span><span class="p">):</span>
<span class="sd">"""Return ndarray composed of passing each array value through some function"""</span>
<span class="k">if</span> <span class="n">function</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">the_input</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">it_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">nditer</span><span class="p">(</span><span class="n">the_input</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'f_index'</span><span class="p">])</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">the_input</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">it_out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">nditer</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'f_index'</span><span class="p">],</span> <span class="n">op_flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'writeonly'</span><span class="p">])</span>
<span class="k">while</span> <span class="ow">not</span> <span class="n">it_input</span><span class="o">.</span><span class="n">finished</span><span class="p">:</span>
<span class="n">val_input</span> <span class="o">=</span> <span class="n">it_input</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">it_out</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">function</span><span class="p">(</span><span class="n">val_input</span><span class="p">)</span>
<span class="n">it_input</span><span class="o">.</span><span class="n">iternext</span><span class="p">()</span>
<span class="n">it_out</span><span class="o">.</span><span class="n">iternext</span><span class="p">()</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="assign_each2"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.assign_each2">[docs]</a><span class="k">def</span> <span class="nf">assign_each2</span><span class="p">(</span><span class="n">input1</span><span class="p">,</span> <span class="n">input2</span><span class="p">,</span> <span class="n">function</span><span class="p">):</span>
<span class="sd">"""Return ndarray composed of passing two array values through some function"""</span>
<span class="k">if</span> <span class="n">function</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">input1</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">input2</span><span class="o">.</span><span class="n">shape</span>
<span class="n">it_input1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">nditer</span><span class="p">(</span><span class="n">input1</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'f_index'</span><span class="p">])</span>
<span class="n">it_input2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">nditer</span><span class="p">(</span><span class="n">input2</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'f_index'</span><span class="p">])</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">input1</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">it_out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">nditer</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'f_index'</span><span class="p">],</span> <span class="n">op_flags</span><span class="o">=</span><span class="p">[</span><span class="s1">'writeonly'</span><span class="p">])</span>
<span class="k">while</span> <span class="ow">not</span> <span class="n">it_input1</span><span class="o">.</span><span class="n">finished</span><span class="p">:</span>
<span class="n">val_input1</span> <span class="o">=</span> <span class="n">it_input1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">val_input2</span> <span class="o">=</span> <span class="n">it_input2</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">it_out</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">function</span><span class="p">(</span><span class="n">val_input1</span><span class="p">,</span> <span class="n">val_input2</span><span class="p">)</span>
<span class="n">it_input1</span><span class="o">.</span><span class="n">iternext</span><span class="p">()</span>
<span class="n">it_input2</span><span class="o">.</span><span class="n">iternext</span><span class="p">()</span>
<span class="n">it_out</span><span class="o">.</span><span class="n">iternext</span><span class="p">()</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="rand_sparse_ndarray"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.rand_sparse_ndarray">[docs]</a><span class="k">def</span> <span class="nf">rand_sparse_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">distribution</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">data_init</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">rsp_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">modifier_func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> shape: list or tuple</span>
<span class="sd"> stype: str</span>
<span class="sd"> valid values: "csr" or "row_sparse"</span>
<span class="sd"> density: float, optional</span>
<span class="sd"> should be between 0 and 1</span>
<span class="sd"> distribution: str, optional</span>
<span class="sd"> valid values: "uniform" or "powerlaw"</span>
<span class="sd"> dtype: numpy.dtype, optional</span>
<span class="sd"> default value is None</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> Result of type CSRNDArray or RowSparseNDArray</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> Below is an example of the powerlaw distribution with csr as the stype.</span>
<span class="sd"> It calculates the nnz using the shape and density.</span>
<span class="sd"> It fills up the ndarray with exponentially increasing number of elements.</span>
<span class="sd"> If there are enough unused_nnzs, n+1th row will have twice more nnzs compared to nth row.</span>
<span class="sd"> else, remaining unused_nnzs will be used in n+1th row</span>
<span class="sd"> If number of cols is too small and we have already reached column size it will fill up</span>
<span class="sd"> all following columns in all followings rows until we reach the required density.</span>
<span class="sd"> >>> csr_arr, _ = rand_sparse_ndarray(shape=(5, 16), stype="csr",</span>
<span class="sd"> density=0.50, distribution="powerlaw")</span>
<span class="sd"> >>> indptr = csr_arr.indptr.asnumpy()</span>
<span class="sd"> >>> indices = csr_arr.indices.asnumpy()</span>
<span class="sd"> >>> data = csr_arr.data.asnumpy()</span>
<span class="sd"> >>> row2nnz = len(data[indptr[1]:indptr[2]])</span>
<span class="sd"> >>> row3nnz = len(data[indptr[2]:indptr[3]])</span>
<span class="sd"> >>> assert(row3nnz == 2*row2nnz)</span>
<span class="sd"> >>> row4nnz = len(data[indptr[3]:indptr[4]])</span>
<span class="sd"> >>> assert(row4nnz == 2*row3nnz)</span>
<span class="sd"> """</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">ctx</span> <span class="k">if</span> <span class="n">ctx</span> <span class="k">else</span> <span class="n">default_context</span><span class="p">()</span>
<span class="n">density</span> <span class="o">=</span> <span class="n">rnd</span><span class="o">.</span><span class="n">rand</span><span class="p">()</span> <span class="k">if</span> <span class="n">density</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">density</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">default_dtype</span><span class="p">()</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">dtype</span>
<span class="n">distribution</span> <span class="o">=</span> <span class="s2">"uniform"</span> <span class="k">if</span> <span class="n">distribution</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">distribution</span>
<span class="k">if</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'row_sparse'</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">distribution</span> <span class="o">==</span> <span class="s2">"uniform"</span><span class="p">),</span> \
<span class="s2">"Distribution </span><span class="si">%s</span><span class="s2"> not supported for row_sparse"</span> <span class="o">%</span> <span class="p">(</span><span class="n">distribution</span><span class="p">)</span>
<span class="c1"># sample index</span>
<span class="k">if</span> <span class="n">rsp_indices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">rsp_indices</span>
<span class="k">assert</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">indices</span><span class="p">)</span> <span class="o"><=</span> <span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">idx_sample</span> <span class="o">=</span> <span class="n">rnd</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argwhere</span><span class="p">(</span><span class="n">idx_sample</span> <span class="o"><</span> <span class="n">density</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">if</span> <span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="o">=</span><span class="s1">'row_sparse'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span><span class="p">,</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([]))</span>
<span class="c1"># generate random values</span>
<span class="n">val</span> <span class="o">=</span> <span class="n">rnd</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># Allow caller to override or adjust random values</span>
<span class="k">if</span> <span class="n">data_init</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">val</span><span class="o">.</span><span class="n">fill</span><span class="p">(</span><span class="n">data_init</span><span class="p">)</span>
<span class="k">if</span> <span class="n">modifier_func</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">val</span> <span class="o">=</span> <span class="n">assign_each</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">modifier_func</span><span class="p">)</span>
<span class="n">arr</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">sparse</span><span class="o">.</span><span class="n">row_sparse_array</span><span class="p">((</span><span class="n">val</span><span class="p">,</span> <span class="n">indices</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">arr</span><span class="p">,</span> <span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'csr'</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
<span class="k">if</span> <span class="n">distribution</span> <span class="o">==</span> <span class="s2">"uniform"</span><span class="p">:</span>
<span class="n">csr</span> <span class="o">=</span> <span class="n">_get_uniform_dataset_csr</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">density</span><span class="p">,</span>
<span class="n">data_init</span><span class="o">=</span><span class="n">data_init</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="n">shuffle_csr_indices</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">csr</span><span class="p">,</span> <span class="p">(</span><span class="n">csr</span><span class="o">.</span><span class="n">indptr</span><span class="p">,</span> <span class="n">csr</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">csr</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">distribution</span> <span class="o">==</span> <span class="s2">"powerlaw"</span><span class="p">:</span>
<span class="n">csr</span> <span class="o">=</span> <span class="n">_get_powerlaw_dataset_csr</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">density</span><span class="o">=</span><span class="n">density</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">csr</span><span class="p">,</span> <span class="p">(</span><span class="n">csr</span><span class="o">.</span><span class="n">indptr</span><span class="p">,</span> <span class="n">csr</span><span class="o">.</span><span class="n">indices</span><span class="p">,</span> <span class="n">csr</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span><span class="p">(</span><span class="kc">False</span><span class="p">),</span> <span class="s2">"Distribution not supported: </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">distribution</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span><span class="p">(</span><span class="kc">False</span><span class="p">),</span> <span class="s2">"unknown storage type"</span>
<span class="k">return</span> <span class="kc">False</span></div>
<div class="viewcode-block" id="rand_ndarray"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.rand_ndarray">[docs]</a><span class="k">def</span> <span class="nf">rand_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="o">=</span><span class="s1">'default'</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">modifier_func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">distribution</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Generate a random sparse ndarray. Returns the generated ndarray."""</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">ctx</span> <span class="k">if</span> <span class="n">ctx</span> <span class="k">else</span> <span class="n">default_context</span><span class="p">()</span>
<span class="k">if</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'default'</span><span class="p">:</span>
<span class="n">arr</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">random_arrays</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">arr</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">rand_sparse_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="n">density</span><span class="p">,</span>
<span class="n">modifier_func</span><span class="o">=</span><span class="n">modifier_func</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="n">shuffle_csr_indices</span><span class="p">,</span>
<span class="n">distribution</span><span class="o">=</span><span class="n">distribution</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">arr</span></div>
<div class="viewcode-block" id="create_sparse_array"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.create_sparse_array">[docs]</a><span class="k">def</span> <span class="nf">create_sparse_array</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">data_init</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">rsp_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">modifier_func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">density</span><span class="o">=.</span><span class="mi">5</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Create a sparse array, For Rsp, assure indices are in a canonical format"""</span>
<span class="k">if</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'row_sparse'</span><span class="p">:</span>
<span class="k">if</span> <span class="n">rsp_indices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">arr_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">rsp_indices</span><span class="p">)</span>
<span class="n">arr_indices</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">arr_indices</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">arr_data</span><span class="p">,</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">)</span> <span class="o">=</span> <span class="n">rand_sparse_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span>
<span class="n">density</span><span class="o">=</span><span class="n">density</span><span class="p">,</span>
<span class="n">data_init</span><span class="o">=</span><span class="n">data_init</span><span class="p">,</span>
<span class="n">rsp_indices</span><span class="o">=</span><span class="n">arr_indices</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">modifier_func</span><span class="o">=</span><span class="n">modifier_func</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'csr'</span><span class="p">:</span>
<span class="n">arr_data</span><span class="p">,</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">)</span> <span class="o">=</span> <span class="n">rand_sparse_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span>
<span class="n">stype</span><span class="p">,</span>
<span class="n">density</span><span class="o">=</span><span class="n">density</span><span class="p">,</span>
<span class="n">data_init</span><span class="o">=</span><span class="n">data_init</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">modifier_func</span><span class="o">=</span><span class="n">modifier_func</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="n">shuffle_csr_indices</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">msg</span> <span class="o">=</span> <span class="s2">"Unknown storage type: "</span> <span class="o">+</span> <span class="n">stype</span>
<span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
<span class="k">return</span> <span class="n">arr_data</span></div>
<div class="viewcode-block" id="create_sparse_array_zd"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.create_sparse_array_zd">[docs]</a><span class="k">def</span> <span class="nf">create_sparse_array_zd</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">density</span><span class="p">,</span> <span class="n">data_init</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">rsp_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">modifier_func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Create sparse array, using only rsp_indices to determine density"""</span>
<span class="k">if</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'row_sparse'</span><span class="p">:</span>
<span class="n">density</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">if</span> <span class="n">rsp_indices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">rsp_indices</span><span class="p">)</span> <span class="o"><=</span> <span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">create_sparse_array</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span>
<span class="n">data_init</span><span class="o">=</span><span class="n">data_init</span><span class="p">,</span>
<span class="n">rsp_indices</span><span class="o">=</span><span class="n">rsp_indices</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">modifier_func</span><span class="o">=</span><span class="n">modifier_func</span><span class="p">,</span>
<span class="n">density</span><span class="o">=</span><span class="n">density</span><span class="p">,</span>
<span class="n">shuffle_csr_indices</span><span class="o">=</span><span class="n">shuffle_csr_indices</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">rand_shape_2d</span><span class="p">(</span><span class="n">dim0</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">dim1</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
<span class="k">return</span> <span class="n">rnd</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim0</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">rnd</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim1</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">rand_shape_3d</span><span class="p">(</span><span class="n">dim0</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">dim1</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">dim2</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
<span class="k">return</span> <span class="n">rnd</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim0</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">rnd</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim1</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">rnd</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">rand_shape_nd</span><span class="p">(</span><span class="n">num_dim</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">rnd</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">num_dim</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">rand_coord_2d</span><span class="p">(</span><span class="n">x_low</span><span class="p">,</span> <span class="n">x_high</span><span class="p">,</span> <span class="n">y_low</span><span class="p">,</span> <span class="n">y_high</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">x_low</span><span class="p">,</span> <span class="n">x_high</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">y_low</span><span class="p">,</span> <span class="n">y_high</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span>
<div class="viewcode-block" id="np_reduce"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.np_reduce">[docs]</a><span class="k">def</span> <span class="nf">np_reduce</span><span class="p">(</span><span class="n">dat</span><span class="p">,</span> <span class="n">axis</span><span class="p">,</span> <span class="n">keepdims</span><span class="p">,</span> <span class="n">numpy_reduce_func</span><span class="p">):</span>
<span class="sd">"""Compatible reduce for old version of NumPy.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dat : np.ndarray</span>
<span class="sd"> Same as NumPy.</span>
<span class="sd"> axis : None or int or list-like</span>
<span class="sd"> Same as NumPy.</span>
<span class="sd"> keepdims : bool</span>
<span class="sd"> Same as NumPy.</span>
<span class="sd"> numpy_reduce_func : function</span>
<span class="sd"> A NumPy reducing function like ``np.sum`` or ``np.max``.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">axis</span> <span class="o">=</span> <span class="p">[</span><span class="n">axis</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">axis</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">axis</span><span class="p">)</span> <span class="k">if</span> <span class="n">axis</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">dat</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">sorted</span><span class="p">(</span><span class="n">axis</span><span class="p">)):</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">numpy_reduce_func</span><span class="p">(</span><span class="n">ret</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">keepdims</span><span class="p">:</span>
<span class="n">keepdims_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">axis</span><span class="p">:</span>
<span class="n">keepdims_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">keepdims_shape</span><span class="p">))</span>
<span class="k">return</span> <span class="n">ret</span></div>
<div class="viewcode-block" id="find_max_violation"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.find_max_violation">[docs]</a><span class="k">def</span> <span class="nf">find_max_violation</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Finds and returns the location of maximum violation."""</span>
<span class="n">rtol</span> <span class="o">=</span> <span class="n">get_rtol</span><span class="p">(</span><span class="n">rtol</span><span class="p">)</span>
<span class="n">atol</span> <span class="o">=</span> <span class="n">get_atol</span><span class="p">(</span><span class="n">atol</span><span class="p">)</span>
<span class="n">diff</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">a</span><span class="o">-</span><span class="n">b</span><span class="p">)</span>
<span class="n">tol</span> <span class="o">=</span> <span class="n">atol</span> <span class="o">+</span> <span class="n">rtol</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">violation</span> <span class="o">=</span> <span class="n">diff</span><span class="o">/</span><span class="p">(</span><span class="n">tol</span><span class="o">+</span><span class="mf">1e-20</span><span class="p">)</span>
<span class="n">loc</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">violation</span><span class="p">)</span>
<span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">unravel_index</span><span class="p">(</span><span class="n">loc</span><span class="p">,</span> <span class="n">violation</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">idx</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">violation</span><span class="p">)</span></div>
<div class="viewcode-block" id="same"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.same">[docs]</a><span class="k">def</span> <span class="nf">same</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="sd">"""Test if two NumPy arrays are the same.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> a : np.ndarray</span>
<span class="sd"> b : np.ndarray</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span></div>
<div class="viewcode-block" id="almost_equal"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.almost_equal">[docs]</a><span class="k">def</span> <span class="nf">almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">equal_nan</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Test if two numpy arrays are almost equal."""</span>
<span class="c1"># pylint: disable=unexpected-keyword-arg</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="n">get_rtol</span><span class="p">(</span><span class="n">rtol</span><span class="p">),</span> <span class="n">atol</span><span class="o">=</span><span class="n">get_atol</span><span class="p">(</span><span class="n">atol</span><span class="p">),</span> <span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span></div>
<span class="c1"># pylint: enable=unexpected-keyword-arg</span>
<div class="viewcode-block" id="assert_almost_equal"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.assert_almost_equal">[docs]</a><span class="k">def</span> <span class="nf">assert_almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">names</span><span class="o">=</span><span class="p">(</span><span class="s1">'a'</span><span class="p">,</span> <span class="s1">'b'</span><span class="p">),</span> <span class="n">equal_nan</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Test that two numpy arrays are almost equal. Raise exception message if not.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> a : np.ndarray</span>
<span class="sd"> b : np.ndarray</span>
<span class="sd"> threshold : None or float</span>
<span class="sd"> The checking threshold. Default threshold will be used if set to ``None``.</span>
<span class="sd"> """</span>
<span class="n">rtol</span> <span class="o">=</span> <span class="n">get_rtol</span><span class="p">(</span><span class="n">rtol</span><span class="p">)</span>
<span class="n">atol</span> <span class="o">=</span> <span class="n">get_atol</span><span class="p">(</span><span class="n">atol</span><span class="p">)</span>
<span class="k">if</span> <span class="n">almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">):</span>
<span class="k">return</span>
<span class="n">index</span><span class="p">,</span> <span class="n">rel</span> <span class="o">=</span> <span class="n">find_max_violation</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">suppress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">msg</span> <span class="o">=</span> <span class="n">npt</span><span class="o">.</span><span class="n">build_err_msg</span><span class="p">([</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">],</span>
<span class="n">err_msg</span><span class="o">=</span><span class="s2">"Error </span><span class="si">%f</span><span class="s2"> exceeds tolerance rtol=</span><span class="si">%f</span><span class="s2">, atol=</span><span class="si">%f</span><span class="s2">. "</span>
<span class="s2">" Location of maximum error:</span><span class="si">%s</span><span class="s2">, a=</span><span class="si">%f</span><span class="s2">, b=</span><span class="si">%f</span><span class="s2">"</span>
<span class="o">%</span> <span class="p">(</span><span class="n">rel</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">index</span><span class="p">),</span> <span class="n">a</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">index</span><span class="p">]),</span>
<span class="n">names</span><span class="o">=</span><span class="n">names</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span></div>
<div class="viewcode-block" id="assert_almost_equal_with_err"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.assert_almost_equal_with_err">[docs]</a><span class="k">def</span> <span class="nf">assert_almost_equal_with_err</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">etol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">names</span><span class="o">=</span><span class="p">(</span><span class="s1">'a'</span><span class="p">,</span> <span class="s1">'b'</span><span class="p">),</span> <span class="n">equal_nan</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="sd">"""Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> a : np.ndarray</span>
<span class="sd"> b : np.ndarray</span>
<span class="sd"> threshold : None or float</span>
<span class="sd"> The checking threshold. Default threshold will be used if set to ``None``.</span>
<span class="sd"> etol : None or float</span>
<span class="sd"> The error rate threshold. If etol is float, return true if error_rate < etol even if</span>
<span class="sd"> any error is found.</span>
<span class="sd"> """</span>
<span class="n">rtol</span> <span class="o">=</span> <span class="n">get_rtol</span><span class="p">(</span><span class="n">rtol</span><span class="p">)</span>
<span class="n">atol</span> <span class="o">=</span> <span class="n">get_atol</span><span class="p">(</span><span class="n">atol</span><span class="p">)</span>
<span class="n">etol</span> <span class="o">=</span> <span class="n">get_etol</span><span class="p">(</span><span class="n">etol</span><span class="p">)</span>
<span class="k">if</span> <span class="n">etol</span><span class="p">:</span>
<span class="n">equals</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">isclose</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="n">atol</span><span class="p">)</span>
<span class="n">err</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">count_nonzero</span><span class="p">(</span><span class="n">equals</span><span class="p">)</span> <span class="o">/</span> <span class="n">equals</span><span class="o">.</span><span class="n">size</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">></span> <span class="n">etol</span><span class="p">:</span>
<span class="c1">#if True:</span>
<span class="n">index</span><span class="p">,</span> <span class="n">rel</span> <span class="o">=</span> <span class="n">find_max_violation</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">suppress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">msg</span> <span class="o">=</span> <span class="n">npt</span><span class="o">.</span><span class="n">build_err_msg</span><span class="p">([</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">],</span>
<span class="n">err_msg</span><span class="o">=</span><span class="s2">"Error </span><span class="si">%f</span><span class="s2"> exceeds tolerance rtol=</span><span class="si">%f</span><span class="s2">, atol=</span><span class="si">%f</span><span class="s2">, etol=</span><span class="si">%f</span><span class="s2">."</span>
<span class="s2">" Error_rate=</span><span class="si">%f</span><span class="s2">. Location of maximum error:</span><span class="si">%s</span><span class="s2">, a=</span><span class="si">%f</span><span class="s2">, b=</span><span class="si">%f</span><span class="s2">"</span>
<span class="o">%</span> <span class="p">(</span><span class="n">rel</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="n">etol</span><span class="p">,</span> <span class="n">err</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">index</span><span class="p">),</span> <span class="n">a</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">index</span><span class="p">]),</span>
<span class="n">names</span><span class="o">=</span><span class="n">names</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
<span class="k">if</span> <span class="n">almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">):</span>
<span class="k">return</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">):</span>
<span class="k">return</span>
<span class="n">index</span><span class="p">,</span> <span class="n">rel</span> <span class="o">=</span> <span class="n">find_max_violation</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">suppress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">msg</span> <span class="o">=</span> <span class="n">npt</span><span class="o">.</span><span class="n">build_err_msg</span><span class="p">([</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">],</span>
<span class="n">err_msg</span><span class="o">=</span><span class="s2">"Error </span><span class="si">%f</span><span class="s2"> exceeds tolerance rtol=</span><span class="si">%f</span><span class="s2">, atol=</span><span class="si">%f</span><span class="s2">. "</span>
<span class="s2">" Location of maximum error:</span><span class="si">%s</span><span class="s2">, a=</span><span class="si">%f</span><span class="s2">, b=</span><span class="si">%f</span><span class="s2">"</span>
<span class="o">%</span> <span class="p">(</span><span class="n">rel</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">index</span><span class="p">),</span> <span class="n">a</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="n">b</span><span class="p">[</span><span class="n">index</span><span class="p">]),</span>
<span class="n">names</span><span class="o">=</span><span class="n">names</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span></div>
<div class="viewcode-block" id="almost_equal_ignore_nan"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.almost_equal_ignore_nan">[docs]</a><span class="k">def</span> <span class="nf">almost_equal_ignore_nan</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Test that two NumPy arrays are almost equal (ignoring NaN in either array).</span>
<span class="sd"> Combines a relative and absolute measure of approximate eqality.</span>
<span class="sd"> If either the relative or absolute check passes, the arrays are considered equal.</span>
<span class="sd"> Including an absolute check resolves issues with the relative check where all</span>
<span class="sd"> array values are close to zero.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> a : np.ndarray</span>
<span class="sd"> b : np.ndarray</span>
<span class="sd"> rtol : None or float</span>
<span class="sd"> The relative threshold. Default threshold will be used if set to ``None``.</span>
<span class="sd"> atol : None or float</span>
<span class="sd"> The absolute threshold. Default threshold will be used if set to ``None``.</span>
<span class="sd"> """</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">nan_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_or</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">a</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
<span class="n">a</span><span class="p">[</span><span class="n">nan_mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">b</span><span class="p">[</span><span class="n">nan_mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">return</span> <span class="n">almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">)</span></div>
<div class="viewcode-block" id="assert_almost_equal_ignore_nan"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.assert_almost_equal_ignore_nan">[docs]</a><span class="k">def</span> <span class="nf">assert_almost_equal_ignore_nan</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">names</span><span class="o">=</span><span class="p">(</span><span class="s1">'a'</span><span class="p">,</span> <span class="s1">'b'</span><span class="p">)):</span>
<span class="sd">"""Test that two NumPy arrays are almost equal (ignoring NaN in either array).</span>
<span class="sd"> Combines a relative and absolute measure of approximate eqality.</span>
<span class="sd"> If either the relative or absolute check passes, the arrays are considered equal.</span>
<span class="sd"> Including an absolute check resolves issues with the relative check where all</span>
<span class="sd"> array values are close to zero.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> a : np.ndarray</span>
<span class="sd"> b : np.ndarray</span>
<span class="sd"> rtol : None or float</span>
<span class="sd"> The relative threshold. Default threshold will be used if set to ``None``.</span>
<span class="sd"> atol : None or float</span>
<span class="sd"> The absolute threshold. Default threshold will be used if set to ``None``.</span>
<span class="sd"> """</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="n">nan_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_or</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">a</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
<span class="n">a</span><span class="p">[</span><span class="n">nan_mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">b</span><span class="p">[</span><span class="n">nan_mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="n">names</span><span class="p">)</span></div>
<div class="viewcode-block" id="assert_exception"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.assert_exception">[docs]</a><span class="k">def</span> <span class="nf">assert_exception</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">exception_type</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Test that function f will throw an exception of type given by `exception_type`"""</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">f</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">assert</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
<span class="k">except</span> <span class="n">exception_type</span><span class="p">:</span>
<span class="k">return</span></div>
<div class="viewcode-block" id="retry"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.retry">[docs]</a><span class="k">def</span> <span class="nf">retry</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="sd">"""Retry n times before failing for stochastic test cases."""</span>
<span class="k">assert</span> <span class="n">n</span> <span class="o">></span> <span class="mi">0</span>
<span class="k">def</span> <span class="nf">decorate</span><span class="p">(</span><span class="n">f</span><span class="p">):</span>
<span class="sd">"""Decorate a test case."""</span>
<span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Wrapper for tests function."""</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">f</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">AssertionError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">e</span>
<span class="k">continue</span>
<span class="k">return</span>
<span class="k">raise</span> <span class="n">err</span>
<span class="k">return</span> <span class="n">wrapper</span>
<span class="k">return</span> <span class="n">decorate</span></div>
<div class="viewcode-block" id="simple_forward"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.simple_forward">[docs]</a><span class="k">def</span> <span class="nf">simple_forward</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">inputs</span><span class="p">):</span>
<span class="sd">"""A simple forward function for a symbol.</span>
<span class="sd"> Primarily used in doctest to test the functionality of a symbol.</span>
<span class="sd"> Takes NumPy arrays as inputs and outputs are also converted to NumPy arrays.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> ctx : Context</span>
<span class="sd"> If ``None``, will take the default context.</span>
<span class="sd"> inputs : keyword arguments</span>
<span class="sd"> Mapping each input name to a NumPy array.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> The result as a numpy array. Multiple results will</span>
<span class="sd"> be returned as a list of NumPy arrays.</span>
<span class="sd"> """</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">ctx</span> <span class="ow">or</span> <span class="n">default_context</span><span class="p">()</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">array</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">exe</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="n">inputs</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="n">is_train</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">outputs</span></div>
<span class="k">def</span> <span class="nf">_parse_location</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">default_dtype</span><span class="p">()):</span>
<span class="sd">"""Parses the given location to a dictionary.</span>
<span class="sd"> Arguments of the provided op `sym` are used as dictionary keys</span>
<span class="sd"> and elements of `location` are used as values.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> sym : Symbol</span>
<span class="sd"> Symbol containing op</span>
<span class="sd"> location : list or tuple or dict</span>
<span class="sd"> Argument values location</span>
<span class="sd"> - if type is list or tuple of `np.ndarray`</span>
<span class="sd"> inner elements are arrays correspoding to</span>
<span class="sd"> ``sym.list_arguments()``.</span>
<span class="sd"> - if type is dict of str -> `np.ndarray`</span>
<span class="sd"> maps the name of arguments to the corresponding `np.ndarray`.</span>
<span class="sd"> *In either case, value of all the arguments must be provided.*</span>
<span class="sd"> ctx : Context</span>
<span class="sd"> Device context.</span>
<span class="sd"> dtype: "asnumpy" or np.float16 or np.float32 or np.float64</span>
<span class="sd"> If dtype is "asnumpy" then the mx.nd.array created will have the same</span>
<span class="sd"> type as th numpy array from which it is copied.</span>
<span class="sd"> Otherwise, dtype is the explicit datatype for all mx.nd.array objects</span>
<span class="sd"> created in this function.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> dict</span>
<span class="sd"> Dictionary with `sym` arguments as keys and `location` elements as</span>
<span class="sd"> values.</span>
<span class="sd"> Examples</span>
<span class="sd"> -------</span>
<span class="sd"> >>> a = mx.symbol.Variable('a')</span>
<span class="sd"> >>> b = mx.symbol.Variable('b')</span>
<span class="sd"> >>> l1 = np.ndarray([2,3])</span>
<span class="sd"> >>> l2 = np.ndarray([3,4])</span>
<span class="sd"> >>> _parse_location(a * b, [l1, l2], None)</span>
<span class="sd"> {'a': <NDArray 2x3 @cpu(0)>, 'b': <NDArray 3x4 @cpu(0)>}</span>
<span class="sd"> >>> _parse_location(a * b, {'a': l1, 'b': l2}, None)</span>
<span class="sd"> {'a': <NDArray 2x3 @cpu(0)>, 'b': <NDArray 3x4 @cpu(0)>}</span>
<span class="sd"> >>> _parse_location(a * b, {'a': l1}, None)</span>
<span class="sd"> ValueError: Symbol arguments and keys of the given location do not match.</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">location</span><span class="p">,</span> <span class="p">(</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">"asnumpy"</span> <span class="ow">or</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">location</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">location</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Symbol arguments and keys of the given location do not match."</span>
<span class="s2">"symbol args:</span><span class="si">%s</span><span class="s2">, location.keys():</span><span class="si">%s</span><span class="s2">"</span>
<span class="o">%</span> <span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">())),</span> <span class="nb">str</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">location</span><span class="o">.</span><span class="n">keys</span><span class="p">()))))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">location</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">(),</span> <span class="n">location</span><span class="p">)}</span>
<span class="n">location</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">"asnumpy"</span> <span class="k">else</span> <span class="n">dtype</span><span class="p">)</span> \
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="k">else</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">return</span> <span class="n">location</span>
<span class="k">def</span> <span class="nf">_parse_aux_states</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">aux_states</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">default_dtype</span><span class="p">()):</span>
<span class="sd">"""Parses the given auxiliary states to a dictionary.</span>
<span class="sd"> Auxiliary states of the provided op `sym` are used as dictionary</span>
<span class="sd"> keys and elements of `aux_states` are used as values.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> sym : Symbol</span>
<span class="sd"> Symbol containing op</span>
<span class="sd"> aux_states : None or list or dict</span>
<span class="sd"> Aux states</span>
<span class="sd"> - if type is list or tuple of `np.ndarray`</span>
<span class="sd"> inner elements are arrays correspoding to</span>
<span class="sd"> ``sym.list_auxiliary_states()``.</span>
<span class="sd"> - if type is dict of str -> `np.ndarray`</span>
<span class="sd"> maps the name of arguments to the corresponding `np.ndarray`.</span>
<span class="sd"> *In either case, all aux states of `sym` must be provided.*</span>
<span class="sd"> ctx : Context</span>
<span class="sd"> Device context.</span>
<span class="sd"> dtype: "asnumpy" or np.float16 or np.float32 or np.float64</span>
<span class="sd"> If dtype is "asnumpy" then the mx.nd.array created will have the same</span>
<span class="sd"> type as th numpy array from which it is copied.</span>
<span class="sd"> Otherwise, dtype is the explicit datatype for all mx.nd.array objects</span>
<span class="sd"> created in this function.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> dict</span>
<span class="sd"> Dictionary with `sym` aux states as keys and `aux_states` elements</span>
<span class="sd"> as values.</span>
<span class="sd"> Examples</span>
<span class="sd"> -------</span>
<span class="sd"> >>> data = mx.symbol.Variable('data')</span>
<span class="sd"> >>> weight = mx.sym.Variable(name='fc1_weight')</span>
<span class="sd"> >>> fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128)</span>
<span class="sd"> >>> fc2 = mx.symbol.BatchNorm(fc1, name='batchnorm0')</span>
<span class="sd"> >>> mean_states = np.ones(3)</span>
<span class="sd"> >>> var_states = np.ones(3)</span>
<span class="sd"> >>> _parse_aux_states(fc2, [mean_states, var_states], None)</span>
<span class="sd"> {'batchnorm0_moving_var': <NDArray 3 @cpu(0)>, 'batchnorm0_moving_mean': <NDArray 3 @cpu(0)>}</span>
<span class="sd"> >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states,</span>
<span class="sd"> ... 'batchnorm0_moving_mean': var_states}, None)</span>
<span class="sd"> {'batchnorm0_moving_var': <NDArray 3 @cpu(0)>, 'batchnorm0_moving_mean': <NDArray 3 @cpu(0)>}</span>
<span class="sd"> >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)</span>
<span class="sd"> ValueError: Symbol aux_states names and given aux_states do not match.</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">"asnumpy"</span> <span class="ow">or</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
<span class="k">if</span> <span class="n">aux_states</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">aux_states</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">aux_states</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_auxiliary_states</span><span class="p">()):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Symbol aux_states names and given aux_states do not match."</span>
<span class="s2">"symbol aux_names:</span><span class="si">%s</span><span class="s2">, aux_states.keys:</span><span class="si">%s</span><span class="s2">"</span>
<span class="o">%</span> <span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_auxiliary_states</span><span class="p">())),</span>
<span class="nb">str</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">aux_states</span><span class="o">.</span><span class="n">keys</span><span class="p">()))))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">aux_states</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="n">aux_names</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">list_auxiliary_states</span><span class="p">()</span>
<span class="n">aux_states</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">aux_names</span><span class="p">,</span> <span class="n">aux_states</span><span class="p">)}</span>
<span class="n">aux_states</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">"asnumpy"</span> <span class="k">else</span> <span class="n">dtype</span><span class="p">)</span> \
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">aux_states</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">return</span> <span class="n">aux_states</span>
<div class="viewcode-block" id="numeric_grad"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.numeric_grad">[docs]</a><span class="k">def</span> <span class="nf">numeric_grad</span><span class="p">(</span><span class="n">executor</span><span class="p">,</span> <span class="n">location</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span>
<span class="n">use_forward_train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">default_dtype</span><span class="p">()):</span>
<span class="sd">"""Calculates a numeric gradient via finite difference method.</span>
<span class="sd"> Class based on Theano's `theano.gradient.numeric_grad` [1]</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> executor : Executor</span>
<span class="sd"> Executor that computes the forward pass.</span>
<span class="sd"> location : list of numpy.ndarray or dict of str to numpy.ndarray</span>
<span class="sd"> Argument values used as location to compute gradient</span>
<span class="sd"> Maps the name of arguments to the corresponding numpy.ndarray.</span>
<span class="sd"> Value of all the arguments must be provided.</span>
<span class="sd"> aux_states : None or list of numpy.ndarray or dict of str to numpy.ndarray, optional</span>
<span class="sd"> Auxiliary states values used as location to compute gradient</span>
<span class="sd"> Maps the name of aux_states to the corresponding numpy.ndarray.</span>
<span class="sd"> Value of all the auxiliary arguments must be provided.</span>
<span class="sd"> eps : float, optional</span>
<span class="sd"> Epsilon for the finite-difference method.</span>
<span class="sd"> use_forward_train : bool, optional</span>
<span class="sd"> Whether to use `is_train=True` in testing.</span>
<span class="sd"> dtype: np.float16 or np.float32 or np.float64</span>
<span class="sd"> Datatype for mx.nd.array.</span>
<span class="sd"> References</span>
<span class="sd"> ---------</span>
<span class="sd"> ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="nf">as_stype</span><span class="p">(</span><span class="n">var</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="k">return</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">cast_storage</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">var</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">stype</span><span class="o">=</span><span class="n">stype</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
<span class="n">approx_grads</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">stype</span> <span class="o">=</span> <span class="n">executor</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">stype</span>
<span class="k">if</span> <span class="n">stype</span> <span class="o">==</span> <span class="s1">'default'</span><span class="p">:</span>
<span class="n">executor</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">as_stype</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">location</span><span class="p">:</span>
<span class="n">location</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">location</span><span class="p">[</span><span class="n">k</span><span class="p">],</span> <span class="n">order</span><span class="o">=</span><span class="s1">'C'</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="n">v</span><span class="o">.</span><span class="n">dtype</span><span class="o">.</span><span class="n">kind</span> <span class="o">!=</span> <span class="s1">'f'</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">stype</span> <span class="o">=</span> <span class="n">executor</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">stype</span>
<span class="n">old_value</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)):</span>
<span class="c1"># inplace update</span>
<span class="n">v</span><span class="o">.</span><span class="n">ravel</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">eps</span><span class="o">/</span><span class="mf">2.0</span>
<span class="n">executor</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">as_stype</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">aux_states</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">aux_states</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">executor</span><span class="o">.</span><span class="n">aux_dict</span><span class="p">[</span><span class="n">key</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">val</span>
<span class="n">executor</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="n">use_forward_train</span><span class="p">)</span>
<span class="n">f_peps</span> <span class="o">=</span> <span class="n">executor</span><span class="o">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">v</span><span class="o">.</span><span class="n">ravel</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">-=</span> <span class="n">eps</span>
<span class="n">executor</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">as_stype</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">aux_states</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">aux_states</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">adstype</span> <span class="o">=</span> <span class="n">executor</span><span class="o">.</span><span class="n">aux_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">stype</span>
<span class="n">executor</span><span class="o">.</span><span class="n">aux_dict</span><span class="p">[</span><span class="n">key</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">as_stype</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">adstype</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">executor</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="n">use_forward_train</span><span class="p">)</span>
<span class="n">f_neps</span> <span class="o">=</span> <span class="n">executor</span><span class="o">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">approx_grad</span> <span class="o">=</span> <span class="p">(</span><span class="n">f_peps</span> <span class="o">-</span> <span class="n">f_neps</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="n">eps</span>
<span class="n">approx_grads</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">ravel</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">approx_grad</span>
<span class="n">v</span><span class="o">.</span><span class="n">ravel</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">old_value</span><span class="o">.</span><span class="n">ravel</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span>
<span class="c1"># copy back the original value</span>
<span class="n">executor</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">k</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">as_stype</span><span class="p">(</span><span class="n">old_value</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="n">approx_grads</span></div>
<div class="viewcode-block" id="check_numeric_gradient"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.check_numeric_gradient">[docs]</a><span class="k">def</span> <span class="nf">check_numeric_gradient</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">numeric_eps</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span>
<span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">grad_nodes</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">use_forward_train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">grad_stype_dict</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">default_dtype</span><span class="p">()):</span>
<span class="sd">"""Verify an operation by checking backward pass via finite difference method.</span>
<span class="sd"> Based on Theano's `theano.gradient.verify_grad` [1]</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> sym : Symbol</span>
<span class="sd"> Symbol containing op to test</span>
<span class="sd"> location : list or tuple or dict</span>
<span class="sd"> Argument values used as location to compute gradient</span>
<span class="sd"> - if type is list of numpy.ndarray, \</span>
<span class="sd"> inner elements should have the same order as mxnet.sym.list_arguments().</span>
<span class="sd"> - if type is dict of str -> numpy.ndarray, \</span>
<span class="sd"> maps the name of arguments to the corresponding numpy.ndarray.</span>
<span class="sd"> *In either case, value of all the arguments must be provided.*</span>
<span class="sd"> aux_states : list or tuple or dict, optional</span>
<span class="sd"> The auxiliary states required when generating the executor for the symbol.</span>
<span class="sd"> numeric_eps : float, optional</span>
<span class="sd"> Delta for the finite difference method that approximates the gradient.</span>
<span class="sd"> check_eps : float, optional</span>
<span class="sd"> relative error eps used when comparing numeric grad to symbolic grad.</span>
<span class="sd"> grad_nodes : None or list or tuple or dict, optional</span>
<span class="sd"> Names of the nodes to check gradient on</span>
<span class="sd"> use_forward_train : bool</span>
<span class="sd"> Whether to use is_train=True when computing the finite-difference.</span>
<span class="sd"> ctx : Context, optional</span>
<span class="sd"> Check the gradient computation on the specified device.</span>
<span class="sd"> grad_stype_dict : dict of str->str, optional</span>
<span class="sd"> Storage type dictionary for gradient ndarrays.</span>
<span class="sd"> dtype: np.float16 or np.float32 or np.float64</span>
<span class="sd"> Datatype for mx.nd.array.</span>
<span class="sd"> References</span>
<span class="sd"> ---------</span>
<span class="sd"> [1] https://github.com/Theano/Theano/blob/master/theano/gradient.py</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
<span class="c1"># cannot use finite differences with small eps without high precision</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">numeric_eps</span> <span class="o">>=</span> <span class="mf">1e-5</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">default_context</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">random_projection</span><span class="p">(</span><span class="n">shape</span><span class="p">):</span>
<span class="sd">"""Get a random weight matrix with not too small elements</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> shape : list or tuple</span>
<span class="sd"> """</span>
<span class="c1"># random_projection should not have elements too small,</span>
<span class="c1"># otherwise too much precision is lost in numerical gradient</span>
<span class="n">plain</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.1</span>
<span class="k">return</span> <span class="n">plain</span>
<span class="n">location</span> <span class="o">=</span> <span class="n">_parse_location</span><span class="p">(</span><span class="n">sym</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="o">=</span><span class="n">location</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">location_npy</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">v</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">aux_states</span> <span class="o">=</span> <span class="n">_parse_aux_states</span><span class="p">(</span><span class="n">sym</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="n">aux_states</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">aux_states</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">aux_states_npy</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">aux_states</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">aux_states_npy</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">grad_nodes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">grad_nodes</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()</span>
<span class="n">grad_req</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="s1">'write'</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">grad_nodes</span><span class="p">}</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">grad_nodes</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="n">grad_nodes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">grad_nodes</span><span class="p">)</span>
<span class="n">grad_req</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="s1">'write'</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">grad_nodes</span><span class="p">}</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">grad_nodes</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="n">grad_req</span> <span class="o">=</span> <span class="n">grad_nodes</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="n">grad_nodes</span> <span class="o">=</span> <span class="n">grad_nodes</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">_</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">infer_shape</span><span class="p">(</span><span class="o">**</span><span class="n">input_shape</span><span class="p">)</span>
<span class="n">proj</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s2">"__random_proj"</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">sym</span> <span class="o">*</span> <span class="n">proj</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">make_loss</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="n">location</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">())</span> <span class="o">+</span>
<span class="p">[(</span><span class="s2">"__random_proj"</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">random_projection</span><span class="p">(</span><span class="n">out_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span>
<span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))])</span>
<span class="n">args_grad_npy</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">([(</span><span class="n">k</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">location</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">grad_nodes</span><span class="p">]</span>
<span class="o">+</span> <span class="p">[(</span><span class="s2">"__random_proj"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">out_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))])</span>
<span class="n">args_grad</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">args_grad_npy</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">if</span> <span class="n">grad_stype_dict</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">grad_stype_dict</span><span class="p">,</span> <span class="nb">dict</span><span class="p">),</span> <span class="s2">"grad_stype_dict must be a dict"</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">grad_stype_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">args_grad</span> <span class="ow">and</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">_STORAGE_TYPE_STR_TO_ID</span> <span class="ow">and</span> <span class="n">v</span> <span class="o">!=</span> <span class="s1">'default'</span><span class="p">:</span>
<span class="c1"># create an uninitialized sparse ndarray for executor</span>
<span class="c1"># if the symbolic grad is expected to be zero, it should not be initialized at all</span>
<span class="n">args_grad</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">args_grad</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">args_grad</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">context</span><span class="p">,</span>
<span class="n">args_grad</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="n">executor</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="n">grad_req</span><span class="p">,</span>
<span class="n">args</span><span class="o">=</span><span class="n">location</span><span class="p">,</span> <span class="n">args_grad</span><span class="o">=</span><span class="n">args_grad</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="n">aux_states</span><span class="p">)</span>
<span class="n">inps</span> <span class="o">=</span> <span class="n">executor</span><span class="o">.</span><span class="n">arg_arrays</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">inps</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">location</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Executor arg_arrays and and location len do not match."</span>
<span class="s2">"Got </span><span class="si">%d</span><span class="s2"> inputs and </span><span class="si">%d</span><span class="s2"> locations"</span><span class="o">%</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">inps</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">location</span><span class="p">)))</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">executor</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
<span class="n">executor</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">executor</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">symbolic_grads</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">executor</span><span class="o">.</span><span class="n">grad_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">grad_nodes</span><span class="p">}</span>
<span class="n">numeric_gradients</span> <span class="o">=</span> <span class="n">numeric_grad</span><span class="p">(</span>
<span class="n">executor</span><span class="p">,</span> <span class="n">location_npy</span><span class="p">,</span> <span class="n">aux_states_npy</span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="n">numeric_eps</span><span class="p">,</span> <span class="n">use_forward_train</span><span class="o">=</span><span class="n">use_forward_train</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">grad_nodes</span><span class="p">:</span>
<span class="n">fd_grad</span> <span class="o">=</span> <span class="n">numeric_gradients</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="n">orig_grad</span> <span class="o">=</span> <span class="n">args_grad_npy</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="n">sym_grad</span> <span class="o">=</span> <span class="n">symbolic_grads</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">if</span> <span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'write'</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">fd_grad</span><span class="p">,</span> <span class="n">sym_grad</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span>
<span class="p">(</span><span class="s2">"NUMERICAL_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">,</span> <span class="s2">"BACKWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">))</span>
<span class="k">elif</span> <span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'add'</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">fd_grad</span><span class="p">,</span> <span class="n">sym_grad</span> <span class="o">-</span> <span class="n">orig_grad</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span>
<span class="p">(</span><span class="s2">"NUMERICAL_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">,</span> <span class="s2">"BACKWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">))</span>
<span class="k">elif</span> <span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'null'</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">orig_grad</span><span class="p">,</span> <span class="n">sym_grad</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span>
<span class="p">(</span><span class="s2">"NUMERICAL_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">,</span> <span class="s2">"BACKWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Invalid grad_req </span><span class="si">%s</span><span class="s2"> for argument </span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="p">(</span><span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">name</span><span class="p">))</span></div>
<div class="viewcode-block" id="check_symbolic_forward"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.check_symbolic_forward">[docs]</a><span class="k">def</span> <span class="nf">check_symbolic_forward</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="p">,</span> <span class="n">expected</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1E-4</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">aux_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">equal_nan</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">default_dtype</span><span class="p">()):</span>
<span class="sd">"""Compares a symbol's forward results with the expected ones.</span>
<span class="sd"> Prints error messages if the forward results are not the same as the expected ones.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ---------</span>
<span class="sd"> sym : Symbol</span>
<span class="sd"> output symbol</span>
<span class="sd"> location : list of np.ndarray or dict of str to np.ndarray</span>
<span class="sd"> The evaluation point</span>
<span class="sd"> - if type is list of np.ndarray</span>
<span class="sd"> Contains all the numpy arrays corresponding to `sym.list_arguments()`.</span>
<span class="sd"> - if type is dict of str to np.ndarray</span>
<span class="sd"> Contains the mapping between argument names and their values.</span>
<span class="sd"> expected : list of np.ndarray or dict of str to np.ndarray</span>
<span class="sd"> The expected output value</span>
<span class="sd"> - if type is list of np.ndarray</span>
<span class="sd"> Contains arrays corresponding to exe.outputs.</span>
<span class="sd"> - if type is dict of str to np.ndarray</span>
<span class="sd"> Contains mapping between sym.list_output() and exe.outputs.</span>
<span class="sd"> check_eps : float, optional</span>
<span class="sd"> Relative error to check to.</span>
<span class="sd"> aux_states : list of np.ndarray of dict, optional</span>
<span class="sd"> - if type is list of np.ndarray</span>
<span class="sd"> Contains all the NumPy arrays corresponding to sym.list_auxiliary_states</span>
<span class="sd"> - if type is dict of str to np.ndarray</span>
<span class="sd"> Contains the mapping between names of auxiliary states and their values.</span>
<span class="sd"> ctx : Context, optional</span>
<span class="sd"> running context</span>
<span class="sd"> dtype: "asnumpy" or np.float16 or np.float32 or np.float64</span>
<span class="sd"> If dtype is "asnumpy" then the mx.nd.array created will have the same</span>
<span class="sd"> type as th numpy array from which it is copied.</span>
<span class="sd"> Otherwise, dtype is the explicit datatype for all mx.nd.array objects</span>
<span class="sd"> created in this function.</span>
<span class="sd"> equal_nan: Boolean</span>
<span class="sd"> if True, `nan` is a valid value for checking equivalency (ie `nan` == `nan`)</span>
<span class="sd"> Example</span>
<span class="sd"> -------</span>
<span class="sd"> >>> shape = (2, 2)</span>
<span class="sd"> >>> lhs = mx.symbol.Variable('lhs')</span>
<span class="sd"> >>> rhs = mx.symbol.Variable('rhs')</span>
<span class="sd"> >>> sym_dot = mx.symbol.dot(lhs, rhs)</span>
<span class="sd"> >>> mat1 = np.array([[1, 2], [3, 4]])</span>
<span class="sd"> >>> mat2 = np.array([[5, 6], [7, 8]])</span>
<span class="sd"> >>> ret_expected = np.array([[19, 22], [43, 50]])</span>
<span class="sd"> >>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">"asnumpy"</span> <span class="ow">or</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">default_context</span><span class="p">()</span>
<span class="n">location</span> <span class="o">=</span> <span class="n">_parse_location</span><span class="p">(</span><span class="n">sym</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="o">=</span><span class="n">location</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">aux_states</span> <span class="o">=</span> <span class="n">_parse_aux_states</span><span class="p">(</span><span class="n">sym</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="n">aux_states</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">expected</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="n">expected</span> <span class="o">=</span> <span class="p">[</span><span class="n">expected</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">sym</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">()]</span>
<span class="n">args_grad_data</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">"asnumpy"</span> <span class="k">else</span> <span class="n">dtype</span><span class="p">)</span> \
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">executor</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="n">location</span><span class="p">,</span> <span class="n">args_grad</span><span class="o">=</span><span class="n">args_grad_data</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="n">aux_states</span><span class="p">)</span>
<span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">executor</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">:</span>
<span class="n">g</span><span class="p">[:]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">executor</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">executor</span><span class="o">.</span><span class="n">outputs</span><span class="p">]</span>
<span class="k">for</span> <span class="n">output_name</span><span class="p">,</span> <span class="n">expect</span><span class="p">,</span> <span class="n">output</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">(),</span> <span class="n">expected</span><span class="p">,</span> <span class="n">outputs</span><span class="p">):</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">expect</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span>
<span class="p">(</span><span class="s2">"EXPECTED_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">output_name</span><span class="p">,</span> <span class="s2">"FORWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">output_name</span><span class="p">),</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span>
<span class="k">return</span> <span class="n">executor</span><span class="o">.</span><span class="n">outputs</span></div>
<div class="viewcode-block" id="check_symbolic_backward"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.check_symbolic_backward">[docs]</a><span class="k">def</span> <span class="nf">check_symbolic_backward</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="p">,</span> <span class="n">out_grads</span><span class="p">,</span> <span class="n">expected</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">aux_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="s1">'write'</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">grad_stypes</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">default_dtype</span><span class="p">()):</span>
<span class="sd">"""Compares a symbol's backward results with the expected ones.</span>
<span class="sd"> Prints error messages if the backward results are not the same as the expected results.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ---------</span>
<span class="sd"> sym : Symbol</span>
<span class="sd"> output symbol</span>
<span class="sd"> location : list of np.ndarray or dict of str to np.ndarray</span>
<span class="sd"> The evaluation point</span>
<span class="sd"> - if type is list of np.ndarray</span>
<span class="sd"> Contains all the NumPy arrays corresponding to ``mx.sym.list_arguments``.</span>
<span class="sd"> - if type is dict of str to np.ndarray</span>
<span class="sd"> Contains the mapping between argument names and their values.</span>
<span class="sd"> out_grads : None or list of np.ndarray or dict of str to np.ndarray</span>
<span class="sd"> NumPys arrays corresponding to sym.outputs for incomming gradient.</span>
<span class="sd"> - if type is list of np.ndarray</span>
<span class="sd"> Contains arrays corresponding to ``exe.outputs``.</span>
<span class="sd"> - if type is dict of str to np.ndarray</span>
<span class="sd"> contains mapping between mxnet.sym.list_output() and Executor.outputs</span>
<span class="sd"> expected : list of np.ndarray or dict of str to np.ndarray</span>
<span class="sd"> expected gradient values</span>
<span class="sd"> - if type is list of np.ndarray</span>
<span class="sd"> Contains arrays corresponding to exe.grad_arrays</span>
<span class="sd"> - if type is dict of str to np.ndarray</span>
<span class="sd"> Contains mapping between ``sym.list_arguments()`` and exe.outputs.</span>
<span class="sd"> check_eps: float, optional</span>
<span class="sd"> Relative error to check to.</span>
<span class="sd"> aux_states : list of np.ndarray or dict of str to np.ndarray</span>
<span class="sd"> grad_req : str or list of str or dict of str to str, optional</span>
<span class="sd"> Gradient requirements. 'write', 'add' or 'null'.</span>
<span class="sd"> ctx : Context, optional</span>
<span class="sd"> Running context.</span>
<span class="sd"> grad_stypes: dict of str->str</span>
<span class="sd"> dictionary of mapping argument name to stype for the gradient</span>
<span class="sd"> equal_nan: Boolean</span>
<span class="sd"> if True, `nan` is a valid value for checking equivalency (ie `nan` == `nan`)</span>
<span class="sd"> dtype: np.float16 or np.float32 or np.float64</span>
<span class="sd"> Datatype for mx.nd.array.</span>
<span class="sd"> Example</span>
<span class="sd"> -------</span>
<span class="sd"> >>> lhs = mx.symbol.Variable('lhs')</span>
<span class="sd"> >>> rhs = mx.symbol.Variable('rhs')</span>
<span class="sd"> >>> sym_add = mx.symbol.elemwise_add(lhs, rhs)</span>
<span class="sd"> >>> mat1 = np.array([[1, 2], [3, 4]])</span>
<span class="sd"> >>> mat2 = np.array([[5, 6], [7, 8]])</span>
<span class="sd"> >>> grad1 = mx.nd.zeros(shape)</span>
<span class="sd"> >>> grad2 = mx.nd.zeros(shape)</span>
<span class="sd"> >>> exec_add = sym_add.bind(default_context(), args={'lhs': mat1, 'rhs': mat2},</span>
<span class="sd"> ... args_grad={'lhs': grad1, 'rhs': grad2}, grad_req={'lhs': 'write', 'rhs': 'write'})</span>
<span class="sd"> >>> exec_add.forward(is_train=True)</span>
<span class="sd"> >>> ograd = mx.nd.ones(shape)</span>
<span class="sd"> >>> grad_expected = ograd.copy().asnumpy()</span>
<span class="sd"> >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">default_context</span><span class="p">()</span>
<span class="n">location</span> <span class="o">=</span> <span class="n">_parse_location</span><span class="p">(</span><span class="n">sym</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="o">=</span><span class="n">location</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">aux_states</span> <span class="o">=</span> <span class="n">_parse_aux_states</span><span class="p">(</span><span class="n">sym</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">aux_states</span><span class="o">=</span><span class="n">aux_states</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">expected</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="n">expected</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">(),</span> <span class="n">expected</span><span class="p">)}</span>
<span class="n">args_grad_npy</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">expected</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">args_grad_data</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">args_grad_npy</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">nd</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">grad_stypes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">grad_stypes</span><span class="p">:</span>
<span class="n">stype</span> <span class="o">=</span> <span class="n">grad_stypes</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
<span class="k">if</span> <span class="n">stype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">stype</span> <span class="o">!=</span> <span class="s1">'default'</span><span class="p">:</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">create_sparse_array</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">nd</span>
<span class="n">args_grad_data</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">out</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">args_grad_data</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">nd</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">grad_req</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">grad_req</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">grad_req</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()}</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">grad_req</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="n">grad_req</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">sym</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">(),</span> <span class="n">grad_req</span><span class="p">)}</span>
<span class="n">executor</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="n">location</span><span class="p">,</span> <span class="n">args_grad</span><span class="o">=</span><span class="n">args_grad_data</span><span class="p">,</span>
<span class="n">aux_states</span><span class="o">=</span><span class="n">aux_states</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="n">grad_req</span><span class="p">)</span>
<span class="n">executor</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">out_grads</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
<span class="n">outg</span> <span class="o">=</span> <span class="nb">list</span><span class="p">()</span>
<span class="k">for</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">out_grads</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">arr</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span>
<span class="n">outg</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">arr</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">outg</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">arr</span><span class="p">)</span>
<span class="n">out_grads</span> <span class="o">=</span> <span class="n">outg</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">out_grads</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="n">outg</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">out_grads</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span>
<span class="n">outg</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">outg</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
<span class="n">out_grads</span> <span class="o">=</span> <span class="n">outg</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">out_grads</span> <span class="ow">is</span> <span class="kc">None</span>
<span class="n">executor</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">out_grads</span><span class="p">)</span>
<span class="n">grads</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">args_grad_data</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">expected</span><span class="p">:</span>
<span class="k">if</span> <span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'write'</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">expected</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">grads</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span>
<span class="p">(</span><span class="s2">"EXPECTED_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">,</span> <span class="s2">"BACKWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">),</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'add'</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">expected</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">grads</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">-</span> <span class="n">args_grad_npy</span><span class="p">[</span><span class="n">name</span><span class="p">],</span>
<span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="p">(</span><span class="s2">"EXPECTED_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">,</span> <span class="s2">"BACKWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">),</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'null'</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">args_grad_npy</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">grads</span><span class="p">[</span><span class="n">name</span><span class="p">],</span>
<span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">,</span> <span class="p">(</span><span class="s2">"EXPECTED_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">,</span> <span class="s2">"BACKWARD_</span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="n">name</span><span class="p">),</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Invalid grad_req </span><span class="si">%s</span><span class="s2"> for argument </span><span class="si">%s</span><span class="s2">"</span><span class="o">%</span><span class="p">(</span><span class="n">grad_req</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">name</span><span class="p">))</span>
<span class="k">return</span> <span class="n">args_grad_data</span></div>
<div class="viewcode-block" id="check_speed"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.check_speed">[docs]</a><span class="k">def</span> <span class="nf">check_speed</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">location</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">typ</span><span class="o">=</span><span class="s2">"whole"</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="sd">"""Check the running speed of a symbol.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> sym : Symbol</span>
<span class="sd"> Symbol to run the speed test.</span>
<span class="sd"> location : none or dict of str to np.ndarray</span>
<span class="sd"> Location to evaluate the inner executor.</span>
<span class="sd"> ctx : Context</span>
<span class="sd"> Running context.</span>
<span class="sd"> N : int, optional</span>
<span class="sd"> Repeat times.</span>
<span class="sd"> grad_req : None or str or list of str or dict of str to str, optional</span>
<span class="sd"> Gradient requirements.</span>
<span class="sd"> typ : str, optional</span>
<span class="sd"> "whole" or "forward"</span>
<span class="sd"> - "whole"</span>
<span class="sd"> Test the forward_backward speed.</span>
<span class="sd"> - "forward"</span>
<span class="sd"> Only test the forward speed.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">ctx</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">default_context</span><span class="p">()</span>
<span class="k">if</span> <span class="n">grad_req</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">grad_req</span> <span class="o">=</span> <span class="s1">'write'</span>
<span class="k">if</span> <span class="n">location</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">exe</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">simple_bind</span><span class="p">(</span><span class="n">grad_req</span><span class="o">=</span><span class="n">grad_req</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">location</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">arr</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span>
<span class="n">exe</span><span class="o">.</span><span class="n">arg_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">location</span><span class="p">,</span> <span class="nb">dict</span><span class="p">),</span> <span class="s2">"Expect dict, get </span><span class="se">\"</span><span class="s2">location</span><span class="se">\"</span><span class="s2">=</span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span><span class="nb">str</span><span class="p">(</span><span class="n">location</span><span class="p">)</span>
<span class="n">exe</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">simple_bind</span><span class="p">(</span><span class="n">grad_req</span><span class="o">=</span><span class="n">grad_req</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span>
<span class="o">**</span><span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">iarr</span> <span class="ow">in</span> <span class="n">location</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">exe</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">][:]</span> <span class="o">=</span> <span class="n">iarr</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">exe</span><span class="o">.</span><span class="n">arg_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">typ</span> <span class="o">==</span> <span class="s2">"whole"</span><span class="p">:</span>
<span class="c1"># Warm up</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">out_grads</span><span class="o">=</span><span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">:</span>
<span class="n">output</span><span class="o">.</span><span class="n">wait_to_read</span><span class="p">()</span>
<span class="c1"># Test forward + backward</span>
<span class="n">tic</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">out_grads</span><span class="o">=</span><span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">waitall</span><span class="p">()</span>
<span class="n">toc</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">forward_backward_time</span> <span class="o">=</span> <span class="p">(</span><span class="n">toc</span> <span class="o">-</span> <span class="n">tic</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">N</span>
<span class="k">return</span> <span class="n">forward_backward_time</span>
<span class="k">elif</span> <span class="n">typ</span> <span class="o">==</span> <span class="s2">"forward"</span><span class="p">:</span>
<span class="c1"># Warm up</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">:</span>
<span class="n">output</span><span class="o">.</span><span class="n">wait_to_read</span><span class="p">()</span>
<span class="c1"># Test forward only</span>
<span class="n">tic</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">waitall</span><span class="p">()</span>
<span class="n">toc</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">forward_time</span> <span class="o">=</span> <span class="p">(</span><span class="n">toc</span> <span class="o">-</span> <span class="n">tic</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">N</span>
<span class="k">return</span> <span class="n">forward_time</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'typ can only be "whole" or "forward".'</span><span class="p">)</span></div>
<div class="viewcode-block" id="check_consistency"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.check_consistency">[docs]</a><span class="k">def</span> <span class="nf">check_consistency</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">ctx_list</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">grad_req</span><span class="o">=</span><span class="s1">'write'</span><span class="p">,</span>
<span class="n">arg_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">aux_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">raise_on_err</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ground_truth</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">equal_nan</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">use_uniform</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">rand_type</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">):</span>
<span class="sd">"""Check symbol gives the same output for different running context</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> sym : Symbol or list of Symbols</span>
<span class="sd"> Symbol(s) to run the consistency test.</span>
<span class="sd"> ctx_list : list</span>
<span class="sd"> Running context. See example for more detail.</span>
<span class="sd"> scale : float, optional</span>
<span class="sd"> Standard deviation of the inner normal distribution. Used in initialization.</span>
<span class="sd"> grad_req : str or list of str or dict of str to str</span>
<span class="sd"> Gradient requirement.</span>
<span class="sd"> use_unifrom: bool</span>
<span class="sd"> Optional, When flag set to true,</span>
<span class="sd"> random input data generated follows uniform distribution,</span>
<span class="sd"> not normal distribution</span>
<span class="sd"> rand_type: np.dtype</span>
<span class="sd"> casts the randomly generated data to this type</span>
<span class="sd"> Optional, when input data is passed via arg_params,</span>
<span class="sd"> defaults to np.float64 (numpy float default)</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> >>> # create the symbol</span>
<span class="sd"> >>> sym = mx.sym.Convolution(num_filter=3, kernel=(3,3), name='conv')</span>
<span class="sd"> >>> # initialize the running context</span>
<span class="sd"> >>> ctx_list =\</span>
<span class="sd">[{'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}},\</span>
<span class="sd"> {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}},\</span>
<span class="sd"> {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float16}},\</span>
<span class="sd"> {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}},\</span>
<span class="sd"> {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}]</span>
<span class="sd"> >>> check_consistency(sym, ctx_list)</span>
<span class="sd"> >>> sym = mx.sym.Concat(name='concat', num_args=2)</span>
<span class="sd"> >>> ctx_list = \</span>
<span class="sd">[{'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\</span>
<span class="sd"> 'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}},\</span>
<span class="sd"> {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\</span>
<span class="sd"> 'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}},\</span>
<span class="sd"> {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\</span>
<span class="sd"> 'type_dict': {'concat_arg0': np.float16, 'concat_arg1': np.float16}},\</span>
<span class="sd"> {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\</span>
<span class="sd"> 'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}},\</span>
<span class="sd"> {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\</span>
<span class="sd"> 'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}]</span>
<span class="sd"> >>> check_consistency(sym, ctx_list)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">tol</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tol</span> <span class="o">=</span> <span class="p">{</span><span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">):</span> <span class="mf">1e-1</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span> <span class="mf">1e-3</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">):</span> <span class="mf">1e-5</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">):</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">):</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">):</span> <span class="mi">0</span><span class="p">}</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tol</span><span class="p">,</span> <span class="n">numbers</span><span class="o">.</span><span class="n">Number</span><span class="p">):</span>
<span class="n">tol</span> <span class="o">=</span> <span class="p">{</span><span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">):</span> <span class="n">tol</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span> <span class="n">tol</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">):</span> <span class="n">tol</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">):</span> <span class="n">tol</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">):</span> <span class="n">tol</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">):</span> <span class="n">tol</span><span class="p">}</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">ctx_list</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">Symbol</span><span class="p">):</span>
<span class="n">sym</span> <span class="o">=</span> <span class="p">[</span><span class="n">sym</span><span class="p">]</span><span class="o">*</span><span class="nb">len</span><span class="p">(</span><span class="n">ctx_list</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">sym</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">ctx_list</span><span class="p">)</span>
<span class="n">output_names</span> <span class="o">=</span> <span class="n">sym</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">()</span>
<span class="n">arg_names</span> <span class="o">=</span> <span class="n">sym</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()</span>
<span class="n">exe_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">s</span><span class="p">,</span> <span class="n">ctx</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">ctx_list</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">s</span><span class="o">.</span><span class="n">list_arguments</span><span class="p">()</span> <span class="o">==</span> <span class="n">arg_names</span>
<span class="k">assert</span> <span class="n">s</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">()</span> <span class="o">==</span> <span class="n">output_names</span>
<span class="n">exe_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">s</span><span class="o">.</span><span class="n">simple_bind</span><span class="p">(</span><span class="n">grad_req</span><span class="o">=</span><span class="n">grad_req</span><span class="p">,</span> <span class="o">**</span><span class="n">ctx</span><span class="p">))</span>
<span class="n">arg_params</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">if</span> <span class="n">arg_params</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">arg_params</span>
<span class="n">aux_params</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">if</span> <span class="n">aux_params</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">aux_params</span>
<span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">exe_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">arg_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="n">n</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">arg_params</span><span class="p">:</span>
<span class="k">if</span> <span class="n">use_uniform</span><span class="p">:</span>
<span class="n">arg_params</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=-</span><span class="mf">0.92</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mf">0.92</span><span class="p">,</span>
<span class="n">size</span><span class="o">=</span><span class="n">arr</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">rand_type</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">arg_params</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">arr</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">rand_type</span><span class="p">)</span>
<span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">exe_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">aux_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="n">n</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">aux_params</span><span class="p">:</span>
<span class="n">aux_params</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">exe</span> <span class="ow">in</span> <span class="n">exe_list</span><span class="p">:</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">exe</span><span class="o">.</span><span class="n">arg_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">arg_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">exe</span><span class="o">.</span><span class="n">aux_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">aux_params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="c1"># We need to initialize the gradient arrays if it's add.</span>
<span class="k">if</span> <span class="p">(</span><span class="n">grad_req</span> <span class="o">==</span> <span class="s2">"add"</span><span class="p">):</span>
<span class="k">for</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">exe</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">:</span>
<span class="n">arr</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">arr</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">arr</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">dtypes</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">dtype</span><span class="p">(</span><span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">exe</span> <span class="ow">in</span> <span class="n">exe_list</span><span class="p">]</span>
<span class="n">max_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dtypes</span><span class="p">)</span>
<span class="n">gt</span> <span class="o">=</span> <span class="n">ground_truth</span>
<span class="k">if</span> <span class="n">gt</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gt</span> <span class="o">=</span> <span class="n">exe_list</span><span class="p">[</span><span class="n">max_idx</span><span class="p">]</span><span class="o">.</span><span class="n">output_dict</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="k">if</span> <span class="n">grad_req</span> <span class="o">!=</span> <span class="s1">'null'</span><span class="p">:</span>
<span class="n">gt</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">exe_list</span><span class="p">[</span><span class="n">max_idx</span><span class="p">]</span><span class="o">.</span><span class="n">grad_dict</span><span class="p">)</span>
<span class="c1"># test</span>
<span class="k">for</span> <span class="n">exe</span> <span class="ow">in</span> <span class="n">exe_list</span><span class="p">:</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">exe</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">exe_list</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="n">max_idx</span><span class="p">:</span>
<span class="k">continue</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">output_names</span><span class="p">,</span> <span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">):</span>
<span class="n">gtarr</span> <span class="o">=</span> <span class="n">gt</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtypes</span><span class="p">[</span><span class="n">i</span><span class="p">])</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">arr</span> <span class="o">=</span> <span class="n">arr</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">arr</span><span class="p">,</span> <span class="n">gtarr</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="n">tol</span><span class="p">[</span><span class="n">dtypes</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="n">atol</span><span class="o">=</span><span class="n">tol</span><span class="p">[</span><span class="n">dtypes</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">AssertionError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'Predict Err: ctx </span><span class="si">%d</span><span class="s1"> vs ctx </span><span class="si">%d</span><span class="s1"> at </span><span class="si">%s</span><span class="s1">'</span><span class="o">%</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">max_idx</span><span class="p">,</span> <span class="n">name</span><span class="p">))</span>
<span class="n">traceback</span><span class="o">.</span><span class="n">print_exc</span><span class="p">()</span>
<span class="k">if</span> <span class="n">raise_on_err</span><span class="p">:</span> <span class="c1"># pylint: disable=no-else-raise</span>
<span class="k">raise</span> <span class="n">e</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">))</span>
<span class="c1"># train</span>
<span class="k">if</span> <span class="n">grad_req</span> <span class="o">!=</span> <span class="s1">'null'</span><span class="p">:</span>
<span class="k">for</span> <span class="n">exe</span> <span class="ow">in</span> <span class="n">exe_list</span><span class="p">:</span>
<span class="n">exe</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">exe</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">exe</span><span class="o">.</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">exe</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">exe_list</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="n">max_idx</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">curr</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="n">output_names</span> <span class="o">+</span> <span class="n">arg_names</span><span class="p">,</span> <span class="n">exe</span><span class="o">.</span><span class="n">outputs</span> <span class="o">+</span> <span class="n">exe</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">arr</span> <span class="ow">in</span> <span class="n">curr</span><span class="p">:</span>
<span class="k">if</span> <span class="n">gt</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">arr</span> <span class="ow">is</span> <span class="kc">None</span>
<span class="k">continue</span>
<span class="n">gtarr</span> <span class="o">=</span> <span class="n">gt</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtypes</span><span class="p">[</span><span class="n">i</span><span class="p">])</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">arr</span> <span class="o">=</span> <span class="n">arr</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">arr</span><span class="p">,</span> <span class="n">gtarr</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="n">tol</span><span class="p">[</span><span class="n">dtypes</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="n">atol</span><span class="o">=</span><span class="n">tol</span><span class="p">[</span><span class="n">dtypes</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span>
<span class="n">equal_nan</span><span class="o">=</span><span class="n">equal_nan</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">AssertionError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'Train Err: ctx </span><span class="si">%d</span><span class="s1"> vs ctx </span><span class="si">%d</span><span class="s1"> at </span><span class="si">%s</span><span class="s1">'</span><span class="o">%</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">max_idx</span><span class="p">,</span> <span class="n">name</span><span class="p">))</span>
<span class="n">traceback</span><span class="o">.</span><span class="n">print_exc</span><span class="p">()</span>
<span class="k">if</span> <span class="n">raise_on_err</span><span class="p">:</span> <span class="c1"># pylint: disable=no-else-raise</span>
<span class="k">raise</span> <span class="n">e</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">))</span>
<span class="k">return</span> <span class="n">gt</span></div>
<div class="viewcode-block" id="list_gpus"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.list_gpus">[docs]</a><span class="k">def</span> <span class="nf">list_gpus</span><span class="p">():</span>
<span class="sd">"""Return a list of GPUs</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> list of int:</span>
<span class="sd"> If there are n GPUs, then return a list [0,1,...,n-1]. Otherwise returns</span>
<span class="sd"> [].</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="nb">range</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">util</span><span class="o">.</span><span class="n">get_gpu_count</span><span class="p">())</span></div>
<div class="viewcode-block" id="download"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.download">[docs]</a><span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">fname</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dirname</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">retries</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
<span class="sd">"""Download an given URL</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> url : str</span>
<span class="sd"> URL to download</span>
<span class="sd"> fname : str, optional</span>
<span class="sd"> filename of the downloaded file. If None, then will guess a filename</span>
<span class="sd"> from url.</span>
<span class="sd"> dirname : str, optional</span>
<span class="sd"> output directory name. If None, then guess from fname or use the current</span>
<span class="sd"> directory</span>
<span class="sd"> overwrite : bool, optional</span>
<span class="sd"> Default is false, which means skipping download if the local file</span>
<span class="sd"> exists. If true, then download the url to overwrite the local file if</span>
<span class="sd"> exists.</span>
<span class="sd"> retries : integer, default 5</span>
<span class="sd"> The number of times to attempt the download in case of failure or non 200 return codes</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> str</span>
<span class="sd"> The filename of the downloaded file</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="n">retries</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"Number of retries should be at least 0"</span>
<span class="k">if</span> <span class="n">fname</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'/'</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">dirname</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">dirname</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">fname</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">dirname</span><span class="p">,</span> <span class="n">fname</span><span class="p">)</span>
<span class="k">if</span> <span class="n">dirname</span> <span class="o">!=</span> <span class="s2">""</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">dirname</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">'create directory </span><span class="si">%s</span><span class="s1">'</span><span class="p">,</span> <span class="n">dirname</span><span class="p">)</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">dirname</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">OSError</span> <span class="k">as</span> <span class="n">exc</span><span class="p">:</span>
<span class="k">if</span> <span class="n">exc</span><span class="o">.</span><span class="n">errno</span> <span class="o">!=</span> <span class="n">errno</span><span class="o">.</span><span class="n">EEXIST</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">OSError</span><span class="p">(</span><span class="s1">'failed to create '</span> <span class="o">+</span> <span class="n">dirname</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">overwrite</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> exists, skipping download"</span><span class="p">,</span> <span class="n">fname</span><span class="p">)</span>
<span class="k">return</span> <span class="n">fname</span>
<span class="k">while</span> <span class="n">retries</span><span class="o">+</span><span class="mi">1</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># Disable pyling too broad Exception</span>
<span class="c1"># pylint: disable=W0703</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">r</span> <span class="o">=</span> <span class="n">requests</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">r</span><span class="o">.</span><span class="n">status_code</span> <span class="o">==</span> <span class="mi">200</span><span class="p">,</span> <span class="s2">"failed to open </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">url</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">fname</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">r</span><span class="o">.</span><span class="n">iter_content</span><span class="p">(</span><span class="n">chunk_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">):</span>
<span class="k">if</span> <span class="n">chunk</span><span class="p">:</span> <span class="c1"># filter out keep-alive new chunks</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">chunk</span><span class="p">)</span>
<span class="k">break</span>
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
<span class="n">retries</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">retries</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># pylint: disable=no-else-raise</span>
<span class="k">raise</span> <span class="n">e</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"download failed, retrying, </span><span class="si">{}</span><span class="s2"> attempt</span><span class="si">{}</span><span class="s2"> left"</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">retries</span><span class="p">,</span> <span class="s1">'s'</span> <span class="k">if</span> <span class="n">retries</span> <span class="o">></span> <span class="mi">1</span> <span class="k">else</span> <span class="s1">''</span><span class="p">))</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"downloaded </span><span class="si">%s</span><span class="s2"> into </span><span class="si">%s</span><span class="s2"> successfully"</span><span class="p">,</span> <span class="n">url</span><span class="p">,</span> <span class="n">fname</span><span class="p">)</span>
<span class="k">return</span> <span class="n">fname</span></div>
<div class="viewcode-block" id="get_mnist"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_mnist">[docs]</a><span class="k">def</span> <span class="nf">get_mnist</span><span class="p">():</span>
<span class="sd">"""Download and load the MNIST dataset</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> dict</span>
<span class="sd"> A dict containing the data</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="nf">read_data</span><span class="p">(</span><span class="n">label_url</span><span class="p">,</span> <span class="n">image_url</span><span class="p">):</span>
<span class="k">with</span> <span class="n">gzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">label_url</span><span class="p">))</span> <span class="k">as</span> <span class="n">flbl</span><span class="p">:</span>
<span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s2">">II"</span><span class="p">,</span> <span class="n">flbl</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">flbl</span><span class="o">.</span><span class="n">read</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">)</span>
<span class="k">with</span> <span class="n">gzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">image_url</span><span class="p">),</span> <span class="s1">'rb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">fimg</span><span class="p">:</span>
<span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">rows</span><span class="p">,</span> <span class="n">cols</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s2">">IIII"</span><span class="p">,</span> <span class="n">fimg</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">16</span><span class="p">))</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">fimg</span><span class="o">.</span><span class="n">read</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">label</span><span class="p">),</span> <span class="n">rows</span><span class="p">,</span> <span class="n">cols</span><span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span><span class="o">/</span><span class="mi">255</span>
<span class="k">return</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">image</span><span class="p">)</span>
<span class="c1"># changed to mxnet.io for more stable hosting</span>
<span class="c1"># path = 'http://yann.lecun.com/exdb/mnist/'</span>
<span class="n">path</span> <span class="o">=</span> <span class="s1">'http://data.mxnet.io/data/mnist/'</span>
<span class="p">(</span><span class="n">train_lbl</span><span class="p">,</span> <span class="n">train_img</span><span class="p">)</span> <span class="o">=</span> <span class="n">read_data</span><span class="p">(</span>
<span class="n">path</span><span class="o">+</span><span class="s1">'train-labels-idx1-ubyte.gz'</span><span class="p">,</span> <span class="n">path</span><span class="o">+</span><span class="s1">'train-images-idx3-ubyte.gz'</span><span class="p">)</span>
<span class="p">(</span><span class="n">test_lbl</span><span class="p">,</span> <span class="n">test_img</span><span class="p">)</span> <span class="o">=</span> <span class="n">read_data</span><span class="p">(</span>
<span class="n">path</span><span class="o">+</span><span class="s1">'t10k-labels-idx1-ubyte.gz'</span><span class="p">,</span> <span class="n">path</span><span class="o">+</span><span class="s1">'t10k-images-idx3-ubyte.gz'</span><span class="p">)</span>
<span class="k">return</span> <span class="p">{</span><span class="s1">'train_data'</span><span class="p">:</span><span class="n">train_img</span><span class="p">,</span> <span class="s1">'train_label'</span><span class="p">:</span><span class="n">train_lbl</span><span class="p">,</span>
<span class="s1">'test_data'</span><span class="p">:</span><span class="n">test_img</span><span class="p">,</span> <span class="s1">'test_label'</span><span class="p">:</span><span class="n">test_lbl</span><span class="p">}</span></div>
<div class="viewcode-block" id="get_mnist_pkl"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_mnist_pkl">[docs]</a><span class="k">def</span> <span class="nf">get_mnist_pkl</span><span class="p">():</span>
<span class="sd">"""Downloads MNIST dataset as a pkl.gz into a directory in the current directory</span>
<span class="sd"> with the name `data`</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="s2">"data"</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/mnist.pkl.gz'</span><span class="p">):</span>
<span class="n">download</span><span class="p">(</span><span class="s1">'http://deeplearning.net/data/mnist/mnist.pkl.gz'</span><span class="p">,</span>
<span class="n">dirname</span><span class="o">=</span><span class="s1">'data'</span><span class="p">)</span></div>
<div class="viewcode-block" id="get_mnist_ubyte"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_mnist_ubyte">[docs]</a><span class="k">def</span> <span class="nf">get_mnist_ubyte</span><span class="p">():</span>
<span class="sd">"""Downloads ubyte version of the MNIST dataset into a directory in the current directory</span>
<span class="sd"> with the name `data` and extracts all files in the zip archive to this directory.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="s2">"data"</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/train-images-idx3-ubyte'</span><span class="p">))</span> <span class="ow">or</span> \
<span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/train-labels-idx1-ubyte'</span><span class="p">))</span> <span class="ow">or</span> \
<span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/t10k-images-idx3-ubyte'</span><span class="p">))</span> <span class="ow">or</span> \
<span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/t10k-labels-idx1-ubyte'</span><span class="p">)):</span>
<span class="n">zip_file_path</span> <span class="o">=</span> <span class="n">download</span><span class="p">(</span><span class="s1">'http://data.mxnet.io/mxnet/data/mnist.zip'</span><span class="p">,</span>
<span class="n">dirname</span><span class="o">=</span><span class="s1">'data'</span><span class="p">)</span>
<span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">zip_file_path</span><span class="p">)</span> <span class="k">as</span> <span class="n">zf</span><span class="p">:</span>
<span class="n">zf</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span></div>
<div class="viewcode-block" id="get_cifar10"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_cifar10">[docs]</a><span class="k">def</span> <span class="nf">get_cifar10</span><span class="p">():</span>
<span class="sd">"""Downloads CIFAR10 dataset into a directory in the current directory with the name `data`,</span>
<span class="sd"> and then extracts all files into the directory `data/cifar`.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="s2">"data"</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/cifar/train.rec'</span><span class="p">))</span> <span class="ow">or</span> \
<span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/cifar/test.rec'</span><span class="p">))</span> <span class="ow">or</span> \
<span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/cifar/train.lst'</span><span class="p">))</span> <span class="ow">or</span> \
<span class="p">(</span><span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'data/cifar/test.lst'</span><span class="p">)):</span>
<span class="n">zip_file_path</span> <span class="o">=</span> <span class="n">download</span><span class="p">(</span><span class="s1">'http://data.mxnet.io/mxnet/data/cifar10.zip'</span><span class="p">,</span>
<span class="n">dirname</span><span class="o">=</span><span class="s1">'data'</span><span class="p">)</span>
<span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">zip_file_path</span><span class="p">)</span> <span class="k">as</span> <span class="n">zf</span><span class="p">:</span>
<span class="n">zf</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span></div>
<div class="viewcode-block" id="get_mnist_iterator"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_mnist_iterator">[docs]</a><span class="k">def</span> <span class="nf">get_mnist_iterator</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">,</span> <span class="n">num_parts</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">part_index</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="sd">"""Returns training and validation iterators for MNIST dataset</span>
<span class="sd"> """</span>
<span class="n">get_mnist_ubyte</span><span class="p">()</span>
<span class="n">flat</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span> <span class="k">else</span> <span class="kc">True</span> <span class="c1"># pylint: disable=simplifiable-if-expression</span>
<span class="n">train_dataiter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">MNISTIter</span><span class="p">(</span>
<span class="n">image</span><span class="o">=</span><span class="s2">"data/train-images-idx3-ubyte"</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">"data/train-labels-idx1-ubyte"</span><span class="p">,</span>
<span class="n">input_shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">flat</span><span class="o">=</span><span class="n">flat</span><span class="p">,</span>
<span class="n">num_parts</span><span class="o">=</span><span class="n">num_parts</span><span class="p">,</span>
<span class="n">part_index</span><span class="o">=</span><span class="n">part_index</span><span class="p">)</span>
<span class="n">val_dataiter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">MNISTIter</span><span class="p">(</span>
<span class="n">image</span><span class="o">=</span><span class="s2">"data/t10k-images-idx3-ubyte"</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">"data/t10k-labels-idx1-ubyte"</span><span class="p">,</span>
<span class="n">input_shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">flat</span><span class="o">=</span><span class="n">flat</span><span class="p">,</span>
<span class="n">num_parts</span><span class="o">=</span><span class="n">num_parts</span><span class="p">,</span>
<span class="n">part_index</span><span class="o">=</span><span class="n">part_index</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">train_dataiter</span><span class="p">,</span> <span class="n">val_dataiter</span><span class="p">)</span></div>
<div class="viewcode-block" id="get_zip_data"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_zip_data">[docs]</a><span class="k">def</span> <span class="nf">get_zip_data</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">url</span><span class="p">,</span> <span class="n">data_origin_name</span><span class="p">):</span>
<span class="sd">"""Download and extract zip data.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data_dir : str</span>
<span class="sd"> Absolute or relative path of the directory name to store zip files</span>
<span class="sd"> url : str</span>
<span class="sd"> URL to download data from</span>
<span class="sd"> data_origin_name : str</span>
<span class="sd"> Name of the downloaded zip file</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> >>> get_zip_data("data_dir",</span>
<span class="sd"> "http://files.grouplens.org/datasets/movielens/ml-10m.zip",</span>
<span class="sd"> "ml-10m.zip")</span>
<span class="sd"> """</span>
<span class="n">data_origin_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">data_origin_name</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">data_origin_name</span><span class="p">):</span>
<span class="n">download</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">dirname</span><span class="o">=</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">zip_file</span> <span class="o">=</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">data_origin_name</span><span class="p">)</span>
<span class="n">zip_file</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="n">path</span><span class="o">=</span><span class="n">data_dir</span><span class="p">)</span></div>
<div class="viewcode-block" id="get_bz2_data"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_bz2_data">[docs]</a><span class="k">def</span> <span class="nf">get_bz2_data</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">data_name</span><span class="p">,</span> <span class="n">url</span><span class="p">,</span> <span class="n">data_origin_name</span><span class="p">):</span>
<span class="sd">"""Download and extract bz2 data.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data_dir : str</span>
<span class="sd"> Absolute or relative path of the directory name to store bz2 files</span>
<span class="sd"> data_name : str</span>
<span class="sd"> Name of the output file in which bz2 contents will be extracted</span>
<span class="sd"> url : str</span>
<span class="sd"> URL to download data from</span>
<span class="sd"> data_origin_name : str</span>
<span class="sd"> Name of the downloaded b2 file</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> >>> get_bz2_data("data_dir", "kdda.t",</span>
<span class="sd"> "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",</span>
<span class="sd"> "kdda.t.bz2")</span>
<span class="sd"> """</span>
<span class="n">data_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">data_name</span><span class="p">)</span>
<span class="n">data_origin_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">data_origin_name</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">data_name</span><span class="p">):</span>
<span class="n">download</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">fname</span><span class="o">=</span><span class="n">data_origin_name</span><span class="p">,</span> <span class="n">dirname</span><span class="o">=</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">bz_file</span> <span class="o">=</span> <span class="n">bz2</span><span class="o">.</span><span class="n">BZ2File</span><span class="p">(</span><span class="n">data_origin_name</span><span class="p">,</span> <span class="s1">'rb'</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">data_name</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">fout</span><span class="p">:</span>
<span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">bz_file</span><span class="p">:</span>
<span class="n">fout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">line</span><span class="p">)</span>
<span class="n">bz_file</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
<span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">data_origin_name</span><span class="p">)</span></div>
<div class="viewcode-block" id="set_env_var"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.set_env_var">[docs]</a><span class="k">def</span> <span class="nf">set_env_var</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="s2">""</span><span class="p">):</span>
<span class="sd">"""Set environment variable</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> key : str</span>
<span class="sd"> Env var to set</span>
<span class="sd"> val : str</span>
<span class="sd"> New value assigned to the env var</span>
<span class="sd"> default_val : str, optional</span>
<span class="sd"> Default value returned if the env var doesn't exist</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> str</span>
<span class="sd"> The value of env var before it is set to the new value</span>
<span class="sd"> """</span>
<span class="n">prev_val</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">default_val</span><span class="p">)</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
<span class="k">return</span> <span class="n">prev_val</span></div>
<div class="viewcode-block" id="same_array"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.same_array">[docs]</a><span class="k">def</span> <span class="nf">same_array</span><span class="p">(</span><span class="n">array1</span><span class="p">,</span> <span class="n">array2</span><span class="p">):</span>
<span class="sd">"""Check whether two NDArrays sharing the same memory block</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> array1 : NDArray</span>
<span class="sd"> First NDArray to be checked</span>
<span class="sd"> array2 : NDArray</span>
<span class="sd"> Second NDArray to be checked</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> bool</span>
<span class="sd"> Whether two NDArrays share the same memory</span>
<span class="sd"> """</span>
<span class="n">array1</span><span class="p">[:]</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">same</span><span class="p">(</span><span class="n">array1</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">array2</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()):</span>
<span class="n">array1</span><span class="p">[:]</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="n">array1</span><span class="p">[:]</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">same</span><span class="p">(</span><span class="n">array1</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">array2</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span></div>
<span class="nd">@contextmanager</span>
<div class="viewcode-block" id="discard_stderr"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.discard_stderr">[docs]</a><span class="k">def</span> <span class="nf">discard_stderr</span><span class="p">():</span>
<span class="sd">"""</span>
<span class="sd"> Discards error output of a routine if invoked as:</span>
<span class="sd"> with discard_stderr():</span>
<span class="sd"> ...</span>
<span class="sd"> """</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">devnull</span><span class="p">,</span> <span class="s1">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="n">bit_bucket</span><span class="p">:</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">stderr_fileno</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">fileno</span><span class="p">()</span>
<span class="n">old_stderr</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">dup</span><span class="p">(</span><span class="n">stderr_fileno</span><span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">os</span><span class="o">.</span><span class="n">dup2</span><span class="p">(</span><span class="n">bit_bucket</span><span class="o">.</span><span class="n">fileno</span><span class="p">(),</span> <span class="n">stderr_fileno</span><span class="p">)</span>
<span class="k">yield</span>
<span class="k">finally</span><span class="p">:</span>
<span class="n">os</span><span class="o">.</span><span class="n">dup2</span><span class="p">(</span><span class="n">old_stderr</span><span class="p">,</span> <span class="n">stderr_fileno</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">AttributeError</span><span class="p">:</span>
<span class="c1"># On some systems is stderr not a file descriptor but actually a virtual pipeline</span>
<span class="c1"># that can not be copied</span>
<span class="k">yield</span></div>
<div class="viewcode-block" id="DummyIter"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.DummyIter">[docs]</a><span class="k">class</span> <span class="nc">DummyIter</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">DataIter</span><span class="p">):</span>
<span class="sd">"""A dummy iterator that always returns the same batch of data</span>
<span class="sd"> (the first data batch of the real data iter). This is usually used for speed testing.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> real_iter: mx.io.DataIter</span>
<span class="sd"> The real data iterator where the first batch of data comes from</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">real_iter</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">DummyIter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">real_iter</span> <span class="o">=</span> <span class="n">real_iter</span>
<span class="bp">self</span><span class="o">.</span><span class="n">provide_data</span> <span class="o">=</span> <span class="n">real_iter</span><span class="o">.</span><span class="n">provide_data</span>
<span class="bp">self</span><span class="o">.</span><span class="n">provide_label</span> <span class="o">=</span> <span class="n">real_iter</span><span class="o">.</span><span class="n">provide_label</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">real_iter</span><span class="o">.</span><span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">the_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">real_iter</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span>
<div class="viewcode-block" id="DummyIter.next"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.DummyIter.next">[docs]</a> <span class="k">def</span> <span class="nf">next</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sd">"""Get a data batch from iterator. The first data batch of real iter is always returned.</span>
<span class="sd"> StopIteration will never be raised.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> DataBatch</span>
<span class="sd"> The data of next batch.</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">the_batch</span></div></div>
<div class="viewcode-block" id="gen_buckets_probs_with_ppf"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.gen_buckets_probs_with_ppf">[docs]</a><span class="k">def</span> <span class="nf">gen_buckets_probs_with_ppf</span><span class="p">(</span><span class="n">ppf</span><span class="p">,</span> <span class="n">nbuckets</span><span class="p">):</span>
<span class="sd">"""Generate the buckets and probabilities for chi_square test when the ppf (Quantile function)</span>
<span class="sd"> is specified.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> ppf : function</span>
<span class="sd"> The Quantile function that takes a probability and maps it back to a value.</span>
<span class="sd"> It's the inverse of the cdf function</span>
<span class="sd"> nbuckets : int</span>
<span class="sd"> size of the buckets</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> buckets : list of tuple</span>
<span class="sd"> The generated buckets</span>
<span class="sd"> probs : list</span>
<span class="sd"> The generate probabilities</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="n">nbuckets</span> <span class="o">></span> <span class="mi">0</span>
<span class="n">probs</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">nbuckets</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nbuckets</span><span class="p">)]</span>
<span class="n">buckets</span> <span class="o">=</span> <span class="p">[(</span><span class="n">ppf</span><span class="p">(</span><span class="n">i</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">nbuckets</span><span class="p">)),</span> <span class="n">ppf</span><span class="p">((</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">nbuckets</span><span class="p">)))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nbuckets</span><span class="p">)]</span>
<span class="k">return</span> <span class="n">buckets</span><span class="p">,</span> <span class="n">probs</span></div>
<div class="viewcode-block" id="mean_check"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.mean_check">[docs]</a><span class="k">def</span> <span class="nf">mean_check</span><span class="p">(</span><span class="n">generator</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma</span><span class="p">,</span> <span class="n">nsamples</span><span class="o">=</span><span class="mi">1000000</span><span class="p">):</span>
<span class="sd">"""Test the generator by matching the mean.</span>
<span class="sd"> We test the sample mean by checking if it falls inside the range</span>
<span class="sd"> (mu - 3 * sigma / sqrt(n), mu + 3 * sigma / sqrt(n))</span>
<span class="sd"> References::</span>
<span class="sd"> @incollection{goucher2009beautiful,</span>
<span class="sd"> title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},</span>
<span class="sd"> author={Goucher, Adam and Riley, Tim},</span>
<span class="sd"> year={2009},</span>
<span class="sd"> chapter=10</span>
<span class="sd"> }</span>
<span class="sd"> Examples::</span>
<span class="sd"> generator = lambda x: np.random.normal(0, 1.0, size=x)</span>
<span class="sd"> mean_check_ret = mean_check(generator, 0, 1.0)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> generator : function</span>
<span class="sd"> The generator function. It's expected to generate N i.i.d samples by calling generator(N).</span>
<span class="sd"> mu : float</span>
<span class="sd"> sigma : float</span>
<span class="sd"> nsamples : int</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> ret : bool</span>
<span class="sd"> Whether the mean test succeeds</span>
<span class="sd"> """</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">generator</span><span class="p">(</span><span class="n">nsamples</span><span class="p">))</span>
<span class="n">sample_mean</span> <span class="o">=</span> <span class="n">samples</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">(</span><span class="n">sample_mean</span> <span class="o">></span> <span class="n">mu</span> <span class="o">-</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">nsamples</span><span class="p">))</span> <span class="ow">and</span>\
<span class="p">(</span><span class="n">sample_mean</span> <span class="o"><</span> <span class="n">mu</span> <span class="o">+</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">nsamples</span><span class="p">))</span>
<span class="k">return</span> <span class="n">ret</span></div>
<div class="viewcode-block" id="get_im2rec_path"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.get_im2rec_path">[docs]</a><span class="k">def</span> <span class="nf">get_im2rec_path</span><span class="p">(</span><span class="n">home_env</span><span class="o">=</span><span class="s2">"MXNET_HOME"</span><span class="p">):</span>
<span class="sd">"""Get path to the im2rec.py tool</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> home_env : str</span>
<span class="sd"> Env variable that holds the path to the MXNET folder</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> str</span>
<span class="sd"> The path to im2rec.py</span>
<span class="sd"> """</span>
<span class="c1"># Check first if the path to MXNET is passed as an env variable</span>
<span class="k">if</span> <span class="n">home_env</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">:</span>
<span class="n">mxnet_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="n">home_env</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Else use currently imported mxnet as reference</span>
<span class="n">mxnet_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="vm">__file__</span><span class="p">)</span>
<span class="c1"># If MXNet was installed through pip, the location of im2rec.py</span>
<span class="n">im2rec_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">mxnet_path</span><span class="p">,</span> <span class="s1">'tools'</span><span class="p">,</span> <span class="s1">'im2rec.py'</span><span class="p">)</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">im2rec_path</span><span class="p">):</span>
<span class="k">return</span> <span class="n">im2rec_path</span>
<span class="c1"># If MXNet has been built locally</span>
<span class="n">im2rec_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">mxnet_path</span><span class="p">,</span> <span class="s1">'..'</span><span class="p">,</span> <span class="s1">'..'</span><span class="p">,</span> <span class="s1">'tools'</span><span class="p">,</span> <span class="s1">'im2rec.py'</span><span class="p">)</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">im2rec_path</span><span class="p">):</span>
<span class="k">return</span> <span class="n">im2rec_path</span>
<span class="k">raise</span> <span class="ne">IOError</span><span class="p">(</span><span class="s1">'Could not find path to tools/im2rec.py'</span><span class="p">)</span></div>
<div class="viewcode-block" id="var_check"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.var_check">[docs]</a><span class="k">def</span> <span class="nf">var_check</span><span class="p">(</span><span class="n">generator</span><span class="p">,</span> <span class="n">sigma</span><span class="p">,</span> <span class="n">nsamples</span><span class="o">=</span><span class="mi">1000000</span><span class="p">):</span>
<span class="sd">"""Test the generator by matching the variance.</span>
<span class="sd"> It will need a large number of samples and is not recommended to use</span>
<span class="sd"> We test the sample variance by checking if it falls inside the range</span>
<span class="sd"> (sigma^2 - 3 * sqrt(2 * sigma^4 / (n-1)), sigma^2 + 3 * sqrt(2 * sigma^4 / (n-1)))</span>
<span class="sd"> References::</span>
<span class="sd"> @incollection{goucher2009beautiful,</span>
<span class="sd"> title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},</span>
<span class="sd"> author={Goucher, Adam and Riley, Tim},</span>
<span class="sd"> year={2009},</span>
<span class="sd"> chapter=10</span>
<span class="sd"> }</span>
<span class="sd"> Examples::</span>
<span class="sd"> generator = lambda x: np.random.normal(0, 1.0, size=x)</span>
<span class="sd"> var_check_ret = var_check(generator, 0, 1.0)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> generator : function</span>
<span class="sd"> The generator function. It's expected to generate N i.i.d samples by calling generator(N).</span>
<span class="sd"> sigma : float</span>
<span class="sd"> nsamples : int</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> ret : bool</span>
<span class="sd"> Whether the variance test succeeds</span>
<span class="sd"> """</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">generator</span><span class="p">(</span><span class="n">nsamples</span><span class="p">))</span>
<span class="n">sample_var</span> <span class="o">=</span> <span class="n">samples</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="n">ddof</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">(</span><span class="n">sample_var</span> <span class="o">></span> <span class="n">sigma</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">-</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">**</span> <span class="mi">4</span> <span class="o">/</span> <span class="p">(</span><span class="n">nsamples</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)))</span> <span class="ow">and</span>\
<span class="p">(</span><span class="n">sample_var</span> <span class="o"><</span> <span class="n">sigma</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">**</span> <span class="mi">4</span> <span class="o">/</span> <span class="p">(</span><span class="n">nsamples</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">ret</span></div>
<div class="viewcode-block" id="chi_square_check"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.chi_square_check">[docs]</a><span class="k">def</span> <span class="nf">chi_square_check</span><span class="p">(</span><span class="n">generator</span><span class="p">,</span> <span class="n">buckets</span><span class="p">,</span> <span class="n">probs</span><span class="p">,</span> <span class="n">nsamples</span><span class="o">=</span><span class="mi">1000000</span><span class="p">):</span>
<span class="sd">"""Run the chi-square test for the generator. The generator can be both continuous and discrete.</span>
<span class="sd"> If the generator is continuous, the buckets should contain tuples of (range_min, range_max) \</span>
<span class="sd"> and the probs should be the corresponding ideal probability within the specific ranges. \</span>
<span class="sd"> Otherwise, the buckets should contain all the possible values generated over the discrete distribution and the \</span>
<span class="sd"> probs should be groud-truth probability.</span>
<span class="sd"> Usually the user is required to specify the probs parameter.</span>
<span class="sd"> After obtaining the p value, we could further use the standard p > 0.05 (alpha) threshold to get \</span>
<span class="sd"> the final result.</span>
<span class="sd"> Examples::</span>
<span class="sd"> buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, 0, 1), 5)</span>
<span class="sd"> generator = lambda x: np.random.normal(0, 1.0, size=x)</span>
<span class="sd"> p = chi_square_check(generator=generator, buckets=buckets, probs=probs)</span>
<span class="sd"> assert(p > 0.05)</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> generator: function</span>
<span class="sd"> A function that is assumed to generate i.i.d samples from a specific distribution.</span>
<span class="sd"> generator(N) should generate N random samples.</span>
<span class="sd"> buckets: list of tuple or list of number</span>
<span class="sd"> The buckets to run the chi-square the test. Make sure that the buckets cover</span>
<span class="sd"> the whole range of the distribution. Also, the buckets must be in ascending order and have</span>
<span class="sd"> no intersection</span>
<span class="sd"> probs: list or tuple</span>
<span class="sd"> The ground-truth probability of the random value fall in a specific bucket.</span>
<span class="sd"> nsamples:int</span>
<span class="sd"> The number of samples to generate for the testing</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> p : float</span>
<span class="sd"> p value that the generator has the expected distribution.</span>
<span class="sd"> A higher value indicates a larger confidence</span>
<span class="sd"> obs_freq : list</span>
<span class="sd"> Observed frequency of buckets</span>
<span class="sd"> expected_freq : list</span>
<span class="sd"> The expected (ground-truth) frequency of the buckets</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ss</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ImportError</span><span class="p">(</span><span class="s2">"scipy is not available."</span>
<span class="s2">" Please check if the scipy python bindings are installed."</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">buckets</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">generator</span><span class="p">(</span><span class="n">nsamples</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">probs</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">buckets</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">buckets</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="c1"># Check whether the buckets are valid and fill them into a npy array</span>
<span class="n">continuous_dist</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">buckets_npy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">buckets</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">buckets</span><span class="p">):</span>
<span class="k">assert</span><span class="p">(</span><span class="n">buckets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o"><=</span> <span class="n">buckets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="n">i</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="n">buckets</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">assert</span><span class="p">(</span><span class="n">buckets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o"><=</span> <span class="n">buckets</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
<span class="n">buckets_npy</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">buckets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="n">buckets_npy</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">buckets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">continuous_dist</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">expected_freq</span> <span class="o">=</span> <span class="p">(</span><span class="n">nsamples</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">continuous_dist</span><span class="p">:</span>
<span class="n">sample_bucket_ids</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">buckets_npy</span><span class="p">,</span> <span class="n">samples</span><span class="p">,</span> <span class="n">side</span><span class="o">=</span><span class="s1">'right'</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">sample_bucket_ids</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span>
<span class="k">if</span> <span class="n">continuous_dist</span><span class="p">:</span>
<span class="n">sample_bucket_ids</span> <span class="o">=</span> <span class="n">sample_bucket_ids</span> <span class="o">//</span> <span class="mi">2</span>
<span class="n">obs_freq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">buckets</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">buckets</span><span class="p">):</span>
<span class="k">if</span> <span class="n">continuous_dist</span><span class="p">:</span>
<span class="n">obs_freq</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">sample_bucket_ids</span> <span class="o">==</span> <span class="n">i</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">obs_freq</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">sample_bucket_ids</span> <span class="o">==</span> <span class="n">buckets</span><span class="p">[</span><span class="n">i</span><span class="p">])</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="n">_</span><span class="p">,</span> <span class="n">p</span> <span class="o">=</span> <span class="n">ss</span><span class="o">.</span><span class="n">chisquare</span><span class="p">(</span><span class="n">f_obs</span><span class="o">=</span><span class="n">obs_freq</span><span class="p">,</span> <span class="n">f_exp</span><span class="o">=</span><span class="n">expected_freq</span><span class="p">)</span>
<span class="k">return</span> <span class="n">p</span><span class="p">,</span> <span class="n">obs_freq</span><span class="p">,</span> <span class="n">expected_freq</span></div>
<div class="viewcode-block" id="verify_generator"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.verify_generator">[docs]</a><span class="k">def</span> <span class="nf">verify_generator</span><span class="p">(</span><span class="n">generator</span><span class="p">,</span> <span class="n">buckets</span><span class="p">,</span> <span class="n">probs</span><span class="p">,</span> <span class="n">nsamples</span><span class="o">=</span><span class="mi">1000000</span><span class="p">,</span> <span class="n">nrepeat</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">success_rate</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.05</span><span class="p">):</span>
<span class="sd">"""Verify whether the generator is correct using chi-square testing.</span>
<span class="sd"> The test is repeated for "nrepeat" times and we check if the success rate is</span>
<span class="sd"> above the threshold (25% by default).</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> generator: function</span>
<span class="sd"> A function that is assumed to generate i.i.d samples from a specific distribution.</span>
<span class="sd"> generator(N) should generate N random samples.</span>
<span class="sd"> buckets: list of tuple or list of number</span>
<span class="sd"> The buckets to run the chi-square the test. Make sure that the buckets cover</span>
<span class="sd"> the whole range of the distribution. Also, the buckets must be in ascending order and</span>
<span class="sd"> have no intersection</span>
<span class="sd"> probs: list or tuple</span>
<span class="sd"> The ground-truth probability of the random value fall in a specific bucket.</span>
<span class="sd"> nsamples: int</span>
<span class="sd"> The number of samples to generate for the testing</span>
<span class="sd"> nrepeat: int</span>
<span class="sd"> The times to repeat the test</span>
<span class="sd"> success_rate: float</span>
<span class="sd"> The desired success rate</span>
<span class="sd"> alpha: float</span>
<span class="sd"> The desired threshold for type-I error i.e. when a true null hypothesis is rejected</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> cs_ret_l: list</span>
<span class="sd"> The p values of the chi-square test.</span>
<span class="sd"> """</span>
<span class="n">cs_ret_l</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">obs_freq_l</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">expected_freq_l</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nrepeat</span><span class="p">):</span>
<span class="n">cs_ret</span><span class="p">,</span> <span class="n">obs_freq</span><span class="p">,</span> <span class="n">expected_freq</span> <span class="o">=</span> <span class="n">chi_square_check</span><span class="p">(</span><span class="n">generator</span><span class="o">=</span><span class="n">generator</span><span class="p">,</span> <span class="n">buckets</span><span class="o">=</span><span class="n">buckets</span><span class="p">,</span>
<span class="n">probs</span><span class="o">=</span><span class="n">probs</span><span class="p">,</span> <span class="n">nsamples</span><span class="o">=</span><span class="n">nsamples</span><span class="p">)</span>
<span class="n">cs_ret_l</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cs_ret</span><span class="p">)</span>
<span class="n">obs_freq_l</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">obs_freq</span><span class="p">)</span>
<span class="n">expected_freq_l</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">expected_freq</span><span class="p">)</span>
<span class="n">success_num</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cs_ret_l</span><span class="p">)</span> <span class="o">></span> <span class="n">alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="k">if</span> <span class="n">success_num</span> <span class="o"><</span> <span class="n">nrepeat</span> <span class="o">*</span> <span class="n">success_rate</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="s2">"Generator test fails, Chi-square p=</span><span class="si">%s</span><span class="s2">, obs_freq=</span><span class="si">%s</span><span class="s2">, expected_freq=</span><span class="si">%s</span><span class="s2">."</span>
<span class="s2">"</span><span class="se">\n</span><span class="s2">buckets=</span><span class="si">%s</span><span class="s2">, probs=</span><span class="si">%s</span><span class="s2">"</span>
<span class="o">%</span> <span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">cs_ret_l</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">obs_freq_l</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">expected_freq_l</span><span class="p">),</span>
<span class="nb">str</span><span class="p">(</span><span class="n">buckets</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">probs</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">cs_ret_l</span></div>
<div class="viewcode-block" id="compare_ndarray_tuple"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.compare_ndarray_tuple">[docs]</a><span class="k">def</span> <span class="nf">compare_ndarray_tuple</span><span class="p">(</span><span class="n">t1</span><span class="p">,</span> <span class="n">t2</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sd">"""Compare ndarray tuple."""</span>
<span class="k">if</span> <span class="n">t1</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">t2</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">t1</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
<span class="k">for</span> <span class="n">s1</span><span class="p">,</span> <span class="n">s2</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">t1</span><span class="p">,</span> <span class="n">t2</span><span class="p">):</span>
<span class="n">compare_ndarray_tuple</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="n">s2</span><span class="p">,</span> <span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">t1</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">t2</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">rtol</span><span class="o">=</span><span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="n">atol</span><span class="p">)</span></div>
<div class="viewcode-block" id="compare_optimizer"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.compare_optimizer">[docs]</a><span class="k">def</span> <span class="nf">compare_optimizer</span><span class="p">(</span><span class="n">opt1</span><span class="p">,</span> <span class="n">opt2</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">w_stype</span><span class="o">=</span><span class="s1">'default'</span><span class="p">,</span> <span class="n">g_stype</span><span class="o">=</span><span class="s1">'default'</span><span class="p">,</span>
<span class="n">rtol</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">compare_states</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="sd">"""Compare opt1 and opt2."""</span>
<span class="k">if</span> <span class="n">w_stype</span> <span class="o">==</span> <span class="s1">'default'</span><span class="p">:</span>
<span class="n">w2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">default_context</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">w1</span> <span class="o">=</span> <span class="n">w2</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">default_context</span><span class="p">())</span>
<span class="k">elif</span> <span class="n">w_stype</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">'row_sparse'</span><span class="p">,</span> <span class="s1">'csr'</span><span class="p">):</span>
<span class="n">w2</span> <span class="o">=</span> <span class="n">rand_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">w_stype</span><span class="p">,</span> <span class="n">density</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">w1</span> <span class="o">=</span> <span class="n">w2</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">default_context</span><span class="p">())</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s1">'default'</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"type not supported yet"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">g_stype</span> <span class="o">==</span> <span class="s1">'default'</span><span class="p">:</span>
<span class="n">g2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">default_context</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">g1</span> <span class="o">=</span> <span class="n">g2</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">default_context</span><span class="p">())</span>
<span class="k">elif</span> <span class="n">g_stype</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">'row_sparse'</span><span class="p">,</span> <span class="s1">'csr'</span><span class="p">):</span>
<span class="n">g2</span> <span class="o">=</span> <span class="n">rand_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">g_stype</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">g1</span> <span class="o">=</span> <span class="n">g2</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">default_context</span><span class="p">())</span><span class="o">.</span><span class="n">tostype</span><span class="p">(</span><span class="s1">'default'</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">"type not supported yet"</span><span class="p">)</span>
<span class="n">state1</span> <span class="o">=</span> <span class="n">opt1</span><span class="o">.</span><span class="n">create_state_multi_precision</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w1</span><span class="p">)</span>
<span class="n">state2</span> <span class="o">=</span> <span class="n">opt2</span><span class="o">.</span><span class="n">create_state_multi_precision</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w2</span><span class="p">)</span>
<span class="k">if</span> <span class="n">compare_states</span><span class="p">:</span>
<span class="n">compare_ndarray_tuple</span><span class="p">(</span><span class="n">state1</span><span class="p">,</span> <span class="n">state2</span><span class="p">)</span>
<span class="n">opt1</span><span class="o">.</span><span class="n">update_multi_precision</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w1</span><span class="p">,</span> <span class="n">g1</span><span class="p">,</span> <span class="n">state1</span><span class="p">)</span>
<span class="n">opt2</span><span class="o">.</span><span class="n">update_multi_precision</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w2</span><span class="p">,</span> <span class="n">g2</span><span class="p">,</span> <span class="n">state2</span><span class="p">)</span>
<span class="k">if</span> <span class="n">compare_states</span><span class="p">:</span>
<span class="n">compare_ndarray_tuple</span><span class="p">(</span><span class="n">state1</span><span class="p">,</span> <span class="n">state2</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="n">atol</span><span class="p">)</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">w1</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">w2</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">rtol</span><span class="o">=</span><span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="n">atol</span><span class="p">)</span></div>
<div class="viewcode-block" id="EnvManager"><a class="viewcode-back" href="../../api/python/tools/test_utils.html#mxnet.test_utils.EnvManager">[docs]</a><span class="k">class</span> <span class="nc">EnvManager</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="sd">"""Environment variable setter and unsetter via with idiom"""</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_key</span> <span class="o">=</span> <span class="n">key</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_next_val</span> <span class="o">=</span> <span class="n">val</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_prev_val</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span> <span class="nf">__enter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_prev_val</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_key</span><span class="p">)</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_key</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_next_val</span>
<span class="k">def</span> <span class="nf">__exit__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ptype</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">trace</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prev_val</span><span class="p">:</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_key</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prev_val</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">del</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_key</span><span class="p">]</span></div>
</pre></div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script src="../../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>