| <!DOCTYPE html> |
| |
| <!--- |
| Licensed to the Apache Software Foundation (ASF) under one |
| or more contributor license agreements. See the NOTICE file |
| distributed with this work for additional information |
| regarding copyright ownership. The ASF licenses this file |
| to you under the Apache License, Version 2.0 (the |
| "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at |
| http://www.apache.org/licenses/LICENSE-2.0 |
| Unless required by applicable law or agreed to in writing, |
| software distributed under the License is distributed on an |
| "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| KIND, either express or implied. See the License for the |
| specific language governing permissions and limitations |
| under the License. |
| --> |
| |
| <html lang=" en"><head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge"> |
| <meta name="viewport" content="width=device-width, initial-scale=1"> |
| <link href="/versions/1.9.0/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 --> |
| <title>Create New Operators | Apache MXNet</title> |
| <meta name="generator" content="Jekyll v3.8.6" /> |
| <meta property="og:title" content="Create New Operators" /> |
| <meta property="og:locale" content="en_US" /> |
| <meta name="description" content="A flexible and efficient library for deep learning." /> |
| <meta property="og:description" content="A flexible and efficient library for deep learning." /> |
| <link rel="canonical" href="https://mxnet.apache.org/versions/1.9.0/api/faq/new_op" /> |
| <meta property="og:url" content="https://mxnet.apache.org/versions/1.9.0/api/faq/new_op" /> |
| <meta property="og:site_name" content="Apache MXNet" /> |
| <script type="application/ld+json"> |
| {"headline":"Create New Operators","@type":"WebPage","description":"A flexible and efficient library for deep learning.","url":"https://mxnet.apache.org/versions/1.9.0/api/faq/new_op","@context":"https://schema.org"}</script> |
| <!-- End Jekyll SEO tag --> |
| <link rel="stylesheet" href="/versions/1.9.0/assets/docsearch.min.css" /> |
| <link rel="stylesheet" href="/versions/1.9.0/assets/retainable.css" /><link rel="stylesheet" href="/versions/1.9.0/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.0/feed.xml" title="Apache MXNet" /><script> |
| if(!(window.doNotTrack === "1" || navigator.doNotTrack === "1" || navigator.doNotTrack === "yes" || navigator.msDoNotTrack === "1")) { |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| } |
| </script> |
| |
| <script src="/versions/1.9.0/assets/js/jquery-3.3.1.min.js"></script> |
| <script src="/versions/1.9.0/assets/js/docsearch.min.js"></script><script src="/versions/1.9.0/assets/js/globalSearch.js" defer></script> |
| <script src="/versions/1.9.0/assets/js/clipboard.js" defer></script> |
| <script src="/versions/1.9.0/assets/js/copycode.js" defer></script></head> |
| <body><header class="site-header" role="banner"> |
| |
| <script> |
| $(document).ready(function () { |
| |
| // HEADER OPACITY LOGIC |
| |
| function opacity_header() { |
| var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")" |
| $('.site-header').css("background-color", value) |
| } |
| |
| $(window).scroll(function () { |
| opacity_header() |
| }) |
| opacity_header(); |
| |
| // MENU SELECTOR LOGIC |
| $('.page-link').each( function () { |
| if (window.location.href.includes(this.href)) { |
| $(this).addClass("page-current"); |
| } |
| }); |
| }) |
| </script> |
| <div class="wrapper"> |
| <a class="site-title" rel="author" href="/versions/1.9.0/"><img |
| src="/versions/1.9.0/assets/img/mxnet_logo.png" class="site-header-logo"></a> |
| <nav class="site-nav"> |
| <input type="checkbox" id="nav-trigger" class="nav-trigger"/> |
| <label for="nav-trigger"> |
| <span class="menu-icon"> |
| <svg viewBox="0 0 18 15" width="18px" height="15px"> |
| <path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/> |
| </svg> |
| </span> |
| </label> |
| <div class="gs-search-border"> |
| <div id="gs-search-icon"></div> |
| <form id="global-search-form"> |
| <input id="global-search" type="text" title="Search" placeholder="Search" /> |
| <div id="global-search-dropdown-container"> |
| <button class="gs-current-version btn" type="button" data-toggle="dropdown"> |
| <span id="gs-current-version-label">1.9.0</span> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown"> |
| |
| |
| <li class="gs-opt gs-versions">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions active">1.9.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| <span id="global-search-close">x</span> |
| </form> |
| </div> |
| <div class="trigger"> |
| <div id="global-search-mobile-border"> |
| <div id="gs-search-icon-mobile"></div> |
| <input id="global-search-mobile" placeholder="Search..." type="text"/> |
| <div id="global-search-dropdown-container-mobile"> |
| <button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown"> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown-mobile"> |
| |
| |
| <li class="gs-opt gs-versions">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions active">1.9.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| </div> |
| <a class="page-link" href="/versions/1.9.0/get_started">Get Started</a> |
| <a class="page-link" href="/versions/1.9.0/blog">Blog</a> |
| <a class="page-link" href="/versions/1.9.0/features">Features</a> |
| <a class="page-link" href="/versions/1.9.0/ecosystem">Ecosystem</a> |
| <a class="page-link" href="/versions/1.9.0/api">Docs & Tutorials</a> |
| <a class="page-link" href="/versions/1.9.0/trusted_by">Trusted By</a> |
| <a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a> |
| <div class="dropdown"> |
| <span class="dropdown-header">1.9.0 |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content"> |
| <a href="/">master</a> |
| <a class="dropdown-option-active" href="/versions/1.9.0/">1.9.0</a> |
| <a href="/versions/1.8.0/">1.8.0</a> |
| <a href="/versions/1.7.0/">1.7.0</a> |
| <a href="/versions/1.6.0/">1.6.0</a> |
| <a href="/versions/1.5.0/">1.5.0</a> |
| <a href="/versions/1.4.1/">1.4.1</a> |
| <a href="/versions/1.3.1/">1.3.1</a> |
| <a href="/versions/1.2.1/">1.2.1</a> |
| <a href="/versions/1.1.0/">1.1.0</a> |
| <a href="/versions/1.0.0/">1.0.0</a> |
| <a href="/versions/0.12.1/">0.12.1</a> |
| <a href="/versions/0.11.0/">0.11.0</a> |
| </div> |
| </div> |
| </div> |
| </nav> |
| </div> |
| </header> |
| <main class="page-content" aria-label="Content"> |
| <script> |
| |
| </script> |
| <article class="post"> |
| |
| <header class="post-header wrapper"> |
| <h1 class="post-title">Create New Operators</h1> |
| <h3></h3></header> |
| |
| <div class="post-content"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3 docs-side-bar"> |
| <h3 style="text-transform: capitalize; padding-left:10px">faq</h3> |
| <ul> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/add_op_in_backend">A Beginner's Guide to Implementing Operators in MXNet Backend</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/caffe">Convert from Caffe to MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/cloud">MXNet on the Cloud</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/distributed_training">Distributed Training in MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/env_var">Environment Variables</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/float16">Float16</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/large_tensor_support">Using MXNet with Large Tensor Support</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/model_parallel_lstm">Model Parallel</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/multi_device">Data Parallelism with Multiple CPU/GPUs on MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/new_op">Create New Operators</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/nnpack">NNPACK for Multi-Core CPU Support in MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/perf">Some Tips for Improving MXNet Performance</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/recordio">Create a Dataset Using RecordIO</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/s3_integration">Use data from S3 for training</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/security">MXNet Security Best Practices</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/smart_device">Deep Learning at the Edge</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/visualize_graph">Visualize Neural Networks</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.0/api/faq/why_mxnet">Why MXNet came to be?</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| </ul> |
| </div> |
| <div class="col-9"> |
| <!--- Licensed to the Apache Software Foundation (ASF) under one --> |
| |
| <!--- or more contributor license agreements. See the NOTICE file --> |
| |
| <!--- distributed with this work for additional information --> |
| |
| <!--- regarding copyright ownership. The ASF licenses this file --> |
| |
| <!--- to you under the Apache License, Version 2.0 (the --> |
| |
| <!--- "License"); you may not use this file except in compliance --> |
| |
| <!--- with the License. You may obtain a copy of the License at --> |
| |
| <!--- http://www.apache.org/licenses/LICENSE-2.0 --> |
| |
| <!--- Unless required by applicable law or agreed to in writing, --> |
| |
| <!--- software distributed under the License is distributed on an --> |
| |
| <!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> |
| |
| <!--- KIND, either express or implied. See the License for the --> |
| |
| <!--- specific language governing permissions and limitations --> |
| |
| <!--- under the License. --> |
| |
| <h1 id="how-to-create-new-operators-layers">How to Create New Operators (Layers)</h1> |
| |
| <p>This tutorials walks you through the process of creating new MXNet operators (or layers). |
| We've done our best to provide high-speed operators for most common use cases. |
| However, if you're engaged in research, |
| there's a good chance you'll want to define custom layers, |
| like a novel loss function. In these cases, you have two options:</p> |
| |
| <ul> |
| <li><p>Use CustomOp to write new operators using a front-end language (e.g., Python) that run on CPUs or GPUs. |
| Depending on your implementation, this can range from very fast (if you only use operators under mx.nd) to very slow (if you copy out the data, using <code>.asnumpy()</code>).</p></li> |
| <li><p>Use C++/mshadow (CUDA). This provides the best performance, but can be difficult |
| if you're not familiar with MXNet, mshadow, or Cuda.</p></li> |
| </ul> |
| |
| <h2 id="customop">CustomOp</h2> |
| |
| <p>Implementing an operator in Python is simple. |
| As an example, let's create a softmax operator. |
| Start by subclassing <code>mxnet.operator.CustomOp</code>, |
| and then override a few methods:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">os</span> |
| <span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="n">mx</span> |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span> |
| |
| <span class="k">class</span> <span class="nc">Softmax</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">CustomOp</span><span class="p">):</span> |
| <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">is_train</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">aux</span><span class="p">):</span> |
| <span class="n">x</span> <span class="o">=</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> |
| <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="o">.</span><span class="nb">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)))</span> |
| <span class="n">y</span> <span class="o">/=</span> <span class="n">y</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span> |
| </code></pre></div> |
| <p>We defined the computation for the forward pass of our operator. |
| The forward function takes a list of input and a list of output NDArrays. |
| For convenience, we called <code>.asnumpy()</code> on the first NDArray in input |
| and convert it to a CPU-based NumPy array. |
| This can be very slow. If you want the best performance, |
| keep data in the NDArray format and use operators under mx.nd to do the computation.</p> |
| |
| <p>At the end, we used CustomOp.assign to assign the resulting array y to out_data[0]. It handles assignment based on the value of req, which can be 'write', 'add', or 'null'.</p> |
| |
| <p>Then do the same for the backward pass:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">in_grad</span><span class="p">,</span> <span class="n">aux</span><span class="p">):</span> |
| <span class="n">l</span> <span class="o">=</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="nb">int</span><span class="p">)</span> |
| <span class="n">y</span> <span class="o">=</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> |
| <span class="n">y</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">l</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">l</span><span class="p">]</span> <span class="o">-=</span> <span class="mf">1.0</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">in_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span> |
| </code></pre></div> |
| <p>Softmax defines the computation of our custom operator, |
| but you still need to define its input/output format |
| by subclassing mx.operator.CustomOpProp. |
| First, register the new operator with the name 'softmax':</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">register</span><span class="p">(</span><span class="s">"softmax"</span><span class="p">)</span> |
| <span class="k">class</span> <span class="nc">SoftmaxProp</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">CustomOpProp</span><span class="p">):</span> |
| </code></pre></div> |
| <p>Then, call the base constructor with <code>need_top_grad=False</code> |
| because softmax is a loss layer and you don't need gradient input from preceding layers:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">SoftmaxProp</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">(</span><span class="n">need_top_grad</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> |
| </code></pre></div> |
| <p>Then declare the input and output:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">list_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="p">[</span><span class="s">'data'</span><span class="p">,</span> <span class="s">'label'</span><span class="p">]</span> |
| |
| <span class="k">def</span> <span class="nf">list_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="p">[</span><span class="s">'output'</span><span class="p">]</span> |
| </code></pre></div> |
| <p>Note that list_arguments declares both input and parameter. |
| We recommend ordering them as follows: <code>['input1', 'input2', ... , 'weight1', 'weight2', ...]</code></p> |
| |
| <p>Next, provide <code>infer_shape</code> to declare the shape of the output/weight |
| and check the consistency of the input shapes:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">infer_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_shape</span><span class="p">):</span> |
| <span class="n">data_shape</span> <span class="o">=</span> <span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="n">label_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">],)</span> |
| <span class="n">output_shape</span> <span class="o">=</span> <span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="k">return</span> <span class="p">[</span><span class="n">data_shape</span><span class="p">,</span> <span class="n">label_shape</span><span class="p">],</span> <span class="p">[</span><span class="n">output_shape</span><span class="p">],</span> <span class="p">[]</span> |
| </code></pre></div> |
| <p>The first axis of an input/output tensor corresponds to different examples within the batch. |
| The label is a set of integers, one for each data entry, |
| and the output has the same shape as the input. |
| The <code>infer_shape</code> function should always return three lists in this order: |
| inputs, outputs, and auxiliary states (which we don't have here), |
| even if one of them is empty.</p> |
| |
| <p>Optionally, you can also define <code>infer_type</code> to declare the input and output data type of your operator. Supported types are <code>np.float32</code>, <code>np.float64</code>, <code>np.float16</code>, <code>np.uint8</code>, and <code>np.int32</code>.</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">infer_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_type</span><span class="p">):</span> |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">in_type</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="k">return</span> <span class="p">[</span><span class="n">dtype</span><span class="p">,</span> <span class="n">dtype</span><span class="p">],</span> <span class="p">[</span><span class="n">dtype</span><span class="p">],</span> <span class="p">[]</span> |
| </code></pre></div> |
| <p>Finally, define a create_operator function that will be called by the back end to create an instance of softmax:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">create_operator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">dtypes</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">Softmax</span><span class="p">()</span> |
| </code></pre></div> |
| <p>To use the custom operator, create a mx.sym.Custom symbol with op_type as the registered name:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">mlp</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Custom</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'softmax'</span><span class="p">,</span> <span class="n">op_type</span><span class="o">=</span><span class="s">'softmax'</span><span class="p">)</span> |
| </code></pre></div> |
| <p>Please see the full code for this example <a href="https://github.com/apache/mxnet/blob/v1.x/example/numpy-ops/custom_softmax.py">here</a>.</p> |
| |
| <h2 id="c">C++</h2> |
| |
| <p>With MXNet v0.9 (the NNVM refactor) or later, creating new operators has become easier. |
| Operators are now registered with NNVM. |
| The following code is an example on how to register an operator (checkout <a href="https://github.com/apache/mxnet/tree/v1.x/src/operator/tensor">src/operator/tensor</a> for more examples):</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">abs</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">MXNET_DESCRIBE</span><span class="p">(</span><span class="s">"Take absolute value of the src"</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o"><</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="o">></span><span class="p">);</span> |
| </code></pre></div> |
| <p>The syntax is quite simple, we register the operator with a name, |
| then set number of inputs and outputs. |
| You can register attributes with any key (<code>FInferShape</code> for example) to any operator, |
| without having to modify a central class interface definition.</p> |
| |
| <h3 id="operator-attribute-system">Operator Attribute System</h3> |
| |
| <p>One of the biggest improvements brought by NNVM is the operator attribute system. |
| This is like traits for types in common languages like C++. |
| We can register any attribute to any operator, with the syntax</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">op</span><span class="o">-</span><span class="n">name</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">AttributeType</span><span class="o">></span><span class="p">(</span><span class="s">"AttributeKey"</span><span class="p">,</span> <span class="n">CorrespondingAttributeObject</span><span class="p">);</span> |
| </code></pre></div> |
| <p>These attributes can be retrieved later for various purposes. |
| For example, <code>FInferShape</code> is used for shape inference, <code>FCompute<cpu></code> is used for carrying out actual computation on CPU.</p> |
| |
| <p>As long as all attributes registered with the same key have the same type, |
| we can register any attributes to operators. |
| The more attribute an operator provides, |
| the more information the system can use for optimization.</p> |
| |
| <h3 id="list-of-basic-attributes">List of basic attributes</h3> |
| |
| <p>In this section, we will go through the basic attributes MXNet expect for all operators. |
| You can find the definition for them in the following two files:</p> |
| |
| <ul> |
| <li><a href="https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h">nnvm/op_attr_types.h</a></li> |
| <li><a href="https://github.com/apache/mxnet/blob/v1.x/include/mxnet/op_attr_types.h">mxnet/op_attr_types.h</a></li> |
| </ul> |
| |
| <h4 id="descriptions-optional">Descriptions (Optional)</h4> |
| |
| <p><code>.describe(comment)</code> adds a comment to the operator. Use <code>.MXNET_DESCRIBE(comment)</code> to add the current file name and line number to comment.</p> |
| |
| <h4 id="attribute-parser-optional">Attribute Parser (Optional)</h4> |
| |
| <p>Set attribute parser with <code>.set_attr_parser(PARSER)</code> where PARSER is a function with prototype <code>void(nnvm::NodeAttr* attrs)</code>. This function should parse the key-word arguments in <code>attrs->dict</code> and store the result in <code>attrs->parsed</code>.</p> |
| |
| <p>Simple arguments can be parsed like |
| <code>c++ |
| NNVM_REGISTER_OP(scalar_op) |
| .set_attr_parser( |
| [](NodeAttrs* attrs) { |
| attrs->parsed = dmlc::stod(attrs->dict["scalar"]); |
| }) |
| </code></p> |
| |
| <p>The parsed arguments can then be accessed in other attribute functions with |
| <code>c++ |
| double alpha = nnvm::get<double>(attrs.parsed); |
| </code></p> |
| |
| <p>More complex ops can use <code>dmlc::Parameters</code> and <code>ParamParser</code> (defined in operator_common.h) for parsing:</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="cp">#include <dmlc/parameter.h> |
| #include <operator_common.h> |
| </span><span class="k">struct</span> <span class="n">ActivationParam</span> <span class="o">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o"><</span><span class="n">ActivationParam</span><span class="o">></span> <span class="p">{</span> |
| <span class="c1">// use int for enumeration</span> |
| <span class="kt">int</span> <span class="n">act_type</span><span class="p">;</span> |
| <span class="n">DMLC_DECLARE_PARAMETER</span><span class="p">(</span><span class="n">ActivationParam</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">act_type</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"relu"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kReLU</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"sigmoid"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kSigmoid</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"tanh"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kTanh</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"softrelu"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kSoftReLU</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Activation function to be applied."</span><span class="p">);</span> |
| <span class="p">}</span> |
| <span class="p">};</span> |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">Activation</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o"><</span><span class="n">ActivationParam</span><span class="o">></span><span class="p">);</span> |
| <span class="c1">// access with:</span> |
| <span class="c1">// const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);</span> |
| </code></pre></div> |
| <h4 id="inputs-outputs">Inputs & Outputs</h4> |
| |
| <p>Number of inputs/outputs can be set with <code>.set_num_inputs(n_in)</code> and <code>.set_num_outputs(n_out)</code> |
| where n_in and n_out are integers.</p> |
| |
| <p>Alternatively, if the number of inputs/outputs is variable and depends on arguments, |
| you can set <code>n_in</code>/<code>n_out</code> to functions with prototype <code>uint32_t(const nnvm::NodeAttrs& attrs)</code> |
| that return the number of inputs/outputs based on parsed arguments.</p> |
| |
| <p>Outputs can be made invisible to other operators by registering <code>FNumVisibleOutputs</code> |
| and returning an integer smaller than <code>n_out</code>.</p> |
| |
| <p>Inputs/outputs can be named by registering <code>FListInputNames</code> and <code>FListOutputNames</code> with prototype <code>std::vector<std::string>(const NodeAttrs& attrs)</code>.</p> |
| |
| <h4 id="argument-descriptions">Argument Descriptions</h4> |
| |
| <p>Set argument descriptions with <code>.add_argument(name, type, comment)</code>. |
| This is necessary for operators to be properly called imperatively.</p> |
| |
| <p>First, add NDArray arguments <code>num_inputs</code> times with type "NDArray" |
| or one time with type "NDArray[]" for ops with variable length inputs.</p> |
| |
| <p>Then add key-word arguments with proper type (float, string, etc). |
| Operators that parse key-word arguments with <code>dmlc::Parameter</code> |
| can add argument descriptions in bulk with <code>.add_arguments(ActivationParam::__FIELDS__())</code> |
| (NDArray arguments still need to be manually added with type "NDArray").</p> |
| |
| <h4 id="finfershape-or-tisbackward-for-backward-only-ops">FInferShape or TIsBackward (for Backward Only Ops)</h4> |
| |
| <p>Normally operators need to have <code>FInferShape</code> with prototype <code>bool(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs)</code>. <code>FInferShape</code> fills unknown shapes (<code>shape.ndim() == 0</code>) in in_attrs/out_attrs based on known shapes in in_attrs/out_attrs. Use <code>ElemwiseShape<n_in, n_out></code> for simple operators with uniform shapes.</p> |
| |
| <p>Operators that are only used for a backward pass can instead register <code>.set_attr<nnvm::TIsBackward>("TIsBackward", true)</code> |
| and their shapes with be copied from the corresponding forward operators.</p> |
| |
| <h4 id="finfertype">FInferType</h4> |
| |
| <p>Similar to <code>FInferShape</code>, <code>FInferType</code> fills unknown types (-1) based on known types. Use <code>ElemwiseType<n_in, n_out></code> for simple operators with uniform types. Operators that registered <code>TIsBackward</code> don't need to register this.</p> |
| |
| <h4 id="finplaceoption-optional">FInplaceOption (Optional)</h4> |
| |
| <p><code>FInplaceOption</code> with prototype <code>std::vector<std::pair<int, int> >(const NodeAttrs& attrs)</code> |
| specifies which input/output pairs can be computed in-place |
| and share memory with each other. |
| Each pair (i, j) in the returned list means |
| that the i-th input can share memory with the j-th output.</p> |
| |
| <h4 id="fgradient-optional-for-imperative-use-required-for-symbolic-use">FGradient (Optional for imperative use, required for symbolic use)</h4> |
| |
| <p>If an operator has gradient, it can be described with <code>FGradient</code> with prototype</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">NodeEntry</span><span class="o">></span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">ObjectPtr</span><span class="o">&</span> <span class="n">n</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">NodeEntry</span><span class="o">>&</span> <span class="n">ograds</span><span class="p">)</span> |
| </code></pre></div> |
| <p>Use utility functions <code>ElemwiseGradUseIn{op_name}</code>, <code>ElemwiseGradUseOut{op_name}</code>, <code>ElemwiseGradUseNone{op_name}</code> for ops that need corresponding forward op's input, |
| output or nothing to calculating gradient.</p> |
| |
| <p>For more complicated patterns, use <code>MakeGradNode(op_name, n, heads, dict)</code> to create gradient entries, |
| where heads are input entries to the backward op, composed from ograds and n->inputs.</p> |
| |
| <p>When assembling a return vector of <code>std::vector<nnvm::NodeEntry> ret;</code> a common pattern would be to |
| either create nodes in place as in:</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">ret</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="n">MakeNode</span><span class="p">(</span><span class="s">"zeros_like"</span><span class="p">,</span> <span class="n">n</span><span class="o">-></span><span class="n">attrs</span><span class="p">.</span><span class="n">name</span> <span class="o">+</span> <span class="s">"_xyz_backward"</span><span class="p">,</span> |
| <span class="p">{</span><span class="n">n</span><span class="o">-></span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">]},</span> <span class="nb">nullptr</span><span class="p">,</span> <span class="o">&</span><span class="n">n</span><span class="p">))</span> |
| </code></pre></div> |
| <p>Or create the node, modify and then move into NodeEntry's constructor if this node is not to be used |
| again. This avoids uneccessary copies of the shared_ptr.</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o"><</span> <span class="n">n</span><span class="o">-></span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">();</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">nnvm</span><span class="o">::</span><span class="n">ObjectPtr</span> <span class="n">node</span> <span class="o">=</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">Node</span><span class="o">::</span><span class="n">Create</span><span class="p">();</span> |
| <span class="n">node</span><span class="o">-></span><span class="n">attrs</span><span class="p">.</span><span class="n">op</span> <span class="o">=</span> <span class="n">copy_op</span><span class="p">;</span> |
| <span class="n">node</span><span class="o">-></span><span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span><span class="n">ograds</span><span class="p">[</span><span class="mi">0</span><span class="p">]};</span> |
| <span class="n">ret</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">node</span><span class="p">));</span> |
| <span class="p">}</span> |
| </code></pre></div> |
| <p>The first case uses RVO and the second in place construction.</p> |
| |
| <h4 id="fcomputexpu">FCompute<xpu></h4> |
| |
| <p>Simple operators can register FCompute<xpu> with <code>.set_attr<FCompute>("FCompute<cpu>", ...)</code> and <code>.set_attr<FCompute>("FCompute<gpu>", ...)</code> for both CPU and (optionally) GPU computation.</p> |
| |
| <p>FCompute has prototype</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">)</span> |
| </code></pre></div> |
| <p><code>req</code> has the same length as <code>outputs</code>. |
| Each entry of <code>req</code> specifies |
| how the corresponding <code>output</code> should be written to. |
| <code>OpReqType</code> is defined as:</p> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span> |
| <span class="n">kNullOp</span><span class="p">,</span> |
| <span class="n">kWriteTo</span><span class="p">,</span> |
| <span class="n">kWriteInplace</span><span class="p">,</span> |
| <span class="n">kAddTo</span> |
| <span class="p">};</span> |
| </code></pre></div> |
| <p>Normally, the <code>req</code> of all <code>outputs</code> should be <code>kWriteTo</code>, |
| meaning that the provided <code>outputs</code> tensor is a <em>raw</em> memory block, |
| so the operator should write results directly into it. |
| In some cases, for example, when calculating the gradient tensor, |
| it would be great if we could accumulate the result, |
| rather than directly overwrite the tensor contents |
| so that no extra space needs to be created each time. |
| In such cases, the corresponding <code>req</code> is set to <code>kAddTo</code>, |
| indicating that a <code>+=</code> should be used.</p> |
| |
| <h3 id="example-abs-operator">Example: abs operator</h3> |
| <div class="highlight"><pre><code class="language-c++" data-lang="c++"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">abs</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">MXNET_DESCRIBE</span><span class="p">(</span><span class="s">"Take absolute value of the src"</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o"><</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">></span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">ElemwiseType</span><span class="o"><</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">></span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span> |
| <span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">){</span> |
| <span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">></span> <span class="o">></span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span> |
| <span class="p">})</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">UnaryCompute</span><span class="o"><</span><span class="n">cpu</span><span class="p">,</span> <span class="n">mshadow_op</span><span class="o">::</span><span class="n">abs</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FGradient</span><span class="o">></span><span class="p">(</span><span class="s">"FGradient"</span><span class="p">,</span> <span class="n">ElemwiseGradUseIn</span><span class="p">{</span><span class="s">"_backward_abs"</span><span class="p">});</span> |
| <span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"data"</span><span class="p">,</span> <span class="s">"NDArray"</span><span class="p">,</span> <span class="s">"Source input"</span><span class="p">)</span> |
| |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_abs</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o"><</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">></span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">ElemwiseType</span><span class="o"><</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">></span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span> |
| <span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">){</span> |
| <span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">></span> <span class="o">></span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">},</span> <span class="p">{</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span> |
| <span class="p">})</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">BinaryCompute</span><span class="o"><</span><span class="n">cpu</span><span class="p">,</span> <span class="n">backward_grad</span><span class="o"><</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">sign</span><span class="o">></span> <span class="o">></span><span class="p">);</span> |
| </code></pre></div> |
| <h3 id="legacy-operators">Legacy Operators</h3> |
| |
| <p>For the legacy (pre 0.9) way of defining operators with C++, please see: |
| - <a href="/versions/1.9.0/api/architecture/overview.html#operators-in-mxnet">Developer Guide - Operators</a> |
| - <a href="/versions/1.9.0/api/architecture/overview.html#simpleop-the-unified-operator-api">Developer Guide - SimpleOp</a></p> |
| |
| </div> |
| </div> |
| |
| </div> |
| </div> |
| |
| </article> |
| |
| </main><footer class="site-footer h-card"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3"> |
| <h4 class="footer-category-title">Resources</h4> |
| <ul class="contact-list"> |
| <li><a href="/versions/1.9.0/community/contribute#mxnet-dev-communications">Mailing lists</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/labels/Roadmap">Github Roadmap</a></li> |
| <li><a href="https://discuss.mxnet.io">MXNet Discuss forum</a></li> |
| <li><a href="/versions/1.9.0/community/contribute">Contribute To MXNet</a></li> |
| </ul> |
| </div> |
| <div class="col-3"> |
| <h4 class="footer-category-title">Apache</h4> |
| <ul class="apache-list"> |
| <li><a href="https://www.apache.org/foundation/">Foundation</a></li> |
| <li><a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a></li> |
| <li><a href="/versions/1.9.0/api/faq/security.html">Security</a></li> |
| <li><a href="https://www.apache.org/licenses/">License</a></li> |
| <li><a href="https://www.apache.org/events/current-event">Events</a></li> |
| <li><a href="https://www.apache.org/foundation/thanks.html">Thanks</a></li> |
| </ul> |
| </div> |
| |
| <div class="col-3"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.0/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.0/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.0/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul> |
| </div> |
| |
| <div class="col-3 footer-text"> |
| <p>A flexible and efficient library for deep learning.</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| <footer class="site-footer2"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3"> |
| <img src="/versions/1.9.0/assets/img/apache_incubator_logo.png" class="footer-logo col-2"> |
| </div> |
| <div class="footer-bottom-warning col-9"> |
| <p>Apache MXNet is an effort undergoing incubation at <a href="http://www.apache.org/">The Apache Software Foundation</a> (ASF), <span |
| style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required |
| of all newly accepted projects until a further review indicates that the infrastructure, |
| communications, and decision making process have stabilized in a manner consistent with other |
| successful ASF projects. While incubation status is not necessarily a reflection of the completeness |
| or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p><p>"Copyright © 2017-2022, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache |
| feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the |
| Apache Software Foundation."</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| |
| |
| |
| |
| </body> |
| |
| </html> |