| <!DOCTYPE html> |
| |
| <!--- |
| Licensed to the Apache Software Foundation (ASF) under one |
| or more contributor license agreements. See the NOTICE file |
| distributed with this work for additional information |
| regarding copyright ownership. The ASF licenses this file |
| to you under the Apache License, Version 2.0 (the |
| "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at |
| http://www.apache.org/licenses/LICENSE-2.0 |
| Unless required by applicable law or agreed to in writing, |
| software distributed under the License is distributed on an |
| "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| KIND, either express or implied. See the License for the |
| specific language governing permissions and limitations |
| under the License. |
| --> |
| |
| <html lang=" en"><head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge"> |
| <meta name="viewport" content="width=device-width, initial-scale=1"> |
| <link href="/versions/1.9.1/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 --> |
| <title>A Beginner's Guide to Implementing Operators in MXNet Backend | Apache MXNet</title> |
| <meta name="generator" content="Jekyll v3.8.6" /> |
| <meta property="og:title" content="A Beginner's Guide to Implementing Operators in MXNet Backend" /> |
| <meta property="og:locale" content="en_US" /> |
| <meta name="description" content="A flexible and efficient library for deep learning." /> |
| <meta property="og:description" content="A flexible and efficient library for deep learning." /> |
| <link rel="canonical" href="https://mxnet.apache.org/versions/1.9.1/api/faq/add_op_in_backend" /> |
| <meta property="og:url" content="https://mxnet.apache.org/versions/1.9.1/api/faq/add_op_in_backend" /> |
| <meta property="og:site_name" content="Apache MXNet" /> |
| <script type="application/ld+json"> |
| {"headline":"A Beginner's Guide to Implementing Operators in MXNet Backend","description":"A flexible and efficient library for deep learning.","url":"https://mxnet.apache.org/versions/1.9.1/api/faq/add_op_in_backend","@type":"WebPage","@context":"https://schema.org"}</script> |
| <!-- End Jekyll SEO tag --> |
| <link rel="stylesheet" href="/versions/1.9.1/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/1.9.1/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.1/feed.xml" title="Apache MXNet" /><!-- Matomo --> |
| <script> |
| var _paq = window._paq = window._paq || []; |
| /* tracker methods like "setCustomDimension" should be called before "trackPageView" */ |
| /* We explicitly disable cookie tracking to avoid privacy issues */ |
| _paq.push(['disableCookies']); |
| _paq.push(['trackPageView']); |
| _paq.push(['enableLinkTracking']); |
| (function() { |
| var u="https://analytics.apache.org/"; |
| _paq.push(['setTrackerUrl', u+'matomo.php']); |
| _paq.push(['setSiteId', '23']); |
| var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0]; |
| g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s); |
| })(); |
| </script> |
| <!-- End Matomo Code --> |
| |
| <script src="/versions/1.9.1/assets/js/jquery-3.3.1.min.js"></script> |
| <script src="/versions/1.9.1/assets/js/docsearch.min.js"></script><script src="/versions/1.9.1/assets/js/globalSearch.js" defer></script> |
| <script src="/versions/1.9.1/assets/js/clipboard.js" defer></script> |
| <script src="/versions/1.9.1/assets/js/copycode.js" defer></script></head> |
| <body><header class="site-header" role="banner"> |
| |
| <script> |
| $(document).ready(function () { |
| |
| // HEADER OPACITY LOGIC |
| |
| function opacity_header() { |
| var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")" |
| $('.site-header').css("background-color", value) |
| } |
| |
| $(window).scroll(function () { |
| opacity_header() |
| }) |
| opacity_header(); |
| |
| // MENU SELECTOR LOGIC |
| $('.page-link').each( function () { |
| if (window.location.href.includes(this.href)) { |
| $(this).addClass("page-current"); |
| } |
| }); |
| }) |
| </script> |
| <div class="wrapper"> |
| <a class="site-title" rel="author" href="/versions/1.9.1/"><img |
| src="/versions/1.9.1/assets/img/mxnet_logo.png" class="site-header-logo"></a> |
| <nav class="site-nav"> |
| <input type="checkbox" id="nav-trigger" class="nav-trigger"/> |
| <label for="nav-trigger"> |
| <span class="menu-icon"> |
| <svg viewBox="0 0 18 15" width="18px" height="15px"> |
| <path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/> |
| </svg> |
| </span> |
| </label> |
| <div class="gs-search-border"> |
| <div id="gs-search-icon"></div> |
| <form id="global-search-form"> |
| <input id="global-search" type="text" title="Search" placeholder="Search" /> |
| <div id="global-search-dropdown-container"> |
| <button class="gs-current-version btn" type="button" data-toggle="dropdown"> |
| <span id="gs-current-version-label">1.9.1</span> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown"> |
| |
| |
| <li class="gs-opt gs-versions">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions active">1.9.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| <span id="global-search-close">x</span> |
| </form> |
| </div> |
| <div class="trigger"> |
| <div id="global-search-mobile-border"> |
| <div id="gs-search-icon-mobile"></div> |
| <input id="global-search-mobile" placeholder="Search..." type="text"/> |
| <div id="global-search-dropdown-container-mobile"> |
| <button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown"> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown-mobile"> |
| |
| |
| <li class="gs-opt gs-versions">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions active">1.9.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| </div> |
| <a class="page-link" href="/versions/1.9.1/get_started">Get Started</a> |
| <a class="page-link" href="/versions/1.9.1/features">Features</a> |
| <a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a> |
| <a class="page-link" href="/versions/1.9.1/api">Docs & Tutorials</a> |
| <a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a> |
| <a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a> |
| <div class="dropdown" style="min-width:100px"> |
| <span class="dropdown-header">Apache |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content" style="min-width:250px"> |
| <a href="https://www.apache.org/foundation/">Apache Software Foundation</a> |
| <a href="https://incubator.apache.org/">Apache Incubator</a> |
| <a href="https://www.apache.org/licenses/">License</a> |
| <a href="/versions/1.9.1/api/faq/security.html">Security</a> |
| <a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a> |
| <a href="https://www.apache.org/events/current-event">Events</a> |
| <a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a> |
| <a href="https://www.apache.org/foundation/thanks.html">Thanks</a> |
| </div> |
| </div> |
| <div class="dropdown"> |
| <span class="dropdown-header">1.9.1 |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content"> |
| <a href="/">master</a> |
| <a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a> |
| <a href="/versions/1.8.0/">1.8.0</a> |
| <a href="/versions/1.7.0/">1.7.0</a> |
| <a href="/versions/1.6.0/">1.6.0</a> |
| <a href="/versions/1.5.0/">1.5.0</a> |
| <a href="/versions/1.4.1/">1.4.1</a> |
| <a href="/versions/1.3.1/">1.3.1</a> |
| <a href="/versions/1.2.1/">1.2.1</a> |
| <a href="/versions/1.1.0/">1.1.0</a> |
| <a href="/versions/1.0.0/">1.0.0</a> |
| <a href="/versions/0.12.1/">0.12.1</a> |
| <a href="/versions/0.11.0/">0.11.0</a> |
| </div> |
| </div> |
| </div> |
| </nav> |
| </div> |
| </header> |
| <main class="page-content" aria-label="Content"> |
| <script> |
| |
| </script> |
| <article class="post"> |
| |
| <header class="post-header wrapper"> |
| <h1 class="post-title">A Beginner's Guide to Implementing Operators in MXNet Backend</h1> |
| <h3></h3></header> |
| |
| <div class="post-content"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3 docs-side-bar"> |
| <h3 style="text-transform: capitalize; padding-left:10px">faq</h3> |
| <ul> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/add_op_in_backend">A Beginner's Guide to Implementing Operators in MXNet Backend</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/caffe">Convert from Caffe to MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/cloud">MXNet on the Cloud</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/distributed_training">Distributed Training in MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/env_var">Environment Variables</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/float16">Float16</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/large_tensor_support">Using MXNet with Large Tensor Support</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/model_parallel_lstm">Model Parallel</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/multi_device">Data Parallelism with Multiple CPU/GPUs on MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/new_op">Create New Operators</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/nnpack">NNPACK for Multi-Core CPU Support in MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/perf">Some Tips for Improving MXNet Performance</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/recordio">Create a Dataset Using RecordIO</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/s3_integration">Use data from S3 for training</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/security">MXNet Security Best Practices</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/smart_device">Deep Learning at the Edge</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/visualize_graph">Visualize Neural Networks</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/1.9.1/api/faq/why_mxnet">Why MXNet came to be?</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| </ul> |
| </div> |
| <div class="col-9"> |
| <!--- Licensed to the Apache Software Foundation (ASF) under one --> |
| |
| <!--- or more contributor license agreements. See the NOTICE file --> |
| |
| <!--- distributed with this work for additional information --> |
| |
| <!--- regarding copyright ownership. The ASF licenses this file --> |
| |
| <!--- to you under the Apache License, Version 2.0 (the --> |
| |
| <!--- "License"); you may not use this file except in compliance --> |
| |
| <!--- with the License. You may obtain a copy of the License at --> |
| |
| <!--- http://www.apache.org/licenses/LICENSE-2.0 --> |
| |
| <!--- Unless required by applicable law or agreed to in writing, --> |
| |
| <!--- software distributed under the License is distributed on an --> |
| |
| <!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> |
| |
| <!--- KIND, either express or implied. See the License for the --> |
| |
| <!--- specific language governing permissions and limitations --> |
| |
| <!--- under the License. --> |
| |
| <h1 id="a-beginners-guide-to-implementing-operators-in-mxnet-backend">A Beginner's Guide to Implementing Operators in MXNet Backend</h1> |
| |
| <h2 id="introduction">Introduction</h2> |
| |
| <p>Operators are essential elements for constructing neural networks. They define mathematical formulas |
| of transforming input data (tensors) to outputs. MXNet has a rich set of operators from simple ones, |
| such as element-wise sum, to complicated ones, such as convolution, that is |
| capable of constructing most of the popular neural networks. You may have noticed |
| that many operators implemented in MXNet have their equivalent forms in Numpy, such as |
| <a href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html">repeat</a>, |
| <a href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html">tile</a>, |
| etc., and wonder why we could not simply use those Numpy operators in MXNet. One of the |
| major reasons is that we need to support both CPU and GPU computing for the operators in MXNet, |
| while Numpy operators do not possess GPU computing capability. |
| In addition, we have performed plenty of |
| optimizations for various components in MXNet, such as tensor data structure (<code>NDArray</code>), |
| execution engine, computational graph and so on, for maximizing memory and runtime efficiency. |
| An operator implemented under the MXNet operator framework would greatly |
| leverage those optimizations for exhaustive performance enhancement.</p> |
| |
| <p>In this tutorial, we are going to practice implementing an operator using |
| C++ in the MXNet backend. After finishing the implementation, |
| we will add unit tests using Python for the operator we just implemented.</p> |
| |
| <h2 id="implementation">Implementation</h2> |
| |
| <h3 id="an-operator-example">An Operator Example</h3> |
| |
| <p>Let's take the <a href="https://en.wikipedia.org/wiki/Quadratic_function">quadratic function</a> |
| as an example: <code>f(x) = ax^2+bx+c</code>. We want to implement an operator called <code>quadratic</code> |
| taking <code>x</code>, which is a tensor, as an input and generating an output tensor <code>y</code> |
| satisfying <code>y.shape=x.shape</code> and each element of <code>y</code> is calculated by feeding the |
| corresponding element of <code>x</code> into the quadratic function <code>f</code>. |
| Here variables <code>a</code>, <code>b</code>, and <code>c</code> are user input parameters. |
| In frontend, the operator works like this:</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">x</span> <span class="o">=</span> <span class="p">[[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]]</span> |
| <span class="n">y</span> <span class="o">=</span> <span class="n">quadratic</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> |
| <span class="n">y</span> <span class="o">=</span> <span class="p">[[</span><span class="mi">6</span><span class="p">,</span> <span class="mi">11</span><span class="p">],</span> <span class="p">[</span><span class="mi">18</span><span class="p">,</span> <span class="mi">27</span><span class="p">]]</span> |
| </code></pre></div> |
| <p>To implement this, we first create three files: <code>quadratic_op-inl.h</code>, |
| <code>quadratic_op.cc</code>, and <code>quadratic_op.cu</code>. The header file's name |
| is prefixed by the operator name and followed by <code>op</code> and <code>-inl</code> |
| indicating that this is an operator implementation with inline |
| functions shared by CPU and GPU computing. The CPU and GPU |
| specific implementations reside in their own <code>.cc</code> and <code>.cu</code> files, |
| respectively. We normally put pure tensor related operators |
| (e.g. <code>tile</code>, <code>repeat</code>, etc.) under |
| the directory <code>src/operator/tensor</code>, and neural network operators |
| (e.g. <code>Convolution</code>, <code>Pooling</code>, etc.) under <code>src/operator/nn</code>. |
| You may have noticed that many neural network operators including |
| <code>Convolution</code> and <code>Pooling</code> are currently saved under <code>src/operator</code>. |
| We plan to move them to <code>src/operator/nn</code> for better file organization |
| and clearer hierarchy in the future.</p> |
| |
| <p>Next, we are going to |
| 1. Define the parameter struct |
| for registering <code>a</code>, <code>b</code>, and <code>c</code> in <code>quadratic_op-inl.h</code>. |
| 2. Define type and shape inference functions in <code>quadratic_op-inl.h</code>. |
| 3. Define forward and backward functions in <code>quadratic_op-inl.h</code>. |
| 4. Register the operator using <a href="https://docs.tvm.ai/dev/nnvm_overview.html">nnvm</a> |
| in <code>quadratic_op.cc</code> and <code>quadratic_op.cu</code> for |
| CPU and GPU computing, respectively.</p> |
| |
| <p>Now let's walk through the process step by step.</p> |
| |
| <h3 id="parameter-registration">Parameter Registration</h3> |
| |
| <p>We first define <code>struct QuadraticParam</code> as a placeholder for the |
| parameters <code>a</code>, <code>b</code>, and <code>c</code> in <code>quadratic_op-inl.h</code>. |
| The struct inherits from a base template |
| struct named <code>dmlc::Parameter</code>, where the template argument is the derived struct |
| <code>QuadraticParam</code>. This technique, which is called <a href="https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern">curiously recurring template |
| pattern</a>, |
| achieves static polymorphism. It is similar to using a virtual function, |
| but without the cost associated with dynamic polymorphism.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">struct</span> <span class="n">QuadraticParam</span> <span class="o">:</span> <span class="k">public</span> <span class="n">dmlc</span><span class="o">::</span><span class="n">Parameter</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span> <span class="p">{</span> |
| <span class="kt">float</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">;</span> |
| <span class="n">DMLC_DECLARE_PARAMETER</span><span class="p">(</span><span class="n">QuadraticParam</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_default</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Coefficient of the quadratic term in the quadratic function."</span><span class="p">);</span> |
| <span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">b</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_default</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Coefficient of the linear term in the quadratic function."</span><span class="p">);</span> |
| <span class="n">DMLC_DECLARE_FIELD</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_default</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">"Constant term in the quadratic function."</span><span class="p">);</span> |
| <span class="p">}</span> |
| <span class="p">};</span> |
| </code></pre></div> |
| <p>The function calls in the above parameter struct are self-explanatory by their names. |
| Note that for each parameter, we set the default value to <code>0.0</code> such that users can |
| skip passing 0-value parameters through the quadratic operator interface. You |
| can choose not to define the default value for a parameter if it is required |
| at runtime. Meanwhile, adding brief descriptions to the parameters enables |
| the documentation engine to display them on |
| <a href="/versions/1.9.1/api/python/docs/api">MXNet documentation web page</a>.</p> |
| |
| <h3 id="attribute-inference">Attribute Inference</h3> |
| |
| <p>Attribute inference is the process of deducing the properties of <code>NDArray</code>s |
| in neural networks from user provided information. Two most common attributes |
| of an <code>NDArray</code> are data shape and data type. |
| Let's take a look at the following example. |
| Given an input <code>NDArray</code> called <code>data</code>, you invoke the <code>quadratic</code> operator |
| like this: <code>output = mx.nd.quadratic(data, a=1, b=2, c=3)</code>. Before calculating |
| the <code>output</code> values, its shape and data type are inferred from the input |
| <code>data</code>'s shape and type following |
| the rules you defined in order to allocate memory space for the output tensor.</p> |
| |
| <p>One important thing to note that inference functions should be capable of |
| performing <strong>mutual inference</strong>, i.e. |
| inferring one argument's attribute from another argument's attribute if |
| possible according to the definition of the operator. |
| This is very useful for a computational graph to deduce unknown attributes |
| for a neural network in symbolic programming. Users can view the computational |
| graph as a symbol with every element initialized for running data |
| throughout the neural network, including memory allocation for each tensor, |
| device placement for each operator, etc. Users normally just need |
| to provide minimum necessary information, such as input data shapes, etc., |
| to the computational graph, and the graph will fill up the unknown attributes |
| using the attribute inference functions defined in the operators building up |
| the neural network.</p> |
| |
| <p>Let's consider the following example.</p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">>>></span> <span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="n">mx</span> |
| <span class="o">>>></span> <span class="n">a</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s">'a'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> |
| <span class="o">>>></span> <span class="n">b</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s">'b'</span><span class="p">)</span> |
| <span class="o">>>></span> <span class="n">c</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s">'c'</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> |
| <span class="o">>>></span> <span class="n">d</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="n">b</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="n">c</span> |
| <span class="o">>>></span> <span class="k">print</span> <span class="n">d</span><span class="o">.</span><span class="n">infer_shape</span><span class="p">()</span> |
| <span class="p">([(</span><span class="il">2L</span><span class="p">,</span> <span class="il">3L</span><span class="p">),</span> <span class="p">(</span><span class="il">2L</span><span class="p">,</span> <span class="il">3L</span><span class="p">),</span> <span class="p">(</span><span class="il">2L</span><span class="p">,</span> <span class="il">3L</span><span class="p">)],</span> <span class="p">[(</span><span class="il">2L</span><span class="p">,</span> <span class="il">3L</span><span class="p">)],</span> <span class="p">[])</span> |
| </code></pre></div> |
| <p>The last line of the above code snippet is a tuple of three lists returned |
| by <code>d.infer_shape()</code>. The first list contains all the argument shapes |
| of <code>a</code>, <code>b</code>, and <code>c</code>. The second contains the output shape of <code>d</code>. The |
| third one represents the shapes of auxiliary states, which is not used |
| in this case, and thus is empty. |
| In this example, we only specified values for variable <code>a</code>'s first dimension |
| and <code>c</code>'s second dimension. The <code>0</code> in shape <code>(2, 0)</code> indicates that the size |
| of the second dimension is unknown, same meaning for shape <code>(0, 3)</code>. |
| However, the symbol <code>d</code> still successfully inferred the shapes |
| for all the variables and final output. This is a result of mutual |
| inference. In MXNet, the whole process can be interpreted as this: |
| 1. <code>a</code> and <code>b</code> are combined via an element-wise multiplication operator, |
| so the shapes of <code>a</code> and <code>b</code> are same and <code>b</code>'s first dimension size is <code>2</code>. |
| 2. <code>b</code> and <code>c</code> are combined via an element-wise multiplication operator too, |
| so the shapes of <code>b</code> and <code>c</code> are same and <code>b</code>'s second dimension size is <code>3</code>. |
| 3. Now <code>b</code>'s shape is completely known, so <code>a</code> and <code>c</code> missing dimension sizes |
| are known as well. |
| 4. <code>d</code> is a result from adding <code>a * b</code> and <code>b * c</code>, so d should also |
| have the same shape as <code>b</code>.</p> |
| |
| <p>The above four steps illustrate how shape inference logic works in MXNet. |
| It is actually implemented in the shape inference functions of the operators for |
| element-wise multiplication and addition.</p> |
| |
| <p>For our <code>quadratic</code> operator, shape inference possesses quite similar logic.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kr">inline</span> <span class="kt">bool</span> <span class="nf">QuadraticOpShape</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span><span class="o">*</span> <span class="n">in_attrs</span><span class="p">,</span> |
| <span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span><span class="o">*</span> <span class="n">out_attrs</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">in_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">out_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| |
| <span class="n">SHAPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">out_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">in_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span> |
| <span class="n">SHAPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">in_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span> |
| <span class="k">return</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">).</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">0U</span> <span class="o">&&</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">).</span><span class="n">Size</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">0U</span><span class="p">;</span> |
| <span class="p">}</span> |
| </code></pre></div> |
| <p>Here are a few things to note about the above function:</p> |
| |
| <ol> |
| <li><code>attrs</code> contains parameters <code>a</code>, <code>b</code>, and <code>c</code> from user input. |
| It's not used here since we don't rely on that information for shape inference.</li> |
| <li><code>in_attrs</code> is a vector containing all input shapes. Since there is |
| only one input argument for operator <code>quadratic</code>, we used macro <code>CHECK_EQ</code> |
| to assert when the vector's size is wrong.</li> |
| <li><code>out_attrs</code> is a vector containing all output shapes. We also used |
| <code>CHECK_EQ</code> to verify the size of the vector since there is only one output.</li> |
| <li>We called macro <code>SHAPE_ASSIGN_CHECK</code> twice for mutual inference. One for |
| inferring the output shape from the input shape, the other one is for inferring |
| the input shape from the output shape. |
| If there are any unequal non-zero values in the same |
| dimension of two shapes, such as <code>(2, 3)</code> and <code>(3, 3)</code>, the macro would throw an |
| exception with an error message for shape inference.</li> |
| <li>At the end of the function body, we checked whether the output shape |
| is completely known by testing whether the shape is not empty and |
| the shape's size is greater than <code>0</code>. Note that in MXNet, an empty shape |
| means that the shape is unknown, and |
| a <code>0</code> in a shape means that the size of that dimension is unknown. In both |
| situations, the missing shape information must |
| be inferred from other shapes. If it cannot be inferred, |
| the function should return <code>false</code> to notify the caller about shape inference failure.</li> |
| <li>MXNet provides a convenience function implementing the logic of mutual inference |
| for general element-wise operators with the following interface. Users can |
| instantiate this function with <code>n_in=1</code> and <code>n_out=1</code> to replace the above |
| function <code>QuadraticOpShape</code> in operator registration (explained later). |
| The function <code>QuadraticOpShape</code> posted here is for the purpose of illustration only.</li> |
| </ol> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">n_in</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n_out</span><span class="o">></span> |
| <span class="kr">inline</span> <span class="kt">bool</span> <span class="nf">ElemwiseShape</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">*</span><span class="n">in_attrs</span><span class="p">,</span> |
| <span class="n">mxnet</span><span class="o">::</span><span class="n">ShapeVector</span> <span class="o">*</span><span class="n">out_attrs</span><span class="p">);</span> |
| </code></pre></div> |
| <p>The same logic goes for data type inference. We will leave the analysis of |
| the following code sample to users. Note that <code>-1</code> means the data type |
| is unknown and must be inferred from other input or output data types.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kr">inline</span> <span class="kt">bool</span> <span class="nf">QuadraticOpType</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">in_attrs</span><span class="p">,</span> |
| <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">out_attrs</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">in_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">out_attrs</span><span class="o">-></span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| |
| <span class="n">TYPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">out_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">in_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span> |
| <span class="n">TYPE_ASSIGN_CHECK</span><span class="p">(</span><span class="o">*</span><span class="n">in_attrs</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">));</span> |
| <span class="k">return</span> <span class="n">out_attrs</span><span class="o">-></span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span> |
| <span class="p">}</span> |
| </code></pre></div> |
| <p>Again, MXNet provides the following convenience function for mutual |
| type inference of element-wise operators. Users can use that |
| in operator registration (explained later).</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">n_in</span><span class="p">,</span> <span class="kt">int</span> <span class="n">n_out</span><span class="o">></span> |
| <span class="kr">inline</span> <span class="kt">bool</span> <span class="nf">ElemwiseType</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">in_attrs</span><span class="p">,</span> |
| <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="kt">int</span><span class="o">>*</span> <span class="n">out_attrs</span><span class="p">);</span> |
| </code></pre></div> |
| <h3 id="forward-function">Forward Function</h3> |
| |
| <p>Forward function defines the operator's behavior in the forward pass |
| of neural networks. For our <code>quadratic</code> operator, it simply implements |
| the logic of running a tensor through the quadratic function by performing |
| a few element-wise operations. The forward function's signature is fixed |
| in MXNet as follows:</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kt">void</span> <span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">);</span> |
| </code></pre></div> |
| <p>We first paste the whole forward function code here |
| and then go through it line by line.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span> <span class="c1">// 1</span> |
| <span class="kt">void</span> <span class="nf">QuadraticOpForward</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> <span class="c1">// 2</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> <span class="c1">// 3</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> <span class="c1">// 4</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> <span class="c1">// 5</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 6</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 7</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 8</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 9</span> |
| <span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span><span class="p">();</span> <span class="c1">// 10</span> |
| <span class="k">const</span> <span class="n">TBlob</span><span class="o">&</span> <span class="n">in_data</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 11</span> |
| <span class="k">const</span> <span class="n">TBlob</span><span class="o">&</span> <span class="n">out_data</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 12</span> |
| <span class="k">const</span> <span class="n">QuadraticParam</span><span class="o">&</span> <span class="n">param</span> <span class="o">=</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">get</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">(</span><span class="n">attrs</span><span class="p">.</span><span class="n">parsed</span><span class="p">);</span> <span class="c1">// 13</span> |
| <span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet_op</span><span class="p">;</span> <span class="c1">// 14</span> |
| <span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">out_data</span><span class="p">.</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 15</span> |
| <span class="n">MXNET_ASSIGN_REQ_SWITCH</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req_type</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 16</span> |
| <span class="n">Kernel</span><span class="o"><</span><span class="n">quadratic_forward</span><span class="o"><</span><span class="n">req_type</span><span class="o">></span><span class="p">,</span> <span class="n">xpu</span><span class="o">>::</span><span class="n">Launch</span><span class="p">(</span> <span class="c1">// 17</span> |
| <span class="n">s</span><span class="p">,</span> <span class="n">out_data</span><span class="p">.</span><span class="n">Size</span><span class="p">(),</span> <span class="n">out_data</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">in_data</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="c1">// 18</span> |
| <span class="n">param</span><span class="p">.</span><span class="n">a</span><span class="p">,</span> <span class="n">param</span><span class="p">.</span><span class="n">b</span><span class="p">,</span> <span class="n">param</span><span class="p">.</span><span class="n">c</span><span class="p">);</span> <span class="c1">// 19</span> |
| <span class="p">});</span> <span class="c1">// 20</span> |
| <span class="p">});</span> <span class="c1">// 21</span> |
| <span class="p">}</span> <span class="c1">// 22</span> |
| </code></pre></div> |
| <ul> |
| <li>Line 1: <code>xpu</code> stands for a generic device type so that the function can be instantiated |
| for both CPU and GPU computing using concrete types <code>cpu</code> and <code>gpu</code>. The instantiation happens |
| at the time when the operator is registered in <code>.cc</code> and <code>.cu</code> files.</li> |
| <li>Line 2: <code>attrs</code> is a node attribute containing the user input parameters <code>a</code>, <code>b</code>, and <code>c</code>. |
| Here the node represents a placeholder for the operator in the whole computational graph for |
| the neural network.</li> |
| <li>Line 3: <code>ctx</code> holds something called <code>stream</code> for |
| serializing asynchronous executions. Let's consider |
| this example for understanding the functionality of <code>stream</code>. |
| We want to launch several GPU kernels with the same <code>stream</code> from CPU. |
| Even though the launching operation is non-blocking, the <code>stream</code> guarantees |
| that the kernels execute in the same order on GPU as they are launched from CPU.</li> |
| <li>Line 4: <code>inputs</code> is a vector of input tensors (only one input tensor |
| for the <code>quadratic</code> operator).</li> |
| <li>Line 5: <code>req</code> is a vector of <code>OpReqType</code> values. Each value defines |
| the way of writing calculated values to the output tensors. |
| Therefore, the number of <code>req</code>s must be the same as the number of output tensors. |
| MXNet currently supports three types of <code>req</code> in frontend: <code>null</code>, <code>write</code>, and <code>add</code>. |
| <code>null</code> means skipping calculating the corresponding output tensor, |
| <code>write</code> means overwriting the values in the output tensor with the ones |
| calculated by this operator, and <code>add</code> means adding the calculated values |
| to the existing ones in the output tensor. Note that <code>null</code> and <code>add</code> are usually |
| seen in backward passes. The former is for skipping calculating |
| the gradients of un-learnable parameters (such as index arrays), |
| and the latter is for accumulating gradients throughout networks.</li> |
| <li>Line 6: <code>outputs</code> is a vector of output tensors (only one |
| output tensor for the <code>quadratic</code> operator).</li> |
| <li>Lines 7-9: Verify that the size of each vector is expected. |
| Otherwise, stop moving forward and print error message.</li> |
| <li>Line 10: Get the <code>stream</code> from the <code>ctx</code> for launching kernels.</li> |
| <li>Lines 11-12: Define the references of the input and output tensors |
| for later coding convenience. Note that <code>TBlob</code> can be understood |
| as a uniform data structure for tensors of various dimensions, such |
| that tensors of different dimensions can be put in a homogeneous container, |
| such as <code>std::vector</code> and <code>std::list</code>. You can still |
| get tensors of desired dimensions from a <code>TBlob</code> object through |
| the interface <code>get_with_shape</code>.</li> |
| <li>Line 13: Get user input parameters from the node attribute.</li> |
| <li>Lines 15-21: This is the place where the mathematical formula of the operator |
| is implemented. The macros <code>MSHADOW_TYPE_SWITCH</code> and <code>MXNET_ASSIGN_REQ_SWITCH</code> enable |
| the code block to work for all the supported data types and <code>req</code> types in MXNet. |
| Inside the inner-most macro, we launch the kernel for calculating |
| the output tensor such that each thread takes an element from |
| the input tensor, feeds it into the quadratic function, and assigns |
| the output element to the output tensor based on <code>req</code> type. Note that |
| <code>Kernel::Launch</code> serves as a universal interface for launching |
| parallel computation on both CPU and GPU. This allows most of |
| the simple operators to share the same piece of code for CPU and GPU as |
| parallelization approaches are often identical on both types of devices. |
| The kernel function is defined as the following, where the function |
| <code>Map</code> is executed by each thread for each input element. The <code>out_data.Size()</code>, |
| in the <code>Kernel::Launch</code> function corresponds to the factor by which the |
| workload will get parallelized among the different threads, which here |
| corresponds to the size of the output array. To explain a little |
| bit more on the two macros used in the kernel struct: (1) <code>MSHADOW_XINLINE</code> is |
| a consolidated macro for inlining functions compiled by both CPU and GPU |
| compilers. It enables CPU and GPU computing to share the same piece of code. |
| (2) <code>KERNEL_ASSIGN</code> is a macro for unifying the statements of different <code>req</code>s |
| into the same line of code. It's named <code>KERNEL_ASSIGN</code> because we call |
| the code blocks running parallel computation kernels. |
| On CPUs, the kernels are normally wrapped by the OpenMP <code>parallel</code> directive; |
| while on GPUs, they are the kernel functions launched by CUDA library.</li> |
| </ul> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">req</span><span class="o">></span> |
| <span class="k">struct</span> <span class="n">quadratic_forward</span> <span class="p">{</span> |
| <span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">DType</span><span class="o">></span> |
| <span class="n">MSHADOW_XINLINE</span> <span class="k">static</span> <span class="kt">void</span> <span class="n">Map</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="p">,</span> <span class="n">DType</span><span class="o">*</span> <span class="n">out_data</span><span class="p">,</span> <span class="k">const</span> <span class="n">DType</span><span class="o">*</span> <span class="n">in_data</span><span class="p">,</span> |
| <span class="k">const</span> <span class="kt">float</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">b</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">c</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">KERNEL_ASSIGN</span><span class="p">(</span><span class="n">out_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">req</span><span class="p">,</span> <span class="n">in_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span> <span class="o">*</span> <span class="n">in_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span> <span class="o">+</span> <span class="n">c</span><span class="p">);</span> |
| <span class="p">}</span> |
| <span class="p">};</span> |
| </code></pre></div> |
| <h3 id="backward-function">Backward Function</h3> |
| |
| <p>Backward functions play the role of propagating derivatives of loss function |
| with respect to the outputs of the last layer throughout the network to the first |
| layer. The whole process is often known as backward propagation. We are not |
| going to delineate the principle of backward propagation here since users can find |
| great details covered in other resources, such as |
| <a href="https://cs231n.github.io/optimization-2/">CS231n</a> and |
| <a href="https://neuralnetworksanddeeplearning.com/chap2.html">How the backgropagation algorithm works</a>. |
| The problem we are going to solve here for the <code>quadratic</code> operator is that |
| given a tensor representing the gradient of the loss function with respect |
| to the output of the operator, calculate the gradient with respect to |
| the input of the operator. There is no need to calculate the derivatives |
| of loss function with respect to user input parameters <code>a</code>, <code>b</code>, and <code>c</code> |
| since they are not learnable parameters in the network. To formulate the problem: |
| given <code>dL/dy</code> and <code>y = a*x^2 + b*x + c</code>, where <code>L</code> represents the loss function and |
| <code>y</code> stands for the output of the quadratic tensor, we need to solve for |
| <code>dL/dx</code>. Using the chain-rule, it is obvious to find that</p> |
| <div class="highlight"><pre><code class="language-" data-lang="">dL/dx = dL/dy * dy/dx = dL/dy * (2*a*x + b). |
| </code></pre></div> |
| <p>The above equation indicates that <code>dL/dx</code> depends on the gradient |
| of the output tensor and value of the input tensor. |
| The backward function's signature is the same as the forward function's. |
| With the aforementioned information in mind, |
| let's breakdown the following backward function line by line.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">xpu</span><span class="o">></span> <span class="c1">// 1</span> |
| <span class="kt">void</span> <span class="nf">QuadraticOpBackward</span><span class="p">(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> <span class="c1">// 2</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> <span class="c1">// 3</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> <span class="c1">// 4</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> <span class="c1">// 5</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 6</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">2U</span><span class="p">);</span> <span class="c1">// 7</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 8</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">req</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> <span class="c1">// 9</span> |
| <span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span> <span class="o">*</span><span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">xpu</span><span class="o">></span><span class="p">();</span> <span class="c1">// 10</span> |
| <span class="k">const</span> <span class="n">TBlob</span><span class="o">&</span> <span class="n">out_grad</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 11</span> |
| <span class="k">const</span> <span class="n">TBlob</span><span class="o">&</span> <span class="n">in_data</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span> <span class="c1">// 12</span> |
| <span class="k">const</span> <span class="n">TBlob</span><span class="o">&</span> <span class="n">in_grad</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">];</span> <span class="c1">// 13</span> |
| <span class="k">const</span> <span class="n">QuadraticParam</span><span class="o">&</span> <span class="n">param</span> <span class="o">=</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">get</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">(</span><span class="n">attrs</span><span class="p">.</span><span class="n">parsed</span><span class="p">);</span> <span class="c1">// 14</span> |
| <span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet_op</span><span class="p">;</span> <span class="c1">// 15</span> |
| <span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">out_grad</span><span class="p">.</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 16</span> |
| <span class="n">MXNET_ASSIGN_REQ_SWITCH</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">req_type</span><span class="p">,</span> <span class="p">{</span> <span class="c1">// 17</span> |
| <span class="n">Kernel</span><span class="o"><</span><span class="n">quadratic_backward</span><span class="o"><</span><span class="n">req_type</span><span class="o">></span><span class="p">,</span> <span class="n">xpu</span><span class="o">>::</span><span class="n">Launch</span><span class="p">(</span> <span class="c1">// 18</span> |
| <span class="n">s</span><span class="p">,</span> <span class="n">in_grad</span><span class="p">.</span><span class="n">Size</span><span class="p">(),</span> <span class="n">in_grad</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">out_grad</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="c1">// 19</span> |
| <span class="n">in_data</span><span class="p">.</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">param</span><span class="p">.</span><span class="n">a</span><span class="p">,</span> <span class="n">param</span><span class="p">.</span><span class="n">b</span><span class="p">);</span> <span class="c1">// 20</span> |
| <span class="p">});</span> <span class="c1">// 21</span> |
| <span class="p">});</span> <span class="c1">// 22</span> |
| <span class="p">}</span> <span class="c1">// 23</span> |
| </code></pre></div> |
| <ul> |
| <li>Lines 1-6: Backward function has the same signature as forward function.</li> |
| <li>Lines 7-9: Check the sizes of the function arguments. One thing to note |
| that since the gradient of the input depends on both the gradient of the output and |
| the input tensor itself, <code>inputs</code> must contain two <code>TBlob</code> objects.</li> |
| <li>Line 10: Get the <code>stream</code> of the context for serializing asynchronous executions.</li> |
| <li>Lines 11-13: Convenience reference variables for later use. We name <code>out_grad</code> |
| as the gradient of the operator output, <code>in_data</code> as the input of the operator, |
| and <code>in_grad</code> as the gradient of the operator input.</li> |
| <li>Line 14: Get the parameter object of <code>QuadraticParam</code>.</li> |
| <li>Lines 16-22: Same as in the forward function, this is where parallel |
| computation for <code>in_grad</code> happens. The struct <code>quadratic_backward</code> implements |
| the formula of calculating each element of <code>in_grad</code> by one thread as the following.</li> |
| </ul> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="k">template</span><span class="o"><</span><span class="kt">int</span> <span class="n">req</span><span class="o">></span> |
| <span class="k">struct</span> <span class="n">quadratic_backward</span> <span class="p">{</span> |
| <span class="k">template</span><span class="o"><</span><span class="k">typename</span> <span class="n">DType</span><span class="o">></span> |
| <span class="n">MSHADOW_XINLINE</span> <span class="k">static</span> <span class="kt">void</span> <span class="n">Map</span><span class="p">(</span><span class="kt">int</span> <span class="n">i</span><span class="p">,</span> <span class="n">DType</span><span class="o">*</span> <span class="n">in_grad</span><span class="p">,</span> <span class="k">const</span> <span class="n">DType</span><span class="o">*</span> <span class="n">out_grad</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">DType</span><span class="o">*</span> <span class="n">in_data</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="kt">float</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span> |
| <span class="n">KERNEL_ASSIGN</span><span class="p">(</span><span class="n">in_grad</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">req</span><span class="p">,</span> <span class="n">out_grad</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">in_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">b</span><span class="p">));</span> |
| <span class="p">}</span> |
| <span class="p">};</span> |
| </code></pre></div> |
| <h3 id="operator-registration">Operator Registration</h3> |
| |
| <p>So far, we have implemented necessary data structure and functions for the operator <code>quadratic</code>. |
| Now let's register them using <code>nnvm</code> to expose the operator <code>quadratic</code> |
| to frontend. Users can consider the registration process as creating the operator object |
| instance, saving it in the operator manager (a singleton), |
| and setting attributes for the operator instance.</p> |
| |
| <p>The following code is from <code>quadratic_op.cc</code>, which is responsible |
| for registering the operator working on CPU.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="n">DMLC_REGISTER_PARAMETER</span><span class="p">(</span><span class="n">QuadraticParam</span><span class="p">);</span> <span class="c1">// 1</span> |
| |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">quadratic</span><span class="p">)</span> <span class="c1">// 2</span> |
| <span class="p">.</span><span class="n">describe</span><span class="p">(</span><span class="s">R"code(This operators implements the quadratic function: // 3 |
| .. math:: |
| |
| f(x) = ax^2+bx+c |
| |
| where :math:`x` is an input tensor and all operations |
| in the function are element-wise. |
| |
| Example: |
| |
| .. code-block:: python |
| :emphasize-lines: 1,3 |
| x = [[1, 2], [3, 4]] |
| y = quadratic(data=x, a=1, b=2, c=3) |
| y = [[6, 11], [18, 27]] |
| |
| )code"</span> <span class="n">ADD_FILELINE</span><span class="p">)</span> <span class="c1">// 4</span> |
| <span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">)</span> <span class="c1">// 5</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1">// 6</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1">// 7</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FListInputNames</span><span class="o">></span><span class="p">(</span><span class="s">"FListInputNames"</span><span class="p">,</span> <span class="c1">// 8</span> |
| <span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 9</span> |
| <span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">></span><span class="p">{</span><span class="s">"data"</span><span class="p">};</span> <span class="c1">// 10</span> |
| <span class="p">})</span> <span class="c1">// 11</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferShape</span><span class="o">></span><span class="p">(</span><span class="s">"FInferShape"</span><span class="p">,</span> <span class="n">QuadraticOpShape</span><span class="p">)</span> <span class="c1">// 12</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInferType</span><span class="o">></span><span class="p">(</span><span class="s">"FInferType"</span><span class="p">,</span> <span class="n">QuadraticOpType</span><span class="p">)</span> <span class="c1">// 13</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">QuadraticOpForward</span><span class="o"><</span><span class="n">cpu</span><span class="o">></span><span class="p">)</span> <span class="c1">// 14</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FGradient</span><span class="o">></span><span class="p">(</span><span class="s">"FGradient"</span><span class="p">,</span> <span class="n">ElemwiseGradUseIn</span><span class="p">{</span><span class="s">"_backward_quadratic"</span><span class="p">})</span> <span class="c1">// 15</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">FInplaceOption</span><span class="o">></span><span class="p">(</span><span class="s">"FInplaceOption"</span><span class="p">,</span> <span class="c1">// 16</span> |
| <span class="p">[](</span><span class="k">const</span> <span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// 17</span> |
| <span class="k">return</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">std</span><span class="o">::</span><span class="n">pair</span><span class="o"><</span><span class="kt">int</span><span class="p">,</span> <span class="kt">int</span><span class="o">></span> <span class="o">></span><span class="p">{{</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">}};</span> <span class="c1">// 18</span> |
| <span class="p">})</span> <span class="c1">// 19</span> |
| <span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"data"</span><span class="p">,</span> <span class="s">"NDArray-or-Symbol"</span><span class="p">,</span> <span class="s">"Input ndarray"</span><span class="p">)</span> <span class="c1">// 20</span> |
| <span class="p">.</span><span class="n">add_arguments</span><span class="p">(</span><span class="n">QuadraticParam</span><span class="o">::</span><span class="n">__FIELDS__</span><span class="p">());</span> <span class="c1">// 21</span> |
| |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_quadratic</span><span class="p">)</span> <span class="c1">// 22</span> |
| <span class="p">.</span><span class="n">set_attr_parser</span><span class="p">(</span><span class="n">ParamParser</span><span class="o"><</span><span class="n">QuadraticParam</span><span class="o">></span><span class="p">)</span> <span class="c1">// 23</span> |
| <span class="p">.</span><span class="n">set_num_inputs</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="c1">// 24</span> |
| <span class="p">.</span><span class="n">set_num_outputs</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1">// 25</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">nnvm</span><span class="o">::</span><span class="n">TIsBackward</span><span class="o">></span><span class="p">(</span><span class="s">"TIsBackward"</span><span class="p">,</span> <span class="nb">true</span><span class="p">)</span> <span class="c1">// 26</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">QuadraticOpBackward</span><span class="o"><</span><span class="n">cpu</span><span class="o">></span><span class="p">);</span> <span class="c1">// 27</span> |
| </code></pre></div> |
| <ul> |
| <li>Line 1: Register the parameter struct.</li> |
| <li>Line 2: Register an operator named <code>quadratic</code> by creating an instance |
| of <code>Op</code> type and save it in the operator manager and return a reference |
| of the just created operator object.</li> |
| <li>Lines 3-4: Add description as an operator attribute |
| including examples of the operator. The documentation engine will extract |
| this description and display it on the documentation web page. |
| <code>emphasize-lines</code> is optional. |
| For more examples and troubleshooting with doc strings, refer to the <a href="https://cwiki.apache.org/confluence/display/MXNET/Documentation+Guide">MXNet |
| developer wiki's Documentation Guide</a>.</li> |
| <li>Line 5: Set parameter struct parser for the operator. It is used for parsing |
| the parameters <code>a</code>, <code>b</code>, and <code>c</code> input from frontend.</li> |
| <li>Line 6: Set the number of inputs for the operator.</li> |
| <li>Line 7: Set the number of outputs for the operator.</li> |
| <li>Lines 8-11: Defines a function generating a vector of names of |
| the operator input arguments. This function is used to add missing |
| arguments that users did not specify when creating a symbolic operator. |
| For example, <code>quad_func=mx.sym.quadratic()</code> is still a valid symbol |
| since we have added the attribute <code>FListInputNames</code> to the operator node |
| in the computational graph. MXNet would |
| add the missing argument with name <code>quadratic0_data</code>, where the prefix |
| <code>quadratic0</code> is the operator name appended with an index and the postfix |
| <code>data</code> comes from the return value of the user defined <code>FListInputName</code> function. |
| Users still can generate an executor for the <code>quad_func</code> like the following: |
| <code>python |
| quad_exe = quad_func.simple_bind(ctx=mx.cpu(), quadratic0_data=(1,)) |
| </code></li> |
| <li>Line 12: Register shape inference function.</li> |
| <li>Line 13: Register type inference function.</li> |
| <li>Line 14: Register forward function.</li> |
| <li>Line 15: Register the function for creating the node of the operator in |
| a backward pass. Note that we used a convenience functor struct <code>ElemwiseGradUseIn</code>. |
| As you can tell from the name, the registered functor creates the node for gradient computation |
| with dependencies on the output gradient node and input node. Similarly, there are |
| other three functors defined as <code>ElemwiseGradUseOut</code>, <code>ElemwiseGradUseInOut</code>, |
| and <code>ElemwiseGradUseNone</code> for developers' convenience. In order to add |
| this attribute, we also need to register a backward operator for <code>quadratic</code> with |
| several basic attributes, as it can share attribute inference |
| functions with the forward operator and is not exposed to frontend.</li> |
| <li>Lines 16-19: This registered function implies that which output tensor can reuse |
| which input tensor's memory space instead of allocating a new memory space for the output. |
| In the operator <code>quadratic</code>, there is only one input and output, and the output can reuse |
| the input memory space, so we store a pair of zeros in the function return vector |
| indicating that <code>inputs[0]</code>'s memory space can be reused by <code>outputs[0]</code>. |
| Note that this function just provides a hint to the computational graph initializer. |
| If there are other nodes depending on the input tensor, the memory space |
| of the input tensor will not be overwritten by the output.</li> |
| <li>Line 20: Define the input argument name as <code>data</code> for the operator.</li> |
| <li>Line 21: Add user input parameters <code>a</code>, <code>b</code>, and <code>c</code> as the attributes of the operator.</li> |
| <li>Line 22: Register an operator named <code>_backward_quadratic</code> for backward pass |
| of the operator <code>quadratic</code>. The underscore prefix in the operator name indicates |
| that this is an operator not exposed to users. The convention |
| of naming an internally used backward operator is prepending the prefix <code>_backward_</code> |
| to the corresponding forward operator name.</li> |
| <li>Line 23: Set the parameter parser for the operator <code>_backward_quadratic</code>.</li> |
| <li>Line 24: Set the number of inputs.</li> |
| <li>Line 25: Set the number of outputs.</li> |
| <li>Line 26: Add <code>TIsBackward</code> attribute for the operator. The shape and type |
| inference passes use this attribute to determine whether a node in the graph is a |
| forward or backward node.</li> |
| <li>Line 27: Register backward function.</li> |
| </ul> |
| |
| <p>So far, we have acquired an operator working on CPU in frontend. |
| In order to register the operator working on GPUs, we just need to add the following |
| code to <code>quadratic_op.cu</code>. Note that forward and backward functions |
| are registered with attribute key <code>FCompute<gpu></code>, rather than <code>FCompute<cpu></code>.</p> |
| <div class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">quadratic</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">QuadraticOpForward</span><span class="o"><</span><span class="n">gpu</span><span class="o">></span><span class="p">);</span> |
| |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">_backward_quadratic</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">QuadraticOpBackward</span><span class="o"><</span><span class="n">gpu</span><span class="o">></span><span class="p">);</span> |
| </code></pre></div> |
| <h3 id="unit-test">Unit Test</h3> |
| |
| <p>Now we have finished implementing the operator <code>quadratic</code> in MXNet backend. |
| If you use python, when you type <code>import mxnet as mx</code>, two python |
| functions for invoking your backend implementation are |
| generated on the fly: one is for imperative programming |
| registered as <code>mxnet.ndarray.quadratic</code> or <code>mx.nd.quadratic</code> for short; |
| the other one is for symbolic programming registered under |
| module <code>mxnet.symbol.quadratic</code> or <code>mx.sym.quadratic</code> for short.</p> |
| |
| <p>In order to unit test it in frontend, we need to add the following code |
| to the python file <code>test_operator.py</code>. A typical operator implementation |
| tests for both the <code>symbol</code> API and the <code>ndarray</code> API. The following test |
| has both these tests. The imperative API test, tests for the <code>ndarray</code> API, |
| <code>mx.nd.contrib.quadratic</code>. The <code>symbol</code> API test, tests for the complete |
| functionality of the operator - the forward pass and the backward |
| pass. To facilitate the testing of these functionalities we use three |
| helper functions available in the <code>mxnet.test_utils</code> module: |
| - <code>check_symbolic_forward</code> |
| - <code>check_symbolic_backward</code> |
| - <code>check_numeric_gradient</code></p> |
| <div class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">test_quadratic_function</span><span class="p">():</span> |
| <span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">):</span> |
| <span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">x</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">c</span> |
| |
| <span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random_sample</span><span class="p">()</span> |
| <span class="n">b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random_sample</span><span class="p">()</span> |
| <span class="n">c</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random_sample</span><span class="p">()</span> |
| <span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s">'data'</span><span class="p">)</span> |
| <span class="n">quad_sym</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">quadratic</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">c</span><span class="p">)</span> |
| <span class="k">for</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">]:</span> |
| <span class="k">for</span> <span class="n">ndim</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">):</span> |
| <span class="n">shape</span> <span class="o">=</span> <span class="n">rand_shape_nd</span><span class="p">(</span><span class="n">ndim</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span> |
| <span class="n">data_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> |
| <span class="n">expected</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">data_np</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span> |
| <span class="n">backward_expected</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">a</span> <span class="o">*</span> <span class="n">data_np</span> <span class="o">+</span> <span class="n">b</span> |
| |
| <span class="c1"># check imperative forward |
| </span> <span class="n">output</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">quadratic</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">data_np</span><span class="p">),</span> <span class="n">a</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">c</span><span class="p">)</span> |
| <span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span><span class="n">expected</span><span class="p">,</span> |
| <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> <span class="k">else</span> <span class="mf">1e-5</span><span class="p">,</span> |
| <span class="n">atol</span><span class="o">=</span><span class="mf">1e-2</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> <span class="k">else</span> <span class="mf">1e-5</span><span class="p">)</span> |
| <span class="c1"># check forward |
| </span> <span class="n">check_symbolic_forward</span><span class="p">(</span><span class="n">quad_sym</span><span class="p">,</span> <span class="p">[</span><span class="n">data_np</span><span class="p">],</span> <span class="p">[</span><span class="n">expected</span><span class="p">],</span> |
| <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> <span class="k">else</span> <span class="mf">1e-5</span><span class="p">,</span> |
| <span class="n">atol</span><span class="o">=</span><span class="mf">1e-2</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> <span class="k">else</span> <span class="mf">1e-5</span><span class="p">)</span> |
| <span class="c1"># check backward |
| </span> <span class="n">check_symbolic_backward</span><span class="p">(</span><span class="n">quad_sym</span><span class="p">,</span> <span class="p">[</span><span class="n">data_np</span><span class="p">],</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">expected</span><span class="o">.</span><span class="n">shape</span><span class="p">)],</span> |
| <span class="p">[</span><span class="n">backward_expected</span><span class="p">],</span> |
| <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> <span class="k">else</span> <span class="mf">1e-5</span><span class="p">,</span> |
| <span class="n">atol</span><span class="o">=</span><span class="mf">1e-2</span> <span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> <span class="k">else</span> <span class="mf">1e-5</span><span class="p">)</span> |
| <span class="c1"># check backward using finite difference |
| </span> <span class="n">check_numeric_gradient</span><span class="p">(</span><span class="n">quad_sym</span><span class="p">,</span> <span class="p">[</span><span class="n">data_np</span><span class="p">],</span> <span class="n">atol</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span> |
| </code></pre></div> |
| <p>In the above test we create a <code>quadratic</code> symbol and feed it into the three |
| utility functions. The <code>check_symbolic_forward</code> and <code>check_symbolic_backward</code> |
| tests the computed values against the expected values that we pass |
| as an argument to the function. The <code>check_numeric_gradient</code> utility function |
| performs <a href="http://ufldl.stanford.edu/tutorial/supervised/DebuggingGradientChecking/">gradient checking</a> |
| to verify the implementation for the backward function of the operator. |
| It will perform a perturbation on the input and calculate the response |
| rate of the output using the |
| <a href="https://en.wikipedia.org/wiki/Finite_difference_method">finite difference method</a>. |
| Then it will compare the gradient from the backward pass with the values |
| from the finite difference method. All three of these tests will be successful |
| once the comparison satisfies user specified <code>rtol</code> and <code>atol</code> values. Here <code>rtol</code> |
| and <code>atol</code> expand to relative tolerance and absolute tolerance respectively. They |
| are used to specify how far the computed values can deviate from the expected values. |
| They are defined as follows</p> |
| <div class="highlight"><pre><code class="language-" data-lang="">abs(Expected_Value - Computed_Value) < RTOL * abs(Expected_Value) + ATOL |
| </code></pre></div> |
| <p>For example, if <code>rtol</code> is <code>1e-5</code> and <code>atol</code> is <code>1e-5</code> and the expected value is |
| <code>1.5623145</code>, then the computed value should lie within the range of |
| <code>(1.562288876855, 1.562340123145)</code> else the test will fail. Make sure you |
| tune the <code>rtol</code> and <code>atol</code> values accordingly. Giving very low values for <code>rtol</code> |
| and <code>atol</code> will likely make the test very flaky. It is recommended that you |
| use the flakiness checker tool to check if the test you have written is flaky |
| or not. You can run the flakiness checker tool for the above test with the |
| following command -</p> |
| <div class="highlight"><pre><code class="language-bash" data-lang="bash">python tools/flakiness_checker.py test_operator.test_quadratic_function |
| </code></pre></div> |
| <p>Please note that for <code>check_symbolic_forward</code> and <code>check_symbolic_backward</code> we pass |
| both the operator symbols and expected results for comparison, for |
| <code>check_numeric_gradient</code> we only pass the operator symbol, as the |
| <code>check_numeric_gradient</code> computes the expected value using finite difference |
| method. Which is why it is highly recommended to add <code>check_numeric_gradient</code> |
| test for every operator with backward function implemented as it eliminates |
| the possibility of passing incorrect expected results into <code>check_symbolic_backward</code>.</p> |
| |
| <h2 id="summary">Summary</h2> |
| |
| <p>In this tutorial, we practiced implementing the operator <code>quadratic</code> in MXNet backend |
| and unit testing the implementation in frontend. More specifically, we added parameter |
| struct for user-input parameters, walked through shape and type inference workflow, |
| implemented forward and backward functions, and registered the operator |
| using nnvm. Congratulations! You now know how to add operators. |
| We welcome your contributions to MXNet.</p> |
| |
| <p><strong>Note</strong>: Source code in the tutorial can be found in |
| <a href="https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/quadratic_op-inl.h">quadratic_op-inl.h</a>, |
| <a href="https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/quadratic_op.cc">quadratic_op.cc</a>, |
| <a href="https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/quadratic_op.cu">quadratic_op.cu</a>, |
| and |
| <a href="https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L6514">test_operator.py</a>.</p> |
| |
| </div> |
| </div> |
| |
| </div> |
| </div> |
| |
| </article> |
| |
| </main><footer class="site-footer h-card"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-4"> |
| <h4 class="footer-category-title">Resources</h4> |
| <ul class="contact-list"> |
| <li><a href="/versions/1.9.1/community/contribute#mxnet-dev-communications">Mailing lists</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/labels/Roadmap">Github Roadmap</a></li> |
| <li><a href="https://medium.com/apache-mxnet">Blog</a></li> |
| <li><a href="https://discuss.mxnet.io">Forum</a></li> |
| <li><a href="/versions/1.9.1/community/contribute">Contribute</a></li> |
| </ul> |
| </div> |
| |
| <div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul> |
| </div> |
| |
| <div class="col-4 footer-text"> |
| <p>A flexible and efficient library for deep learning.</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| <footer class="site-footer2"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3"> |
| <img src="/versions/1.9.1/assets/img/apache_incubator_logo.png" class="footer-logo col-2"> |
| </div> |
| <div class="footer-bottom-warning col-9"> |
| <p>Apache MXNet is an effort undergoing incubation at <a href="http://www.apache.org/">The Apache Software Foundation</a> (ASF), <span |
| style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required |
| of all newly accepted projects until a further review indicates that the infrastructure, |
| communications, and decision making process have stabilized in a manner consistent with other |
| successful ASF projects. While incubation status is not necessarily a reflection of the completeness |
| or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p><p>"Copyright © 2017-2022, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache |
| feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the |
| Apache Software Foundation."</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| |
| |
| |
| |
| </body> |
| |
| </html> |