blob: 08ac499b542eaafd7af926c409a7889ec5e743aa [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="MXNet System Architecture" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="MXNet System Architecture" property="og:description"/>
<title>MXNet System Architecture — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../_static/underscore.js" type="text/javascript"></script>
<script src="../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../_static/doctools.js" type="text/javascript"></script>
<script src="../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/0.12.1/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../genindex.html" rel="index" title="Index">
<link href="../search.html" rel="search" title="Search"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/0.12.1/install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.12.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.12.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/0.12.1/faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/0.12.1/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/0.12.1">Github</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">0.12.1<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/index.html>1.5.0</a></li><li><a href=/versions/1.4.1/index.html>1.4.1</a></li><li><a href=/versions/1.3.1/index.html>1.3.1</a></li><li><a href=/versions/1.2.1/index.html>1.2.1</a></li><li><a href=/versions/1.1.0/index.html>1.1.0</a></li><li><a href=/versions/1.0.0/index.html>1.0.0</a></li><li><a href=/versions/0.12.1/index.html>0.12.1</a></li><li><a href=/versions/0.11.0/index.html>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/0.12.1/install/index.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.12.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.12.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/0.12.1/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/0.12.1/faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/0.12.1/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/0.12.1/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/0.12.1/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/0.12.1/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/0.12.1" tabindex="-1">Github</a></li>
<li><a href="/versions/0.12.1/community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="/versions/0.12.1/community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">0.12.1</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/index.html>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/index.html>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/index.html>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/index.html>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/index.html>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/index.html>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/index.html>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/index.html>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorials/index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/index.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="mxnet-system-architecture">
<span id="mxnet-system-architecture"></span><h1>MXNet System Architecture<a class="headerlink" href="#mxnet-system-architecture" title="Permalink to this headline"></a></h1>
<p><img alt="System Overview" src="https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/mxnet/system/overview.png"/></p>
<p>This figure shows the major modules and components of the MXNet system and their interaction. The modules are:</p>
<ul class="simple">
<li>Runtime Dependency Engine: Schedules and executes the
operations according to their read/write dependency.</li>
<li>Storage Allocator: Efficiently allocates and recycles memory blocks
on host (CPU) and devices (GPUs).</li>
<li>Resource Manager: Manages global resources, such as the random number generator
and temporal space.</li>
<li>NDArray: Dynamic, asynchronous n-dimensional arrays,
which provide flexible imperative programs for MXNet.</li>
<li>Symbolic Execution: Static symbolic graph executor,
which provides efficient symbolic graph execution and optimization.</li>
<li>Operator: Operators that define static forward and gradient
calculation (backprop).</li>
<li>SimpleOp: Operators that extend NDArray operators and symbolic operators
in a unified fashion.</li>
<li>Symbol Construction: Symbolic construction, which provides a way to construct
a computation graph (net configuration).</li>
<li>KVStore: Key-value store interface for efficient parameter synchronization.</li>
<li>Data Loading(IO): Efficient distributed data loading and augmentation.</li>
</ul>
</div>
<div class="section" id="mxnet-system-components">
<span id="mxnet-system-components"></span><h1>MXNet System Components<a class="headerlink" href="#mxnet-system-components" title="Permalink to this headline"></a></h1>
<div class="section" id="execution-engine">
<span id="execution-engine"></span><h2>Execution Engine<a class="headerlink" href="#execution-engine" title="Permalink to this headline"></a></h2>
<p>You can use MXNet’s engine not only for deep learning,
but for any domain-specific problem.
It’s designed to solve a general problem:
execute a bunch of functions following their dependencies.
Execution of any two functions with dependencies should be serialized.
To boost performance, functions with no dependencies <em>can</em> be executed in parallel.
For a general discussion of this topic,
see our <a class="reference internal" href="note_engine.html"><span class="doc">notes on the dependency engine</span></a>.</p>
<div class="section" id="interface">
<span id="interface"></span><h3>Interface<a class="headerlink" href="#interface" title="Permalink to this headline"></a></h3>
<p>The following API is the core interface for the execution engine:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="kt">void</span> <span class="nf">PushSync</span><span class="p">(</span><span class="n">Fn</span> <span class="n">exec_fun</span><span class="p">,</span> <span class="n">Context</span> <span class="n">exec_ctx</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">VarHandle</span><span class="o">></span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">const_vars</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">VarHandle</span><span class="o">></span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">mutate_vars</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</pre></div>
</div>
<p>This API allows you to push a function (<code class="docutils literal"><span class="pre">exec_fun</span></code>),
along with its context information and dependencies, to the engine.
<code class="docutils literal"><span class="pre">exec_ctx</span></code> is the context information in which the <code class="docutils literal"><span class="pre">exec_fun</span></code> should be executed,
<code class="docutils literal"><span class="pre">const_vars</span></code> denotes the variables that the function reads from,
and <code class="docutils literal"><span class="pre">mutate_vars</span></code> are the variables to be modified.
The engine provides the following guarantee:</p>
<blockquote>
<div><em>The execution of any two functions
that modify a common variable
is serialized in their push order.</em></div></blockquote>
</div>
<div class="section" id="function">
<span id="function"></span><h3>Function<a class="headerlink" href="#function" title="Permalink to this headline"></a></h3>
<p>The function type of the engine is:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">using</span> <span class="n">Fn</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o"><</span><span class="kt">void</span><span class="p">(</span><span class="n">RunContext</span><span class="p">)</span><span class="o">></span><span class="p">;</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">RunContext</span></code> contains runtime information, which is determined by the engine:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">RunContext</span> <span class="p">{</span>
<span class="c1">// stream pointer which could be safely cast to</span>
<span class="c1">// cudaStream_t* type</span>
<span class="kt">void</span> <span class="o">*</span><span class="n">stream</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
<p>Alternatively, you could use <code class="docutils literal"><span class="pre">mxnet::engine::DAGEngine::Fn</span></code>, which has the same type definition.</p>
<p>All of the functions are executed by the engine’s internal threads.
In such a model, it’s usually not a good idea to push <em>blocking</em> functions
to the engine (usually for dealing with I/O tasks like disk, web service, UI, etc.)
because it will occupy the execution thread and reduce total throughput.
In that case, we provide another <em>asynchronous</em> function type:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">using</span> <span class="n">Callback</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o"><</span><span class="kt">void</span><span class="p">()</span><span class="o">></span><span class="p">;</span>
<span class="k">using</span> <span class="n">AsyncFn</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o"><</span><span class="kt">void</span><span class="p">(</span><span class="n">RunContext</span><span class="p">,</span> <span class="n">Callback</span><span class="p">)</span><span class="o">></span><span class="p">;</span>
</pre></div>
</div>
<p>In the <code class="docutils literal"><span class="pre">AsyncFn</span></code> function, you can pass the heavy part to your own threads
and safely exit the body of the function.
The engine doesn’t consider the function finished
until the <code class="docutils literal"><span class="pre">Callback</span></code> function is called.</p>
</div>
<div class="section" id="context">
<span id="context"></span><h3>Context<a class="headerlink" href="#context" title="Permalink to this headline"></a></h3>
<p>You can specify the <code class="docutils literal"><span class="pre">Context</span></code> of the function to be executed within.
This usually includes whether the function should be run on a CPU or a GPU,
and if you specify a GPU, which GPU to use.
<code class="docutils literal"><span class="pre">Context</span></code> is different from <code class="docutils literal"><span class="pre">RunContext</span></code>.
<code class="docutils literal"><span class="pre">Context</span></code> contains device type (GPU/CPU) and device id,
while <code class="docutils literal"><span class="pre">RunContext</span></code> contains information that can be decided only during runtime,
for example, on which stream the function should be executed.</p>
</div>
<div class="section" id="varhandle">
<span id="varhandle"></span><h3>VarHandle<a class="headerlink" href="#varhandle" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">VarHandle</span></code> is used to specify the dependencies of functions.
The MXNet engine is designed to be decoupled from other MXNet modules.
So <code class="docutils literal"><span class="pre">VarHandle</span></code> is like an engine-provided token you use
to represent the external resources the functions can use or modify.
It’s designed to be lightweight, so creating,
deleting, or copying a variable incurs little overhead.
Upon pushing the functions, you need to specify the variables
that will be used (immutable) in the <code class="docutils literal"><span class="pre">const_vars</span></code> vector,
and the variables that will be modified (mutable) in the <code class="docutils literal"><span class="pre">mutate_vars</span></code> vector.
The engine uses one rule for resolving the dependencies among functions:</p>
<blockquote>
<div><em>The execution of any two functions when one of them modifies at least one common variable is serialized in their push order.</em></div></blockquote>
<p>For example, if <code class="docutils literal"><span class="pre">Fn1</span></code> and <code class="docutils literal"><span class="pre">Fn2</span></code> both mutate <code class="docutils literal"><span class="pre">V2</span></code> then <code class="docutils literal"><span class="pre">Fn2</span></code>
is guaranteed to be executed after <code class="docutils literal"><span class="pre">Fn1</span></code>
if <code class="docutils literal"><span class="pre">Fn2</span></code> is pushed after <code class="docutils literal"><span class="pre">Fn1</span></code>.
On the other hand, if <code class="docutils literal"><span class="pre">Fn1</span></code> and <code class="docutils literal"><span class="pre">Fn2</span></code> both use <code class="docutils literal"><span class="pre">V2</span></code>,
their actual execution order could be random.</p>
<p>This design allows the engine to schedule <em>state-mutating</em> operations in a manner
that minimizes calls to allocate new memory.
For example, the weight update function in DNN
can now use the <code class="docutils literal"><span class="pre">+=</span></code> operator
to update the weights in place,
rather than generating a new weight array each time.</p>
<p>To create a variable, use the <code class="docutils literal"><span class="pre">NewVar()</span></code> API.
To delete a variable, use the <code class="docutils literal"><span class="pre">PushDelete</span></code> API.</p>
</div>
<div class="section" id="push-and-wait">
<span id="push-and-wait"></span><h3>Push and Wait<a class="headerlink" href="#push-and-wait" title="Permalink to this headline"></a></h3>
<p><em>All <code class="docutils literal"><span class="pre">Push</span></code> APIs are asynchronous.</em> The API call returns immediately
regardless of whether the pushed <code class="docutils literal"><span class="pre">Fn</span></code> is finished or not.
This allows the engine to start computing at the same time
as the user thread is pushing functions.
<code class="docutils literal"><span class="pre">Push</span></code> APIs are not thread-safe.
To be specific, only one thread should make engine API calls at a time.</p>
<p>If you want to wait for a specific <code class="docutils literal"><span class="pre">Fn</span></code> to finish,
include a callback function in the closure,
and call the function at the end of your <code class="docutils literal"><span class="pre">Fn</span></code>.</p>
<p>If you want to wait for all <code class="docutils literal"><span class="pre">Fn</span></code>s
that involve (use or mutate) a certain variable to finish,
use the <code class="docutils literal"><span class="pre">WaitForVar(var)</span></code> API.</p>
<p>If you want to wait for all pushed <code class="docutils literal"><span class="pre">Fn</span></code>s to finish,
use the <code class="docutils literal"><span class="pre">WaitForAll()</span></code> API.</p>
</div>
<div class="section" id="save-object-creation-cost">
<span id="save-object-creation-cost"></span><h3>Save Object Creation Cost<a class="headerlink" href="#save-object-creation-cost" title="Permalink to this headline"></a></h3>
<p>In some cases, you need to push several functions to the engine for a long period of time.
If the computation of these functions is light,
the overhead of copying lambdas and creating use/mutate variable lists becomes relatively high.
We provide an API to create an <code class="docutils literal"><span class="pre">OprHandle</span></code> beforehand:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="n">OprHandle</span> <span class="nf">NewOperator</span><span class="p">(</span><span class="n">AsyncFn</span> <span class="n">fn</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">VarHandle</span><span class="o">></span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">const_vars</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">VarHandle</span><span class="o">></span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">mutate_vars</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</pre></div>
</div>
<p>You can keep pushing the <code class="docutils literal"><span class="pre">OprHandle</span></code> without repeatedly creating them:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="kt">void</span> <span class="nf">Push</span><span class="p">(</span><span class="n">OprHandle</span> <span class="n">op</span><span class="p">,</span> <span class="n">Context</span> <span class="n">exec_ctx</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</pre></div>
</div>
<p>To delete it, call the <code class="docutils literal"><span class="pre">DeleteOperator(OprHandle</span> <span class="pre">op)</span></code> API.
Ensure that the operator has finished computing before calling this API.</p>
</div>
<div class="section" id="api-reference">
<span id="api-reference"></span><h3>API Reference<a class="headerlink" href="#api-reference" title="Permalink to this headline"></a></h3>
<blockquote>
<div><dl class="class">
<dt id="_CPPv2N5mxnet6EngineE">
<span id="mxnet::Engine"></span><span class="target" id="classmxnet_1_1Engine"></span><em class="property">class </em><code class="descname">Engine</code><a class="headerlink" href="#_CPPv2N5mxnet6EngineE" title="Permalink to this definition"></a><br/></dt>
<dd><p>Dependency engine that schedules operations. </p>
<div class="breathe-sectiondef docutils container">
<p class="breathe-sectiondef-title rubric">Public Types</p>
<dl class="type">
<dt id="_CPPv2N5mxnet6Engine18CallbackOnCompleteE">
<span id="mxnet::Engine::CallbackOnComplete"></span><span class="target" id="classmxnet_1_1Engine_1a16b757432556f835d27f1b5e1dbe1b06"></span><em class="property">typedef </em>engine::CallbackOnComplete <code class="descname">CallbackOnComplete</code><a class="headerlink" href="#_CPPv2N5mxnet6Engine18CallbackOnCompleteE" title="Permalink to this definition"></a><br/></dt>
<dd><p>callback on complete </p>
</dd></dl>
<dl class="type">
<dt id="_CPPv2N5mxnet6Engine6SyncFnE">
<span id="mxnet::Engine::SyncFn"></span><span class="target" id="classmxnet_1_1Engine_1a07f30ab85fca436e1bbcc72cd4d8bb35"></span><em class="property">typedef </em>std::function<void<span class="sig-paren">(</span>RunContext<span class="sig-paren">)</span>> <code class="descname">SyncFn</code><a class="headerlink" href="#_CPPv2N5mxnet6Engine6SyncFnE" title="Permalink to this definition"></a><br/></dt>
<dd><p>Synchronous operation to pass to engine. </p>
</dd></dl>
<dl class="type">
<dt id="_CPPv2N5mxnet6Engine7AsyncFnE">
<span id="mxnet::Engine::AsyncFn"></span><span class="target" id="classmxnet_1_1Engine_1ad41feff70bba0f29fc24f60b5381984c"></span><em class="property">typedef </em>std::function<void<span class="sig-paren">(</span>RunContext, <a class="reference internal" href="#_CPPv2N5mxnet6Engine18CallbackOnCompleteE" title="mxnet::Engine::CallbackOnComplete">CallbackOnComplete</a><span class="sig-paren">)</span>> <code class="descname">AsyncFn</code><a class="headerlink" href="#_CPPv2N5mxnet6Engine7AsyncFnE" title="Permalink to this definition"></a><br/></dt>
<dd><p>Asynchronous operation to pass to engine. </p>
</dd></dl>
<dl class="type">
<dt id="_CPPv2N5mxnet6Engine9VarHandleE">
<span id="mxnet::Engine::VarHandle"></span><span class="target" id="classmxnet_1_1Engine_1aac31510c793a12944c33f9cac6150491"></span><em class="property">typedef </em>engine::VarHandle <code class="descname">VarHandle</code><a class="headerlink" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="Permalink to this definition"></a><br/></dt>
<dd><p>Variable pointer. </p>
</dd></dl>
<dl class="type">
<dt id="_CPPv2N5mxnet6Engine9OprHandleE">
<span id="mxnet::Engine::OprHandle"></span><span class="target" id="classmxnet_1_1Engine_1a832436e413a075291aa1a631942c3f01"></span><em class="property">typedef </em>engine::OprHandle <code class="descname">OprHandle</code><a class="headerlink" href="#_CPPv2N5mxnet6Engine9OprHandleE" title="Permalink to this definition"></a><br/></dt>
<dd><p>Operator pointer. </p>
</dd></dl>
</div>
<div class="breathe-sectiondef docutils container">
<p class="breathe-sectiondef-title rubric">Public Functions</p>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine14NotifyShutdownEv">
<span id="mxnet::Engine::NotifyShutdown"></span><span class="target" id="classmxnet_1_1Engine_1a3c0e2989538b5369c1592eddbcf0181c"></span><em class="property">virtual</em> void <code class="descname">NotifyShutdown</code><span class="sig-paren">(</span><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine14NotifyShutdownEv" title="Permalink to this definition"></a><br/></dt>
<dd><p>Notify the engine about a shutdown, This can help engine to print less messages into display. </p>
<p>User do not have to call this function. <dl class="docutils">
<dt><strong>Return</strong></dt>
<dd>0 when success, -1 when failure happens. </dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine11NewVariableEv">
<span id="mxnet::Engine::NewVariable"></span><span class="target" id="classmxnet_1_1Engine_1a6e141b188f018d5d933ab99868631d5e"></span><em class="property">virtual</em> <a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a> <code class="descname">NewVariable</code><span class="sig-paren">(</span><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine11NewVariableEv" title="Permalink to this definition"></a><br/></dt>
<dd><p>Allocate a new variable, the variable can then be used to schedule the operation concurrently via dependency patterns. </p>
<p><dl class="docutils">
<dt><strong>Return</strong></dt>
<dd>The new variable allocated. </dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine11NewOperatorE7AsyncFnRNSt6vectorI9VarHandleEERNSt6vectorI9VarHandleEE10FnPropertyPKc">
<span id="mxnet::Engine::NewOperator__AsyncFn.std::vector:VarHandle:CR.std::vector:VarHandle:CR.FnProperty.cCP"></span><span class="target" id="classmxnet_1_1Engine_1aff3332258a158ef33a9a4b7bcdc2fe6f"></span><em class="property">virtual</em> <a class="reference internal" href="#_CPPv2N5mxnet6Engine9OprHandleE" title="mxnet::Engine::OprHandle">OprHandle</a> <code class="descname">NewOperator</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine7AsyncFnE" title="mxnet::Engine::AsyncFn">AsyncFn</a> <em>fn</em>, std::vector<<a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a>> <em class="property">const</em> &amp;<em>const_vars</em>, std::vector<<a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a>> <em class="property">const</em> &amp;<em>mutable_vars</em>, FnProperty <em>prop</em> = FnProperty::kNormal, <em class="property">const</em> char *<em>opr_name</em> = nullptr<span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine11NewOperatorE7AsyncFnRNSt6vectorI9VarHandleEERNSt6vectorI9VarHandleEE10FnPropertyPKc" title="Permalink to this definition"></a><br/></dt>
<dd><p>Create a new operator. The returned operator could be saved externally so that it could be resued for scheduling. </p>
<p><dl class="docutils">
<dt><strong>Return</strong></dt>
<dd>The new operator allocated. </dd>
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">fn</span></code>: The execution function. </li>
<li><code class="docutils literal"><span class="pre">const_vars</span></code>: The variables that current operation will use but not mutate. </li>
<li><code class="docutils literal"><span class="pre">mutable_vars</span></code>: The variables that current operation will mutate. </li>
<li><code class="docutils literal"><span class="pre">prop</span></code>: Property of the function. </li>
<li><code class="docutils literal"><span class="pre">opr_name</span></code>: The operator name. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine14DeleteOperatorE9OprHandle">
<span id="mxnet::Engine::DeleteOperator__OprHandle"></span><span class="target" id="classmxnet_1_1Engine_1a3fd7c8b35a2f52805506cc242bf82ca7"></span><em class="property">virtual</em> void <code class="descname">DeleteOperator</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine9OprHandleE" title="mxnet::Engine::OprHandle">OprHandle</a> <em>op</em><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine14DeleteOperatorE9OprHandle" title="Permalink to this definition"></a><br/></dt>
<dd><p>Delete the given operator. </p>
<p><p>The delete will not happen immediately, but will wait until all the operations using this operator are completed. </p>
<dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">op</span></code>: The operator to delete.</li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine4PushE9OprHandle7Contextib">
<span id="mxnet::Engine::Push__OprHandle.Context.i.b"></span><span class="target" id="classmxnet_1_1Engine_1ad213d6b1a7c1e0d4d41275b9efe5f097"></span><em class="property">virtual</em> void <code class="descname">Push</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine9OprHandleE" title="mxnet::Engine::OprHandle">OprHandle</a> <em>op</em>, Context <em>exec_ctx</em>, int <em>priority</em> = 0, bool <em>profiling</em> = false<span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine4PushE9OprHandle7Contextib" title="Permalink to this definition"></a><br/></dt>
<dd><p>Push an operator to the engine. </p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">op</span></code>: The operator to push. </li>
<li><code class="docutils literal"><span class="pre">exec_ctx</span></code>: Execution context. </li>
<li><code class="docutils literal"><span class="pre">priority</span></code>: Priority of the action, as hint to the engine. </li>
<li><code class="docutils literal"><span class="pre">profiling</span></code>: The variable indicate whether to profile this operator. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine9PushAsyncE7AsyncFn7ContextRNSt6vectorI9VarHandleEERNSt6vectorI9VarHandleEE10FnPropertyiPKc">
<span id="mxnet::Engine::PushAsync__AsyncFn.Context.std::vector:VarHandle:CR.std::vector:VarHandle:CR.FnProperty.i.cCP"></span><span class="target" id="classmxnet_1_1Engine_1ac71feb4f966cd4573452bd148b850c82"></span><em class="property">virtual</em> void <code class="descname">PushAsync</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine7AsyncFnE" title="mxnet::Engine::AsyncFn">AsyncFn</a> <em>exec_fun</em>, Context <em>exec_ctx</em>, std::vector<<a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a>> <em class="property">const</em> &amp;<em>const_vars</em>, std::vector<<a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a>> <em class="property">const</em> &amp;<em>mutable_vars</em>, FnProperty <em>prop</em> = FnProperty::kNormal, int <em>priority</em> = 0, <em class="property">const</em> char *<em>opr_name</em> = nullptr<span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine9PushAsyncE7AsyncFn7ContextRNSt6vectorI9VarHandleEERNSt6vectorI9VarHandleEE10FnPropertyiPKc" title="Permalink to this definition"></a><br/></dt>
<dd><p>Push an asynchronous operation to the engine. </p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">exec_fun</span></code>: Execution function, this function takes a parameter on_complete that must be called when the execution completes. </li>
<li><code class="docutils literal"><span class="pre">exec_ctx</span></code>: Execution context. </li>
<li><code class="docutils literal"><span class="pre">const_vars</span></code>: The variables that current operation will use but not mutate. </li>
<li><code class="docutils literal"><span class="pre">mutable_vars</span></code>: The variables that current operation will mutate. </li>
<li><code class="docutils literal"><span class="pre">prop</span></code>: Property of the function. </li>
<li><code class="docutils literal"><span class="pre">priority</span></code>: Priority of the action, as hint to the engine. </li>
<li><code class="docutils literal"><span class="pre">opr_name</span></code>: The operator name. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine14DeleteVariableE6SyncFn7Context9VarHandle">
<span id="mxnet::Engine::DeleteVariable__SyncFn.Context.VarHandle"></span><span class="target" id="classmxnet_1_1Engine_1a738e5192dab345ab0ec9888b095903cf"></span><em class="property">virtual</em> void <code class="descname">DeleteVariable</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine6SyncFnE" title="mxnet::Engine::SyncFn">SyncFn</a> <em>delete_fn</em>, Context <em>exec_ctx</em>, <a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a> <em>var</em><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine14DeleteVariableE6SyncFn7Context9VarHandle" title="Permalink to this definition"></a><br/></dt>
<dd><p>Schedule the deletion of a variable. </p>
<p>The delete will not happen immediately, but will wait until all the operations depending on var are completed.</p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">delete_fn</span></code>: A function that will be called after the variable is deleted. </li>
<li><code class="docutils literal"><span class="pre">exec_ctx</span></code>: Execution context. </li>
<li><code class="docutils literal"><span class="pre">var</span></code>: The variable to be deleted. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine10WaitForVarE9VarHandle">
<span id="mxnet::Engine::WaitForVar__VarHandle"></span><span class="target" id="classmxnet_1_1Engine_1aed51bd7f294d9f2b569764a0c151d883"></span><em class="property">virtual</em> void <code class="descname">WaitForVar</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a> <em>var</em><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine10WaitForVarE9VarHandle" title="Permalink to this definition"></a><br/></dt>
<dd><p>Wait for a variable. </p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">var</span></code>: The variable we should wait for. This function returns when the variable is ready. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine10WaitForAllEv">
<span id="mxnet::Engine::WaitForAll"></span><span class="target" id="classmxnet_1_1Engine_1a64483aecce780e96056be89d6289e782"></span><em class="property">virtual</em> void <code class="descname">WaitForAll</code><span class="sig-paren">(</span><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine10WaitForAllEv" title="Permalink to this definition"></a><br/></dt>
<dd><p>Wait until all the activity of engine finishes. </p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6EngineD0Ev">
<span id="mxnet::Engine::~Engine"></span><span class="target" id="classmxnet_1_1Engine_1aff025321827e15096c02342225f2395b"></span><em class="property">virtual</em> <code class="descname">~Engine</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#_CPPv2N5mxnet6EngineD0Ev" title="Permalink to this definition"></a><br/></dt>
<dd><p>virtual destructor </p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine8PushSyncE6SyncFn7ContextRNSt6vectorI9VarHandleEERNSt6vectorI9VarHandleEE10FnPropertyiPKc">
<span id="mxnet::Engine::PushSync__SyncFn.Context.std::vector:VarHandle:CR.std::vector:VarHandle:CR.FnProperty.i.cCP"></span><span class="target" id="classmxnet_1_1Engine_1a1c2f38927e4bf7a62e23353b0bd3d619"></span>void <code class="descname">PushSync</code><span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6Engine6SyncFnE" title="mxnet::Engine::SyncFn">SyncFn</a> <em>exec_fn</em>, Context <em>exec_ctx</em>, std::vector<<a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a>> <em class="property">const</em> &amp;<em>const_vars</em>, std::vector<<a class="reference internal" href="#_CPPv2N5mxnet6Engine9VarHandleE" title="mxnet::Engine::VarHandle">VarHandle</a>> <em class="property">const</em> &amp;<em>mutable_vars</em>, FnProperty <em>prop</em> = FnProperty::kNormal, int <em>priority</em> = 0, <em class="property">const</em> char *<em>opr_name</em> = nullptr<span class="sig-paren">)</span><a class="headerlink" href="#_CPPv2N5mxnet6Engine8PushSyncE6SyncFn7ContextRNSt6vectorI9VarHandleEERNSt6vectorI9VarHandleEE10FnPropertyiPKc" title="Permalink to this definition"></a><br/></dt>
<dd><p>Push an synchronous operation to the engine. </p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">exec_fn</span></code>: Execution function that executes the operation. </li>
<li><code class="docutils literal"><span class="pre">exec_ctx</span></code>: Execution context. </li>
<li><code class="docutils literal"><span class="pre">const_vars</span></code>: The variables that current operation will use but not mutate. </li>
<li><code class="docutils literal"><span class="pre">mutable_vars</span></code>: The variables that current operation will mutate. </li>
<li><code class="docutils literal"><span class="pre">prop</span></code>: Property of the function. </li>
<li><code class="docutils literal"><span class="pre">priority</span></code>: Priority of the action, as hint to the engine. </li>
<li><code class="docutils literal"><span class="pre">opr_name</span></code>: The operator name. </li>
</ul>
</dd>
<dt><strong>Template Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">SyncFn</span></code>: the synchronous function to be pushed. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine14CreateCallbackEPFvP6EnginePvEPv">
<span class="target" id="classmxnet_1_1Engine_1acf25be815b4200b48ee6e4e169bf95b8"></span><a class="reference internal" href="#_CPPv2N5mxnet6Engine18CallbackOnCompleteE" title="mxnet::Engine::CallbackOnComplete">CallbackOnComplete</a> <code class="descname">CreateCallback</code><span class="sig-paren">(</span>void (*<em>callback</em>)<span class="sig-paren">(</span><a class="reference internal" href="#_CPPv2N5mxnet6EngineE" title="mxnet::Engine">Engine</a> *, void *<span class="sig-paren">)</span>, void *<em>param</em>, <span class="sig-paren">)</span><a class="headerlink" href="#_CPPv2N5mxnet6Engine14CreateCallbackEPFvP6EnginePvEPv" title="Permalink to this definition"></a><br/></dt>
<dd><p>factory function to create OnComplete callback. </p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">callback</span></code>: th static callback function. </li>
<li><code class="docutils literal"><span class="pre">param</span></code>: the paramter passed to callback. </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2NK5mxnet6Engine26num_omp_threads_per_workerEv">
<span id="mxnet::Engine::num_omp_threads_per_workerC"></span><span class="target" id="classmxnet_1_1Engine_1ae33a279d7bf72f7aff61b315fe8793a4"></span><em class="property">virtual</em> int <code class="descname">num_omp_threads_per_worker</code><span class="sig-paren">(</span><span class="sig-paren">)</span> <em class="property">const</em> = 0<a class="headerlink" href="#_CPPv2NK5mxnet6Engine26num_omp_threads_per_workerEv" title="Permalink to this definition"></a><br/></dt>
<dd><p>Return the number of OMP threads that should be used per worker. </p>
<p><dl class="docutils">
<dt><strong>Return</strong></dt>
<dd>Number of OMP threads that should be used per worker </dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine30set_num_omp_threads_per_workerEi">
<span id="mxnet::Engine::set_num_omp_threads_per_worker__i"></span><span class="target" id="classmxnet_1_1Engine_1aa49c861a1e2eb60b1ec52e45c1c35098"></span><em class="property">virtual</em> void <code class="descname">set_num_omp_threads_per_worker</code><span class="sig-paren">(</span>int <em>num_omp_threads_per_worker</em><span class="sig-paren">)</span> = 0<a class="headerlink" href="#_CPPv2N5mxnet6Engine30set_num_omp_threads_per_workerEi" title="Permalink to this definition"></a><br/></dt>
<dd><p>Set the number of OMP threads that should be used per worker. </p>
<p><dl class="docutils">
<dt><strong>Parameters</strong></dt>
<dd><ul class="breatheparameterlist first last simple">
<li><code class="docutils literal"><span class="pre">num_threads_per_worker</span></code>: Number of OMP threads to be used per worker </li>
</ul>
</dd>
</dl>
</p>
</dd></dl>
</div>
<div class="breathe-sectiondef docutils container">
<p class="breathe-sectiondef-title rubric">Public Static Functions</p>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine3GetEv">
<span id="mxnet::Engine::Get"></span><span class="target" id="classmxnet_1_1Engine_1ae0a23da15ef63d9479c7468e1f2f825f"></span><em class="property">static</em> <a class="reference internal" href="#_CPPv2N5mxnet6EngineE" title="mxnet::Engine">Engine</a> *<code class="descname">Get</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#_CPPv2N5mxnet6Engine3GetEv" title="Permalink to this definition"></a><br/></dt>
<dd><p><dl class="docutils">
<dt><strong>Return</strong></dt>
<dd><a class="reference internal" href="#classmxnet_1_1Engine"><span class="std std-ref">Engine</span></a> singleton. </dd>
</dl>
</p>
</dd></dl>
<dl class="function">
<dt id="_CPPv2N5mxnet6Engine13_GetSharedRefEv">
<span id="mxnet::Engine::_GetSharedRef"></span><span class="target" id="classmxnet_1_1Engine_1ab6417f2ae519b946c104f975d84d55d5"></span><em class="property">static</em> std::shared_ptr<<a class="reference internal" href="#_CPPv2N5mxnet6EngineE" title="mxnet::Engine">Engine</a>> <code class="descname">_GetSharedRef</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#_CPPv2N5mxnet6Engine13_GetSharedRefEv" title="Permalink to this definition"></a><br/></dt>
<dd><p>Get shared pointer reference to engine singleton. Most user should not call this function. This function is called by another singleton X who requires engine to be destructed after X. </p>
<p><dl class="docutils">
<dt><strong>Return</strong></dt>
<dd>A shared pointer to <a class="reference internal" href="#classmxnet_1_1Engine"><span class="std std-ref">Engine</span></a> singleton. </dd>
</dl>
</p>
</dd></dl>
</div>
</dd></dl>
</div></blockquote>
</div>
</div>
<div class="section" id="operators-in-mxnet">
<span id="operators-in-mxnet"></span><h2>Operators in MXNet<a class="headerlink" href="#operators-in-mxnet" title="Permalink to this headline"></a></h2>
<p>In MXNet, an operator is a class that contains both actual computation logic
and auxiliary information that can aid the system in performing optimizations,
like in-place updates and auto-derivatives.
To understand the remainder of the document,
we recommend that you familiarize yourself with the <code class="docutils literal"><span class="pre">mshadow</span></code> library,
because all operators compute on the tensor-like structure <code class="docutils literal"><span class="pre">mshadow::TBlob</span></code>
provided by the system during runtime.</p>
<p>MXNet’s operator interface allows you to:</p>
<ul class="simple">
<li>Reduce memory allocation cost by specifying in-place updates.</li>
<li>Hide some internal arguments from Python to make it cleaner.</li>
<li>Define the relationships among input tensors and output tensors,
which allows the system to perform shape checking for you.</li>
<li>Acquire additional temporary spaces from the system
to perform computation (e.g., calling <code class="docutils literal"><span class="pre">cudnn</span></code> routines).</li>
</ul>
<div class="section" id="operator-interface">
<span id="operator-interface"></span><h3>Operator Interface<a class="headerlink" href="#operator-interface" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">Forward</span></code> is the core operator interface:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="kt">void</span> <span class="nf">Forward</span><span class="p">(</span><span class="k">const</span> <span class="n">OpContext</span> <span class="o">&amp;</span><span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">></span> <span class="o">&amp;</span><span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">aux_states</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">OpContext</span></code> structure is:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">OpContext</span> <span class="p">{</span>
<span class="kt">int</span> <span class="n">is_train</span><span class="p">;</span>
<span class="n">RunContext</span> <span class="n">run_ctx</span><span class="p">;</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">Resource</span><span class="o">></span> <span class="n">requested</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>It describes whether the operator is in the train or test phase,
which device the operator should be run on (in <code class="docutils literal"><span class="pre">run_ctx</span></code>),
and requested resources (covered in the following sections).</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">in_data</span></code> and <code class="docutils literal"><span class="pre">out_data</span></code> represent the input and output tensors, respectively.
All of the tensor spaces have been allocated by the system.</li>
<li><code class="docutils literal"><span class="pre">req</span></code> denotes how the computation results are written into the <code class="docutils literal"><span class="pre">out_data</span></code>.
In other words, <code class="docutils literal"><span class="pre">req.size()</span> <span class="pre">==</span> <span class="pre">out_data.size()</span></code> and <code class="docutils literal"><span class="pre">req[i]</span></code>
correspond to the write type of <code class="docutils literal"><span class="pre">out_data[i]</span></code>.</li>
<li>The <code class="docutils literal"><span class="pre">OpReqType</span></code> is defined as:</li>
</ul>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span>
<span class="n">kNullOp</span><span class="p">,</span>
<span class="n">kWriteTo</span><span class="p">,</span>
<span class="n">kWriteInplace</span><span class="p">,</span>
<span class="n">kAddTo</span>
<span class="p">};</span>
</pre></div>
</div>
<p>Normally, the types of all <code class="docutils literal"><span class="pre">out_data</span></code> should be <code class="docutils literal"><span class="pre">kWriteTo</span></code>,
meaning that the provided <code class="docutils literal"><span class="pre">out_data</span></code> tensor is a <em>raw</em> memory block,
so the operator should write results directly into it.
In some cases, for example when calculating the <code class="docutils literal"><span class="pre">gradient</span></code> tensor,
it would be great if we could accumulate the result,
rather than directly overwrite the tensor contents
so that no extra space needs to be created each time.
In such a case, the corresponding <code class="docutils literal"><span class="pre">req</span></code> type is set as <code class="docutils literal"><span class="pre">kAddTo</span></code>,
indicating that a <code class="docutils literal"><span class="pre">+=</span></code> should be called.</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">aux_states</span></code> is intentionally designed for auxiliary tensors used to help computation. Currently, it is useless.</li>
</ul>
<p>Aside from the <code class="docutils literal"><span class="pre">Forward</span></code> operator, you could optionally implement the <code class="docutils literal"><span class="pre">Backward</span></code> interface:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="kt">void</span> <span class="nf">Backward</span><span class="p">(</span><span class="k">const</span> <span class="n">OpContext</span> <span class="o">&amp;</span><span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">></span> <span class="o">&amp;</span><span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">></span> <span class="o">&amp;</span><span class="n">aux_states</span><span class="p">);</span>
</pre></div>
</div>
<p>This interface follows the same design principle as the <code class="docutils literal"><span class="pre">Forward</span></code> interface,
except that <code class="docutils literal"><span class="pre">out_grad</span></code>, <code class="docutils literal"><span class="pre">in_data</span></code>, and <code class="docutils literal"><span class="pre">out_data</span></code> are given,
and the operator computes <code class="docutils literal"><span class="pre">in_grad</span></code> as the results.
The naming strategy is similar to Torch’s convention,
and can be summarized in following figure:</p>
<p>[input/output semantics figure]</p>
<p>Some operators might not require all of the following:
<code class="docutils literal"><span class="pre">out_grad</span></code>, <code class="docutils literal"><span class="pre">in_data</span></code> and <code class="docutils literal"><span class="pre">out_data</span></code>.
You can specify these dependencies with the <code class="docutils literal"><span class="pre">DeclareBackwardDependency</span></code> interface in <code class="docutils literal"><span class="pre">OperatorProperty</span></code>.</p>
</div>
<div class="section" id="operator-property">
<span id="operator-property"></span><h3>Operator Property<a class="headerlink" href="#operator-property" title="Permalink to this headline"></a></h3>
<p>One convolution might have several implementations,
and you might want to switch among them to achieve the best performance.
Therefore, we separate the operator <em>semantic</em> interfaces
from the implementation interface (<code class="docutils literal"><span class="pre">Operator</span></code> class)
into the <code class="docutils literal"><span class="pre">OperatorProperty</span></code> class.
The <code class="docutils literal"><span class="pre">OperatorProperty</span></code> interface consists of:</p>
<ul class="simple">
<li><strong>InferShape:</strong></li>
</ul>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="kt">bool</span> <span class="nf">InferShape</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">*</span><span class="n">in_shape</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">*</span><span class="n">out_shape</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">*</span><span class="n">aux_shape</span><span class="p">)</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</pre></div>
</div>
<p>This interface has two purposes:</p>
<ul class="simple">
<li>Tell the system the size of each input and output tensor,
so it can allocate space for them before the <code class="docutils literal"><span class="pre">Forward</span></code> and <code class="docutils literal"><span class="pre">Backward</span></code> call.</li>
<li>Perform a size check to make sure that there isn’t an obvious error before running.
The shape in <code class="docutils literal"><span class="pre">in_shape</span></code> is set by the system
(from the <code class="docutils literal"><span class="pre">out_shape</span></code> of the previous operators).
It returns <code class="docutils literal"><span class="pre">false</span></code> when there is not enough information
to infer shapes or throws an error when the shape is inconsistent.</li>
<li><strong>Request Resources:</strong> Operations like <code class="docutils literal"><span class="pre">cudnnConvolutionForward</span></code> need a work space for computation.
If the system can manage that, it could then perform optimizations,
like reuse the space, and so on.
MXNet defines two interfaces to achieve this:</li>
</ul>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">ResourceRequest</span><span class="o">></span> <span class="n">ForwardResource</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_shape</span><span class="p">)</span> <span class="k">const</span><span class="p">;</span>
<span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">ResourceRequest</span><span class="o">></span> <span class="n">BackwardResource</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TShape</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_shape</span><span class="p">)</span> <span class="k">const</span><span class="p">;</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">ResourceRequest</span></code> structure (in <code class="docutils literal"><span class="pre">resource.h</span></code>) currently contains only a type flag:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">ResourceRequest</span> <span class="p">{</span>
<span class="k">enum</span> <span class="n">Type</span> <span class="p">{</span>
<span class="n">kRandom</span><span class="p">,</span> <span class="c1">// get a mshadow::Random<xpu> object</span>
<span class="n">kTempSpace</span><span class="p">,</span> <span class="c1">// request temporary space</span>
<span class="p">};</span>
<span class="n">Type</span> <span class="n">type</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
<p>If <code class="docutils literal"><span class="pre">ForwardResource</span></code> and <code class="docutils literal"><span class="pre">BackwardResource</span></code> return non-empty arrays,
the system offers the corresponding resources through the <code class="docutils literal"><span class="pre">ctx</span></code> parameter
in the <code class="docutils literal"><span class="pre">Forward</span></code> and <code class="docutils literal"><span class="pre">Backward</span></code> interface of <code class="docutils literal"><span class="pre">Operator</span></code>.
Basically, to access those resources, simply write:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">auto</span> <span class="n">tmp_space_res</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">requested</span><span class="p">[</span><span class="n">kTempSpace</span><span class="p">].</span><span class="n">get_space</span><span class="p">(</span><span class="n">some_shape</span><span class="p">,</span> <span class="n">some_stream</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">rand_res</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">requested</span><span class="p">[</span><span class="n">kRandom</span><span class="p">].</span><span class="n">get_random</span><span class="p">(</span><span class="n">some_stream</span><span class="p">);</span>
</pre></div>
</div>
<p>For an example, see <code class="docutils literal"><span class="pre">src/operator/cudnn_convolution-inl.h</span></code>.</p>
<ul class="simple">
<li><strong>Backward dependency:</strong> Let’s look at two different operator signatures
(we name all of the arguments for demonstration purposes):</li>
</ul>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="kt">void</span> <span class="nf">FullyConnectedForward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">weight</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_data</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">FullyConnectedBackward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">weight</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_grad</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">PoolingForward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_data</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">PoolingBackward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_grad</span><span class="p">);</span>
</pre></div>
</div>
<p>Note that <code class="docutils literal"><span class="pre">out_data</span></code> in <code class="docutils literal"><span class="pre">FullyConnectedForward</span></code>
is not used by <code class="docutils literal"><span class="pre">FullyConnectedBackward</span></code>,
while <code class="docutils literal"><span class="pre">PoolingBackward</span></code> requires all of the arguments of <code class="docutils literal"><span class="pre">PoolingForward</span></code>.
Therefore, for <code class="docutils literal"><span class="pre">FullyConnectedForward</span></code>,
the <code class="docutils literal"><span class="pre">out_data</span></code> tensor once consumed could be safely freed
because the backward function will not need it.
This provides a chance for the system to collect some tensors
as garbage as soon as possible.
To specify this situation, we provide an interface:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="n">DeclareBackwardDependency</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span><span class="p">;</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">int</span></code> element of the argument vector is an ID
to distinguish different arrays.
Let’s see how this interface specifies different dependencies
for <code class="docutils literal"><span class="pre">FullyConnected</span></code> and <code class="docutils literal"><span class="pre">Pooling</span></code>:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="n">FullyConnectedProperty</span><span class="o">::</span><span class="n">DeclareBackwardDependency</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span><span class="n">out_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]};</span> <span class="c1">// NOTE: out_data[0] is NOT included</span>
<span class="p">}</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="n">PoolingProperty</span><span class="o">::</span><span class="n">DeclareBackwardDependency</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span><span class="n">out_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]};</span>
<span class="p">}</span>
</pre></div>
</div>
<ul class="simple">
<li><strong>In place Option:</strong> To further save the cost of memory allocation,
you can use in-place updates.
They are appropriate for element-wise operations
when the input tensor and output tensor have the same shape.
You specify and in-place update with the following interface:</li>
</ul>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">void</span><span class="o">*>></span> <span class="n">ElewiseOpProperty</span><span class="o">::</span><span class="n">ForwardInplaceOption</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">void</span><span class="o">*></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span> <span class="p">{</span><span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]}</span> <span class="p">};</span>
<span class="p">}</span>
<span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">void</span><span class="o">*>></span> <span class="n">ElewiseOpProperty</span><span class="o">::</span><span class="n">BackwardInplaceOption</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">></span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">void</span><span class="o">*></span> <span class="o">&amp;</span><span class="n">in_grad</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span> <span class="p">{</span><span class="n">out_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">]}</span> <span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>This tells the system that the <code class="docutils literal"><span class="pre">in_data[0]</span></code> and <code class="docutils literal"><span class="pre">out_data[0]</span></code> tensors could share the same memory spaces during <code class="docutils literal"><span class="pre">Forward</span></code>, and so do <code class="docutils literal"><span class="pre">out_grad[0]</span></code> and <code class="docutils literal"><span class="pre">in_grad[0]</span></code> during <code class="docutils literal"><span class="pre">Backward</span></code>.</p>
<blockquote>
<div><strong>Important:</strong> Even if you use the preceding specification, it’s <em>not</em> guaranteed that the input and output tensors will share the same space. In fact, this is only a suggestion for the system, which makes the final decision. However, in either case, the decision is completely transparent to you, so the actual <code class="docutils literal"><span class="pre">Forward</span></code> and <code class="docutils literal"><span class="pre">Backward</span></code> implementation does not need to consider that.</div></blockquote>
<ul class="simple">
<li><strong>Expose Operator to Python:</strong> Because of the restrictions of C++, you need user to implement following interfaces:</li>
</ul>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="c1">// initial the property class from a list of key-value string pairs</span>
<span class="k">virtual</span> <span class="kt">void</span> <span class="nf">Init</span><span class="p">(</span><span class="k">const</span> <span class="n">vector</span><span class="o"><</span><span class="n">pair</span><span class="o"><</span><span class="n">string</span><span class="p">,</span> <span class="n">string</span><span class="o">>></span> <span class="o">&amp;</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="c1">// return the parameters in a key-value string map</span>
<span class="k">virtual</span> <span class="n">map</span><span class="o"><</span><span class="n">string</span><span class="p">,</span> <span class="n">string</span><span class="o">></span> <span class="n">GetParams</span><span class="p">()</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="c1">// return the name of arguments (for generating signature in python)</span>
<span class="k">virtual</span> <span class="n">vector</span><span class="o"><</span><span class="n">string</span><span class="o">></span> <span class="n">ListArguments</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the name of output values</span>
<span class="k">virtual</span> <span class="n">vector</span><span class="o"><</span><span class="n">string</span><span class="o">></span> <span class="n">ListOutputs</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the name of auxiliary states</span>
<span class="k">virtual</span> <span class="n">vector</span><span class="o"><</span><span class="n">string</span><span class="o">></span> <span class="n">ListAuxiliaryStates</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the number of output values</span>
<span class="k">virtual</span> <span class="kt">int</span> <span class="nf">NumOutputs</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the number of visible outputs</span>
<span class="k">virtual</span> <span class="kt">int</span> <span class="nf">NumVisibleOutputs</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</pre></div>
</div>
</div>
<div class="section" id="create-an-operator-from-the-operator-property">
<span id="create-an-operator-from-the-operator-property"></span><h3>Create an Operator from the Operator Property<a class="headerlink" href="#create-an-operator-from-the-operator-property" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">OperatorProperty</span></code> includes all <em>semantic</em> attributes of an operation. It’s also responsible for creating the <code class="docutils literal"><span class="pre">Operator</span></code> pointer for actual computation.</p>
<div class="section" id="create-operator">
<span id="create-operator"></span><h4>Create Operator<a class="headerlink" href="#create-operator" title="Permalink to this headline"></a></h4>
<p>Implement the following interface in <code class="docutils literal"><span class="pre">OperatorProperty</span></code>:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">virtual</span> <span class="n">Operator</span><span class="o">*</span> <span class="nf">CreateOperator</span><span class="p">(</span><span class="n">Context</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</pre></div>
</div>
<p>For example:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">class</span> <span class="nc">ConvolutionOp</span> <span class="p">{</span>
<span class="k">public</span><span class="o">:</span>
<span class="kt">void</span> <span class="n">Forward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="kt">void</span> <span class="n">Backward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="p">};</span>
<span class="k">class</span> <span class="nc">ConvolutionOpProperty</span> <span class="o">:</span> <span class="k">public</span> <span class="n">OperatorProperty</span> <span class="p">{</span>
<span class="k">public</span><span class="o">:</span>
<span class="n">Operator</span><span class="o">*</span> <span class="n">CreateOperator</span><span class="p">(</span><span class="n">Context</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="k">new</span> <span class="n">ConvolutionOp</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">};</span>
</pre></div>
</div>
</div>
<div class="section" id="parametrize-operator">
<span id="parametrize-operator"></span><h4>Parametrize Operator<a class="headerlink" href="#parametrize-operator" title="Permalink to this headline"></a></h4>
<p>When implementing a convolution operator, you need to know the kernel size,
the stride size, padding size, and so on.
These parameters should be passed to the operator
before any <code class="docutils literal"><span class="pre">Forward</span></code> or <code class="docutils literal"><span class="pre">Backward</span></code> interface is called.
To do so, you could define a <code class="docutils literal"><span class="pre">ConvolutionParam</span></code> structure, as follows:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="cp">#include</span> <span class="cpf"><dmlc/parameter.h></span><span class="cp"></span>
<span class="k">struct</span> <span class="nl">ConvolutionParam</span> <span class="p">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o"><</span><span class="n">ConvolutionParam</span><span class="o">></span> <span class="p">{</span>
<span class="n">TShape</span> <span class="n">kernel</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">pad</span><span class="p">;</span>
<span class="kt">uint32_t</span> <span class="n">num_filter</span><span class="p">,</span> <span class="n">num_group</span><span class="p">,</span> <span class="n">workspace</span><span class="p">;</span>
<span class="kt">bool</span> <span class="n">no_bias</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
<p>Put it in <code class="docutils literal"><span class="pre">ConvolutionOpProperty</span></code>, and pass it to the operator class during construction:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">class</span> <span class="nc">ConvolutionOp</span> <span class="p">{</span>
<span class="k">public</span><span class="o">:</span>
<span class="n">ConvolutionOp</span><span class="p">(</span><span class="n">ConvolutionParam</span> <span class="n">p</span><span class="p">)</span><span class="o">:</span> <span class="n">param_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="p">{}</span>
<span class="kt">void</span> <span class="n">Forward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="kt">void</span> <span class="n">Backward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="k">private</span><span class="o">:</span>
<span class="n">ConvolutionParam</span> <span class="n">param_</span><span class="p">;</span>
<span class="p">};</span>
<span class="k">class</span> <span class="nc">ConvolutionOpProperty</span> <span class="o">:</span> <span class="k">public</span> <span class="n">OperatorProperty</span> <span class="p">{</span>
<span class="k">public</span><span class="o">:</span>
<span class="kt">void</span> <span class="n">Init</span><span class="p">(</span><span class="k">const</span> <span class="n">vector</span><span class="o"><</span><span class="n">pair</span><span class="o"><</span><span class="n">string</span><span class="p">,</span> <span class="n">string</span><span class="o">>&amp;</span> <span class="n">kwargs</span><span class="p">)</span> <span class="p">{</span>
<span class="c1">// initialize param_ using kwargs</span>
<span class="p">}</span>
<span class="n">Operator</span><span class="o">*</span> <span class="n">CreateOperator</span><span class="p">(</span><span class="n">Context</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="k">new</span> <span class="n">ConvolutionOp</span><span class="p">(</span><span class="n">param_</span><span class="p">);</span>
<span class="p">}</span>
<span class="k">private</span><span class="o">:</span>
<span class="n">ConvolutionParam</span> <span class="n">param_</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
</div>
<div class="section" id="register-the-operator-property-class-and-the-parameter-class-to-mxnet">
<span id="register-the-operator-property-class-and-the-parameter-class-to-mxnet"></span><h4>Register the Operator Property Class and the Parameter Class to MXNet<a class="headerlink" href="#register-the-operator-property-class-and-the-parameter-class-to-mxnet" title="Permalink to this headline"></a></h4>
<p>Use the following macros to register the parameter structure and the operator property class to MXNet:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="n">DMLC_REGISTER_PARAMETER</span><span class="p">(</span><span class="n">ConvolutionParam</span><span class="p">);</span>
<span class="n">MXNET_REGISTER_OP_PROPERTY</span><span class="p">(</span><span class="n">Convolution</span><span class="p">,</span> <span class="n">ConvolutionOpProperty</span><span class="p">);</span>
</pre></div>
</div>
<p>The first argument is the name string, the second is the property class name.</p>
</div>
</div>
<div class="section" id="interface-summary">
<span id="interface-summary"></span><h3>Interface Summary<a class="headerlink" href="#interface-summary" title="Permalink to this headline"></a></h3>
<p>We’ve almost covered the entire interface required to define a new operator. Let’s do a recap:</p>
<ul class="simple">
<li>Use the <code class="docutils literal"><span class="pre">Operator</span></code> interface to write your computation logic (<code class="docutils literal"><span class="pre">Forward</span></code> and <code class="docutils literal"><span class="pre">Backward</span></code>).</li>
<li>Use the <code class="docutils literal"><span class="pre">OperatorProperty</span></code> interface to:<ul>
<li>Pass the parameter to the operator class (you can use the <code class="docutils literal"><span class="pre">Init</span></code> interface).</li>
<li>Create an operator using the <code class="docutils literal"><span class="pre">CreateOperator</span></code> interface.</li>
<li>Correctly implement the operator description interface, such as the names of arguments, etc.</li>
<li>Correctly implement the <code class="docutils literal"><span class="pre">InferShape</span></code> interface to set the output tensor shape.</li>
<li>[Optional] If additional resources are needed, check <code class="docutils literal"><span class="pre">ForwardResource</span></code> and <code class="docutils literal"><span class="pre">BackwardResource</span></code>.</li>
<li>[Optional] If <code class="docutils literal"><span class="pre">Backward</span></code> doesn’t need all of the input and output of <code class="docutils literal"><span class="pre">Forward</span></code>, check <code class="docutils literal"><span class="pre">DeclareBackwardDependency</span></code>.</li>
<li>[Optional] If in-place update is supported, check <code class="docutils literal"><span class="pre">ForwardInplaceOption</span></code> and <code class="docutils literal"><span class="pre">BackwardInplaceOption</span></code>.</li>
</ul>
</li>
<li>Register the <code class="docutils literal"><span class="pre">OperatorProperty</span></code> class and the parameter class.</li>
</ul>
</div>
</div>
<div class="section" id="unifying-the-ndarray-operator-and-symbolic-operator">
<span id="unifying-the-ndarray-operator-and-symbolic-operator"></span><h2>Unifying the NDArray Operator and Symbolic Operator<a class="headerlink" href="#unifying-the-ndarray-operator-and-symbolic-operator" title="Permalink to this headline"></a></h2>
<p>NDArray operations are similar to symbolic operations,
except that sometimes you can’t write in place to the operands
without a complete dependency graph.
However, the logic underlying NDArray and symbolic operations are almost identical.
<em>SimpleOp</em>, a new unified operator API,
unifies different invoking processes
and returns to the fundamental elements of operators.
Because most mathematical operators attend to one or two operands,
and more operands make dependency-related optimization useful,
the unified operator is specifically designed for unary and binary operations.</p>
<p>Consider the elements of an operation.
Ideally, you need only functions and derivatives
to describe an operation.
Let’s restrict that to the space of unary and binary operations.
How do we classify all operations to maximize the possibility
of in-place write optimization?
Note that you can separate functions by the number of operands.
Derivatives are a bit more complex.
To construct a dependency graph, you need to know whether output value,
input data, or neither are needed alongside head gradient.
Gradient functions in the unified API are differentiated
by the types of operands it takes for calculation.</p>
<p>Before you learn more about the SimpleOp interface,
we recommend that you review the
<a class="reference external" href="https://github.com/dmlc/mshadow/tree/master/guide">mshadow library guide</a>
because calculations will be done in the <code class="docutils literal"><span class="pre">mshadow::TBlob</span></code> structure.</p>
<p>In the following example, we’ll create an operator
functioning as a smooth l1 loss,
which is a mixture of l1 loss and l2 loss. The loss itself can be written as:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span> <span class="n">loss</span> <span class="o">=</span> <span class="n">outside_weight</span> <span class="o">.*</span> <span class="n">f</span><span class="p">(</span><span class="n">inside_weight</span> <span class="o">.*</span> <span class="p">(</span><span class="n">data</span> <span class="o">-</span> <span class="n">label</span><span class="p">))</span>
<span class="n">grad</span> <span class="o">=</span> <span class="n">outside_weight</span> <span class="o">.*</span> <span class="n">inside_weight</span> <span class="o">.*</span> <span class="n">f</span><span class="s1">'(inside_weight .* (data - label))</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">.*</span></code> stands for element-wise multiplication, and <code class="docutils literal"><span class="pre">f</span></code>, <code class="docutils literal"><span class="pre">f'</span></code> is the smooth l1 loss function,
which we are assuming is in <code class="docutils literal"><span class="pre">mshadow</span></code> for now.
At first glance, it’s impossible to implement
this particular loss as a unary or binary operator.
But we have automatic differentiation in symbolic execution.
That simplifies the loss to <code class="docutils literal"><span class="pre">f</span></code> and <code class="docutils literal"><span class="pre">f'</span></code> directly.
This loss is no more complex than a <code class="docutils literal"><span class="pre">sin</span></code> or an <code class="docutils literal"><span class="pre">abs</span></code> function,
and can certainly be implemented as a unary operator.</p>
</div>
<div class="section" id="simpleop-the-unified-operator-api">
<span id="simpleop-the-unified-operator-api"></span><h2>SimpleOp: The Unified Operator API<a class="headerlink" href="#simpleop-the-unified-operator-api" title="Permalink to this headline"></a></h2>
<div class="section" id="define-shapes">
<span id="define-shapes"></span><h3>Define Shapes<a class="headerlink" href="#define-shapes" title="Permalink to this headline"></a></h3>
<p>The <code class="docutils literal"><span class="pre">mshadow</span></code> library requires explicit memory allocation.
As a consequence, all data shapes
must be provided before any calculation occurs.
Before we proceed with defining functions and gradient,
let’s check input data shape consistency and provide output shape.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">typedef</span> <span class="nf">TShape</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryShapeFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">TShape</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">);</span>
<span class="k">typedef</span> <span class="nf">TShape</span> <span class="p">(</span><span class="o">*</span><span class="n">BinaryShapeFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">TShape</span><span class="o">&amp;</span> <span class="k">const</span> <span class="n">TShape</span><span class="o">&amp;</span> <span class="n">rhs</span><span class="p">,</span><span class="n">lhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">);</span>
</pre></div>
</div>
<p>You can use <code class="docutils literal"><span class="pre">mshadow::TShape</span></code> to check input data shape and designate output data shape.
If you don’t define this function, the default output shape is the same as the input shape.
In the case of a binary operator, the shape of <code class="docutils literal"><span class="pre">lhs</span></code> and <code class="docutils literal"><span class="pre">rhs</span></code> is checked as the same by default.</p>
<p>You can also use shape functions to check if any additional arguments and resources are present.
Refer to the additional usages of <code class="docutils literal"><span class="pre">EnvArguments</span></code> to accomplish this.</p>
<p>Before we start on our smooth l1 loss example, we define a <code class="docutils literal"><span class="pre">XPU</span></code> to <code class="docutils literal"><span class="pre">cpu</span></code> or <code class="docutils literal"><span class="pre">gpu</span></code> in the header
<code class="docutils literal"><span class="pre">smooth_l1_unary-inl.h</span></code> implementation so that we reuse the same code in <code class="docutils literal"><span class="pre">smooth_l1_unary.cc</span></code> and
<code class="docutils literal"><span class="pre">smooth_l1_unary.cu</span></code>.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="cp">#include</span> <span class="cpf"><mxnet/operator_util.h></span><span class="cp"></span>
<span class="cp">#if defined(__CUDACC__)</span>
<span class="cp">#define XPU gpu</span>
<span class="cp">#else</span>
<span class="cp">#define XPU cpu</span>
<span class="cp">#endif</span>
</pre></div>
</div>
<p>In our smooth l1 loss example, it’s okay to use the default behavior whereby the output has the same shape as the source.
Written explicitly, it is:</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="kr">inline</span> <span class="n">TShape</span> <span class="nf">SmoothL1Shape_</span><span class="p">(</span><span class="k">const</span> <span class="n">TShape</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="n">TShape</span><span class="p">(</span><span class="n">src</span><span class="p">);</span>
</pre></div>
</div>
</div>
<div class="section" id="define-functions">
<span id="define-functions"></span><h3>Define Functions<a class="headerlink" href="#define-functions" title="Permalink to this headline"></a></h3>
<p>Create a unary or binary function with one output: <code class="docutils literal"><span class="pre">mshadow::TBlob</span></code>.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">ret</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">BinaryFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">lhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">rhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">ret</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
</pre></div>
</div>
<ul class="simple">
<li>Functions are differentiated by the types of input arguments.</li>
<li><code class="docutils literal"><span class="pre">RunContext</span> <span class="pre">ctx</span></code> contains information needed during runtime for execution.</li>
</ul>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">RunContext</span> <span class="p">{</span>
<span class="kt">void</span> <span class="o">*</span><span class="n">stream</span><span class="p">;</span> <span class="c1">// the stream of the device, can be NULL or Stream<gpu>* in GPU mode</span>
<span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span> <span class="kr">inline</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">>*</span> <span class="n">get_stream</span><span class="p">()</span> <span class="c1">// get mshadow stream from Context</span>
<span class="p">}</span> <span class="c1">// namespace mxnet</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">mshadow::stream<xpu></span> <span class="pre">*s</span> <span class="pre">=</span> <span class="pre">ctx.get_stream<xpu>();</span></code> is an example of obtaining a stream from <code class="docutils literal"><span class="pre">ctx</span></code>.</p>
<ul class="simple">
<li><code class="docutils literal"><span class="pre">OpReqType</span> <span class="pre">req</span></code> denotes how computation results are written into <code class="docutils literal"><span class="pre">ret</span></code>.</li>
</ul>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span>
<span class="n">kNullOp</span><span class="p">,</span> <span class="c1">// no operation, do not write anything</span>
<span class="n">kWriteTo</span><span class="p">,</span> <span class="c1">// write gradient to provided space</span>
<span class="n">kWriteInplace</span><span class="p">,</span> <span class="c1">// perform an in-place write</span>
<span class="n">kAddTo</span> <span class="c1">// add to the provided space</span>
<span class="p">};</span>
</pre></div>
</div>
<p>A macro is defined in <code class="docutils literal"><span class="pre">operator_util.h</span></code> for a simplified use of <code class="docutils literal"><span class="pre">OpReqType</span></code>.
<code class="docutils literal"><span class="pre">ASSIGN_DISPATCH(out,</span> <span class="pre">req,</span> <span class="pre">exp)</span></code> checks <code class="docutils literal"><span class="pre">req</span></code> and performs an assignment.</p>
<p>In our smooth l1 loss example, we use <code class="docutils literal"><span class="pre">UnaryFunction</span></code> to define the function of this operator.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span>
<span class="kt">void</span> <span class="n">SmoothL1Forward_</span><span class="p">(</span><span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span> <span class="o">*</span><span class="n">ret</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">)</span> <span class="p">{</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="p">;</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">expr</span><span class="p">;</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span><span class="p">();</span>
<span class="n">real_t</span> <span class="n">sigma2</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span> <span class="o">*</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span><span class="p">;</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">ret</span><span class="o">-></span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span> <span class="n">out</span> <span class="o">=</span> <span class="n">ret</span><span class="o">-></span><span class="n">get</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span> <span class="n">in</span> <span class="o">=</span> <span class="n">src</span><span class="p">.</span><span class="n">get</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">ASSIGN_DISPATCH</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span>
<span class="n">F</span><span class="o"><</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">smooth_l1_loss</span><span class="o">></span><span class="p">(</span><span class="n">in</span><span class="p">,</span> <span class="n">ScalarExp</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">sigma2</span><span class="p">)));</span>
<span class="p">});</span>
<span class="p">}</span>
</pre></div>
</div>
<p>After obtaining <code class="docutils literal"><span class="pre">mshadow::Stream</span></code> from <code class="docutils literal"><span class="pre">RunContext</span></code>, we get <code class="docutils literal"><span class="pre">mshadow::Tensor</span></code> from <code class="docutils literal"><span class="pre">mshadow::TBlob</span></code>.
<code class="docutils literal"><span class="pre">mshadow::F</span></code> is a shortcut to initiate a <code class="docutils literal"><span class="pre">mshadow</span></code> expression. The macro <code class="docutils literal"><span class="pre">MSHADOW_TYPE_SWITCH(type,</span> <span class="pre">DType,</span> <span class="pre">...)</span></code>
handles details on different types, and the macro <code class="docutils literal"><span class="pre">ASSIGN_DISPATCH(out,</span> <span class="pre">req,</span> <span class="pre">exp)</span></code> checks <code class="docutils literal"><span class="pre">OpReqType</span></code> and
performs actions accordingly. <code class="docutils literal"><span class="pre">sigma2</span></code> is a special parameter in this loss, which we will cover later.</p>
</div>
<div class="section" id="define-gradients-optional">
<span id="define-gradients-optional"></span><h3>Define Gradients (Optional)<a class="headerlink" href="#define-gradients-optional" title="Permalink to this headline"></a></h3>
<p>Create a gradient function with various types of inputs.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="c1">// depending only on out_grad</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryGradFunctionT0</span><span class="p">)(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
<span class="c1">// depending only on out_value</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryGradFunctionT1</span><span class="p">)(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OutputValue</span><span class="o">&amp;</span> <span class="n">out_value</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
<span class="c1">// depending only on in_data</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryGradFunctionT2</span><span class="p">)(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">Input0</span><span class="o">&amp;</span> <span class="n">in_data0</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
</pre></div>
</div>
<p>Gradient functions of binary operators have similar structures, except that <code class="docutils literal"><span class="pre">Input</span></code>, <code class="docutils literal"><span class="pre">TBlob</span></code>, and <code class="docutils literal"><span class="pre">OpReqType</span></code>
are doubled.</p>
<p><code class="docutils literal"><span class="pre">GradFunctionArgument</span></code></p>
<p><code class="docutils literal"><span class="pre">Input0</span></code>, <code class="docutils literal"><span class="pre">Input</span></code>, <code class="docutils literal"><span class="pre">OutputValue</span></code>, and <code class="docutils literal"><span class="pre">OutputGrad</span></code> all share the structure of <code class="docutils literal"><span class="pre">GradFunctionArgument</span></code>,
which is defined as:</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">GradFunctionArgument</span> <span class="p">{</span>
<span class="n">TBlob</span> <span class="n">data</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>In our smooth l1 loss example, note that it’s an <code class="docutils literal"><span class="pre">f'(x)</span></code>,
which utilizes input for the gradient calculation,
so the <code class="docutils literal"><span class="pre">UnaryGradFunctionT2</span></code> is suitable.
To enable the chain rule of the gradient,
we also need to multiply <code class="docutils literal"><span class="pre">out_grad</span></code> from the top to the result of <code class="docutils literal"><span class="pre">in_grad</span></code>.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span>
<span class="kt">void</span> <span class="n">SmoothL1BackwardUseIn_</span><span class="p">(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">Input0</span><span class="o">&amp;</span> <span class="n">in_data0</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span> <span class="o">*</span><span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">)</span> <span class="p">{</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="p">;</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">expr</span><span class="p">;</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span><span class="p">();</span>
<span class="n">real_t</span> <span class="n">sigma2</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span> <span class="o">*</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span><span class="p">;</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">in_grad</span><span class="o">-></span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span> <span class="n">src</span> <span class="o">=</span> <span class="n">in_data0</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">get</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span> <span class="n">ograd</span> <span class="o">=</span> <span class="n">out_grad</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">get</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span> <span class="n">igrad</span> <span class="o">=</span> <span class="n">in_grad</span><span class="o">-></span><span class="n">get</span><span class="o"><</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">ASSIGN_DISPATCH</span><span class="p">(</span><span class="n">igrad</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span>
<span class="n">ograd</span> <span class="o">*</span> <span class="n">F</span><span class="o"><</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">smooth_l1_gradient</span><span class="o">></span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">ScalarExp</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(</span><span class="n">sigma2</span><span class="p">)));</span>
<span class="p">});</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="register-simpleop-to-mxnet">
<span id="register-simpleop-to-mxnet"></span><h3>Register SimpleOp to MXNet<a class="headerlink" href="#register-simpleop-to-mxnet" title="Permalink to this headline"></a></h3>
<p>After creating the shape, function, and gradient, restore them into both an NDArray operator and
a symbolic operator. To simplify this process, use the registration macro defined in <code class="docutils literal"><span class="pre">operator_util.h</span></code>.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="n">MXNET_REGISTER_SIMPLE_OP</span><span class="p">(</span><span class="n">Name</span><span class="p">,</span> <span class="n">DEV</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_shape_function</span><span class="p">(</span><span class="n">Shape</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_function</span><span class="p">(</span><span class="n">DEV</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">Function</span><span class="o"><</span><span class="n">XPU</span><span class="o">></span><span class="p">,</span> <span class="n">SimpleOpInplaceOption</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_gradient</span><span class="p">(</span><span class="n">DEV</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">Gradient</span><span class="o"><</span><span class="n">XPU</span><span class="o">></span><span class="p">,</span> <span class="n">SimpleOpInplaceOption</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"description"</span><span class="p">);</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">SimpleOpInplaceOption</span></code> is defined as:</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">enum</span> <span class="n">SimpleOpInplaceOption</span> <span class="p">{</span>
<span class="n">kNoInplace</span><span class="p">,</span> <span class="c1">// do not allow inplace in arguments</span>
<span class="n">kInplaceInOut</span><span class="p">,</span> <span class="c1">// allow inplace in with out (unary)</span>
<span class="n">kInplaceOutIn</span><span class="p">,</span> <span class="c1">// allow inplace out_grad with in_grad (unary)</span>
<span class="n">kInplaceLhsOut</span><span class="p">,</span> <span class="c1">// allow inplace left operand with out (binary)</span>
<span class="n">kInplaceOutLhs</span> <span class="c1">// allow inplace out_grad with lhs_grad (binary)</span>
<span class="p">};</span>
</pre></div>
</div>
<p>In our example, we have a gradient function that relies on input data, so the function can’t be written in
place. The output gradient has no purpose after gradient computation, so the gradient can be written in place.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="n">MXNET_REGISTER_SIMPLE_OP</span><span class="p">(</span><span class="n">smooth_l1</span><span class="p">,</span> <span class="n">XPU</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_function</span><span class="p">(</span><span class="n">XPU</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">SmoothL1Forward_</span><span class="o"><</span><span class="n">XPU</span><span class="o">></span><span class="p">,</span> <span class="n">kNoInplace</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_gradient</span><span class="p">(</span><span class="n">XPU</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">SmoothL1BackwardUseIn_</span><span class="o"><</span><span class="n">XPU</span><span class="o">></span><span class="p">,</span> <span class="n">kInplaceOutIn</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_enable_scalar</span><span class="p">(</span><span class="nb">true</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Calculate Smooth L1 Loss(lhs, scalar)"</span><span class="p">);</span>
</pre></div>
</div>
<p>Remember from the discussion of shape functions that a default behavior without <code class="docutils literal"><span class="pre">set_shape_function</span></code> forces the inputs
(if they’re binary) to be the same shape and yield the same shape for output. We’ll discuss <code class="docutils literal"><span class="pre">set_enable_scalar</span></code> later.</p>
</div>
<div class="section" id="ndarray-operator-summary">
<span id="ndarray-operator-summary"></span><h3>NDArray Operator Summary<a class="headerlink" href="#ndarray-operator-summary" title="Permalink to this headline"></a></h3>
<ul class="simple">
<li>Create a shape function for determining the output shape.</li>
<li>Create a function as the forward routine by choosing a suitable function type.</li>
<li>Create a gradient as the backward routine by choosing a suitable gradient type.</li>
<li>Register the operator using the registration process.</li>
</ul>
</div>
</div>
<div class="section" id="additional-information-on-simpleop">
<span id="additional-information-on-simpleop"></span><h2>Additional Information on SimpleOp<a class="headerlink" href="#additional-information-on-simpleop" title="Permalink to this headline"></a></h2>
<div class="section" id="using-simpleop-on-envarguments">
<span id="using-simpleop-on-envarguments"></span><h3>Using SimpleOp on EnvArguments<a class="headerlink" href="#using-simpleop-on-envarguments" title="Permalink to this headline"></a></h3>
<p>Some operations might need a scalar as input, such as a gradient scale, a set of keyword arguments
controlling behavior, or a temporary space to speed up calculations.<code class="docutils literal"><span class="pre">EnvArguments</span></code> provides additional arguments and resources to make calculations more scalable
and efficient.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">EnvArguments</span> <span class="p">{</span>
<span class="n">real_t</span> <span class="n">scalar</span><span class="p">;</span> <span class="c1">// scalar argument, if enabled</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">></span> <span class="o">></span> <span class="n">kwargs</span><span class="p">;</span> <span class="c1">// keyword arguments</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">Resource</span><span class="o">></span> <span class="n">resource</span><span class="p">;</span> <span class="c1">// pointer to the resources requested</span>
<span class="p">};</span>
</pre></div>
</div>
<p>More registration parameters are required to enable these additional features. To prevent confusion on parameters, <code class="docutils literal"><span class="pre">scalar</span></code> and <code class="docutils literal"><span class="pre">kwargs</span></code>
can’t be present at the same time. To enable <code class="docutils literal"><span class="pre">scalar</span></code>, use
<code class="docutils literal"><span class="pre">set_enable_scalar(bool</span> <span class="pre">enable_scalar)</span></code> in registration. Then, in forward functions and gradients, the <code class="docutils literal"><span class="pre">scalar</span></code> can be accessed from <code class="docutils literal"><span class="pre">env.scalar</span></code> as in the function parameter <code class="docutils literal"><span class="pre">EnvArguments</span> <span class="pre">env</span></code>.</p>
<p>To enable <code class="docutils literal"><span class="pre">kwargs</span></code>, use <code class="docutils literal"><span class="pre">set_enable_kwargs(bool</span> <span class="pre">enable_kwargs)</span></code> in registration. Then, in forward
functions and gradients, additional arguments are contained in <code class="docutils literal"><span class="pre">env.kwarg</span></code>, which is defined as
<code class="docutils literal"><span class="pre">std::vector<std::pair<std::string,</span> <span class="pre">std::string></span> <span class="pre">></span></code>. Use the DMLC parameter structure to
simplify parsing keyword arguments. For more details, see the <a class="reference external" href="https://github.com/dmlc/dmlc-core/blob/master/doc/parameter.md">guide on parameter structure</a>.</p>
<p>Additional resources like <code class="docutils literal"><span class="pre">mshadow::Random<xpu></span></code> and temporary memory space can also be requested and
accessed from <code class="docutils literal"><span class="pre">EnvArguments.resource</span></code>. The registration routine is <code class="docutils literal"><span class="pre">set_resource_request(ResourceRequest</span> <span class="pre">req)</span></code>
or <code class="docutils literal"><span class="pre">set_resource_request(const</span> <span class="pre">std::vector<ResourceRequest>)</span></code>, where <code class="docutils literal"><span class="pre">mxnet::ResourceRequest</span></code> is defined as:</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">struct</span> <span class="n">ResourceRequest</span> <span class="p">{</span>
<span class="k">enum</span> <span class="n">Type</span> <span class="p">{</span> <span class="c1">// Resource type, indicating what the pointer type is</span>
<span class="n">kRandom</span><span class="p">,</span> <span class="c1">// mshadow::Random<xpu> object</span>
<span class="n">kTempSpace</span> <span class="c1">// A dynamic temp space that can be arbitrary size</span>
<span class="p">};</span>
<span class="n">Type</span> <span class="n">type</span><span class="p">;</span> <span class="c1">// type of resources</span>
<span class="p">};</span>
</pre></div>
</div>
<p>Registration will request the declared resource requests from <code class="docutils literal"><span class="pre">mxnet::ResourceManager</span></code>, and place resources
in <code class="docutils literal"><span class="pre">std::vector<Resource></span> <span class="pre">resource</span></code> in <code class="docutils literal"><span class="pre">EnvArguments</span></code>. To access resources, use the following:</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">auto</span> <span class="n">tmp_space_res</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">resources</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">get_space</span><span class="p">(</span><span class="n">some_shape</span><span class="p">,</span> <span class="n">some_stream</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">rand_res</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">resources</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">get_random</span><span class="p">(</span><span class="n">some_stream</span><span class="p">);</span>
</pre></div>
</div>
<p>For an example, see <code class="docutils literal"><span class="pre">src/operator/loss_binary_op-inl.h</span></code>.</p>
<p>In our smooth l1 loss example, a scalar input is needed to mark the turning point of a loss function. Therefore,
in the registration process, we use <code class="docutils literal"><span class="pre">set_enable_scalar(true)</span></code>, and use <code class="docutils literal"><span class="pre">env.scalar</span></code> in function and gradient
declarations.</p>
</div>
<div class="section" id="crafting-a-tensor-operation">
<span id="crafting-a-tensor-operation"></span><h3>Crafting a Tensor Operation<a class="headerlink" href="#crafting-a-tensor-operation" title="Permalink to this headline"></a></h3>
<p>Because computation utilizes the <code class="docutils literal"><span class="pre">mshadow</span></code> library and we sometimes don’t have functions readily available, we
can craft tensor operations in operator implementations. If you define such functions as element-wise, you
can implement them as a <code class="docutils literal"><span class="pre">mxnet::op::mshadow_op</span></code>. <code class="docutils literal"><span class="pre">src/operator/mshadow_op.h</span></code> that contains a lot of <code class="docutils literal"><span class="pre">mshadow_op</span></code>,
for example. <code class="docutils literal"><span class="pre">mshadow_op</span></code> are expression mappers. They deal with the scalar case of desired functions. For details, see
<a class="reference external" href="https://github.com/dmlc/mshadow/tree/master/doc">mshadow expression API guide</a>.</p>
<p>If an operation can’t be done in an element-wise way, like the softmax loss and gradient, then you need to create a new tensor operation. You need to create as <code class="docutils literal"><span class="pre">mshadow</span></code> function and as <code class="docutils literal"><span class="pre">mshadow::cuda</span></code>
function directly. For details, see the <code class="docutils literal"><span class="pre">mshadow</span></code> library. For an example, see <code class="docutils literal"><span class="pre">src/operator/roi_pooling.cc</span></code>.</p>
<p>In our smooth l1 loss example, we create two mappers, namely the scalar cases of smooth l1 loss and gradient.</p>
<div class="highlight-cpp"><div class="highlight"><pre><span></span> <span class="k">namespace</span> <span class="n">mshadow_op</span> <span class="p">{</span>
<span class="k">struct</span> <span class="n">smooth_l1_loss</span> <span class="p">{</span>
<span class="c1">// a is x, b is sigma2</span>
<span class="n">MSHADOW_XINLINE</span> <span class="k">static</span> <span class="n">real_t</span> <span class="n">Map</span><span class="p">(</span><span class="n">real_t</span> <span class="n">a</span><span class="p">,</span> <span class="n">real_t</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="n">a</span> <span class="o">></span> <span class="mf">1.0f</span> <span class="o">/</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="n">a</span> <span class="o">-</span> <span class="mf">0.5f</span> <span class="o">/</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="n">a</span> <span class="o"><</span> <span class="o">-</span><span class="mf">1.0f</span> <span class="o">/</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="o">-</span><span class="n">a</span> <span class="o">-</span> <span class="mf">0.5f</span> <span class="o">/</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
<span class="k">return</span> <span class="mf">0.5f</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">}</span>
<span class="p">};</span>
<span class="p">}</span>
</pre></div>
</div>
<p>The gradient, which can be found in <code class="docutils literal"><span class="pre">src/operator/smooth_l1_unary-inl.h</span></code>, is similar.</p>
</div>
<div class="section" id="beyond-two-operands">
<span id="beyond-two-operands"></span><h3>Beyond Two Operands<a class="headerlink" href="#beyond-two-operands" title="Permalink to this headline"></a></h3>
<p>The new unified API is designed to fulfill the fundamentals of an operation. For operators with more than two inputs,
more than one output, or that need more features, see the original <a class="reference external" href="/versions/0.12.1/architecture/overview.html#operators-in-mxnet">Operator API</a>.</p>
</div>
</div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">MXNet System Architecture</a></li>
<li><a class="reference internal" href="#mxnet-system-components">MXNet System Components</a><ul>
<li><a class="reference internal" href="#execution-engine">Execution Engine</a><ul>
<li><a class="reference internal" href="#interface">Interface</a></li>
<li><a class="reference internal" href="#function">Function</a></li>
<li><a class="reference internal" href="#context">Context</a></li>
<li><a class="reference internal" href="#varhandle">VarHandle</a></li>
<li><a class="reference internal" href="#push-and-wait">Push and Wait</a></li>
<li><a class="reference internal" href="#save-object-creation-cost">Save Object Creation Cost</a></li>
<li><a class="reference internal" href="#api-reference">API Reference</a></li>
</ul>
</li>
<li><a class="reference internal" href="#operators-in-mxnet">Operators in MXNet</a><ul>
<li><a class="reference internal" href="#operator-interface">Operator Interface</a></li>
<li><a class="reference internal" href="#operator-property">Operator Property</a></li>
<li><a class="reference internal" href="#create-an-operator-from-the-operator-property">Create an Operator from the Operator Property</a><ul>
<li><a class="reference internal" href="#create-operator">Create Operator</a></li>
<li><a class="reference internal" href="#parametrize-operator">Parametrize Operator</a></li>
<li><a class="reference internal" href="#register-the-operator-property-class-and-the-parameter-class-to-mxnet">Register the Operator Property Class and the Parameter Class to MXNet</a></li>
</ul>
</li>
<li><a class="reference internal" href="#interface-summary">Interface Summary</a></li>
</ul>
</li>
<li><a class="reference internal" href="#unifying-the-ndarray-operator-and-symbolic-operator">Unifying the NDArray Operator and Symbolic Operator</a></li>
<li><a class="reference internal" href="#simpleop-the-unified-operator-api">SimpleOp: The Unified Operator API</a><ul>
<li><a class="reference internal" href="#define-shapes">Define Shapes</a></li>
<li><a class="reference internal" href="#define-functions">Define Functions</a></li>
<li><a class="reference internal" href="#define-gradients-optional">Define Gradients (Optional)</a></li>
<li><a class="reference internal" href="#register-simpleop-to-mxnet">Register SimpleOp to MXNet</a></li>
<li><a class="reference internal" href="#ndarray-operator-summary">NDArray Operator Summary</a></li>
</ul>
</li>
<li><a class="reference internal" href="#additional-information-on-simpleop">Additional Information on SimpleOp</a><ul>
<li><a class="reference internal" href="#using-simpleop-on-envarguments">Using SimpleOp on EnvArguments</a></li>
<li><a class="reference internal" href="#crafting-a-tensor-operation">Crafting a Tensor Operation</a></li>
<li><a class="reference internal" href="#beyond-two-operands">Beyond Two Operands</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../_static/js/search.js" type="text/javascript"></script>
<script src="../_static/js/navbar.js" type="text/javascript"></script>
<script src="../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../_static/js/copycode.js" type="text/javascript"></script>
<script src="../_static/js/page.js" type="text/javascript"></script>
<script src="../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>