blob: 847907aa10b634948cd1aea5b8e030123091db84 [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/1.9.1/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Create New Operators | Apache MXNet</title>
<meta name="generator" content="Jekyll v3.8.6" />
<meta property="og:title" content="Create New Operators" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/1.9.1/api/faq/new_op" />
<meta property="og:url" content="https://mxnet.apache.org/versions/1.9.1/api/faq/new_op" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"description":"A flexible and efficient library for deep learning.","headline":"Create New Operators","@type":"WebPage","url":"https://mxnet.apache.org/versions/1.9.1/api/faq/new_op","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/1.9.1/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/1.9.1/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.1/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/1.9.1/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/1.9.1/assets/js/docsearch.min.js"></script><script src="/versions/1.9.1/assets/js/globalSearch.js" defer></script>
<script src="/versions/1.9.1/assets/js/clipboard.js" defer></script>
<script src="/versions/1.9.1/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/1.9.1/"><img
src="/versions/1.9.1/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">1.9.1</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/1.9.1/get_started">Get Started</a>
<a class="page-link" href="/versions/1.9.1/features">Features</a>
<a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/1.9.1/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/1.9.1/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">1.9.1
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a href="/">master</a>
<a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Create New Operators</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<h3 style="text-transform: capitalize; padding-left:10px">faq</h3>
<ul>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/add_op_in_backend">A Beginner's Guide to Implementing Operators in MXNet Backend</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/caffe">Convert from Caffe to MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/cloud">MXNet on the Cloud</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/distributed_training">Distributed Training in MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/env_var">Environment Variables</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/float16">Float16</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/large_tensor_support">Using MXNet with Large Tensor Support</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/model_parallel_lstm">Model Parallel</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/multi_device">Data Parallelism with Multiple CPU/GPUs on MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/new_op">Create New Operators</a></li>
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/nnpack">NNPACK for Multi-Core CPU Support in MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/perf">Some Tips for Improving MXNet Performance</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/recordio">Create a Dataset Using RecordIO</a></li>
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/s3_integration">Use data from S3 for training</a></li>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/security">MXNet Security Best Practices</a></li>
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/smart_device">Deep Learning at the Edge</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/visualize_graph">Visualize Neural Networks</a></li>
<!-- page-category -->
<li><a href="/versions/1.9.1/api/faq/why_mxnet">Why MXNet came to be?</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- resource-p -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="how-to-create-new-operators-layers">How to Create New Operators (Layers)</h1>
<p>This tutorials walks you through the process of creating new MXNet operators (or layers).
We&#39;ve done our best to provide high-speed operators for most common use cases.
However, if you&#39;re engaged in research,
there&#39;s a good chance you&#39;ll want to define custom layers,
like a novel loss function. In these cases, you have two options:</p>
<ul>
<li><p>Use CustomOp to write new operators using a front-end language (e.g., Python) that run on CPUs or GPUs.
Depending on your implementation, this can range from very fast (if you only use operators under mx.nd) to very slow (if you copy out the data, using <code>.asnumpy()</code>).</p></li>
<li><p>Use C++/mshadow (CUDA). This provides the best performance, but can be difficult
if you&#39;re not familiar with MXNet, mshadow, or Cuda.</p></li>
</ul>
<h2 id="customop">CustomOp</h2>
<p>Implementing an operator in Python is simple.
As an example, let&#39;s create a softmax operator.
Start by subclassing <code>mxnet.operator.CustomOp</code>,
and then override a few methods:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="n">mx</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="k">class</span> <span class="nc">Softmax</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">CustomOp</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">is_train</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">aux</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="o">.</span><span class="nb">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)))</span>
<span class="n">y</span> <span class="o">/=</span> <span class="n">y</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>
</code></pre></div>
<p>We defined the computation for the forward pass of our operator.
The forward function takes a list of input and a list of output NDArrays.
For convenience, we called <code>.asnumpy()</code> on the first NDArray in input
and convert it to a CPU-based NumPy array.
This can be very slow. If you want the best performance,
keep data in the NDArray format and use operators under mx.nd to do the computation.</p>
<p>At the end, we used CustomOp.assign to assign the resulting array y to out_data[0]. It handles assignment based on the value of req, which can be &#39;write&#39;, &#39;add&#39;, or &#39;null&#39;.</p>
<p>Then do the same for the backward pass:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">in_grad</span><span class="p">,</span> <span class="n">aux</span><span class="p">):</span>
<span class="n">l</span> <span class="o">=</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="nb">int</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">y</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">l</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">l</span><span class="p">]</span> <span class="o">-=</span> <span class="mf">1.0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">in_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>
</code></pre></div>
<p>Softmax defines the computation of our custom operator,
but you still need to define its input/output format
by subclassing mx.operator.CustomOpProp.
First, register the new operator with the name &#39;softmax&#39;:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">register</span><span class="p">(</span><span class="s">"softmax"</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">SoftmaxProp</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">CustomOpProp</span><span class="p">):</span>
</code></pre></div>
<p>Then, call the base constructor with <code>need_top_grad=False</code>
because softmax is a loss layer and you don&#39;t need gradient input from preceding layers:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">SoftmaxProp</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">(</span><span class="n">need_top_grad</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div>
<p>Then declare the input and output:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">list_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[</span><span class="s">'data'</span><span class="p">,</span> <span class="s">'label'</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">list_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[</span><span class="s">'output'</span><span class="p">]</span>
</code></pre></div>
<p>Note that list_arguments declares both input and parameter.
We recommend ordering them as follows: <code>[&#39;input1&#39;, &#39;input2&#39;, ... , &#39;weight1&#39;, &#39;weight2&#39;, ...]</code></p>
<p>Next, provide <code>infer_shape</code> to declare the shape of the output/weight
and check the consistency of the input shapes:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">infer_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_shape</span><span class="p">):</span>
<span class="n">data_shape</span> <span class="o">=</span> <span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">label_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">],)</span>
<span class="n">output_shape</span> <span class="o">=</span> <span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">data_shape</span><span class="p">,</span> <span class="n">label_shape</span><span class="p">],</span> <span class="p">[</span><span class="n">output_shape</span><span class="p">],</span> <span class="p">[]</span>
</code></pre></div>
<p>The first axis of an input/output tensor corresponds to different examples within the batch.
The label is a set of integers, one for each data entry,
and the output has the same shape as the input.
The <code>infer_shape</code> function should always return three lists in this order:
inputs, outputs, and auxiliary states (which we don&#39;t have here),
even if one of them is empty.</p>
<p>Optionally, you can also define <code>infer_type</code> to declare the input and output data type of your operator. Supported types are <code>np.float32</code>, <code>np.float64</code>, <code>np.float16</code>, <code>np.uint8</code>, and <code>np.int32</code>.</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">infer_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_type</span><span class="p">):</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">in_type</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">dtype</span><span class="p">,</span> <span class="n">dtype</span><span class="p">],</span> <span class="p">[</span><span class="n">dtype</span><span class="p">],</span> <span class="p">[]</span>
</code></pre></div>
<p>Finally, define a create_operator function that will be called by the back end to create an instance of softmax:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">create_operator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">dtypes</span><span class="p">):</span>
<span class="k">return</span> <span class="n">Softmax</span><span class="p">()</span>
</code></pre></div>
<p>To use the custom operator, create a mx.sym.Custom symbol with op_type as the registered name:</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">mlp</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Custom</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'softmax'</span><span class="p">,</span> <span class="n">op_type</span><span class="o">=</span><span class="s">'softmax'</span><span class="p">)</span>
</code></pre></div>
<p>Please see the full code for this example <a href="https://github.com/apache/mxnet/blob/v1.x/example/numpy-ops/custom_softmax.py">here</a>.</p>
<h2 id="c">C++</h2>
<p>With MXNet v0.9 (the NNVM refactor) or later, creating new operators has become easier.
Operators are now registered with NNVM.
The following code is an example on how to register an operator (checkout <a href="https://github.com/apache/mxnet/tree/v1.x/src/operator/tensor">src/operator/tensor</a> for more examples):</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">abs</span><span class="p">)</span>
<span class="p">.</span><span class="n">MXNET_DESCRIBE</span><span class="p">(</span><span class="s">"Take absolute value of the src"</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o">&lt;</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="o">&gt;</span><span class="p">);</span>
</code></pre></div>
<p>The syntax is quite simple, we register the operator with a name,
then set number of inputs and outputs.
You can register attributes with any key (<code>FInferShape</code> for example) to any operator,
without having to modify a central class interface definition.</p>
<h3 id="operator-attribute-system">Operator Attribute System</h3>
<p>One of the biggest improvements brought by NNVM is the operator attribute system.
This is like traits for types in common languages like C++.
We can register any attribute to any operator, with the syntax</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">op</span><span class="o">-</span><span class="n">name</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">AttributeType</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"AttributeKey"</span><span class="p">,</span> <span class="n">CorrespondingAttributeObject</span><span class="p">);</span>
</code></pre></div>
<p>These attributes can be retrieved later for various purposes.
For example, <code>FInferShape</code> is used for shape inference, <code>FCompute&lt;cpu&gt;</code> is used for carrying out actual computation on CPU.</p>
<p>As long as all attributes registered with the same key have the same type,
we can register any attributes to operators.
The more attribute an operator provides,
the more information the system can use for optimization.</p>
<h3 id="list-of-basic-attributes">List of basic attributes</h3>
<p>In this section, we will go through the basic attributes MXNet expect for all operators.
You can find the definition for them in the following two files:</p>
<ul>
<li><a href="https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h">nnvm/op_attr_types.h</a></li>
<li><a href="https://github.com/apache/mxnet/blob/v1.x/include/mxnet/op_attr_types.h">mxnet/op_attr_types.h</a></li>
</ul>
<h4 id="descriptions-optional">Descriptions (Optional)</h4>
<p><code>.describe(comment)</code> adds a comment to the operator. Use <code>.MXNET_DESCRIBE(comment)</code> to add the current file name and line number to comment.</p>
<h4 id="attribute-parser-optional">Attribute Parser (Optional)</h4>
<p>Set attribute parser with <code>.set_attr_parser(PARSER)</code> where PARSER is a function with prototype <code>void(nnvm::NodeAttr* attrs)</code>. This function should parse the key-word arguments in <code>attrs-&gt;dict</code> and store the result in <code>attrs-&gt;parsed</code>.</p>
<p>Simple arguments can be parsed like
<code>c++
NNVM_REGISTER_OP(scalar_op)
.set_attr_parser(
[](NodeAttrs* attrs) {
attrs-&gt;parsed = dmlc::stod(attrs-&gt;dict[&quot;scalar&quot;]);
})
</code></p>
<p>The parsed arguments can then be accessed in other attribute functions with
<code>c++
double alpha = nnvm::get&lt;double&gt;(attrs.parsed);
</code></p>
<p>More complex ops can use <code>dmlc::Parameters</code> and <code>ParamParser</code> (defined in operator_common.h) for parsing:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="cp">#include &lt;dmlc/parameter.h&gt;
#include &lt;operator_common.h&gt;
</span><span class="k">struct</span> <span class="n">ActivationParam</span> <span class="o">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o">&lt;</span><span class="n">ActivationParam</span><span class="o">&gt;</span> <span class="p">{</span>
<span class="c1">// use int for enumeration</span>
<span class="kt">int</span> <span class="n">act_type</span><span class="p">;</span>
<span class="n">DMLC_DECLARE_PARAMETER</span><span class="p">(</span><span class="n">ActivationParam</span><span class="p">)</span> <span class="p">{</span>
<span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">act_type</span><span class="p">)</span>
<span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"relu"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kReLU</span><span class="p">)</span>
<span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"sigmoid"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kSigmoid</span><span class="p">)</span>
<span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"tanh"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kTanh</span><span class="p">)</span>
<span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"softrelu"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kSoftReLU</span><span class="p">)</span>
<span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Activation function to be applied."</span><span class="p">);</span>
<span class="p">}</span>
<span class="p">};</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">Activation</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o">&lt;</span><span class="n">ActivationParam</span><span class="o">&gt;</span><span class="p">);</span>
<span class="c1">// access with:</span>
<span class="c1">// const ActivationParam&amp; param = nnvm::get&lt;ActivationParam&gt;(attrs.parsed);</span>
</code></pre></div>
<h4 id="inputs-outputs">Inputs &amp; Outputs</h4>
<p>Number of inputs/outputs can be set with <code>.set_num_inputs(n_in)</code> and <code>.set_num_outputs(n_out)</code>
where n_in and n_out are integers.</p>
<p>Alternatively, if the number of inputs/outputs is variable and depends on arguments,
you can set <code>n_in</code>/<code>n_out</code> to functions with prototype <code>uint32_t(const nnvm::NodeAttrs&amp; attrs)</code>
that return the number of inputs/outputs based on parsed arguments.</p>
<p>Outputs can be made invisible to other operators by registering <code>FNumVisibleOutputs</code>
and returning an integer smaller than <code>n_out</code>.</p>
<p>Inputs/outputs can be named by registering <code>FListInputNames</code> and <code>FListOutputNames</code> with prototype <code>std::vector&lt;std::string&gt;(const NodeAttrs&amp; attrs)</code>.</p>
<h4 id="argument-descriptions">Argument Descriptions</h4>
<p>Set argument descriptions with <code>.add_argument(name, type, comment)</code>.
This is necessary for operators to be properly called imperatively.</p>
<p>First, add NDArray arguments <code>num_inputs</code> times with type &quot;NDArray&quot;
or one time with type &quot;NDArray[]&quot; for ops with variable length inputs.</p>
<p>Then add key-word arguments with proper type (float, string, etc).
Operators that parse key-word arguments with <code>dmlc::Parameter</code>
can add argument descriptions in bulk with <code>.add_arguments(ActivationParam::__FIELDS__())</code>
(NDArray arguments still need to be manually added with type &quot;NDArray&quot;).</p>
<h4 id="finfershape-or-tisbackward-for-backward-only-ops">FInferShape or TIsBackward (for Backward Only Ops)</h4>
<p>Normally operators need to have <code>FInferShape</code> with prototype <code>bool(const nnvm::NodeAttrs&amp; attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs)</code>. <code>FInferShape</code> fills unknown shapes (<code>shape.ndim() == 0</code>) in in_attrs/out_attrs based on known shapes in in_attrs/out_attrs. Use <code>ElemwiseShape&lt;n_in, n_out&gt;</code> for simple operators with uniform shapes.</p>
<p>Operators that are only used for a backward pass can instead register <code>.set_attr&lt;nnvm::TIsBackward&gt;(&quot;TIsBackward&quot;, true)</code>
and their shapes with be copied from the corresponding forward operators.</p>
<h4 id="finfertype">FInferType</h4>
<p>Similar to <code>FInferShape</code>, <code>FInferType</code> fills unknown types (-1) based on known types. Use <code>ElemwiseType&lt;n_in, n_out&gt;</code> for simple operators with uniform types. Operators that registered <code>TIsBackward</code> don&#39;t need to register this.</p>
<h4 id="finplaceoption-optional">FInplaceOption (Optional)</h4>
<p><code>FInplaceOption</code> with prototype <code>std::vector&lt;std::pair&lt;int, int&gt; &gt;(const NodeAttrs&amp; attrs)</code>
specifies which input/output pairs can be computed in-place
and share memory with each other.
Each pair (i, j) in the returned list means
that the i-th input can share memory with the j-th output.</p>
<h4 id="fgradient-optional-for-imperative-use-required-for-symbolic-use">FGradient (Optional for imperative use, required for symbolic use)</h4>
<p>If an operator has gradient, it can be described with <code>FGradient</code> with prototype</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">NodeEntry</span><span class="o">&gt;</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">ObjectPtr</span><span class="o">&amp;</span> <span class="n">n</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">NodeEntry</span><span class="o">&gt;&amp;</span> <span class="n">ograds</span><span class="p">)</span>
</code></pre></div>
<p>Use utility functions <code>ElemwiseGradUseIn{op_name}</code>, <code>ElemwiseGradUseOut{op_name}</code>, <code>ElemwiseGradUseNone{op_name}</code> for ops that need corresponding forward op&#39;s input,
output or nothing to calculating gradient.</p>
<p>For more complicated patterns, use <code>MakeGradNode(op_name, n, heads, dict)</code> to create gradient entries,
where heads are input entries to the backward op, composed from ograds and n-&gt;inputs.</p>
<p>When assembling a return vector of <code>std::vector&lt;nnvm::NodeEntry&gt; ret;</code> a common pattern would be to
either create nodes in place as in:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">ret</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="n">MakeNode</span><span class="p">(</span><span class="s">"zeros_like"</span><span class="p">,</span> <span class="n">n</span><span class="o">-&gt;</span><span class="n">attrs</span><span class="p">.</span><span class="n">name</span> <span class="o">+</span> <span class="s">"_xyz_backward"</span><span class="p">,</span>
<span class="p">{</span><span class="n">n</span><span class="o">-&gt;</span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]},</span> <span class="nb">nullptr</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">n</span><span class="p">))</span>
</code></pre></div>
<p>Or create the node, modify and then move into NodeEntry&#39;s constructor if this node is not to be used
again. This avoids uneccessary copies of the shared_ptr.</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n</span><span class="o">-&gt;</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
<span class="n">nnvm</span><span class="o">::</span><span class="n">ObjectPtr</span> <span class="n">node</span> <span class="o">=</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">Node</span><span class="o">::</span><span class="n">Create</span><span class="p">();</span>
<span class="n">node</span><span class="o">-&gt;</span><span class="n">attrs</span><span class="p">.</span><span class="n">op</span> <span class="o">=</span> <span class="n">copy_op</span><span class="p">;</span>
<span class="n">node</span><span class="o">-&gt;</span><span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span><span class="n">ograds</span><span class="p">[</span><span class="mi">0</span><span class="p">]};</span>
<span class="n">ret</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">node</span><span class="p">));</span>
<span class="p">}</span>
</code></pre></div>
<p>The first case uses RVO and the second in place construction.</p>
<h4 id="fcomputexpu">FCompute&lt;xpu&gt;</h4>
<p>Simple operators can register FCompute<xpu> with <code>.set_attr&lt;FCompute&gt;(&quot;FCompute&lt;cpu&gt;&quot;, ...)</code> and <code>.set_attr&lt;FCompute&gt;(&quot;FCompute&lt;gpu&gt;&quot;, ...)</code> for both CPU and (optionally) GPU computation.</p>
<p>FCompute has prototype</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">inputs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;&amp;</span> <span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">outputs</span><span class="p">)</span>
</code></pre></div>
<p><code>req</code> has the same length as <code>outputs</code>.
Each entry of <code>req</code> specifies
how the corresponding <code>output</code> should be written to.
<code>OpReqType</code> is defined as:</p>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span>
<span class="n">kNullOp</span><span class="p">,</span>
<span class="n">kWriteTo</span><span class="p">,</span>
<span class="n">kWriteInplace</span><span class="p">,</span>
<span class="n">kAddTo</span>
<span class="p">};</span>
</code></pre></div>
<p>Normally, the <code>req</code> of all <code>outputs</code> should be <code>kWriteTo</code>,
meaning that the provided <code>outputs</code> tensor is a <em>raw</em> memory block,
so the operator should write results directly into it.
In some cases, for example, when calculating the gradient tensor,
it would be great if we could accumulate the result,
rather than directly overwrite the tensor contents
so that no extra space needs to be created each time.
In such cases, the corresponding <code>req</code> is set to <code>kAddTo</code>,
indicating that a <code>+=</code> should be used.</p>
<h3 id="example-abs-operator">Example: abs operator</h3>
<div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">abs</span><span class="p">)</span>
<span class="p">.</span><span class="n">MXNET_DESCRIBE</span><span class="p">(</span><span class="s">"Take absolute value of the src"</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o">&lt;</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">&gt;</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">ElemwiseType</span><span class="o">&lt;</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">&gt;</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span>
<span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">){</span>
<span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">&gt;</span> <span class="o">&gt;</span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span>
<span class="p">})</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">FCompute</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FCompute&lt;cpu&gt;"</span><span class="p">,</span> <span class="n">UnaryCompute</span><span class="o">&lt;</span><span class="n">cpu</span><span class="p">,</span> <span class="n">mshadow_op</span><span class="o">::</span><span class="n">abs</span><span class="o">&gt;</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FGradient</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FGradient"</span><span class="p">,</span> <span class="n">ElemwiseGradUseIn</span><span class="p">{</span><span class="s">"_backward_abs"</span><span class="p">});</span>
<span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"data"</span><span class="p">,</span> <span class="s">"NDArray"</span><span class="p">,</span> <span class="s">"Source input"</span><span class="p">)</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_abs</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="o">&gt;</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">ElemwiseType</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="o">&gt;</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span>
<span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">){</span>
<span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o">&lt;</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">&gt;</span> <span class="o">&gt;</span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">},</span> <span class="p">{</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span>
<span class="p">})</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">FCompute</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FCompute&lt;cpu&gt;"</span><span class="p">,</span> <span class="n">BinaryCompute</span><span class="o">&lt;</span><span class="n">cpu</span><span class="p">,</span> <span class="n">backward_grad</span><span class="o">&lt;</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">sign</span><span class="o">&gt;</span> <span class="o">&gt;</span><span class="p">);</span>
</code></pre></div>
<h3 id="legacy-operators">Legacy Operators</h3>
<p>For the legacy (pre 0.9) way of defining operators with C++, please see:
- <a href="/versions/1.9.1/api/architecture/overview.html#operators-in-mxnet">Developer Guide - Operators</a>
- <a href="/versions/1.9.1/api/architecture/overview.html#simpleop-the-unified-operator-api">Developer Guide - SimpleOp</a></p>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/1.9.1/community/contribute#mxnet-dev-communications">Mailing lists</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
<li><a href="https://github.com/apache/mxnet/labels/Roadmap">Github Roadmap</a></li>
<li><a href="https://medium.com/apache-mxnet">Blog</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/1.9.1/community/contribute">Contribute</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/1.9.1/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>