| <!DOCTYPE html> |
| |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta content="IE=edge" http-equiv="X-UA-Compatible"/> |
| <meta content="width=device-width, initial-scale=1" name="viewport"/> |
| <title>How to Create New Operators (Layers) — mxnet documentation</title> |
| <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> |
| <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> |
| <link href="../_static/basic.css" rel="stylesheet" type="text/css"> |
| <link href="../_static/pygments.css" rel="stylesheet" type="text/css"> |
| <link href="../_static/mxnet.css" rel="stylesheet" type="text/css"/> |
| <script type="text/javascript"> |
| var DOCUMENTATION_OPTIONS = { |
| URL_ROOT: '../', |
| VERSION: '', |
| COLLAPSE_INDEX: false, |
| FILE_SUFFIX: '.html', |
| HAS_SOURCE: true, |
| SOURCELINK_SUFFIX: '' |
| }; |
| </script> |
| <script src="../_static/jquery-1.11.1.js" type="text/javascript"></script> |
| <script src="../_static/underscore.js" type="text/javascript"></script> |
| <script src="../_static/searchtools_custom.js" type="text/javascript"></script> |
| <script src="../_static/doctools.js" type="text/javascript"></script> |
| <script src="../_static/selectlang.js" type="text/javascript"></script> |
| <script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> |
| <script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new |
| Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| |
| </script> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../_static/jquery.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../_static/underscore.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../_static/doctools.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> |
| <!-- --> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/> |
| </link></link></head> |
| <body role="document"><!-- Previous Navbar Layout |
| <div class="navbar navbar-default navbar-fixed-top"> |
| <div class="container"> |
| <div class="navbar-header"> |
| <button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar"> |
| <span class="sr-only">Toggle navigation</span> |
| <span class="icon-bar"></span> |
| <span class="icon-bar"></span> |
| <span class="icon-bar"></span> |
| </button> |
| <a href="../" class="navbar-brand"> |
| <img src="http://data.mxnet.io/theme/mxnet.png"> |
| </a> |
| </div> |
| <div id="navbar" class="navbar-collapse collapse"> |
| <ul id="navbar" class="navbar navbar-left"> |
| |
| <li> <a href="../get_started/index.html">Get Started</a> </li> |
| |
| <li> <a href="../tutorials/index.html">Tutorials</a> </li> |
| |
| <li> <a href="../how_to/index.html">How To</a> </li> |
| |
| |
| <li class="dropdown"> |
| <a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a> |
| <ul class="dropdown-menu"> |
| |
| <li><a href="../packages/python/index.html"> |
| Python |
| </a></li> |
| |
| <li><a href="../packages/r/index.html"> |
| R |
| </a></li> |
| |
| <li><a href="../packages/julia/index.html"> |
| Julia |
| </a></li> |
| |
| <li><a href="../packages/c++/index.html"> |
| C++ |
| </a></li> |
| |
| <li><a href="../packages/scala/index.html"> |
| Scala |
| </a></li> |
| |
| <li><a href="../packages/perl/index.html"> |
| Perl |
| </a></li> |
| |
| </ul> |
| </li> |
| |
| <li> <a href="../system/index.html">System</a> </li> |
| <li> |
| <form class="" role="search" action="../search.html" method="get" autocomplete="off"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input type="text" name="q" class="form-control" placeholder="Search"> |
| </div> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| |
| </form> </li> |
| </ul> |
| <ul id="navbar" class="navbar navbar-right"> |
| <li> <a href="../index.html"><span class="flag-icon flag-icon-us"></span></a> </li> |
| <li> <a href="..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| Previous Navbar Layout End --> |
| <div class="navbar navbar-fixed-top"> |
| <div class="container" id="navContainer"> |
| <div class="innder" id="header-inner"> |
| <h1 id="logo-wrap"> |
| <a href="../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a> |
| </h1> |
| <nav class="nav-bar" id="main-nav"> |
| <a class="main-nav-link" href="../get_started/install.html">Install</a> |
| <a class="main-nav-link" href="../tutorials/index.html">Tutorials</a> |
| <a class="main-nav-link" href="../how_to/index.html">How To</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> |
| <ul class="dropdown-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="../api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="../api/scala/index.html">Scala</a></li> |
| <li><a class="main-nav-link" href="../api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="../api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="../api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="../api/perl/index.html">Perl</a></li> |
| </ul> |
| </span> |
| <a class="main-nav-link" href="../architecture/index.html">Architecture</a> |
| <!-- <a class="main-nav-link" href="../community/index.html">Community</a> --> |
| <a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a> |
| <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav> |
| <script> function getRootPath(){ return "../" } </script> |
| <div class="burgerIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> |
| <ul class="dropdown-menu dropdown-menu-right" id="burgerMenu"> |
| <li><a href="../get_started/install.html">Install</a></li> |
| <li><a href="../tutorials/index.html">Tutorials</a></li> |
| <li><a href="../how_to/index.html">How To</a></li> |
| <li class="dropdown-submenu"> |
| <a href="#" tabindex="-1">API</a> |
| <ul class="dropdown-menu"> |
| <li><a href="../api/python/index.html" tabindex="-1">Python</a> |
| </li> |
| <li><a href="../api/scala/index.html" tabindex="-1">Scala</a> |
| </li> |
| <li><a href="../api/r/index.html" tabindex="-1">R</a> |
| </li> |
| <li><a href="../api/julia/index.html" tabindex="-1">Julia</a> |
| </li> |
| <li><a href="../api/c++/index.html" tabindex="-1">C++</a> |
| </li> |
| <li><a href="../api/perl/index.html" tabindex="-1">Perl</a> |
| </li> |
| </ul> |
| </li> |
| <li><a href="../architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li> |
| <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul> |
| </div> |
| <div class="plusIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> |
| <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> |
| </div> |
| <div id="search-input-wrap"> |
| <form action="../search.html" autocomplete="off" class="" method="get" role="search"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input class="form-control" name="q" placeholder="Search" type="text"/> |
| </div> |
| <input name="check_keywords" type="hidden" value="yes"> |
| <input name="area" type="hidden" value="default"/> |
| </input></form> |
| <div id="search-preview"></div> |
| </div> |
| <div id="searchIcon"> |
| <span aria-hidden="true" class="glyphicon glyphicon-search"></span> |
| </div> |
| <!-- <div id="lang-select-wrap"> --> |
| <!-- <label id="lang-select-label"> --> |
| <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> |
| <!-- <span></span> --> |
| <!-- </label> --> |
| <!-- <select id="lang-select"> --> |
| <!-- <option value="en">Eng</option> --> |
| <!-- <option value="zh">中文</option> --> |
| <!-- </select> --> |
| <!-- </div> --> |
| <!-- <a id="mobile-nav-toggle"> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| </a> --> |
| </div> |
| </div> |
| </div> |
| <div class="container"> |
| <div class="row"> |
| <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../api/python/index.html">Python Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../api/r/index.html">R Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../api/julia/index.html">Julia Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../api/c++/index.html">C++ Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../api/scala/index.html">Scala Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../api/perl/index.html">Perl Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="index.html">HowTo Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../architecture/index.html">System Documents</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../tutorials/index.html">Tutorials</a></li> |
| </ul> |
| </div> |
| </div> |
| <div class="content"> |
| <div class="section" id="how-to-create-new-operators-layers"> |
| <span id="how-to-create-new-operators-layers"></span><h1>How to Create New Operators (Layers)<a class="headerlink" href="#how-to-create-new-operators-layers" title="Permalink to this headline">¶</a></h1> |
| <p>This tutorials walks you through the process of creating new MXNet operators (or layers). |
| We’ve done our best to provide high-speed operators for most common use cases. |
| However, if you’re engaged in research, |
| there’s a good chance you’ll want to define custom layers, |
| like a novel loss function. In these cases, you have two options:</p> |
| <ul class="simple"> |
| <li>Use CustomOp to write new operators using a front-end language (e.g., Python) that run on CPUs or GPUs. |
| Depending on your implementation, this can range from very fast (if you only use operators under mx.nd) to very slow (if you copy out the data, using <code class="docutils literal"><span class="pre">.asnumpy()</span></code>).</li> |
| <li>Use C++/mshadow (CUDA). This provides the best performance, but can be difficult |
| if you’re not familiar with MXNet, mshadow, or Cuda.</li> |
| </ul> |
| <div class="section" id="customop"> |
| <span id="customop"></span><h2>CustomOp<a class="headerlink" href="#customop" title="Permalink to this headline">¶</a></h2> |
| <p>Implementing an operator in Python is simple. |
| As an example, let’s create a softmax operator. |
| Start by subclassing <code class="docutils literal"><span class="pre">mxnet.operator.CustomOp</span></code>, |
| and then override a few methods:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span> |
| <span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span> |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span> |
| |
| <span class="k">class</span> <span class="nc">Softmax</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">CustomOp</span><span class="p">):</span> |
| <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">is_train</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">aux</span><span class="p">):</span> |
| <span class="n">x</span> <span class="o">=</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> |
| <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)))</span> |
| <span class="n">y</span> <span class="o">/=</span> <span class="n">y</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <p>We defined the computation for the forward pass of our operator. |
| The forward function takes a list of input and a list of output NDArrays. |
| For convenience, we called <code class="docutils literal"><span class="pre">.asnumpy()</span></code> on the first NDArray in input |
| and convert it to a CPU-based NumPy array. |
| This can be very slow. If you want the best performance, |
| keep data in the NDArray format and use operators under mx.nd to do the computation.</p> |
| <p>At the end, we used CustomOp.assign to assign the resulting array y to out_data[0]. It handles assignment based on the value of req, which can be ‘write’, ‘add’, or ‘null’.</p> |
| <p>Then do the same for the backward pass:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">req</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">,</span> <span class="n">in_data</span><span class="p">,</span> <span class="n">out_data</span><span class="p">,</span> <span class="n">in_grad</span><span class="p">,</span> <span class="n">aux</span><span class="p">):</span> |
| <span class="n">l</span> <span class="o">=</span> <span class="n">in_data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int</span><span class="p">)</span> |
| <span class="n">y</span> <span class="o">=</span> <span class="n">out_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span> |
| <span class="n">y</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">l</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">l</span><span class="p">]</span> <span class="o">-=</span> <span class="mf">1.0</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">in_grad</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <p>Softmax defines the computation of our custom operator, |
| but you still need to define its input/output format |
| by subclassing mx.operator.CustomOpProp. |
| First, register the new operator with the name ‘softmax’:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="nd">@mx.operator.register</span><span class="p">(</span><span class="s2">"softmax"</span><span class="p">)</span> |
| <span class="k">class</span> <span class="nc">SoftmaxProp</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">operator</span><span class="o">.</span><span class="n">CustomOpProp</span><span class="p">):</span> |
| </pre></div> |
| </div> |
| <p>Then, call the base constructor with <code class="docutils literal"><span class="pre">need_top_grad=False</span></code> |
| because softmax is a loss layer and you don’t need gradient input from preceding layers:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="nb">super</span><span class="p">(</span><span class="n">SoftmaxProp</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">need_top_grad</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>Then declare the input and output:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">list_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="p">[</span><span class="s1">'data'</span><span class="p">,</span> <span class="s1">'label'</span><span class="p">]</span> |
| |
| <span class="k">def</span> <span class="nf">list_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> |
| <span class="k">return</span> <span class="p">[</span><span class="s1">'output'</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| <p>Note that list_arguments declares both input and parameter. |
| We recommend ordering them as follows: <code class="docutils literal"><span class="pre">['input1',</span> <span class="pre">'input2',</span> <span class="pre">...</span> <span class="pre">,</span> <span class="pre">'weight1',</span> <span class="pre">'weight2',</span> <span class="pre">...]</span></code></p> |
| <p>Next, provide <code class="docutils literal"><span class="pre">infer_shape</span></code> to declare the shape of the output/weight |
| and check the consistency of the input shapes:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">infer_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_shape</span><span class="p">):</span> |
| <span class="n">data_shape</span> <span class="o">=</span> <span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="n">label_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">],)</span> |
| <span class="n">output_shape</span> <span class="o">=</span> <span class="n">in_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="k">return</span> <span class="p">[</span><span class="n">data_shape</span><span class="p">,</span> <span class="n">label_shape</span><span class="p">],</span> <span class="p">[</span><span class="n">output_shape</span><span class="p">],</span> <span class="p">[]</span> |
| </pre></div> |
| </div> |
| <p>The first axis of an input/output tensor corresponds to different examples within the batch. |
| The label is a set of integers, one for each data entry, |
| and the output has the same shape as the input. |
| The <code class="docutils literal"><span class="pre">infer_shape</span></code> function should always return three lists in this order: |
| inputs, outputs, and auxiliary states (which we don’t have here), |
| even if one of them is empty.</p> |
| <p>Optionally, you can also define <code class="docutils literal"><span class="pre">infer_type</span></code> to declare the input and output data type of your operator. Supported types are <code class="docutils literal"><span class="pre">np.float32</span></code>, <code class="docutils literal"><span class="pre">np.float64</span></code>, <code class="docutils literal"><span class="pre">np.float16</span></code>, <code class="docutils literal"><span class="pre">np.uint8</span></code>, and <code class="docutils literal"><span class="pre">np.int32</span></code>.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">infer_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_type</span><span class="p">):</span> |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">in_type</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <span class="k">return</span> <span class="p">[</span><span class="n">dtype</span><span class="p">,</span> <span class="n">dtype</span><span class="p">],</span> <span class="p">[</span><span class="n">dtype</span><span class="p">],</span> <span class="p">[]</span> |
| </pre></div> |
| </div> |
| <p>Finally, define a create_operator function that will be called by the back end to create an instance of softmax:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">create_operator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ctx</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">dtypes</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">Softmax</span><span class="p">()</span> |
| </pre></div> |
| </div> |
| <p>To use the custom operator, create a mx.sym.Custom symbol with op_type as the registered name:</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">mlp</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Custom</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">,</span> <span class="n">op_type</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>Please see the full code for this example <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/example/numpy-ops/custom_softmax.py">here</a>.</p> |
| </div> |
| <div class="section" id="c"> |
| <span id="c"></span><h2>C++<a class="headerlink" href="#c" title="Permalink to this headline">¶</a></h2> |
| <p>With MXNet v0.9 (the NNVM refactor) or later, creating new operators has become easier. |
| Operators are now registered with NNVM. |
| The following code is an example on how to register an operator (checkout <a class="reference external" href="https://github.com/dmlc/mxnet/tree/master/src/operator/tensor">src/operator/tensor</a> for more examples):</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">abs</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">MXNET_DESCRIBE</span><span class="p">(</span><span class="s">"Take absolute value of the src"</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o"><</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="o">></span><span class="p">);</span> |
| </pre></div> |
| </div> |
| <p>The syntax is quite simple, we register the operator with a name, |
| then set number of inputs and outputs. |
| You can register attributes with any key (<code class="docutils literal"><span class="pre">FInferShape</span></code> for example) to any operator, |
| without having to modify a central class interface definition.</p> |
| <div class="section" id="operator-attribute-system"> |
| <span id="operator-attribute-system"></span><h3>Operator Attribute System<a class="headerlink" href="#operator-attribute-system" title="Permalink to this headline">¶</a></h3> |
| <p>One of the biggest improvements brought by NNVM is the operator attribute system. |
| This is like traits for types in common languages like C++. |
| We can register any attribute to any operator, with the syntax</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">op</span><span class="o">-</span><span class="n">name</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">AttributeType</span><span class="o">></span><span class="p">(</span><span class="s">"AttributeKey"</span><span class="p">,</span> <span class="n">CorrespondingAttributeObject</span><span class="p">);</span> |
| </pre></div> |
| </div> |
| <p>These attributes can be retrieved later for various purposes. |
| For example, <code class="docutils literal"><span class="pre">FInferShape</span></code> is used for shape inference, <code class="docutils literal"><span class="pre">FCompute<cpu></span></code> is used for carrying out actual computation on CPU.</p> |
| <p>As long as all attributes registered with the same key have the same type, |
| we can register any attributes to operators. |
| The more attribute an operator provides, |
| the more information the system can use for optimization.</p> |
| </div> |
| <div class="section" id="list-of-basic-attributes"> |
| <span id="list-of-basic-attributes"></span><h3>List of basic attributes<a class="headerlink" href="#list-of-basic-attributes" title="Permalink to this headline">¶</a></h3> |
| <p>In this section, we will go through the basic attributes MXNet expect for all operators. |
| You can find the definition for them in the following two files:</p> |
| <div class="toctree-wrapper compound"> |
| <ul> |
| <li class="toctree-l1"><a class="reference external" href="https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h">nnvm/op</a></li> |
| <li class="toctree-l1"><a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/include/mxnet/op_attr_types.h">mxnet/op</a></li> |
| </ul> |
| </div> |
| <div class="section" id="descriptions-optional"> |
| <span id="descriptions-optional"></span><h4>Descriptions (Optional)<a class="headerlink" href="#descriptions-optional" title="Permalink to this headline">¶</a></h4> |
| <p><code class="docutils literal"><span class="pre">.describe(comment)</span></code> adds a comment to the operator. Use <code class="docutils literal"><span class="pre">.MXNET_DESCRIBE(comment)</span></code> to add the current file name and line number to comment.</p> |
| </div> |
| <div class="section" id="attribute-parser-optional"> |
| <span id="attribute-parser-optional"></span><h4>Attribute Parser (Optional)<a class="headerlink" href="#attribute-parser-optional" title="Permalink to this headline">¶</a></h4> |
| <p>Set attribute parser with <code class="docutils literal"><span class="pre">.set_attr_parser(PARSER)</span></code> where PARSER is a function with prototype <code class="docutils literal"><span class="pre">void(nnvm::NodeAttr*</span> <span class="pre">attrs)</span></code>. This function should parse the key-word arguments in <code class="docutils literal"><span class="pre">attrs->dict</span></code> and store the result in <code class="docutils literal"><span class="pre">attrs->parsed</span></code>.</p> |
| <p>Simple arguments can be parsed like</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">scalar_op</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span> |
| <span class="p">[](</span><span class="n">NodeAttrs</span><span class="o">*</span> <span class="n">attrs</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">attrs</span><span class="o">-></span><span class="n">parsed</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">stod</span><span class="p">(</span><span class="n">attrs</span><span class="o">-></span><span class="n">dict</span><span class="p">[</span><span class="s">"scalar"</span><span class="p">]);</span> |
| <span class="p">})</span> |
| </pre></div> |
| </div> |
| <p>The parsed arguments can then be accessed in other attribute functions with</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span>double alpha = nnvm::get<double>(attrs.parsed); |
| </pre></div> |
| </div> |
| <p>More complex ops can use <code class="docutils literal"><span class="pre">dmlc::Parameters</span></code> and <code class="docutils literal"><span class="pre">ParamParser</span></code> (defined in operator_common.h) for parsing:</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="cp">#include</span> <span class="cpf"><dmlc/parameter.h></span><span class="cp"></span> |
| <span class="cp">#include</span> <span class="cpf"><operator_common.h></span><span class="cp"></span> |
| <span class="k">struct</span> <span class="nl">ActivationParam</span> <span class="p">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o"><</span><span class="n">ActivationParam</span><span class="o">></span> <span class="p">{</span> |
| <span class="c1">// use int for enumeration</span> |
| <span class="kt">int</span> <span class="n">act_type</span><span class="p">;</span> |
| <span class="n">DMLC_DECLARE_PARAMETER</span><span class="p">(</span><span class="n">ActivationParam</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">act_type</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"relu"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kReLU</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"sigmoid"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kSigmoid</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"tanh"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kTanh</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">add_enum</span><span class="p">(</span><span class="s">"softrelu"</span><span class="p">,</span> <span class="n">activation</span><span class="o">::</span><span class="n">kSoftReLU</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Activation function to be applied."</span><span class="p">);</span> |
| <span class="p">}</span> |
| <span class="p">};</span> |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">Activation</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o"><</span><span class="n">ActivationParam</span><span class="o">></span><span class="p">);</span> |
| <span class="c1">// access with:</span> |
| <span class="c1">// const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="inputs-outputs"> |
| <span id="inputs-outputs"></span><h4>Inputs & Outputs<a class="headerlink" href="#inputs-outputs" title="Permalink to this headline">¶</a></h4> |
| <p>Number of inputs/outputs can be set with <code class="docutils literal"><span class="pre">.set_num_inputs(n_in)</span></code> and <code class="docutils literal"><span class="pre">.set_num_outputs(n_out)</span></code> |
| where n_in and n_out are integers.</p> |
| <p>Alternatively, if the number of inputs/outputs is variable and depends on arguments, |
| you can set <code class="docutils literal"><span class="pre">n_in</span></code>/<code class="docutils literal"><span class="pre">n_out</span></code> to functions with prototype <code class="docutils literal"><span class="pre">uint32_t(const</span> <span class="pre">nnvm::NodeAttrs&</span> <span class="pre">attrs)</span></code> |
| that return the number of inputs/outputs based on parsed arguments.</p> |
| <p>Outputs can be made invisible to other operators by registering <code class="docutils literal"><span class="pre">FNumVisibleOutputs</span></code> |
| and returning an integer smaller than <code class="docutils literal"><span class="pre">n_out</span></code>.</p> |
| <p>Inputs/outputs can be named by registering <code class="docutils literal"><span class="pre">FListInputNames</span></code> and <code class="docutils literal"><span class="pre">FListOutputNames</span></code> with prototype <code class="docutils literal"><span class="pre">std::vector<std::string>(const</span> <span class="pre">NodeAttrs&</span> <span class="pre">attrs)</span></code>.</p> |
| </div> |
| <div class="section" id="argument-descriptions"> |
| <span id="argument-descriptions"></span><h4>Argument Descriptions<a class="headerlink" href="#argument-descriptions" title="Permalink to this headline">¶</a></h4> |
| <p>Set argument descriptions with <code class="docutils literal"><span class="pre">.add_argument(name,</span> <span class="pre">type,</span> <span class="pre">comment)</span></code>. |
| This is necessary for operators to be properly called imperatively.</p> |
| <p>First, add NDArray arguments <code class="docutils literal"><span class="pre">num_inputs</span></code> times with type “NDArray” |
| or one time with type “NDArray[]” for ops with variable length inputs.</p> |
| <p>Then add key-word arguments with proper type (float, string, etc). |
| Operators that parse key-word arguments with <code class="docutils literal"><span class="pre">dmlc::Parameter</span></code> |
| can add argument descriptions in bulk with <code class="docutils literal"><span class="pre">.add_arguments(ActivationParam::__FIELDS__())</span></code> |
| (NDArray arguments still need to be manually added with type “NDArray”).</p> |
| </div> |
| <div class="section" id="finfershape-or-tisbackward-for-backward-only-ops"> |
| <span id="finfershape-or-tisbackward-for-backward-only-ops"></span><h4>FInferShape or TIsBackward (for Backward Only Ops)<a class="headerlink" href="#finfershape-or-tisbackward-for-backward-only-ops" title="Permalink to this headline">¶</a></h4> |
| <p>Normally operators need to have <code class="docutils literal"><span class="pre">FInferShape</span></code> with prototype <code class="docutils literal"><span class="pre">bool(const</span> <span class="pre">nnvm::NodeAttrs&</span> <span class="pre">attrs,</span> <span class="pre">std::vector<TShape></span> <span class="pre">*in_attrs,</span> <span class="pre">std::vector<TShape></span> <span class="pre">*out_attrs)</span></code>. <code class="docutils literal"><span class="pre">FInferShape</span></code> fills unknown shapes (<code class="docutils literal"><span class="pre">shape.ndim()</span> <span class="pre">==</span> <span class="pre">0</span></code>) in in_attrs/out_attrs based on known shapes in in_attrs/out_attrs. Use <code class="docutils literal"><span class="pre">ElemwiseShape<n_in,</span> <span class="pre">n_out></span></code> for simple operators with uniform shapes.</p> |
| <p>Operators that are only used for a backward pass can instead register <code class="docutils literal"><span class="pre">.set_attr<nnvm::TIsBackward>("TIsBackward",</span> <span class="pre">true)</span></code> |
| and their shapes with be copied from the corresponding forward operators.</p> |
| </div> |
| <div class="section" id="finfertype"> |
| <span id="finfertype"></span><h4>FInferType<a class="headerlink" href="#finfertype" title="Permalink to this headline">¶</a></h4> |
| <p>Similar to <code class="docutils literal"><span class="pre">FInferShape</span></code>, <code class="docutils literal"><span class="pre">FInferType</span></code> fills unknown types (-1) based on known types. Use <code class="docutils literal"><span class="pre">ElemwiseType<n_in,</span> <span class="pre">n_out></span></code> for simple operators with uniform types. Operators that registered <code class="docutils literal"><span class="pre">TIsBackward</span></code> don’t need to register this.</p> |
| </div> |
| <div class="section" id="finplaceoption-optional"> |
| <span id="finplaceoption-optional"></span><h4>FInplaceOption (Optional)<a class="headerlink" href="#finplaceoption-optional" title="Permalink to this headline">¶</a></h4> |
| <p><code class="docutils literal"><span class="pre">FInplaceOption</span></code> with prototype <code class="docutils literal"><span class="pre">std::vector<std::pair<int,</span> <span class="pre">int></span> <span class="pre">>(const</span> <span class="pre">NodeAttrs&</span> <span class="pre">attrs)</span></code> |
| specifies which input/output pairs can be computed in-place |
| and share memory with each other. |
| Each pair (i, j) in the returned list means |
| that the i-th input can share memory with the j-th output.</p> |
| </div> |
| <div class="section" id="fgradient-optional-for-imperative-use-required-for-symbolic-use"> |
| <span id="fgradient-optional-for-imperative-use-required-for-symbolic-use"></span><h4>FGradient (Optional for imperative use, required for symbolic use)<a class="headerlink" href="#fgradient-optional-for-imperative-use-required-for-symbolic-use" title="Permalink to this headline">¶</a></h4> |
| <p>If an operator has gradient, it can be described with <code class="docutils literal"><span class="pre">FGradient</span></code> with prototype</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">NodeEntry</span><span class="o">></span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodePtr</span><span class="o">&</span> <span class="n">n</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">NodeEntry</span><span class="o">>&</span> <span class="n">ograds</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>Use utility functions <code class="docutils literal"><span class="pre">ElemwiseGradUseIn{op_name}</span></code>, <code class="docutils literal"><span class="pre">ElemwiseGradUseOut{op_name}</span></code>, <code class="docutils literal"><span class="pre">ElemwiseGradUseNone{op_name}</span></code> for ops that need corresponding forward op’s input, |
| output or nothing to calculating gradient.</p> |
| <p>For more complicated patterns, use <code class="docutils literal"><span class="pre">MakeGradNode(op_name,</span> <span class="pre">n,</span> <span class="pre">heads,</span> <span class="pre">dict)</span></code> to create gradient entries, |
| where heads are input entries to the backward op, composed from ograds and n->inputs.</p> |
| </div> |
| <div class="section" id="fcompute-xpu"> |
| <span id="fcompute-xpu"></span><h4>FCompute<xpu><a class="headerlink" href="#fcompute-xpu" title="Permalink to this headline">¶</a></h4> |
| <p>Simple operators can register FCompute<xpu> with <code class="docutils literal"><span class="pre">.set_attr<FCompute>("FCompute<cpu>",</span> <span class="pre">...)</span></code> and <code class="docutils literal"><span class="pre">.set_attr<FCompute>("FCompute<gpu>",</span> <span class="pre">...)</span></code> for both CPU and (optionally) GPU computation.</xpu></p> |
| <p>FCompute has prototype</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="kt">void</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><code class="docutils literal"><span class="pre">req</span></code> has the same length as <code class="docutils literal"><span class="pre">outputs</span></code>. |
| Each entry of <code class="docutils literal"><span class="pre">req</span></code> specifies |
| how the corresponding <code class="docutils literal"><span class="pre">output</span></code> should be written to. |
| <code class="docutils literal"><span class="pre">OpReqType</span></code> is defined as:</p> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="k">enum</span> <span class="n">OpReqType</span> <span class="p">{</span> |
| <span class="n">kNullOp</span><span class="p">,</span> |
| <span class="n">kWriteTo</span><span class="p">,</span> |
| <span class="n">kWriteInplace</span><span class="p">,</span> |
| <span class="n">kAddTo</span> |
| <span class="p">};</span> |
| </pre></div> |
| </div> |
| <p>Normally, the <code class="docutils literal"><span class="pre">req</span></code> of all <code class="docutils literal"><span class="pre">outputs</span></code> should be <code class="docutils literal"><span class="pre">kWriteTo</span></code>, |
| meaning that the provided <code class="docutils literal"><span class="pre">outputs</span></code> tensor is a <em>raw</em> memory block, |
| so the operator should write results directly into it. |
| In some cases, for example, when calculating the gradient tensor, |
| it would be great if we could accumulate the result, |
| rather than directly overwrite the tensor contents |
| so that no extra space needs to be created each time. |
| In such cases, the corresponding <code class="docutils literal"><span class="pre">req</span></code> is set to <code class="docutils literal"><span class="pre">kAddTo</span></code>, |
| indicating that a <code class="docutils literal"><span class="pre">+=</span></code> should be used.</p> |
| </div> |
| </div> |
| <div class="section" id="example-abs-operator"> |
| <span id="example-abs-operator"></span><h3>Example: abs operator<a class="headerlink" href="#example-abs-operator" title="Permalink to this headline">¶</a></h3> |
| <div class="highlight-c++"><div class="highlight"><pre><span></span><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">abs</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">MXNET_DESCRIBE</span><span class="p">(</span><span class="s">"Take absolute value of the src"</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o"><</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">></span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">ElemwiseType</span><span class="o"><</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">></span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span> |
| <span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">){</span> |
| <span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">></span> <span class="o">></span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span> |
| <span class="p">})</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">UnaryCompute</span><span class="o"><</span><span class="n">cpu</span><span class="p">,</span> <span class="n">mshadow_op</span><span class="o">::</span><span class="n">abs</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FGradient</span><span class="o">></span><span class="p">(</span><span class="s">"FGradient"</span><span class="p">,</span> <span class="n">ElemwiseGradUseIn</span><span class="p">{</span><span class="s">"_backward_abs"</span><span class="p">});</span> |
| <span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"data"</span><span class="p">,</span> <span class="s">"NDArray"</span><span class="p">,</span> <span class="s">"Source input"</span><span class="p">)</span> |
| |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_abs</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">ElemwiseShape</span><span class="o"><</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">></span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">ElemwiseType</span><span class="o"><</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="o">></span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">></span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span> |
| <span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">){</span> |
| <span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">></span> <span class="o">></span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">},</span> <span class="p">{</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span> |
| <span class="p">})</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">BinaryCompute</span><span class="o"><</span><span class="n">cpu</span><span class="p">,</span> <span class="n">unary_bwd</span><span class="o"><</span><span class="n">mshadow_op</span><span class="o">::</span><span class="n">sign</span><span class="o">></span> <span class="o">></span><span class="p">);</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="legacy-operators"> |
| <span id="legacy-operators"></span><h3>Legacy Operators<a class="headerlink" href="#legacy-operators" title="Permalink to this headline">¶</a></h3> |
| <p>For the legacy (pre 0.9) way of defining operators with C++, please see:</p> |
| <div class="toctree-wrapper compound"> |
| <ul> |
| <li class="toctree-l1"><a class="reference external" href="http://mxnet.io/architecture/overview.html#operators-in-mxnet">Developer Guide - Operators</a></li> |
| <li class="toctree-l1"><a class="reference external" href="http://mxnet.io/architecture/overview.html#simpleop-the-unified-operator-api">Developer Guide - SimpleOp</a></li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| <div class="container"> |
| <div class="footer"> |
| <p> © 2015-2017 DMLC. All rights reserved. </p> |
| </div> |
| </div> |
| </div> |
| <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <h3><a href="../index.html">Table Of Contents</a></h3> |
| <ul> |
| <li><a class="reference internal" href="#">How to Create New Operators (Layers)</a><ul> |
| <li><a class="reference internal" href="#customop">CustomOp</a></li> |
| <li><a class="reference internal" href="#c">C++</a><ul> |
| <li><a class="reference internal" href="#operator-attribute-system">Operator Attribute System</a></li> |
| <li><a class="reference internal" href="#list-of-basic-attributes">List of basic attributes</a><ul> |
| <li><a class="reference internal" href="#descriptions-optional">Descriptions (Optional)</a></li> |
| <li><a class="reference internal" href="#attribute-parser-optional">Attribute Parser (Optional)</a></li> |
| <li><a class="reference internal" href="#inputs-outputs">Inputs & Outputs</a></li> |
| <li><a class="reference internal" href="#argument-descriptions">Argument Descriptions</a></li> |
| <li><a class="reference internal" href="#finfershape-or-tisbackward-for-backward-only-ops">FInferShape or TIsBackward (for Backward Only Ops)</a></li> |
| <li><a class="reference internal" href="#finfertype">FInferType</a></li> |
| <li><a class="reference internal" href="#finplaceoption-optional">FInplaceOption (Optional)</a></li> |
| <li><a class="reference internal" href="#fgradient-optional-for-imperative-use-required-for-symbolic-use">FGradient (Optional for imperative use, required for symbolic use)</a></li> |
| <li><a class="reference internal" href="#fcompute-xpu">FCompute<xpu></a></li> |
| </ul> |
| </li> |
| <li><a class="reference internal" href="#example-abs-operator">Example: abs operator</a></li> |
| <li><a class="reference internal" href="#legacy-operators">Legacy Operators</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> <!-- pagename != index --> |
| <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |
| <script src="../_static/js/sidebar.js" type="text/javascript"></script> |
| <script src="../_static/js/search.js" type="text/javascript"></script> |
| <script src="../_static/js/navbar.js" type="text/javascript"></script> |
| <script src="../_static/js/clipboard.min.js" type="text/javascript"></script> |
| <script src="../_static/js/copycode.js" type="text/javascript"></script> |
| <script type="text/javascript"> |
| $('body').ready(function () { |
| $('body').css('visibility', 'visible'); |
| }); |
| </script> |
| </div></body> |
| </html> |