blob: 0fee4ca6a62a9cdd6a7e9817cdd0cf28081c57f6 [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/1.9.1/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Apache MXNet System Architecture | Apache MXNet</title>
<meta name="generator" content="Jekyll v3.8.6" />
<meta property="og:title" content="Apache MXNet System Architecture" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/1.9.1/api/architecture/overview" />
<meta property="og:url" content="https://mxnet.apache.org/versions/1.9.1/api/architecture/overview" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"description":"A flexible and efficient library for deep learning.","headline":"Apache MXNet System Architecture","@type":"WebPage","url":"https://mxnet.apache.org/versions/1.9.1/api/architecture/overview","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/1.9.1/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/1.9.1/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.1/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/1.9.1/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/1.9.1/assets/js/docsearch.min.js"></script><script src="/versions/1.9.1/assets/js/globalSearch.js" defer></script>
<script src="/versions/1.9.1/assets/js/clipboard.js" defer></script>
<script src="/versions/1.9.1/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/1.9.1/"><img
src="/versions/1.9.1/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">1.9.1</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/1.9.1/get_started">Get Started</a>
<a class="page-link" href="/versions/1.9.1/features">Features</a>
<a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/1.9.1/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/1.9.1/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">1.9.1
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a href="/">master</a>
<a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Apache MXNet System Architecture</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<h3 style="text-transform: capitalize; padding-left:10px">architecture</h3>
<ul>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/architecture/exception_handling">Exception Handling in Apache MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/architecture/note_data_loading">Efficient Data Loaders</a></li>
<!-- page-category -->
<li><a href="/versions/1.9.1/api/architecture/note_engine">Dependency Engine</a></li>
<!-- page-category -->
<li><a href="/versions/1.9.1/api/architecture/note_memory">Memory Consumption</a></li>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/architecture/overview">Apache MXNet System Architecture</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/architecture/program_model">Deep Learning Programming Paradigm</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- resource-p -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="mxnet-system-architecture">MXNet System Architecture</h1>
<p><img src="https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/mxnet/system/overview.png" alt="System Overview"></p>
<p>This figure shows the major modules and components of the MXNet system and their interaction. The modules are:</p>
<ul>
<li>Runtime Dependency Engine: Schedules and executes the
operations according to their read/write dependency.</li>
<li>Storage Allocator: Efficiently allocates and recycles memory blocks
on host (CPU) and devices (GPUs).</li>
<li>Resource Manager: Manages global resources, such as the random number generator
and temporal space.</li>
<li>NDArray: Dynamic, asynchronous n-dimensional arrays,
which provide flexible imperative programs for MXNet.</li>
<li>Symbolic Execution: Static symbolic graph executor,
which provides efficient symbolic graph execution and optimization.</li>
<li>Operator: Operators that define static forward and gradient
calculation (backprop).</li>
<li>SimpleOp: Operators that extend NDArray operators and symbolic operators
in a unified fashion.</li>
<li>Symbol Construction: Symbolic construction, which provides a way to construct
a computation graph (net configuration).</li>
<li>KVStore: Key-value store interface for efficient parameter synchronization.</li>
<li>Data Loading(IO): Efficient distributed data loading and augmentation.</li>
</ul>
<h1 id="mxnet-system-components">MXNet System Components</h1>
<h2 id="execution-engine">Execution Engine</h2>
<p>You can use MXNet&#39;s engine not only for deep learning,
but for any domain-specific problem.
It&#39;s designed to solve a general problem:
execute a bunch of functions following their dependencies.
Execution of any two functions with dependencies should be serialized.
To boost performance, functions with no dependencies <em>can</em> be executed in parallel.
For a general discussion of this topic,
see our <a href="note_engine">notes on the dependency engine</a>.</p>
<h3 id="interface">Interface</h3>
<p>The following API is the core interface for the execution engine:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="kt">void</span> <span class="n">PushSync</span><span class="p">(</span><span class="n">Fn</span> <span class="n">exec_fun</span><span class="p">,</span> <span class="n">Context</span> <span class="n">exec_ctx</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">VarHandle</span><span class="o">&gt;</span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">const_vars</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">VarHandle</span><span class="o">&gt;</span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">mutate_vars</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</code></pre></div>
<p>This API allows you to push a function (<code>exec_fun</code>),
along with its context information and dependencies, to the engine.
<code>exec_ctx</code> is the context information in which the <code>exec_fun</code> should be executed,
<code>const_vars</code> denotes the variables that the function reads from,
and <code>mutate_vars</code> are the variables to be modified.
The engine provides the following guarantee:</p>
<blockquote>
<p><em>The execution of any two functions
that modify a common variable
is serialized in their push order.</em></p>
</blockquote>
<h3 id="function">Function</h3>
<p>The function type of the engine is:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">using</span> <span class="n">Fn</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="n">RunContext</span><span class="p">)</span><span class="o">&gt;</span><span class="p">;</span>
</code></pre></div>
<p><code>RunContext</code> contains runtime information, which is determined by the engine:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">struct</span> <span class="n">RunContext</span> <span class="p">{</span>
<span class="c1">// stream pointer which could be safely cast to</span>
<span class="c1">// cudaStream_t* type</span>
<span class="kt">void</span> <span class="o">*</span><span class="n">stream</span><span class="p">;</span>
<span class="p">};</span>
</code></pre></div>
<p>Alternatively, you could use <code>mxnet::engine::DAGEngine::Fn</code>, which has the same type definition.</p>
<p>All of the functions are executed by the engine&#39;s internal threads.
In such a model, it&#39;s usually not a good idea to push <em>blocking</em> functions
to the engine (usually for dealing with I/O tasks like disk, web service, UI, etc.)
because it will occupy the execution thread and reduce total throughput.
In that case, we provide another <em>asynchronous</em> function type:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">using</span> <span class="n">Callback</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">()</span><span class="o">&gt;</span><span class="p">;</span>
<span class="k">using</span> <span class="n">AsyncFn</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">function</span><span class="o">&lt;</span><span class="kt">void</span><span class="p">(</span><span class="n">RunContext</span><span class="p">,</span> <span class="n">Callback</span><span class="p">)</span><span class="o">&gt;</span><span class="p">;</span>
</code></pre></div>
<p>In the <code>AsyncFn</code> function, you can pass the heavy part to your own threads
and safely exit the body of the function.
The engine doesn&#39;t consider the function finished
until the <code>Callback</code> function is called.</p>
<h3 id="context">Context</h3>
<p>You can specify the <code>Context</code> of the function to be executed within.
This usually includes whether the function should be run on a CPU or a GPU,
and if you specify a GPU, which GPU to use.
<code>Context</code> is different from <code>RunContext</code>.
<code>Context</code> contains device type (GPU/CPU) and device id,
while <code>RunContext</code> contains information that can be decided only during runtime,
for example, on which stream the function should be executed.</p>
<h3 id="varhandle">VarHandle</h3>
<p><code>VarHandle</code> is used to specify the dependencies of functions.
The MXNet engine is designed to be decoupled from other MXNet modules.
So <code>VarHandle</code> is like an engine-provided token you use
to represent the external resources the functions can use or modify.
It&#39;s designed to be lightweight, so creating,
deleting, or copying a variable incurs little overhead.
Upon pushing the functions, you need to specify the variables
that will be used (immutable) in the <code>const_vars</code> vector,
and the variables that will be modified (mutable) in the <code>mutate_vars</code> vector.
The engine uses one rule for resolving the dependencies among functions:</p>
<blockquote>
<p><em>The execution of any two functions when one of them modifies at least one common variable is serialized in their push order.</em></p>
</blockquote>
<p>For example, if <code>Fn1</code> and <code>Fn2</code> both mutate <code>V2</code> then <code>Fn2</code>
is guaranteed to be executed after <code>Fn1</code>
if <code>Fn2</code> is pushed after <code>Fn1</code>.
On the other hand, if <code>Fn1</code> and <code>Fn2</code> both use <code>V2</code>,
their actual execution order could be random.</p>
<p>This design allows the engine to schedule <em>state-mutating</em> operations in a manner
that minimizes calls to allocate new memory.
For example, the weight update function in DNN
can now use the <code>+=</code> operator
to update the weights in place,
rather than generating a new weight array each time.</p>
<p>To create a variable, use the <code>NewVar()</code> API.
To delete a variable, use the <code>PushDelete</code> API.</p>
<h3 id="push-and-wait">Push and Wait</h3>
<p><em>All <code>Push</code> APIs are asynchronous.</em> The API call returns immediately
regardless of whether the pushed <code>Fn</code> is finished or not.
This allows the engine to start computing at the same time
as the user thread is pushing functions.
<code>Push</code> APIs are not thread-safe.
To be specific, only one thread should make engine API calls at a time.</p>
<p>If you want to wait for a specific <code>Fn</code> to finish,
include a callback function in the closure,
and call the function at the end of your <code>Fn</code>.</p>
<p>If you want to wait for all <code>Fn</code>s
that involve (use or mutate) a certain variable to finish,
use the <code>WaitForVar(var)</code> API.</p>
<p>If you want to wait for all pushed <code>Fn</code>s to finish,
use the <code>WaitForAll()</code> API.</p>
<h3 id="save-object-creation-cost">Save Object Creation Cost</h3>
<p>In some cases, you need to push several functions to the engine for a long period of time.
If the computation of these functions is light,
the overhead of copying lambdas and creating use/mutate variable lists becomes relatively high.
We provide an API to create an <code>OprHandle</code> beforehand:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="n">OprHandle</span> <span class="n">NewOperator</span><span class="p">(</span><span class="n">AsyncFn</span> <span class="n">fn</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">VarHandle</span><span class="o">&gt;</span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">const_vars</span><span class="p">,</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">VarHandle</span><span class="o">&gt;</span> <span class="k">const</span><span class="o">&amp;</span> <span class="n">mutate_vars</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</code></pre></div>
<p>You can keep pushing the <code>OprHandle</code> without repeatedly creating them:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="kt">void</span> <span class="n">Push</span><span class="p">(</span><span class="n">OprHandle</span> <span class="n">op</span><span class="p">,</span> <span class="n">Context</span> <span class="n">exec_ctx</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</code></pre></div>
<p>To delete it, call the <code>DeleteOperator(OprHandle op)</code> API.
Ensure that the operator has finished computing before calling this API.</p>
<h2 id="operators-in-mxnet">Operators in MXNet</h2>
<p>In MXNet, an operator is a class that contains both actual computation logic
and auxiliary information that can aid the system in performing optimizations,
like in-place updates and auto-derivatives.
To understand the remainder of the document,
we recommend that you familiarize yourself with the <code>mshadow</code> library,
because all operators compute on the tensor-like structure <code>mshadow::TBlob</code>
provided by the system during runtime.</p>
<p>MXNet&#39;s operator interface allows you to:</p>
<ul>
<li>Reduce memory allocation cost by specifying in-place updates.</li>
<li>Hide some internal arguments from Python to make it cleaner.</li>
<li>Define the relationships among input tensors and output tensors,
which allows the system to perform shape checking for you.</li>
<li>Acquire additional temporary spaces from the system
to perform computation (e.g., calling <code>cudnn</code> routines).</li>
</ul>
<h3 id="operator-interface">Operator Interface</h3>
<p><code>Forward</code> is the core operator interface:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="kt">void</span> <span class="n">Forward</span><span class="p">(</span><span class="k">const</span> <span class="n">OpContext</span> <span class="o">&amp;</span><span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">aux_states</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</code></pre></div>
<p>The <code>OpContext</code> structure is:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">struct</span> <span class="n">OpContext</span> <span class="p">{</span>
<span class="kt">int</span> <span class="n">is_train</span><span class="p">;</span>
<span class="n">RunContext</span> <span class="n">run_ctx</span><span class="p">;</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">Resource</span><span class="o">&gt;</span> <span class="n">requested</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div>
<p>It describes whether the operator is in the train or test phase,
which device the operator should be run on (in <code>run_ctx</code>),
and requested resources (covered in the following sections).</p>
<ul>
<li><code>in_data</code> and <code>out_data</code> represent the input and output tensors, respectively.
All of the tensor spaces have been allocated by the system.</li>
<li><p><code>req</code> denotes how the computation results are written into the <code>out_data</code>.
In other words, <code>req.size() == out_data.size()</code> and <code>req[i]</code>
correspond to the write type of <code>out_data[i]</code>.</p></li>
<li><p>The <code>OpReqType</code> is defined as:</p></li>
</ul>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span>
<span class="n">kNullOp</span><span class="p">,</span>
<span class="n">kWriteTo</span><span class="p">,</span>
<span class="n">kWriteInplace</span><span class="p">,</span>
<span class="n">kAddTo</span>
<span class="p">};</span>
</code></pre></div>
<p>Normally, the types of all <code>out_data</code> should be <code>kWriteTo</code>,
meaning that the provided <code>out_data</code> tensor is a <em>raw</em> memory block,
so the operator should write results directly into it.
In some cases, for example when calculating the <code>gradient</code> tensor,
it would be great if we could accumulate the result,
rather than directly overwrite the tensor contents
so that no extra space needs to be created each time.
In such a case, the corresponding <code>req</code> type is set as <code>kAddTo</code>,
indicating that a <code>+=</code> should be called.</p>
<ul>
<li><code>aux_states</code> is intentionally designed for auxiliary tensors used to help computation. Currently, it is useless.</li>
</ul>
<p>Aside from the <code>Forward</code> operator, you could optionally implement the <code>Backward</code> interface:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="kt">void</span> <span class="nf">Backward</span><span class="p">(</span><span class="k">const</span> <span class="n">OpContext</span> <span class="o">&amp;</span><span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">aux_states</span><span class="p">);</span>
</code></pre></div>
<p>This interface follows the same design principle as the <code>Forward</code> interface,
except that <code>out_grad</code>, <code>in_data</code>, and <code>out_data</code> are given,
and the operator computes <code>in_grad</code> as the results.
The naming strategy is similar to Torch&#39;s convention,
and can be summarized in following figure:</p>
<p>[input/output semantics figure]</p>
<p>Some operators might not require all of the following:
<code>out_grad</code>, <code>in_data</code> and <code>out_data</code>.
You can specify these dependencies with the <code>DeclareBackwardDependency</code> interface in <code>OperatorProperty</code>.</p>
<h3 id="operator-property">Operator Property</h3>
<p>One convolution might have several implementations,
and you might want to switch among them to achieve the best performance.
Therefore, we separate the operator <em>semantic</em> interfaces
from the implementation interface (<code>Operator</code> class)
into the <code>OperatorProperty</code> class.
The <code>OperatorProperty</code> interface consists of:</p>
<ul>
<li><strong>InferShape:</strong></li>
</ul>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="kt">bool</span> <span class="n">InferShape</span><span class="p">(</span><span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">*</span><span class="n">in_shape</span><span class="p">,</span>
<span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">*</span><span class="n">out_shape</span><span class="p">,</span>
<span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">*</span><span class="n">aux_shape</span><span class="p">)</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</code></pre></div>
<p>This interface has two purposes:
* Tell the system the size of each input and output tensor,
so it can allocate space for them before the <code>Forward</code> and <code>Backward</code> call.
* Perform a size check to make sure that there isn&#39;t an obvious error before running.
The shape in <code>in_shape</code> is set by the system
(from the <code>out_shape</code> of the previous operators).
It returns <code>false</code> when there is not enough information
to infer shapes or throws an error when the shape is inconsistent.</p>
<ul>
<li><strong>Request Resources:</strong> Operations like <code>cudnnConvolutionForward</code> need a work space for computation.
If the system can manage that, it could then perform optimizations,
like reuse the space, and so on.
MXNet defines two interfaces to achieve this:</li>
</ul>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">ResourceRequest</span><span class="o">&gt;</span> <span class="n">ForwardResource</span><span class="p">(</span>
<span class="k">const</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">&amp;</span><span class="n">in_shape</span><span class="p">)</span> <span class="k">const</span><span class="p">;</span>
<span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">ResourceRequest</span><span class="o">&gt;</span> <span class="n">BackwardResource</span><span class="p">(</span>
<span class="k">const</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">&amp;</span><span class="n">in_shape</span><span class="p">)</span> <span class="k">const</span><span class="p">;</span>
</code></pre></div>
<p>The <code>ResourceRequest</code> structure (in <code>resource.h</code>) currently contains only a type flag:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">struct</span> <span class="n">ResourceRequest</span> <span class="p">{</span>
<span class="k">enum</span> <span class="n">Type</span> <span class="p">{</span>
<span class="n">kRandom</span><span class="p">,</span> <span class="c1">// get a mshadow::Random&lt;xpu&gt; object</span>
<span class="n">kTempSpace</span><span class="p">,</span> <span class="c1">// request temporary space</span>
<span class="p">};</span>
<span class="n">Type</span> <span class="n">type</span><span class="p">;</span>
<span class="p">};</span>
</code></pre></div>
<p>If <code>ForwardResource</code> and <code>BackwardResource</code> return non-empty arrays,
the system offers the corresponding resources through the <code>ctx</code> parameter
in the <code>Forward</code> and <code>Backward</code> interface of <code>Operator</code>.
Basically, to access those resources, simply write:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">auto</span> <span class="n">tmp_space_res</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">requested</span><span class="p">[</span><span class="n">kTempSpace</span><span class="p">].</span><span class="n">get_space</span><span class="p">(</span><span class="n">some_shape</span><span class="p">,</span> <span class="n">some_stream</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">rand_res</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">requested</span><span class="p">[</span><span class="n">kRandom</span><span class="p">].</span><span class="n">get_random</span><span class="p">(</span><span class="n">some_stream</span><span class="p">);</span>
</code></pre></div>
<p>For an example, see <code>src/operator/cudnn_convolution-inl.h</code>.</p>
<ul>
<li><strong>Backward dependency:</strong> Let&#39;s look at two different operator signatures
(we name all of the arguments for demonstration purposes):</li>
</ul>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="kt">void</span> <span class="nf">FullyConnectedForward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">weight</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_data</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">FullyConnectedBackward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">weight</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_grad</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">PoolingForward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_data</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">PoolingBackward</span><span class="p">(</span><span class="n">TBlob</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">TBlob</span> <span class="n">in_grad</span><span class="p">);</span>
</code></pre></div>
<p>Note that <code>out_data</code> in <code>FullyConnectedForward</code>
is not used by <code>FullyConnectedBackward</code>,
while <code>PoolingBackward</code> requires all of the arguments of <code>PoolingForward</code>.
Therefore, for <code>FullyConnectedForward</code>,
the <code>out_data</code> tensor once consumed could be safely freed
because the backward function will not need it.
This provides a chance for the system to collect some tensors
as garbage as soon as possible.
To specify this situation, we provide an interface:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">DeclareBackwardDependency</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span><span class="p">;</span>
</code></pre></div>
<p>The <code>int</code> element of the argument vector is an ID
to distinguish different arrays.
Let&#39;s see how this interface specifies different dependencies
for <code>FullyConnected</code> and <code>Pooling</code>:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">FullyConnectedProperty</span><span class="o">::</span><span class="n">DeclareBackwardDependency</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span><span class="n">out_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]};</span> <span class="c1">// NOTE: out_data[0] is NOT included</span>
<span class="p">}</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">PoolingProperty</span><span class="o">::</span><span class="n">DeclareBackwardDependency</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span><span class="n">out_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]};</span>
<span class="p">}</span>
</code></pre></div>
<ul>
<li><strong>In place Option:</strong> To further save the cost of memory allocation,
you can use in-place updates.
They are appropriate for element-wise operations
when the input tensor and output tensor have the same shape.
You specify and in-place update with the following interface:</li>
</ul>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="kt">int</span><span class="p">,</span> <span class="kt">void</span><span class="o">*&gt;&gt;</span> <span class="n">ElewiseOpProperty</span><span class="o">::</span><span class="n">ForwardInplaceOption</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">void</span><span class="o">*&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span> <span class="p">{</span><span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]}</span> <span class="p">};</span>
<span class="p">}</span>
<span class="k">virtual</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="kt">int</span><span class="p">,</span> <span class="kt">void</span><span class="o">*&gt;&gt;</span> <span class="n">ElewiseOpProperty</span><span class="o">::</span><span class="n">BackwardInplaceOption</span><span class="p">(</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">in_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">out_data</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">void</span><span class="o">*&gt;</span> <span class="o">&amp;</span><span class="n">in_grad</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="p">{</span> <span class="p">{</span><span class="n">out_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">]}</span> <span class="p">}</span>
<span class="p">}</span>
</code></pre></div>
<p>This tells the system that the <code>in_data[0]</code> and <code>out_data[0]</code> tensors could share the same memory spaces during <code>Forward</code>, and so do <code>out_grad[0]</code> and <code>in_grad[0]</code> during <code>Backward</code>.</p>
<blockquote>
<p><strong>Important:</strong> Even if you use the preceding specification, it&#39;s <em>not</em> guaranteed that the input and output tensors will share the same space. In fact, this is only a suggestion for the system, which makes the final decision. However, in either case, the decision is completely transparent to you, so the actual <code>Forward</code> and <code>Backward</code> implementation does not need to consider that.</p>
</blockquote>
<ul>
<li><strong>Expose Operator to Python:</strong> Because of the restrictions of C++, you need user to implement following interfaces:</li>
</ul>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="c1">// initial the property class from a list of key-value string pairs</span>
<span class="k">virtual</span> <span class="kt">void</span> <span class="n">Init</span><span class="p">(</span><span class="k">const</span> <span class="n">vector</span><span class="o">&lt;</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">string</span><span class="p">,</span> <span class="n">string</span><span class="o">&gt;&gt;</span> <span class="o">&amp;</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="c1">// return the parameters in a key-value string map</span>
<span class="k">virtual</span> <span class="n">map</span><span class="o">&lt;</span><span class="n">string</span><span class="p">,</span> <span class="n">string</span><span class="o">&gt;</span> <span class="n">GetParams</span><span class="p">()</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="c1">// return the name of arguments (for generating signature in python)</span>
<span class="k">virtual</span> <span class="n">vector</span><span class="o">&lt;</span><span class="n">string</span><span class="o">&gt;</span> <span class="n">ListArguments</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the name of output values</span>
<span class="k">virtual</span> <span class="n">vector</span><span class="o">&lt;</span><span class="n">string</span><span class="o">&gt;</span> <span class="n">ListOutputs</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the name of auxiliary states</span>
<span class="k">virtual</span> <span class="n">vector</span><span class="o">&lt;</span><span class="n">string</span><span class="o">&gt;</span> <span class="n">ListAuxiliaryStates</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the number of output values</span>
<span class="k">virtual</span> <span class="kt">int</span> <span class="n">NumOutputs</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
<span class="c1">// return the number of visible outputs</span>
<span class="k">virtual</span> <span class="kt">int</span> <span class="n">NumVisibleOutputs</span><span class="p">()</span> <span class="k">const</span><span class="p">;</span>
</code></pre></div>
<h3 id="create-an-operator-from-the-operator-property">Create an Operator from the Operator Property</h3>
<p><code>OperatorProperty</code> includes all <em>semantic</em> attributes of an operation. It&#39;s also responsible for creating the <code>Operator</code> pointer for actual computation.</p>
<h4 id="create-operator">Create Operator</h4>
<p>Implement the following interface in <code>OperatorProperty</code>:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">virtual</span> <span class="n">Operator</span><span class="o">*</span> <span class="n">CreateOperator</span><span class="p">(</span><span class="n">Context</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">const</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
</code></pre></div>
<p>For example:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">class</span> <span class="nc">ConvolutionOp</span> <span class="p">{</span>
<span class="nl">public:</span>
<span class="kt">void</span> <span class="n">Forward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="kt">void</span> <span class="n">Backward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="p">};</span>
<span class="k">class</span> <span class="nc">ConvolutionOpProperty</span> <span class="o">:</span> <span class="k">public</span> <span class="n">OperatorProperty</span> <span class="p">{</span>
<span class="nl">public:</span>
<span class="n">Operator</span><span class="o">*</span> <span class="n">CreateOperator</span><span class="p">(</span><span class="n">Context</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="k">new</span> <span class="n">ConvolutionOp</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">};</span>
</code></pre></div>
<h4 id="parametrize-operator">Parametrize Operator</h4>
<p>When implementing a convolution operator, you need to know the kernel size,
the stride size, padding size, and so on.
These parameters should be passed to the operator
before any <code>Forward</code> or <code>Backward</code> interface is called.
To do so, you could define a <code>ConvolutionParam</code> structure, as follows:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="cp">#include &lt;dmlc/parameter.h&gt;
</span> <span class="k">struct</span> <span class="n">ConvolutionParam</span> <span class="o">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o">&lt;</span><span class="n">ConvolutionParam</span><span class="o">&gt;</span> <span class="p">{</span>
<span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span> <span class="n">kernel</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">pad</span><span class="p">;</span>
<span class="kt">uint32_t</span> <span class="n">num_filter</span><span class="p">,</span> <span class="n">num_group</span><span class="p">,</span> <span class="n">workspace</span><span class="p">;</span>
<span class="kt">bool</span> <span class="n">no_bias</span><span class="p">;</span>
<span class="p">};</span>
</code></pre></div>
<p>Put it in <code>ConvolutionOpProperty</code>, and pass it to the operator class during construction:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="k">class</span> <span class="nc">ConvolutionOp</span> <span class="p">{</span>
<span class="nl">public:</span>
<span class="n">ConvolutionOp</span><span class="p">(</span><span class="n">ConvolutionParam</span> <span class="n">p</span><span class="p">)</span><span class="o">:</span> <span class="n">param_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="p">{}</span>
<span class="kt">void</span> <span class="n">Forward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="kt">void</span> <span class="n">Backward</span><span class="p">(</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span> <span class="p">...</span> <span class="p">}</span>
<span class="nl">private:</span>
<span class="n">ConvolutionParam</span> <span class="n">param_</span><span class="p">;</span>
<span class="p">};</span>
<span class="k">class</span> <span class="nc">ConvolutionOpProperty</span> <span class="o">:</span> <span class="k">public</span> <span class="n">OperatorProperty</span> <span class="p">{</span>
<span class="nl">public:</span>
<span class="kt">void</span> <span class="n">Init</span><span class="p">(</span><span class="k">const</span> <span class="n">vector</span><span class="o">&lt;</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">string</span><span class="p">,</span> <span class="n">string</span><span class="o">&gt;&amp;</span> <span class="n">kwargs</span><span class="p">)</span> <span class="p">{</span>
<span class="c1">// initialize param_ using kwargs</span>
<span class="p">}</span>
<span class="n">Operator</span><span class="o">*</span> <span class="n">CreateOperator</span><span class="p">(</span><span class="n">Context</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
<span class="k">return</span> <span class="k">new</span> <span class="n">ConvolutionOp</span><span class="p">(</span><span class="n">param_</span><span class="p">);</span>
<span class="p">}</span>
<span class="nl">private:</span>
<span class="n">ConvolutionParam</span> <span class="n">param_</span><span class="p">;</span>
<span class="p">};</span>
</code></pre></div>
<h4 id="register-the-operator-property-class-and-the-parameter-class-to-mxnet">Register the Operator Property Class and the Parameter Class to MXNet</h4>
<p>Use the following macros to register the parameter structure and the operator property class to MXNet:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"> <span class="n">DMLC_REGISTER_PARAMETER</span><span class="p">(</span><span class="n">ConvolutionParam</span><span class="p">);</span>
<span class="n">MXNET_REGISTER_OP_PROPERTY</span><span class="p">(</span><span class="n">Convolution</span><span class="p">,</span> <span class="n">ConvolutionOpProperty</span><span class="p">);</span>
</code></pre></div>
<p>The first argument is the name string, the second is the property class name.</p>
<h3 id="interface-summary">Interface Summary</h3>
<p>We&#39;ve almost covered the entire interface required to define a new operator. Let&#39;s do a recap:</p>
<ul>
<li>Use the <code>Operator</code> interface to write your computation logic (<code>Forward</code> and <code>Backward</code>).</li>
<li>Use the <code>OperatorProperty</code> interface to:
<ul>
<li>Pass the parameter to the operator class (you can use the <code>Init</code> interface).</li>
<li>Create an operator using the <code>CreateOperator</code> interface.</li>
<li>Correctly implement the operator description interface, such as the names of arguments, etc.</li>
<li>Correctly implement the <code>InferShape</code> interface to set the output tensor shape.</li>
<li>[Optional] If additional resources are needed, check <code>ForwardResource</code> and <code>BackwardResource</code>.</li>
<li>[Optional] If <code>Backward</code> doesn&#39;t need all of the input and output of <code>Forward</code>, check <code>DeclareBackwardDependency</code>.</li>
<li>[Optional] If in-place update is supported, check <code>ForwardInplaceOption</code> and <code>BackwardInplaceOption</code>.</li>
</ul></li>
<li>Register the <code>OperatorProperty</code> class and the parameter class.</li>
</ul>
<h2 id="unifying-the-ndarray-operator-and-symbolic-operator">Unifying the NDArray Operator and Symbolic Operator</h2>
<p>NDArray operations are similar to symbolic operations,
except that sometimes you can&#39;t write in place to the operands
without a complete dependency graph.
However, the logic underlying NDArray and symbolic operations are almost identical.
<em>SimpleOp</em>, a new unified operator API,
unifies different invoking processes
and returns to the fundamental elements of operators.
Because most mathematical operators attend to one or two operands,
and more operands make dependency-related optimization useful,
the unified operator is specifically designed for unary and binary operations.</p>
<p>Consider the elements of an operation.
Ideally, you need only functions and derivatives
to describe an operation.
Let&#39;s restrict that to the space of unary and binary operations.
How do we classify all operations to maximize the possibility
of in-place write optimization?
Note that you can separate functions by the number of operands.
Derivatives are a bit more complex.
To construct a dependency graph, you need to know whether output value,
input data, or neither are needed alongside head gradient.
Gradient functions in the unified API are differentiated
by the types of operands it takes for calculation.</p>
<p>Before you learn more about the SimpleOp interface,
we recommend that you review the
<a href="https://github.com/dmlc/mshadow/tree/master/guide">mshadow library guide</a>
because calculations will be done in the <code>mshadow::TBlob</code> structure.</p>
<p>In the following example, we&#39;ll create an operator
functioning as a smooth l1 loss,
which is a mixture of l1 loss and l2 loss. The loss itself can be written as:</p>
<div class="highlight"><pre><code class="language-" data-lang=""> loss = outside_weight .* f(inside_weight .* (data - label))
grad = outside_weight .* inside_weight .* f'(inside_weight .* (data - label))
</code></pre></div>
<p><code>.*</code> stands for element-wise multiplication, and <code>f</code>, <code>f&#39;</code> is the smooth l1 loss function,
which we are assuming is in <code>mshadow</code> for now.
At first glance, it&#39;s impossible to implement
this particular loss as a unary or binary operator.
But we have automatic differentiation in symbolic execution.
That simplifies the loss to <code>f</code> and <code>f&#39;</code> directly.
This loss is no more complex than a <code>sin</code> or an <code>abs</code> function,
and can certainly be implemented as a unary operator.</p>
<h2 id="simpleop-the-unified-operator-api">SimpleOp: The Unified Operator API</h2>
<h3 id="define-shapes">Define Shapes</h3>
<p>The <code>mshadow</code> library requires explicit memory allocation.
As a consequence, all data shapes
must be provided before any calculation occurs.
Before we proceed with defining functions and gradient,
let&#39;s check input data shape consistency and provide output shape.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">typedef</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryShapeFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">);</span>
<span class="k">typedef</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span> <span class="p">(</span><span class="o">*</span><span class="n">BinaryShapeFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span><span class="o">&amp;</span> <span class="n">lhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span><span class="o">&amp;</span> <span class="n">rhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">);</span>
</code></pre></div>
<p>You can use <code>mshadow::TShape</code> to check input data shape and designate output data shape.
If you don&#39;t define this function, the default output shape is the same as the input shape.
In the case of a binary operator, the shape of <code>lhs</code> and <code>rhs</code> is checked as the same by default.</p>
<p>You can also use shape functions to check if any additional arguments and resources are present.
Refer to the additional usages of <code>EnvArguments</code> to accomplish this.</p>
<p>Before we start on our smooth l1 loss example, we define a <code>XPU</code> to <code>cpu</code> or <code>gpu</code> in the header
<code>smooth_l1_unary-inl.h</code> implementation so that we reuse the same code in <code>smooth_l1_unary.cc</code> and
<code>smooth_l1_unary.cu</code>.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="cp">#include &lt;mxnet/operator_util.h&gt;
</span> <span class="cp">#if defined(__CUDACC__)
</span> <span class="cp">#define XPU gpu
</span> <span class="cp">#else
</span> <span class="cp">#define XPU cpu
</span> <span class="cp">#endif
</span></code></pre></div>
<p>In our smooth l1 loss example, it&#39;s okay to use the default behavior whereby the output has the same shape as the source.
Written explicitly, it is:</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="kr">inline</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span> <span class="nf">SmoothL1Shape_</span><span class="p">(</span><span class="k">const</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">TShape</span><span class="p">(</span><span class="n">src</span><span class="p">);</span>
<span class="p">}</span>
</code></pre></div>
<h3 id="define-functions">Define Functions</h3>
<p>Create a unary or binary function with one output: <code>mshadow::TBlob</code>.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">ret</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">BinaryFunction</span><span class="p">)(</span><span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">lhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">rhs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">ret</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
</code></pre></div>
<ul>
<li>Functions are differentiated by the types of input arguments.</li>
<li><code>RunContext ctx</code> contains information needed during runtime for execution.</li>
</ul>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">struct</span> <span class="n">RunContext</span> <span class="p">{</span>
<span class="kt">void</span> <span class="o">*</span><span class="n">stream</span><span class="p">;</span> <span class="c1">// the stream of the device, can be NULL or Stream&lt;gpu&gt;* in GPU mode</span>
<span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">&gt;</span> <span class="kr">inline</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o">&lt;</span><span class="n">xpu</span><span class="o">&gt;*</span> <span class="n">get_stream</span><span class="p">()</span> <span class="c1">// get mshadow stream from Context</span>
<span class="p">}</span> <span class="c1">// namespace mxnet</span>
</code></pre></div>
<p><code>mshadow::stream&lt;xpu&gt; *s = ctx.get_stream&lt;xpu&gt;();</code> is an example of obtaining a stream from <code>ctx</code>.
* <code>OpReqType req</code> denotes how computation results are written into <code>ret</code>.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span>
<span class="n">kNullOp</span><span class="p">,</span> <span class="c1">// no operation, do not write anything</span>
<span class="n">kWriteTo</span><span class="p">,</span> <span class="c1">// write gradient to provided space</span>
<span class="n">kWriteInplace</span><span class="p">,</span> <span class="c1">// perform an in-place write</span>
<span class="n">kAddTo</span> <span class="c1">// add to the provided space</span>
<span class="p">};</span>
</code></pre></div>
<p>A macro is defined in <code>operator_util.h</code> for a simplified use of <code>OpReqType</code>.
<code>ASSIGN_DISPATCH(out, req, exp)</code> checks <code>req</code> and performs an assignment.</p>
<p>In our smooth l1 loss example, we use <code>UnaryFunction</code> to define the function of this operator.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">&gt;</span>
<span class="kt">void</span> <span class="nf">SmoothL1Forward_</span><span class="p">(</span><span class="k">const</span> <span class="n">TBlob</span><span class="o">&amp;</span> <span class="n">src</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span> <span class="o">*</span><span class="n">ret</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">)</span> <span class="p">{</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="p">;</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">expr</span><span class="p">;</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o">&lt;</span><span class="n">xpu</span><span class="o">&gt;</span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o">&lt;</span><span class="n">xpu</span><span class="o">&gt;</span><span class="p">();</span>
<span class="n">real_t</span> <span class="n">sigma2</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span> <span class="o">*</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span><span class="p">;</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">ret</span><span class="o">-&gt;</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span> <span class="n">out</span> <span class="o">=</span> <span class="n">ret</span><span class="o">-&gt;</span><span class="n">get</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span> <span class="n">in</span> <span class="o">=</span> <span class="n">src</span><span class="p">.</span><span class="n">get</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">ASSIGN_DISPATCH</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span>
<span class="n">F</span><span class="o">&lt;</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">smooth_l1_loss</span><span class="o">&gt;</span><span class="p">(</span><span class="n">in</span><span class="p">,</span> <span class="n">ScalarExp</span><span class="o">&lt;</span><span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">sigma2</span><span class="p">)));</span>
<span class="p">});</span>
<span class="p">}</span>
</code></pre></div>
<p>After obtaining <code>mshadow::Stream</code> from <code>RunContext</code>, we get <code>mshadow::Tensor</code> from <code>mshadow::TBlob</code>.
<code>mshadow::F</code> is a shortcut to initiate a <code>mshadow</code> expression. The macro <code>MSHADOW_TYPE_SWITCH(type, DType, ...)</code>
handles details on different types, and the macro <code>ASSIGN_DISPATCH(out, req, exp)</code> checks <code>OpReqType</code> and
performs actions accordingly. <code>sigma2</code> is a special parameter in this loss, which we will cover later.</p>
<h3 id="define-gradients-optional">Define Gradients (Optional)</h3>
<p>Create a gradient function with various types of inputs.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="c1">// depending only on out_grad</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryGradFunctionT0</span><span class="p">)(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
<span class="c1">// depending only on out_value</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryGradFunctionT1</span><span class="p">)(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OutputValue</span><span class="o">&amp;</span> <span class="n">out_value</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
<span class="c1">// depending only on in_data</span>
<span class="k">typedef</span> <span class="nf">void</span> <span class="p">(</span><span class="o">*</span><span class="n">UnaryGradFunctionT2</span><span class="p">)(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">Input0</span><span class="o">&amp;</span> <span class="n">in_data0</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">);</span>
</code></pre></div>
<p>Gradient functions of binary operators have similar structures, except that <code>Input</code>, <code>TBlob</code>, and <code>OpReqType</code>
are doubled.</p>
<p><code>GradFunctionArgument</code></p>
<p><code>Input0</code>, <code>Input</code>, <code>OutputValue</code>, and <code>OutputGrad</code> all share the structure of <code>GradFunctionArgument</code>,
which is defined as:</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">struct</span> <span class="n">GradFunctionArgument</span> <span class="p">{</span>
<span class="n">TBlob</span> <span class="n">data</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div>
<p>In our smooth l1 loss example, note that it&#39;s an <code>f&#39;(x)</code>,
which utilizes input for the gradient calculation,
so the <code>UnaryGradFunctionT2</code> is suitable.
To enable the chain rule of the gradient,
we also need to multiply <code>out_grad</code> from the top to the result of <code>in_grad</code>.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">template</span><span class="o">&lt;</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">&gt;</span>
<span class="kt">void</span> <span class="nf">SmoothL1BackwardUseIn_</span><span class="p">(</span><span class="k">const</span> <span class="n">OutputGrad</span><span class="o">&amp;</span> <span class="n">out_grad</span><span class="p">,</span>
<span class="k">const</span> <span class="n">Input0</span><span class="o">&amp;</span> <span class="n">in_data0</span><span class="p">,</span>
<span class="k">const</span> <span class="n">EnvArguments</span><span class="o">&amp;</span> <span class="n">env</span><span class="p">,</span>
<span class="n">TBlob</span> <span class="o">*</span><span class="n">in_grad</span><span class="p">,</span>
<span class="n">OpReqType</span> <span class="n">req</span><span class="p">,</span>
<span class="n">RunContext</span> <span class="n">ctx</span><span class="p">)</span> <span class="p">{</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="p">;</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">expr</span><span class="p">;</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o">&lt;</span><span class="n">xpu</span><span class="o">&gt;</span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o">&lt;</span><span class="n">xpu</span><span class="o">&gt;</span><span class="p">();</span>
<span class="n">real_t</span> <span class="n">sigma2</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span> <span class="o">*</span> <span class="n">env</span><span class="p">.</span><span class="n">scalar</span><span class="p">;</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">in_grad</span><span class="o">-&gt;</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span> <span class="n">src</span> <span class="o">=</span> <span class="n">in_data0</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">get</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span> <span class="n">ograd</span> <span class="o">=</span> <span class="n">out_grad</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">get</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Tensor</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span> <span class="n">igrad</span> <span class="o">=</span> <span class="n">in_grad</span><span class="o">-&gt;</span><span class="n">get</span><span class="o">&lt;</span><span class="n">xpu</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">s</span><span class="p">);</span>
<span class="n">ASSIGN_DISPATCH</span><span class="p">(</span><span class="n">igrad</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span>
<span class="n">ograd</span> <span class="o">*</span> <span class="n">F</span><span class="o">&lt;</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">smooth_l1_gradient</span><span class="o">&gt;</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">ScalarExp</span><span class="o">&lt;</span><span class="n">DType</span><span class="o">&gt;</span><span class="p">(</span><span class="n">sigma2</span><span class="p">)));</span>
<span class="p">});</span>
<span class="p">}</span>
</code></pre></div>
<h3 id="register-simpleop-to-mxnet">Register SimpleOp to MXNet</h3>
<p>After creating the shape, function, and gradient, restore them into both an NDArray operator and
a symbolic operator. To simplify this process, use the registration macro defined in <code>operator_util.h</code>.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="n">MXNET_REGISTER_SIMPLE_OP</span><span class="p">(</span><span class="n">Name</span><span class="p">,</span> <span class="n">DEV</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_shape_function</span><span class="p">(</span><span class="n">Shape</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_function</span><span class="p">(</span><span class="n">DEV</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">Function</span><span class="o">&lt;</span><span class="n">XPU</span><span class="o">&gt;</span><span class="p">,</span> <span class="n">SimpleOpInplaceOption</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_gradient</span><span class="p">(</span><span class="n">DEV</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">Gradient</span><span class="o">&lt;</span><span class="n">XPU</span><span class="o">&gt;</span><span class="p">,</span> <span class="n">SimpleOpInplaceOption</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"description"</span><span class="p">);</span>
</code></pre></div>
<p><code>SimpleOpInplaceOption</code> is defined as:</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">enum</span> <span class="n">SimpleOpInplaceOption</span> <span class="p">{</span>
<span class="n">kNoInplace</span><span class="p">,</span> <span class="c1">// do not allow inplace in arguments</span>
<span class="n">kInplaceInOut</span><span class="p">,</span> <span class="c1">// allow inplace in with out (unary)</span>
<span class="n">kInplaceOutIn</span><span class="p">,</span> <span class="c1">// allow inplace out_grad with in_grad (unary)</span>
<span class="n">kInplaceLhsOut</span><span class="p">,</span> <span class="c1">// allow inplace left operand with out (binary)</span>
<span class="n">kInplaceOutLhs</span> <span class="c1">// allow inplace out_grad with lhs_grad (binary)</span>
<span class="p">};</span>
</code></pre></div>
<p>In our example, we have a gradient function that relies on input data, so the function can&#39;t be written in
place. The output gradient has no purpose after gradient computation, so the gradient can be written in place.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="n">MXNET_REGISTER_SIMPLE_OP</span><span class="p">(</span><span class="n">smooth_l1</span><span class="p">,</span> <span class="n">XPU</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_function</span><span class="p">(</span><span class="n">XPU</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">SmoothL1Forward_</span><span class="o">&lt;</span><span class="n">XPU</span><span class="o">&gt;</span><span class="p">,</span> <span class="n">kNoInplace</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_gradient</span><span class="p">(</span><span class="n">XPU</span><span class="o">::</span><span class="n">kDevMask</span><span class="p">,</span> <span class="n">SmoothL1BackwardUseIn_</span><span class="o">&lt;</span><span class="n">XPU</span><span class="o">&gt;</span><span class="p">,</span> <span class="n">kInplaceOutIn</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_enable_scalar</span><span class="p">(</span><span class="nb">true</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Calculate Smooth L1 Loss(lhs, scalar)"</span><span class="p">);</span>
</code></pre></div>
<p>Remember from the discussion of shape functions that a default behavior without <code>set_shape_function</code> forces the inputs
(if they&#39;re binary) to be the same shape and yield the same shape for output. We&#39;ll discuss <code>set_enable_scalar</code> later.</p>
<h3 id="ndarray-operator-summary">NDArray Operator Summary</h3>
<ul>
<li>Create a shape function for determining the output shape.</li>
<li>Create a function as the forward routine by choosing a suitable function type.</li>
<li>Create a gradient as the backward routine by choosing a suitable gradient type.</li>
<li>Register the operator using the registration process.</li>
</ul>
<h2 id="additional-information-on-simpleop">Additional Information on SimpleOp</h2>
<h3 id="using-simpleop-on-envarguments">Using SimpleOp on EnvArguments</h3>
<p>Some operations might need a scalar as input, such as a gradient scale, a set of keyword arguments
controlling behavior, or a temporary space to speed up calculations.<code>EnvArguments</code> provides additional arguments and resources to make calculations more scalable
and efficient.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">struct</span> <span class="n">EnvArguments</span> <span class="p">{</span>
<span class="n">real_t</span> <span class="n">scalar</span><span class="p">;</span> <span class="c1">// scalar argument, if enabled</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">,</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&gt;</span> <span class="o">&gt;</span> <span class="n">kwargs</span><span class="p">;</span> <span class="c1">// keyword arguments</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">Resource</span><span class="o">&gt;</span> <span class="n">resource</span><span class="p">;</span> <span class="c1">// pointer to the resources requested</span>
<span class="p">};</span>
</code></pre></div>
<p>More registration parameters are required to enable these additional features. To prevent confusion on parameters, <code>scalar</code> and <code>kwargs</code>
can&#39;t be present at the same time. To enable <code>scalar</code>, use
<code>set_enable_scalar(bool enable_scalar)</code> in registration. Then, in forward functions and gradients, the <code>scalar</code> can be accessed from <code>env.scalar</code> as in the function parameter <code>EnvArguments env</code>.</p>
<p>To enable <code>kwargs</code>, use <code>set_enable_kwargs(bool enable_kwargs)</code> in registration. Then, in forward
functions and gradients, additional arguments are contained in <code>env.kwarg</code>, which is defined as
<code>std::vector&lt;std::pair&lt;std::string, std::string&gt; &gt;</code>. Use the DMLC parameter structure to
simplify parsing keyword arguments. For more details, see the <a href="https://github.com/dmlc/dmlc-core/blob/master/doc/parameter.md">guide on parameter structure</a>.</p>
<p>Additional resources like <code>mshadow::Random&lt;xpu&gt;</code> and temporary memory space can also be requested and
accessed from <code>EnvArguments.resource</code>. The registration routine is <code>set_resource_request(ResourceRequest req)</code>
or <code>set_resource_request(const std::vector&lt;ResourceRequest&gt;)</code>, where <code>mxnet::ResourceRequest</code> is defined as:</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">struct</span> <span class="n">ResourceRequest</span> <span class="p">{</span>
<span class="k">enum</span> <span class="n">Type</span> <span class="p">{</span> <span class="c1">// Resource type, indicating what the pointer type is</span>
<span class="n">kRandom</span><span class="p">,</span> <span class="c1">// mshadow::Random&lt;xpu&gt; object</span>
<span class="n">kTempSpace</span> <span class="c1">// A dynamic temp space that can be arbitrary size</span>
<span class="p">};</span>
<span class="n">Type</span> <span class="n">type</span><span class="p">;</span> <span class="c1">// type of resources</span>
<span class="p">};</span>
</code></pre></div>
<p>Registration will request the declared resource requests from <code>mxnet::ResourceManager</code>, and place resources
in <code>std::vector&lt;Resource&gt; resource</code> in <code>EnvArguments</code>. To access resources, use the following:</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">auto</span> <span class="n">tmp_space_res</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">resources</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">get_space</span><span class="p">(</span><span class="n">some_shape</span><span class="p">,</span> <span class="n">some_stream</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">rand_res</span> <span class="o">=</span> <span class="n">env</span><span class="p">.</span><span class="n">resources</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">get_random</span><span class="p">(</span><span class="n">some_stream</span><span class="p">);</span>
</code></pre></div>
<p>For an example, see <code>src/operator/loss_binary_op-inl.h</code>.</p>
<p>In our smooth l1 loss example, a scalar input is needed to mark the turning point of a loss function. Therefore,
in the registration process, we use <code>set_enable_scalar(true)</code>, and use <code>env.scalar</code> in function and gradient
declarations.</p>
<h3 id="crafting-a-tensor-operation">Crafting a Tensor Operation</h3>
<p>Because computation utilizes the <code>mshadow</code> library and we sometimes don&#39;t have functions readily available, we
can craft tensor operations in operator implementations. If you define such functions as element-wise, you
can implement them as a <code>mxnet::op::mshadow_op</code>. <code>src/operator/mshadow_op.h</code> that contains a lot of <code>mshadow_op</code>,
for example. <code>mshadow_op</code> are expression mappers. They deal with the scalar case of desired functions. For details, see
<a href="https://github.com/dmlc/mshadow/tree/master/doc">mshadow expression API guide</a>.</p>
<p>If an operation can&#39;t be done in an element-wise way, like the softmax loss and gradient, then you need to create a new tensor operation. You need to create as <code>mshadow</code> function and as <code>mshadow::cuda</code>
function directly. For details, see the <code>mshadow</code> library. For an example, see <code>src/operator/roi_pooling.cc</code>.</p>
<p>In our smooth l1 loss example, we create two mappers, namely the scalar cases of smooth l1 loss and gradient.</p>
<div class="highlight"><pre><code class="language-cpp" data-lang="cpp"> <span class="k">namespace</span> <span class="n">mshadow_op</span> <span class="p">{</span>
<span class="k">struct</span> <span class="n">smooth_l1_loss</span> <span class="p">{</span>
<span class="c1">// a is x, b is sigma2</span>
<span class="n">MSHADOW_XINLINE</span> <span class="k">static</span> <span class="n">real_t</span> <span class="n">Map</span><span class="p">(</span><span class="n">real_t</span> <span class="n">a</span><span class="p">,</span> <span class="n">real_t</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="n">a</span> <span class="o">&gt;</span> <span class="mf">1.0</span><span class="n">f</span> <span class="o">/</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="n">a</span> <span class="o">-</span> <span class="mf">0.5</span><span class="n">f</span> <span class="o">/</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="n">a</span> <span class="o">&lt;</span> <span class="o">-</span><span class="mf">1.0</span><span class="n">f</span> <span class="o">/</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="o">-</span><span class="n">a</span> <span class="o">-</span> <span class="mf">0.5</span><span class="n">f</span> <span class="o">/</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
<span class="k">return</span> <span class="mf">0.5</span><span class="n">f</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span>
<span class="p">}</span>
<span class="p">};</span>
<span class="p">}</span>
</code></pre></div>
<p>The gradient, which can be found in <code>src/operator/smooth_l1_unary-inl.h</code>, is similar.</p>
<h3 id="beyond-two-operands">Beyond Two Operands</h3>
<p>The new unified API is designed to fulfill the fundamentals of an operation. For operators with more than two inputs,
more than one output, or that need more features, see the original <a href="overview#operators-in-mxnet">Operator API</a>.</p>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/1.9.1/community/contribute#mxnet-dev-communications">Mailing lists</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
<li><a href="https://github.com/apache/mxnet/labels/Roadmap">Github Roadmap</a></li>
<li><a href="https://medium.com/apache-mxnet">Blog</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/1.9.1/community/contribute">Contribute</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/1.9.1/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>