blob: e5f6cefc042d65ace1edb392cda73c12c510bc2e [file] [log] [blame]
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
--><!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>A Beginner’s Guide to Implementing Operators in MXNet Backend — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../_static/underscore.js" type="text/javascript"></script>
<script src="../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../_static/doctools.js" type="text/javascript"></script>
<script src="../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../genindex.html" rel="index" title="Index">
<link href="../search.html" rel="search" title="Search"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background.png" role="document">
<div class="content-block"><!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../install/index.html">Install</a>
<a class="main-nav-link" href="../tutorials/index.html">Tutorials</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../gluon/index.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">Tutorials</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../api/perl/index.html">Perl</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="../faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.0.0/example">Examples</a></li>
<li><a class="main-nav-link" href="../model_zoo/index.html">Model Zoo</a></li>
</ul>
</span>
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="../community/index.html">Community</a></li>
<li><a class="main-nav-link" href="../community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="../community/powered_by.html">Powered By</a></li>
</ul>
</span>
<a class="main-nav-link" href="http://discuss.mxnet.io">Discuss</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(1.0.0)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/>1.1.0</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="../install/index.html">Install</a></li>
<li><a class="main-nav-link" href="../tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="../community/index.html" tabindex="-1">Community</a></li>
<li><a href="../community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="../community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="../tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="../faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="../architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.0.0/example" tabindex="-1">Examples</a></li>
<li><a href="../model_zoo/index.html" tabindex="-1">Model Zoo</a></li>
</ul>
</li>
<li><a href="../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(1.0.0)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/>1.1.0</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/1.0.0/index.html>1.0.0</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/0.12.0/index.html>0.12.0</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/0.11.0/index.html>0.11.0</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorials/index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/index.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="a-beginner-s-guide-to-implementing-operators-in-mxnet-backend">
<span id="a-beginner-s-guide-to-implementing-operators-in-mxnet-backend"></span><h1>A Beginner’s Guide to Implementing Operators in MXNet Backend<a class="headerlink" href="#a-beginner-s-guide-to-implementing-operators-in-mxnet-backend" title="Permalink to this headline"></a></h1>
<div class="section" id="introduction">
<span id="introduction"></span><h2>Introduction<a class="headerlink" href="#introduction" title="Permalink to this headline"></a></h2>
<p>Operators are essential elements for constructing neural networks. They define mathematical formulas
of transforming input data (tensors) to outputs. MXNet has a rich set of operators from simple ones,
such as element-wise sum, to complicated ones, such as convolution, that is
capable of constructing most of the popular neural networks. You may have noticed
that many operators implemented in MXNet have their equivalent forms in Numpy, such as
<a class="reference external" href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html">repeat</a>,
<a class="reference external" href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html">tile</a>,
etc., and wonder why we could not simply use those Numpy operators in MXNet. One of the
major reasons is that we need to support both CPU and GPU computing for the operators in MXNet,
while Numpy operators do not possess GPU computing capability.
In addition, we have performed plenty of
optimizations for various components in MXNet, such as tensor data structure (<code class="docutils literal"><span class="pre">NDArray</span></code>),
execution engine, computational graph and so on, for maximizing memory and runtime efficiency.
An operator implemented under the MXNet operator framework would greatly
leverage those optimizations for exhaustive performance enhancement.</p>
<p>In this tutorial, we are going to practice implementing an operator using
C++ in the MXNet backend. After finishing the implementation,
we will add unit tests using Python for the operator we just implemented.</p>
</div>
<div class="section" id="implementation">
<span id="implementation"></span><h2>Implementation<a class="headerlink" href="#implementation" title="Permalink to this headline"></a></h2>
<div class="section" id="an-operator-example">
<span id="an-operator-example"></span><h3>An Operator Example<a class="headerlink" href="#an-operator-example" title="Permalink to this headline"></a></h3>
<p>Let’s take the <a class="reference external" href="https://en.wikipedia.org/wiki/Quadratic_function">quadratic function</a>
as an example: <code class="docutils literal"><span class="pre">f(x)</span> <span class="pre">=</span> <span class="pre">ax^2+bx+c</span></code>. We want to implement an operator called <code class="docutils literal"><span class="pre">quadratic</span></code>
taking <code class="docutils literal"><span class="pre">x</span></code>, which is a tensor, as an input and generating an output tensor <code class="docutils literal"><span class="pre">y</span></code>
satisfying <code class="docutils literal"><span class="pre">y.shape=x.shape</span></code> and each element of <code class="docutils literal"><span class="pre">y</span></code> is calculated by feeding the
corresponding element of <code class="docutils literal"><span class="pre">x</span></code> into the quadratic function <code class="docutils literal"><span class="pre">f</span></code>.
Here variables <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code> are user input parameters.
In frontend, the operator works like this:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">x</span> <span class="o">=</span> <span class="p">[[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]]</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">quadratic</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="p">[[</span><span class="mi">6</span><span class="p">,</span> <span class="mi">11</span><span class="p">],</span> <span class="p">[</span><span class="mi">18</span><span class="p">,</span> <span class="mi">27</span><span class="p">]]</span>
</pre></div>
</div>
<p>To implement this, we first create three files: <code class="docutils literal"><span class="pre">quadratic_op-inl.h</span></code>,
<code class="docutils literal"><span class="pre">quadratic_op.cc</span></code>, and <code class="docutils literal"><span class="pre">quadratic_op.cu</span></code>. The header file’s name
is prefixed by the operator name and followed by <code class="docutils literal"><span class="pre">op</span></code> and <code class="docutils literal"><span class="pre">-inl</span></code>
indicating that this is an operator implementation with inline
functions shared by CPU and GPU computing. The CPU and GPU
specific implementations reside in their own <code class="docutils literal"><span class="pre">.cc</span></code> and <code class="docutils literal"><span class="pre">.cu</span></code> files,
respectively. We normally put pure tensor related operators
(e.g. <code class="docutils literal"><span class="pre">tile</span></code>, <code class="docutils literal"><span class="pre">repeat</span></code>, etc.) under
the directory <code class="docutils literal"><span class="pre">src/operator/tensor</span></code>, and neural network operators
(e.g. <code class="docutils literal"><span class="pre">Convolution</span></code>, <code class="docutils literal"><span class="pre">Pooling</span></code>, etc.) under <code class="docutils literal"><span class="pre">src/operator/nn</span></code>.
You may have noticed that many neural network operators including
<code class="docutils literal"><span class="pre">Convolution</span></code> and <code class="docutils literal"><span class="pre">Pooling</span></code> are currently saved under <code class="docutils literal"><span class="pre">src/operator</span></code>.
We plan to move them to <code class="docutils literal"><span class="pre">src/operator/nn</span></code> for better file organization
and clearer hierarchy in the future.</p>
<p>Next, we are going to</p>
<ol class="simple">
<li>Define the parameter struct
for registering <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code> in <code class="docutils literal"><span class="pre">quadratic_op-inl.h</span></code>.</li>
<li>Define type and shape inference functions in <code class="docutils literal"><span class="pre">quadratic_op-inl.h</span></code>.</li>
<li>Define forward and backward functions in <code class="docutils literal"><span class="pre">quadratic_op-inl.h</span></code>.</li>
<li>Register the operator using <a class="reference external" href="https://github.com/dmlc/nnvm">nnvm</a>
in <code class="docutils literal"><span class="pre">quadratic_op.cc</span></code> and <code class="docutils literal"><span class="pre">quadratic_op.cu</span></code> for
CPU and GPU computing, respectively.</li>
</ol>
<p>Now let’s walk through the process step by step.</p>
</div>
<div class="section" id="parameter-registration">
<span id="parameter-registration"></span><h3>Parameter Registration<a class="headerlink" href="#parameter-registration" title="Permalink to this headline"></a></h3>
<p>We first define <code class="docutils literal"><span class="pre">struct</span> <span class="pre">QuadraticParam</span></code> as a placeholder for the
parameters <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code> in <code class="docutils literal"><span class="pre">quadratic_op-inl.h</span></code>.
The struct inherits from a base template
struct named <code class="docutils literal"><span class="pre">dmlc::Parameter</span></code>, where the template argument is the derived struct
<code class="docutils literal"><span class="pre">QuadraticParam</span></code>. This technique, which is called <a class="reference external" href="https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern">curiously recurring template
pattern</a>,
achieves static polymorphism. It is similar to using a virtual function,
but without the cost associated with dynamic polymorphism.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">struct</span> <span class="nl">QuadraticParam</span> <span class="p">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span> <span class="p">{</span>
<span class="kt">float</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">;</span>
<span class="n">DMLC_DECLARE_PARAMETER</span><span class="p">(</span><span class="n">QuadraticParam</span><span class="p">)</span> <span class="p">{</span>
<span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_default</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Coefficient of the quadratic term in the quadratic function."</span><span class="p">);</span>
<span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_default</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Coefficient of the linear term in the quadratic function."</span><span class="p">);</span>
<span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">c</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_default</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Constant term in the quadratic function."</span><span class="p">);</span>
<span class="p">}</span>
<span class="p">};</span>
</pre></div>
</div>
<p>The function calls in the above parameter struct are self-explanatory by their names.
Note that for each parameter, we set the default value to <code class="docutils literal"><span class="pre">0.0</span></code> such that users can
skip passing 0-value parameters through the quadratic operator interface. You
can choose not to define the default value for a parameter if it is required
at runtime. Meanwhile, adding brief descriptions to the parameters enables
the documentation engine to display them on
<a class="reference external" href="https://mxnet.incubator.apache.org/api/python/index.html">MXNet documentation web page</a>.</p>
</div>
<div class="section" id="attribute-inference">
<span id="attribute-inference"></span><h3>Attribute Inference<a class="headerlink" href="#attribute-inference" title="Permalink to this headline"></a></h3>
<p>Attribute inference is the process of deducing the properties of <code class="docutils literal"><span class="pre">NDArray</span></code>s
in neural networks from user provided information. Two most common attributes
of an <code class="docutils literal"><span class="pre">NDArray</span></code> are data shape and data type.
Let’s take a look at the following example.
Given an input <code class="docutils literal"><span class="pre">NDArray</span></code> called <code class="docutils literal"><span class="pre">data</span></code>, you invoke the <code class="docutils literal"><span class="pre">quadratic</span></code> operator
like this: <code class="docutils literal"><span class="pre">output</span> <span class="pre">=</span> <span class="pre">mx.nd.quadratic(data,</span> <span class="pre">a=1,</span> <span class="pre">b=2,</span> <span class="pre">c=3)</span></code>. Before calculating
the <code class="docutils literal"><span class="pre">output</span></code> values, its shape and data type are inferred from the input
<code class="docutils literal"><span class="pre">data</span></code>‘s shape and type following
the rules you defined in order to allocate memory space for the output tensor.</p>
<p>One important thing to note that inference functions should be capable of
performing <strong>mutual inference</strong>, i.e.
inferring one argument’s attribute from another argument’s attribute if
possible according to the definition of the operator.
This is very useful for a computational graph to deduce unknown attributes
for a neural network in symbolic programming. Users can view the computational
graph as a symbol with every element initialized for running data
throughout the neural network, including memory allocation for each tensor,
device placement for each operator, etc. Users normally just need
to provide minimum necessary information, such as input data shapes, etc.,
to the computational graph, and the graph will fill up the unknown attributes
using the attribute inference functions defined in the operators building up
the neural network.</p>
<p>Let’s consider the following example.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="gp">>>> </span><span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'a'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="gp">>>> </span><span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'b'</span><span class="p">)</span>
<span class="gp">>>> </span><span class="n">c</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'c'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="gp">>>> </span><span class="n">d</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="n">b</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="n">c</span>
<span class="gp">>>> </span><span class="k">print</span> <span class="n">d</span><span class="o">.</span><span class="n">infer_shape</span><span class="p">()</span>
<span class="go">([(2L, 3L), (2L, 3L), (2L, 3L)], [(2L, 3L)], [])</span>
</pre></div>
</div>
<p>The last line of the above code snippet is a tuple of three lists returned
by <code class="docutils literal"><span class="pre">d.infer_shape()</span></code>. The first list contains all the argument shapes
of <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code>. The second contains the output shape of <code class="docutils literal"><span class="pre">d</span></code>. The
third one represents the shapes of auxiliary states, which is not used
in this case, and thus is empty.
In this example, we only specified values for variable <code class="docutils literal"><span class="pre">a</span></code>‘s first dimension
and <code class="docutils literal"><span class="pre">c</span></code>‘s second dimension. The <code class="docutils literal"><span class="pre">0</span></code> in shape <code class="docutils literal"><span class="pre">(2,</span> <span class="pre">0)</span></code> indicates that the size
of the second dimension is unknown, same meaning for shape <code class="docutils literal"><span class="pre">(0,</span> <span class="pre">3)</span></code>.
However, the symbol <code class="docutils literal"><span class="pre">d</span></code> still successfully inferred the shapes
for all the variables and final output. This is a result of mutual
inference. In MXNet, the whole process can be interpreted as this:</p>
<ol class="simple">
<li><code class="docutils literal"><span class="pre">a</span></code> and <code class="docutils literal"><span class="pre">b</span></code> are combined via an element-wise multiplication operator,
so the shapes of <code class="docutils literal"><span class="pre">a</span></code> and <code class="docutils literal"><span class="pre">b</span></code> are same and <code class="docutils literal"><span class="pre">b</span></code>‘s first dimension size is <code class="docutils literal"><span class="pre">2</span></code>.</li>
<li><code class="docutils literal"><span class="pre">b</span></code> and <code class="docutils literal"><span class="pre">c</span></code> are combined via an element-wise multiplication operator too,
so the shapes of <code class="docutils literal"><span class="pre">b</span></code> and <code class="docutils literal"><span class="pre">c</span></code> are same and <code class="docutils literal"><span class="pre">b</span></code>‘s second dimension size is <code class="docutils literal"><span class="pre">3</span></code>.</li>
<li>Now <code class="docutils literal"><span class="pre">b</span></code>‘s shape is completely known, so <code class="docutils literal"><span class="pre">a</span></code> and <code class="docutils literal"><span class="pre">c</span></code> missing dimension sizes
are known as well.</li>
<li><code class="docutils literal"><span class="pre">d</span></code> is a result from adding <code class="docutils literal"><span class="pre">a</span> <span class="pre">*</span> <span class="pre">b</span></code> and <code class="docutils literal"><span class="pre">b</span> <span class="pre">*</span> <span class="pre">c</span></code>, so d should also
have the same shape as <code class="docutils literal"><span class="pre">b</span></code>.</li>
</ol>
<p>The above four steps illustrate how shape inference logic works in MXNet.
It is actually implemented in the shape inference functions of the operators for
element-wise multiplication and addition.</p>
<p>For our <code class="docutils literal"><span class="pre">quadratic</span></code> operator, shape inference possesses quite similar logic.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="kr">inline</span> <span class="kt">bool</span> <span class="nf">QuadraticOpShape</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">>*</span> <span class="n">in_attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">>*</span> <span class="n">out_attrs</span><span class="p">)</span> <span class="p">{</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">in_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">out_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="n">SHAPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">out_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">in_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span>
<span class="n">SHAPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">in_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span>
<span class="k">return</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">).</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">0U</span> <span class="o">&amp;&amp;</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">).</span><span class="n">Size</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">0U</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>Here are a few things to note about the above function:</p>
<ol class="simple">
<li><code class="docutils literal"><span class="pre">attrs</span></code> contains parameters <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code> from user input.
It’s not used here since we don’t rely on that information for shape inference.</li>
<li><code class="docutils literal"><span class="pre">in_attrs</span></code> is a vector containing all input shapes. Since there is
only one input argument for operator <code class="docutils literal"><span class="pre">quadratic</span></code>, we used macro <code class="docutils literal"><span class="pre">CHECK_EQ</span></code>
to assert when the vector’s size is wrong.</li>
<li><code class="docutils literal"><span class="pre">out_attrs</span></code> is a vector containing all output shapes. We also used
<code class="docutils literal"><span class="pre">CHECK_EQ</span></code> to verify the size of the vector since there is only one output.</li>
<li>We called macro <code class="docutils literal"><span class="pre">SHAPE_ASSIGN_CHECK</span></code> twice for mutual inference. One for
inferring the output shape from the input shape, the other one is for inferring
the input shape from the output shape.
If there are any unequal non-zero values in the same
dimension of two shapes, such as <code class="docutils literal"><span class="pre">(2,</span> <span class="pre">3)</span></code> and <code class="docutils literal"><span class="pre">(3,</span> <span class="pre">3)</span></code>, the macro would throw an
exception with an error message for shape inference.</li>
<li>At the end of the function body, we checked whether the output shape
is completely known by testing whether the shape is not empty and
the shape’s size is greater than <code class="docutils literal"><span class="pre">0</span></code>. Note that in MXNet, an empty shape
means that the shape is unknown, and
a <code class="docutils literal"><span class="pre">0</span></code> in a shape means that the size of that dimension is unknown. In both
situations, the missing shape information must
be inferred from other shapes. If it cannot be inferred,
the function should return <code class="docutils literal"><span class="pre">false</span></code> to notify the caller about shape inference failure.</li>
<li>MXNet provides a convenience function implementing the logic of mutual inference
for general element-wise operators with the following interface. Users can
instantiate this function with <code class="docutils literal"><span class="pre">n_in=1</span></code> and <code class="docutils literal"><span class="pre">n_out=1</span></code> to replace the above
function <code class="docutils literal"><span class="pre">QuadraticOpShape</span></code> in operator registration (explained later).
The function <code class="docutils literal"><span class="pre">QuadraticOpShape</span></code> posted here is for the purpose of illustration only.</li>
</ol>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">n_in</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n_out</span><span class="o">></span>
<span class="kr">inline</span> <span class="kt">bool</span> <span class="n">ElemwiseShape</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">*</span><span class="n">in_attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">*</span><span class="n">out_attrs</span><span class="p">);</span>
</pre></div>
</div>
<p>The same logic goes for data type inference. We will leave the analysis of
the following code sample to users. Note that <code class="docutils literal"><span class="pre">-1</span></code> means the data type
is unknown and must be inferred from other input or output data types.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="kr">inline</span> <span class="kt">bool</span> <span class="nf">QuadraticOpType</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">in_attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">out_attrs</span><span class="p">)</span> <span class="p">{</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">in_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">out_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="n">TYPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">out_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">in_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span>
<span class="n">TYPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">in_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span>
<span class="k">return</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>Again, MXNet provides the following convenience function for mutual
type inference of element-wise operators. Users can use that
in operator registration (explained later).</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">n_in</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n_out</span><span class="o">></span>
<span class="kr">inline</span> <span class="kt">bool</span> <span class="n">ElemwiseType</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">in_attrs</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">out_attrs</span><span class="p">);</span>
</pre></div>
</div>
</div>
<div class="section" id="forward-function">
<span id="forward-function"></span><h3>Forward Function<a class="headerlink" href="#forward-function" title="Permalink to this headline"></a></h3>
<p>Forward function defines the operator’s behavior in the forward pass
of neural networks. For our <code class="docutils literal"><span class="pre">quadratic</span></code> operator, it simply implements
the logic of running a tensor through the quadratic function by performing
a few element-wise operations. The forward function’s signature is fixed
in MXNet as follows:</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="kt">void</span> <span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&amp;</span> <span class="n">inputs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&amp;</span> <span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&amp;</span> <span class="n">outputs</span><span class="p">);</span>
</pre></div>
</div>
<p>We first paste the whole forward function code here
and then go through it line by line.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span> <span class="c1">// 1</span>
<span class="kt">void</span> <span class="n">QuadraticOpForward</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span> <span class="c1">// 2</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span> <span class="c1">// 3</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&amp;</span> <span class="n">inputs</span><span class="p">,</span> <span class="c1">// 4</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&amp;</span> <span class="n">req</span><span class="p">,</span> <span class="c1">// 5</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&amp;</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 6</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 7</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 8</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 9</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span><span class="p">();</span> <span class="c1">// 10</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">in_data</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 11</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">out_data</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 12</span>
<span class="k">const</span> <span class="n">QuadraticParam</span><span class="o">&amp;</span> <span class="n">param</span> <span class="o">=</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">get</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">(</span><span class="n">attrs</span><span class="p">.</span><span class="n">parsed</span><span class="p">);</span> <span class="c1">// 13</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet_op</span><span class="p">;</span> <span class="c1">// 14</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">out_data</span><span class="p">.</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 15</span>
<span class="n">MXNET_ASSIGN_REQ_SWITCH</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req_type</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 16</span>
<span class="n">Kernel</span><span class="o"><</span><span class="n">quadratic_forward</span><span class="o"><</span><span class="n">req_type</span><span class="o">></span><span class="p">,</span> <span class="n">xpu</span><span class="o">>::</span><span class="n">Launch</span><span class="p">(</span> <span class="c1">// 17</span>
<span class="n">s</span><span class="p">,</span> <span class="n">out_data</span><span class="p">.</span><span class="n">Size</span><span class="p">(),</span> <span class="n">out_data</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">in_data</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="c1">// 18</span>
<span class="n">param</span><span class="p">.</span><span class="n">a</span><span class="p">,</span> <span class="n">param</span><span class="p">.</span><span class="n">b</span><span class="p">,</span> <span class="n">param</span><span class="p">.</span><span class="n">c</span><span class="p">);</span> <span class="c1">// 19</span>
<span class="p">});</span> <span class="c1">// 20</span>
<span class="p">});</span> <span class="c1">// 21</span>
<span class="p">}</span> <span class="c1">// 22</span>
</pre></div>
</div>
<ul class="simple">
<li>Line 1: <code class="docutils literal"><span class="pre">xpu</span></code> stands for a generic device type so that the function can be instantiated
for both CPU and GPU computing using concrete types <code class="docutils literal"><span class="pre">cpu</span></code> and <code class="docutils literal"><span class="pre">gpu</span></code>. The instantiation happens
at the time when the operator is registered in <code class="docutils literal"><span class="pre">.cc</span></code> and <code class="docutils literal"><span class="pre">.cu</span></code> files.</li>
<li>Line 2: <code class="docutils literal"><span class="pre">attrs</span></code> is a node attribute containing the user input parameters <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code>.
Here the node represents a placeholder for the operator in the whole computational graph for
the neural network.</li>
<li>Line 3: <code class="docutils literal"><span class="pre">ctx</span></code> holds something called <code class="docutils literal"><span class="pre">stream</span></code> for
serializing asynchronous executions. Let’s consider
this example for understanding the functionality of <code class="docutils literal"><span class="pre">stream</span></code>.
We want to launch several GPU kernels with the same <code class="docutils literal"><span class="pre">stream</span></code> from CPU.
Even though the launching operation is non-blocking, the <code class="docutils literal"><span class="pre">stream</span></code> guarantees
that the kernels execute in the same order on GPU as they are launched from CPU.</li>
<li>Line 4: <code class="docutils literal"><span class="pre">inputs</span></code> is a vector of input tensors (only one input tensor
for the <code class="docutils literal"><span class="pre">quadratic</span></code> operator).</li>
<li>Line 5: <code class="docutils literal"><span class="pre">req</span></code> is a vector of <code class="docutils literal"><span class="pre">OpReqType</span></code> values. Each value defines
the way of writing calculated values to the output tensors.
Therefore, the number of <code class="docutils literal"><span class="pre">req</span></code>s must be the same as the number of output tensors.
MXNet currently supports three types of <code class="docutils literal"><span class="pre">req</span></code> in frontend: <code class="docutils literal"><span class="pre">null</span></code>, <code class="docutils literal"><span class="pre">write</span></code>, and <code class="docutils literal"><span class="pre">add</span></code>.
<code class="docutils literal"><span class="pre">null</span></code> means skipping calculating the corresponding output tensor,
<code class="docutils literal"><span class="pre">write</span></code> means overwriting the values in the output tensor with the ones
calculated by this operator, and <code class="docutils literal"><span class="pre">add</span></code> means adding the calculated values
to the existing ones in the output tensor. Note that <code class="docutils literal"><span class="pre">null</span></code> and <code class="docutils literal"><span class="pre">add</span></code> are usually
seen in backward passes. The former is for skipping calculating
the gradients of un-learnable parameters (such as index arrays),
and the latter is for accumulating gradients throughout networks.</li>
<li>Line 6: <code class="docutils literal"><span class="pre">outputs</span></code> is a vector of output tensors (only one
output tensor for the <code class="docutils literal"><span class="pre">quadratic</span></code> operator).</li>
<li>Lines 7-9: Verify that the size of each vector is expected.
Otherwise, stop moving forward and print error message.</li>
<li>Line 10: Get the <code class="docutils literal"><span class="pre">stream</span></code> from the <code class="docutils literal"><span class="pre">ctx</span></code> for launching kernels.</li>
<li>Lines 11-12: Define the references of the input and output tensors
for later coding convenience. Note that <code class="docutils literal"><span class="pre">TBlob</span></code> can be understood
as a uniform data structure for tensors of various dimensions, such
that tensors of different dimensions can be put in a homogeneous container,
such as <code class="docutils literal"><span class="pre">std::vector</span></code> and <code class="docutils literal"><span class="pre">std::list</span></code>. You can still
get tensors of desired dimensions from a <code class="docutils literal"><span class="pre">TBlob</span></code> object through
the interface <code class="docutils literal"><span class="pre">get_with_shape</span></code>.</li>
<li>Line 13: Get user input parameters from the node attribute.</li>
<li>Lines 15-21: This is the place where the mathematical formula of the operator
is implemented. The macros <code class="docutils literal"><span class="pre">MSHADOW_TYPE_SWITCH</span></code> and <code class="docutils literal"><span class="pre">MXNET_ASSIGN_REQ_SWITCH</span></code> enable
the code block to work for all the supported data types and <code class="docutils literal"><span class="pre">req</span></code> types in MXNet.
Inside the inner-most macro, we launch the kernel for calculating
the output tensor such that each thread takes an element from
the input tensor, feeds it into the quadratic function, and assigns
the output element to the output tensor based on <code class="docutils literal"><span class="pre">req</span></code> type. Note that
<code class="docutils literal"><span class="pre">Kernel::Launch</span></code> serves as a universal interface for launching
parallel computation on both CPU and GPU. This allows most of
the simple operators to share the same piece of code for CPU and GPU as
parallelization approaches are often identical on both types of devices.
The kernel function is defined as the following, where the function
<code class="docutils literal"><span class="pre">Map</span></code> is executed by each thread for each input element. To explain a little
bit more on the two macros used in the kernel struct: (1) <code class="docutils literal"><span class="pre">MSHADOW_XINLINE</span></code> is
a consolidated macro for inlining functions compiled by both CPU and GPU
compilers. It enables CPU and GPU computing to share the same piece of code.
(2) <code class="docutils literal"><span class="pre">KERNEL_ASSIGN</span></code> is a macro for unifying the statements of different <code class="docutils literal"><span class="pre">req</span></code>s
into the same line of code. It’s named <code class="docutils literal"><span class="pre">KERNEL_ASSIGN</span></code> because we call
the code blocks running parallel computation kernels.
On CPUs, the kernels are normally wrapped by the OpenMP <code class="docutils literal"><span class="pre">parallel</span></code> directive;
while on GPUs, they are the kernel functions launched by CUDA library.</li>
</ul>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">req</span><span class="o">></span>
<span class="k">struct</span> <span class="n">quadratic_forward</span> <span class="p">{</span>
<span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">DType</span><span class="o">></span>
<span class="n">MSHADOW_XINLINE</span> <span class="k">static</span> <span class="kt">void</span> <span class="n">Map</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="p">,</span> <span class="n">DType</span><span class="o">*</span> <span class="n">out_data</span><span class="p">,</span> <span class="k">const</span> <span class="n">DType</span><span class="o">*</span> <span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="kt">float</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">b</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">c</span><span class="p">)</span> <span class="p">{</span>
<span class="n">KERNEL_ASSIGN</span><span class="p">(</span><span class="n">out_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">req</span><span class="p">,</span> <span class="n">in_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span> <span class="o">*</span> <span class="n">in_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span> <span class="o">+</span> <span class="n">c</span><span class="p">);</span>
<span class="p">}</span>
<span class="p">};</span>
</pre></div>
</div>
</div>
<div class="section" id="backward-function">
<span id="backward-function"></span><h3>Backward Function<a class="headerlink" href="#backward-function" title="Permalink to this headline"></a></h3>
<p>Backward functions play the role of propagating derivatives of loss function
with respect to the outputs of the last layer throughout the network to the first
layer. The whole process is often known as backward propagation. We are not
going to delineate the principle of backward propagation here since users can find
great details covered in other resources, such as
<a class="reference external" href="http://cs231n.github.io/optimization-2/">CS231n</a> and
<a class="reference external" href="http://neuralnetworksanddeeplearning.com/chap2.html">How the backgropagation algorithm works</a>.
The problem we are going to solve here for the <code class="docutils literal"><span class="pre">quadratic</span></code> operator is that
given a tensor representing the gradient of the loss function with respect
to the output of the operator, calculate the gradient with respect to
the input of the operator. There is no need to calculate the derivatives
of loss function with respect to user input parameters <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code>
since they are not learnable parameters in the network. To formulate the problem:
given <code class="docutils literal"><span class="pre">dL/dy</span></code> and <code class="docutils literal"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">a*x^2</span> <span class="pre">+</span> <span class="pre">b*x</span> <span class="pre">+</span> <span class="pre">c</span></code>, where <code class="docutils literal"><span class="pre">L</span></code> represents the loss function and
<code class="docutils literal"><span class="pre">y</span></code> stands for the output of the quadratic tensor, we need to solve for
<code class="docutils literal"><span class="pre">dL/dx</span></code>. Using the chain-rule, it is obvious to find that</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">dL</span><span class="o">/</span><span class="n">dx</span> <span class="o">=</span> <span class="n">dL</span><span class="o">/</span><span class="n">dy</span> <span class="o">*</span> <span class="n">dy</span><span class="o">/</span><span class="n">dx</span> <span class="o">=</span> <span class="n">dL</span><span class="o">/</span><span class="n">dy</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">a</span><span class="o">*</span><span class="n">x</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span><span class="o">.</span>
</pre></div>
</div>
<p>The above equation indicates that <code class="docutils literal"><span class="pre">dL/dx</span></code> depends on the gradient
of the output tensor and value of the input tensor.
The backward function’s signature is the same as the forward function’s.
With the aforementioned information in mind,
let’s breakdown the following backward function line by line.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span> <span class="c1">// 1</span>
<span class="kt">void</span> <span class="n">QuadraticOpBackward</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span> <span class="c1">// 2</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span> <span class="c1">// 3</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&amp;</span> <span class="n">inputs</span><span class="p">,</span> <span class="c1">// 4</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&amp;</span> <span class="n">req</span><span class="p">,</span> <span class="c1">// 5</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&amp;</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 6</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">2U</span><span class="p">);</span> <span class="c1">// 7</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 8</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 9</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span><span class="p">();</span> <span class="c1">// 10</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">out_grad</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 11</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">in_data</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span> <span class="c1">// 12</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">in_grad</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 13</span>
<span class="k">const</span> <span class="n">QuadraticParam</span><span class="o">&amp;</span> <span class="n">param</span> <span class="o">=</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">get</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">(</span><span class="n">attrs</span><span class="p">.</span><span class="n">parsed</span><span class="p">);</span> <span class="c1">// 14</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet_op</span><span class="p">;</span> <span class="c1">// 15</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">out_grad</span><span class="p">.</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 16</span>
<span class="n">MXNET_ASSIGN_REQ_SWITCH</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req_type</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 17</span>
<span class="n">Kernel</span><span class="o"><</span><span class="n">quadratic_backward</span><span class="o"><</span><span class="n">req_type</span><span class="o">></span><span class="p">,</span> <span class="n">xpu</span><span class="o">>::</span><span class="n">Launch</span><span class="p">(</span> <span class="c1">// 18</span>
<span class="n">s</span><span class="p">,</span> <span class="n">in_grad</span><span class="p">.</span><span class="n">Size</span><span class="p">(),</span> <span class="n">in_grad</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">out_grad</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="c1">// 19</span>
<span class="n">in_data</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">param</span><span class="p">.</span><span class="n">a</span><span class="p">,</span> <span class="n">param</span><span class="p">.</span><span class="n">b</span><span class="p">);</span> <span class="c1">// 20</span>
<span class="p">});</span> <span class="c1">// 21</span>
<span class="p">});</span> <span class="c1">// 22</span>
<span class="p">}</span> <span class="c1">// 23</span>
</pre></div>
</div>
<ul class="simple">
<li>Lines 1-6: Backward function has the same signature as forward function.</li>
<li>Lines 7-9: Check the sizes of the function arguments. One thing to note
that since the gradient of the input depends on both the gradient of the output and
the input tensor itself, <code class="docutils literal"><span class="pre">inputs</span></code> must contain two <code class="docutils literal"><span class="pre">TBlob</span></code> objects.</li>
<li>Line 10: Get the <code class="docutils literal"><span class="pre">stream</span></code> of the context for serializing asynchronous executions.</li>
<li>Lines 11-13: Convenience reference variables for later use. We name <code class="docutils literal"><span class="pre">out_grad</span></code>
as the gradient of the operator output, <code class="docutils literal"><span class="pre">in_data</span></code> as the input of the operator,
and <code class="docutils literal"><span class="pre">in_grad</span></code> as the gradient of the operator input.</li>
<li>Line 14: Get the parameter object of <code class="docutils literal"><span class="pre">QuadraticParam</span></code>.</li>
<li>Lines 16-22: Same as in the forward function, this is where parallel
computation for <code class="docutils literal"><span class="pre">in_grad</span></code> happens. The struct <code class="docutils literal"><span class="pre">quadratic_backward</span></code> implements
the formula of calculating each element of <code class="docutils literal"><span class="pre">in_grad</span></code> by one thread as the following.</li>
</ul>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">req</span><span class="o">></span>
<span class="k">struct</span> <span class="n">quadratic_backward</span> <span class="p">{</span>
<span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">DType</span><span class="o">></span>
<span class="n">MSHADOW_XINLINE</span> <span class="k">static</span> <span class="kt">void</span> <span class="n">Map</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="p">,</span> <span class="n">DType</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span> <span class="k">const</span> <span class="n">DType</span><span class="o">*</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">DType</span><span class="o">*</span> <span class="n">in_data</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="n">KERNEL_ASSIGN</span><span class="p">(</span><span class="n">in_grad</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">req</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">in_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">b</span><span class="p">));</span>
<span class="p">}</span>
<span class="p">};</span>
</pre></div>
</div>
</div>
<div class="section" id="operator-registration">
<span id="operator-registration"></span><h3>Operator Registration<a class="headerlink" href="#operator-registration" title="Permalink to this headline"></a></h3>
<p>So far, we have implemented necessary data structure and functions for the operator <code class="docutils literal"><span class="pre">quadratic</span></code>.
Now let’s register them using <code class="docutils literal"><span class="pre">nnvm</span></code> to expose the operator <code class="docutils literal"><span class="pre">quadratic</span></code>
to frontend. Users can consider the registration process as creating the operator object
instance, saving it in the operator manager (a singleton),
and setting attributes for the operator instance.</p>
<p>The following code is from <code class="docutils literal"><span class="pre">quadratic_op.cc</span></code>, which is responsible
for registering the operator working on CPU.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">DMLC_REGISTER_PARAMETER</span><span class="p">(</span><span class="n">QuadraticParam</span><span class="p">);</span> <span class="c1">// 1</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">quadratic</span><span class="p">)</span> <span class="c1">// 2</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="sa">R</span><span class="s">"</span><span class="dl">code(</span><span class="s">This operators implements the quadratic function: // 3</span>
<span class="s">.. math::</span>
<span class="s"> f(x) = ax^2+bx+c</span>
<span class="s">where :math:`x` is an input tensor and all operations</span>
<span class="s">in the function are element-wise.</span>
<span class="s">Example::</span>
<span class="s"> x = [[1, 2], [3, 4]]</span>
<span class="s"> y = quadratic(data=x, a=1, b=2, c=3)</span>
<span class="s"> y = [[6, 11], [18, 27]]</span>
<span class="dl">)code</span><span class="s">"</span> <span class="n">ADD_FILELINE</span><span class="p">)</span> <span class="c1">// 4</span>
<span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">)</span> <span class="c1">// 5</span>
<span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1">// 6</span>
<span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1">// 7</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FListInputNames</span><span class="o">></span><span class="p">(</span><span class="s">"FListInputNames"</span><span class="p">,</span> <span class="c1">// 8</span>
<span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 9</span>
<span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">></span><span class="p">{</span><span class="s">"data"</span><span class="p">};</span> <span class="c1">// 10</span>
<span class="p">})</span> <span class="c1">// 11</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">QuadraticOpShape</span><span class="p">)</span> <span class="c1">// 12</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">></span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">QuadraticOpType</span><span class="p">)</span> <span class="c1">// 13</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">QuadraticOpForward</span><span class="o"><</span><span class="n">cpu</span><span class="o">></span><span class="p">)</span> <span class="c1">// 14</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FGradient</span><span class="o">></span><span class="p">(</span><span class="s">"FGradient"</span><span class="p">,</span> <span class="n">ElemwiseGradUseIn</span><span class="p">{</span><span class="s">"_backward_quadratic"</span><span class="p">})</span> <span class="c1">// 15</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">></span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span> <span class="c1">// 16</span>
<span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 17</span>
<span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">></span> <span class="o">></span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span> <span class="c1">// 18</span>
<span class="p">})</span> <span class="c1">// 19</span>
<span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"data"</span><span class="p">,</span> <span class="s">"NDArray-or-Symbol"</span><span class="p">,</span> <span class="s">"Input ndarray"</span><span class="p">)</span> <span class="c1">// 20</span>
<span class="p">.</span><span class="n">add_arguments</span><span class="p">(</span><span class="n">QuadraticParam</span><span class="o">::</span><span class="n">__FIELDS__</span><span class="p">());</span> <span class="c1">// 21</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_quadratic</span><span class="p">)</span> <span class="c1">// 22</span>
<span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">)</span> <span class="c1">// 23</span>
<span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="c1">// 24</span>
<span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1">// 25</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">TIsBackward</span><span class="o">></span><span class="p">(</span><span class="s">"TIsBackward"</span><span class="p">,</span> <span class="nb">true</span><span class="p">)</span> <span class="c1">// 26</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">QuadraticOpBackward</span><span class="o"><</span><span class="n">cpu</span><span class="o">></span><span class="p">);</span> <span class="c1">// 27</span>
</pre></div>
</div>
<ul class="simple">
<li>Line 1: Register the parameter struct.</li>
<li>Line 2: Register an operator named <code class="docutils literal"><span class="pre">quadratic</span></code> by creating an instance
of <code class="docutils literal"><span class="pre">Op</span></code> type and save it in the operator manager and return a reference
of the just created operator object.</li>
<li>Lines 3-4: Add description as an operator attribute
including examples of the operator. The documentation engine would extract
this description and display it on the documentation web page.</li>
<li>Line 5: Set parameter struct parser for the operator. It is used for parsing
the parameters <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code> input from frontend.</li>
<li>Line 6: Set the number of inputs for the operator.</li>
<li>Line 7: Set the number of outputs for the operator.</li>
<li>Lines 8-11: Defines a function generating a vector of names of
the operator input arguments. This function is used to add missing
arguments that users did not specify when creating a symbolic operator.
For example, <code class="docutils literal"><span class="pre">quad_func=mx.sym.quadratic()</span></code> is still a valid symbol
since we have added the attribute <code class="docutils literal"><span class="pre">FListInputNames</span></code> to the operator node
in the computational graph. MXNet would
add the missing argument with name <code class="docutils literal"><span class="pre">quadratic0_data</span></code>, where the prefix
<code class="docutils literal"><span class="pre">quadratic0</span></code> is the operator name appended with an index and the postfix
<code class="docutils literal"><span class="pre">data</span></code> comes from the return value of the user defined <code class="docutils literal"><span class="pre">FListInputName</span></code> function.
Users still can generate an executor for the <code class="docutils literal"><span class="pre">quand_func</span></code> like the following:</li>
</ul>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">quand_exe</span> <span class="o">=</span> <span class="n">quand_func</span><span class="o">.</span><span class="n">simple_bind</span><span class="p">(</span><span class="n">ctx</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">quandratic0_data</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,))</span>
</pre></div>
</div>
<ul class="simple">
<li>Line 12: Register shape inference function.</li>
<li>Line 13: Register type inference function.</li>
<li>Line 14: Register forward function.</li>
<li>Line 15: Register the function for creating the node of the operator in
a backward pass. Note that we used a convenience functor struct <code class="docutils literal"><span class="pre">ElemwiseGradUseIn</span></code>.
As you can tell from the name, the registered functor creates the node for gradient computation
with dependencies on the output gradient node and input node. Similarly, there are
other three functors defined as <code class="docutils literal"><span class="pre">ElemwiseGradUseOut</span></code>, <code class="docutils literal"><span class="pre">ElemwiseGradUseInOut</span></code>,
and <code class="docutils literal"><span class="pre">ElemwiseGradUseNone</span></code> for developers’ convenience. In order to add
this attribute, we also need to register a backward operator for <code class="docutils literal"><span class="pre">quadratic</span></code> with
several basic attributes, as it can share attribute inference
functions with the forward operator and is not exposed to frontend.</li>
<li>Lines 16-19: This registered function implies that which output tensor can reuse
which input tensor’s memory space instead of allocating a new memory space for the output.
In the operator <code class="docutils literal"><span class="pre">quadratic</span></code>, there is only one input and output, and the output can reuse
the input memory space, so we store a pair of zeros in the function return vector
indicating that <code class="docutils literal"><span class="pre">inputs[0]</span></code>‘s memory space can be reused by <code class="docutils literal"><span class="pre">outputs[0]</span></code>.
Note that this function just provides a hint to the computational graph initializer.
If there are other nodes depending on the input tensor, the memory space
of the input tensor will not be overwritten by the output.</li>
<li>Line 20: Define the input argument name as <code class="docutils literal"><span class="pre">data</span></code> for the operator.</li>
<li>Line 21: Add user input parameters <code class="docutils literal"><span class="pre">a</span></code>, <code class="docutils literal"><span class="pre">b</span></code>, and <code class="docutils literal"><span class="pre">c</span></code> as the attributes of the operator.</li>
<li>Line 22: Register an operator named <code class="docutils literal"><span class="pre">_backward_quadratic</span></code> for backward pass
of the operator <code class="docutils literal"><span class="pre">quadratic</span></code>. The underscore prefix in the operator name indicates
that this is an operator not exposed to users. The convention
of naming an internally used backward operator is prepending the prefix <code class="docutils literal"><span class="pre">_backward_</span></code>
to the corresponding forward operator name.</li>
<li>Line 23: Set the parameter parser for the operator <code class="docutils literal"><span class="pre">_backward_quadratic</span></code>.</li>
<li>Line 24: Set the number of inputs.</li>
<li>Line 25: Set the number of outputs.</li>
<li>Line 26: Add <code class="docutils literal"><span class="pre">TIsBackward</span></code> attribute for the operator. The shape and type
inference passes use this attribute to determine whether a node in the graph is a
forward or backward node.</li>
<li>Line 27: Register backward function.</li>
</ul>
<p>So far, we have acquired an operator working on CPU in frontend.
In order to register the operator working on GPUs, we just need to add the following
code to <code class="docutils literal"><span class="pre">quadratic_op.cu</span></code>. Note that forward and backward functions
are registered with attribute key <code class="docutils literal"><span class="pre">FCompute<gpu></span></code>, rather than <code class="docutils literal"><span class="pre">FCompute<cpu></span></code>.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">quadratic</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">QuadraticOpForward</span><span class="o"><</span><span class="n">gpu</span><span class="o">></span><span class="p">);</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_quadratic</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">QuadraticOpBackward</span><span class="o"><</span><span class="n">gpu</span><span class="o">></span><span class="p">);</span>
</pre></div>
</div>
</div>
<div class="section" id="unit-test">
<span id="unit-test"></span><h3>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h3>
<p>Now we have finished implementing the operator <code class="docutils literal"><span class="pre">quadratic</span></code> in MXNet backend.
If you use python, when you type <code class="docutils literal"><span class="pre">import</span> <span class="pre">mxnet</span> <span class="pre">as</span> <span class="pre">mx</span></code>, two python
functions for invoking your backend implementation are
generated on the fly: one is for imperative programming
registered as <code class="docutils literal"><span class="pre">mxnet.ndarray.quadratic</span></code> or <code class="docutils literal"><span class="pre">mx.nd.quadratic</span></code> for short;
the other one is for symbolic
programming registered under module <code class="docutils literal"><span class="pre">mxnet.symbol.quadratic</span></code>
or <code class="docutils literal"><span class="pre">mx.sym.quadratic</span></code> for short.</p>
<p>In order to unit test it in frontend, we need to add the following code
to the python file <code class="docutils literal"><span class="pre">test_operator.py</span></code>. Note that while testing the
forward pass is straightforward using <code class="docutils literal"><span class="pre">mx.nd.quadratic</span></code>, testing
the backward involves a bit of more efforts. We create a
<code class="docutils literal"><span class="pre">quadratic</span></code> symbol and feed it into the utility function <code class="docutils literal"><span class="pre">check_numeric_gradient</span></code>.
The utility function will perform a perturbation on the input
and calculate the response rate of the output using the
<a class="reference external" href="https://en.wikipedia.org/wiki/Finite_difference_method">finite difference method</a>.
Then it will compare the gradient from the backward pass with
the values from the finite difference method. The test
will be successful once the comparison satisfies user specified
relative and absolute thresholds.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">test_quadratic_function</span><span class="p">():</span>
<span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">):</span>
<span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">x</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">c</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random_sample</span><span class="p">()</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random_sample</span><span class="p">()</span>
<span class="n">c</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random_sample</span><span class="p">()</span>
<span class="k">for</span> <span class="n">ndim</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">):</span>
<span class="c1"># check forward</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">rand_shape_nd</span><span class="p">(</span><span class="n">ndim</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">rand_ndarray</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">,</span> <span class="n">stype</span><span class="o">=</span><span class="s1">'default'</span><span class="p">)</span>
<span class="n">data_np</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">expected</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">data_np</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">quadratic</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">c</span><span class="p">)</span>
<span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">expected</span><span class="p">)</span>
<span class="c1"># check backward using finite difference</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="n">quad_sym</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">quadratic</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">c</span><span class="p">)</span>
<span class="n">check_numeric_gradient</span><span class="p">(</span><span class="n">quad_sym</span><span class="p">,</span> <span class="p">[</span><span class="n">data_np</span><span class="p">])</span>
</pre></div>
</div>
<p>Note that here we used <code class="docutils literal"><span class="pre">mx.nd.quadratic</span></code> to test the forward function
and <code class="docutils literal"><span class="pre">check_numeric_gradient</span></code> to test the backward function. In MXNet,
two other utility functions are also commonly used: <code class="docutils literal"><span class="pre">check_symbolic_forward</span></code>
and <code class="docutils literal"><span class="pre">check_symbolic_backward</span></code>. By using them in unit tests,
users need to pass in the operator symbols and expected results
for comparison. Please also note that
we highly recommend adding <code class="docutils literal"><span class="pre">check_numeric_gradient</span></code> test for every operator
with backward function implemented as it eliminates the possibility
of passing incorrect expected results into <code class="docutils literal"><span class="pre">check_symbolic_backward</span></code>.</p>
</div>
</div>
<div class="section" id="summary">
<span id="summary"></span><h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline"></a></h2>
<p>In this tutorial, we practiced implementing the operator <code class="docutils literal"><span class="pre">quadratic</span></code> in MXNet backend
and unit testing the implementation in frontend. More specifically, we added parameter
struct for user-input parameters, walked through shape and type inference workflow,
implemented forward and backward functions, and registered the operator
using nnvm. Congratulations! You now know how to add operators.
We welcome your contributions to MXNet.</p>
<p><strong>Note</strong>: Source code in the tutorial can be found in
<a class="reference external" href="https://github.com/reminisce/mxnet/blob/add_op_example_for_tutorial/src/operator/tensor/quadratic_op-inl.h">quadratic_op-inl.h</a>,
<a class="reference external" href="https://github.com/reminisce/mxnet/blob/add_op_example_for_tutorial/src/operator/tensor/quadratic_op.cc">quadratic_op.cc</a>,
<a class="reference external" href="https://github.com/reminisce/mxnet/blob/add_op_example_for_tutorial/src/operator/tensor/quadratic_op.cu">quadratic_op.cu</a>,
and
<a class="reference external" href="https://github.com/reminisce/mxnet/blob/add_op_example_for_tutorial/tests/python/unittest/test_operator.py#L4008">test_operator.py</a>.</p>
</div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">A Beginner’s Guide to Implementing Operators in MXNet Backend</a><ul>
<li><a class="reference internal" href="#introduction">Introduction</a></li>
<li><a class="reference internal" href="#implementation">Implementation</a><ul>
<li><a class="reference internal" href="#an-operator-example">An Operator Example</a></li>
<li><a class="reference internal" href="#parameter-registration">Parameter Registration</a></li>
<li><a class="reference internal" href="#attribute-inference">Attribute Inference</a></li>
<li><a class="reference internal" href="#forward-function">Forward Function</a></li>
<li><a class="reference internal" href="#backward-function">Backward Function</a></li>
<li><a class="reference internal" href="#operator-registration">Operator Registration</a></li>
<li><a class="reference internal" href="#unit-test">Unit Test</a></li>
</ul>
</li>
<li><a class="reference internal" href="#summary">Summary</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div><!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../_static/js/search.js" type="text/javascript"></script>
<script src="../_static/js/navbar.js" type="text/javascript"></script>
<script src="../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../_static/js/copycode.js" type="text/javascript"></script>
<script src="../_static/js/page.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>