| <!DOCTYPE html> |
| |
| <!--- |
| Licensed to the Apache Software Foundation (ASF) under one |
| or more contributor license agreements. See the NOTICE file |
| distributed with this work for additional information |
| regarding copyright ownership. The ASF licenses this file |
| to you under the Apache License, Version 2.0 (the |
| "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at |
| http://www.apache.org/licenses/LICENSE-2.0 |
| Unless required by applicable law or agreed to in writing, |
| software distributed under the License is distributed on an |
| "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| KIND, either express or implied. See the License for the |
| specific language governing permissions and limitations |
| under the License. |
| --> |
| |
| <html lang=" en"><head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge"> |
| <meta name="viewport" content="width=device-width, initial-scale=1"> |
| <link href="/versions/master/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 --> |
| <title>Using runtime compilation (RTC) to write CUDA kernels in MXNet | Apache MXNet</title> |
| <meta name="generator" content="Jekyll v4.0.0" /> |
| <meta property="og:title" content="Using runtime compilation (RTC) to write CUDA kernels in MXNet" /> |
| <meta property="og:locale" content="en_US" /> |
| <meta name="description" content="A flexible and efficient library for deep learning." /> |
| <meta property="og:description" content="A flexible and efficient library for deep learning." /> |
| <link rel="canonical" href="https://mxnet.apache.org/versions/master/api/faq/using_rtc" /> |
| <meta property="og:url" content="https://mxnet.apache.org/versions/master/api/faq/using_rtc" /> |
| <meta property="og:site_name" content="Apache MXNet" /> |
| <script type="application/ld+json"> |
| {"url":"https://mxnet.apache.org/versions/master/api/faq/using_rtc","headline":"Using runtime compilation (RTC) to write CUDA kernels in MXNet","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script> |
| <!-- End Jekyll SEO tag --> |
| <link rel="stylesheet" href="/versions/master/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><!-- Matomo --> |
| <script> |
| var _paq = window._paq = window._paq || []; |
| /* tracker methods like "setCustomDimension" should be called before "trackPageView" */ |
| /* We explicitly disable cookie tracking to avoid privacy issues */ |
| _paq.push(['disableCookies']); |
| _paq.push(['trackPageView']); |
| _paq.push(['enableLinkTracking']); |
| (function() { |
| var u="https://analytics.apache.org/"; |
| _paq.push(['setTrackerUrl', u+'matomo.php']); |
| _paq.push(['setSiteId', '23']); |
| var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0]; |
| g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s); |
| })(); |
| </script> |
| <!-- End Matomo Code --> |
| |
| <script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script> |
| <script src="/versions/master/assets/js/docsearch.min.js"></script><script src="/versions/master/assets/js/globalSearch.js" defer></script> |
| <script src="/versions/master/assets/js/clipboard.js" defer></script> |
| <script src="/versions/master/assets/js/copycode.js" defer></script></head> |
| <body><header class="site-header" role="banner"> |
| |
| <script> |
| $(document).ready(function () { |
| |
| // HEADER OPACITY LOGIC |
| |
| function opacity_header() { |
| var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")" |
| $('.site-header').css("background-color", value) |
| } |
| |
| $(window).scroll(function () { |
| opacity_header() |
| }) |
| opacity_header(); |
| |
| // MENU SELECTOR LOGIC |
| $('.page-link').each( function () { |
| if (window.location.href.includes(this.href)) { |
| $(this).addClass("page-current"); |
| } |
| }); |
| }) |
| </script> |
| <div class="wrapper"> |
| <a class="site-title" rel="author" href="/versions/master/"><img |
| src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a> |
| <nav class="site-nav"> |
| <input type="checkbox" id="nav-trigger" class="nav-trigger"/> |
| <label for="nav-trigger"> |
| <span class="menu-icon"> |
| <svg viewBox="0 0 18 15" width="18px" height="15px"> |
| <path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/> |
| </svg> |
| </span> |
| </label> |
| <div class="gs-search-border"> |
| <div id="gs-search-icon"></div> |
| <form id="global-search-form"> |
| <input id="global-search" type="text" title="Search" placeholder="Search" /> |
| <div id="global-search-dropdown-container"> |
| <button class="gs-current-version btn" type="button" data-toggle="dropdown"> |
| <span id="gs-current-version-label">master</span> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown"> |
| |
| |
| <li class="gs-opt gs-versions active">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.9.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| <span id="global-search-close">x</span> |
| </form> |
| </div> |
| <div class="trigger"> |
| <div id="global-search-mobile-border"> |
| <div id="gs-search-icon-mobile"></div> |
| <input id="global-search-mobile" placeholder="Search..." type="text"/> |
| <div id="global-search-dropdown-container-mobile"> |
| <button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown"> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown-mobile"> |
| |
| |
| <li class="gs-opt gs-versions active">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.9.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| </div> |
| <a class="page-link" href="/versions/master/get_started">Get Started</a> |
| <a class="page-link" href="/versions/master/features">Features</a> |
| <a class="page-link" href="/versions/master/ecosystem">Ecosystem</a> |
| <a class="page-link" href="/versions/master/api">Docs & Tutorials</a> |
| <a class="page-link" href="/versions/master/trusted_by">Trusted By</a> |
| <a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a> |
| <div class="dropdown" style="min-width:100px"> |
| <span class="dropdown-header">Apache |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content" style="min-width:250px"> |
| <a href="https://www.apache.org/foundation/">Apache Software Foundation</a> |
| <a href="https://incubator.apache.org/">Apache Incubator</a> |
| <a href="https://www.apache.org/licenses/">License</a> |
| <a href="/versions/master/api/faq/security.html">Security</a> |
| <a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a> |
| <a href="https://www.apache.org/events/current-event">Events</a> |
| <a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a> |
| <a href="https://www.apache.org/foundation/thanks.html">Thanks</a> |
| </div> |
| </div> |
| <div class="dropdown"> |
| <span class="dropdown-header">master |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content"> |
| |
| |
| <a class="dropdown-option-active" href="/">master</a> |
| |
| |
| |
| <a href="/versions/1.9.1/">1.9.1</a> |
| |
| |
| |
| <a href="/versions/1.8.0/">1.8.0</a> |
| |
| |
| |
| <a href="/versions/1.7.0/">1.7.0</a> |
| |
| |
| |
| <a href="/versions/1.6.0/">1.6.0</a> |
| |
| |
| |
| <a href="/versions/1.5.0/">1.5.0</a> |
| |
| |
| |
| <a href="/versions/1.4.1/">1.4.1</a> |
| |
| |
| |
| <a href="/versions/1.3.1/">1.3.1</a> |
| |
| |
| |
| <a href="/versions/1.2.1/">1.2.1</a> |
| |
| |
| |
| <a href="/versions/1.1.0/">1.1.0</a> |
| |
| |
| |
| <a href="/versions/1.0.0/">1.0.0</a> |
| |
| |
| |
| <a href="/versions/0.12.1/">0.12.1</a> |
| |
| |
| |
| <a href="/versions/0.11.0/">0.11.0</a> |
| |
| |
| </div> |
| </div> |
| </div> |
| </nav> |
| </div> |
| </header> |
| <main class="page-content" aria-label="Content"> |
| <script> |
| |
| </script> |
| <article class="post"> |
| |
| <header class="post-header wrapper"> |
| <h1 class="post-title">Using runtime compilation (RTC) to write CUDA kernels in MXNet</h1> |
| <h3></h3></header> |
| |
| <div class="post-content"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3 docs-side-bar"> |
| <h3 style="text-transform: capitalize; padding-left:10px">faq</h3> |
| <ul> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/add_op_in_backend">A Beginner's Guide to Implementing Operators in MXNet Backend</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/cloud">MXNet on the Cloud</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/distributed_training">Distributed Training in MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/env_var">Environment Variables</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/float16">Float16</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/large_tensor_support">Using MXNet with Large Tensor Support</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/model_parallel_lstm">Model Parallel</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/new_op">Create New Operators</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/perf">Some Tips for Improving MXNet Performance</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/recordio">Create a Dataset Using RecordIO</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/s3_integration">Use data from S3 for training</a></li> |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/security">MXNet Security Best Practices</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/tensor_inspector_tutorial">Use TensorInspector to Help Debug Operators</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/using_rtc">Using runtime compilation (RTC) to write CUDA kernels in MXNet</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| |
| <li><a href="/versions/master/api/faq/why_mxnet">Why MXNet came to be?</a></li> |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| </ul> |
| </div> |
| <div class="col-9"> |
| <!--- Licensed to the Apache Software Foundation (ASF) under one --> |
| <!--- or more contributor license agreements. See the NOTICE file --> |
| <!--- distributed with this work for additional information --> |
| <!--- regarding copyright ownership. The ASF licenses this file --> |
| <!--- to you under the Apache License, Version 2.0 (the --> |
| <!--- "License"); you may not use this file except in compliance --> |
| <!--- with the License. You may obtain a copy of the License at --> |
| |
| <!--- http://www.apache.org/licenses/LICENSE-2.0 --> |
| |
| <!--- Unless required by applicable law or agreed to in writing, --> |
| <!--- software distributed under the License is distributed on an --> |
| <!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> |
| <!--- KIND, either express or implied. See the License for the --> |
| <!--- specific language governing permissions and limitations --> |
| <!--- under the License. --> |
| |
| <h1 id="using-runtime-compilation-rtc-to-write-cuda-kernels-in-mxnet">Using runtime compilation (RTC) to write CUDA kernels in MXNet</h1> |
| |
| <h2 id="introduction">Introduction</h2> |
| |
| <p>CUDA kernel is a function running on the GPU to perform computation. This tutorial assumes the |
| reader has a basic knowledge about how to write such kernels.</p> |
| |
| <p>There are currently 2 typical ways of writing and launching CUDA kernels in MXNet. The first one is |
| to use the <code class="highlighter-rouge">Kernel<...>::Launch()</code> API, which is suitable for simple elementwise operations and |
| enables writing only portion of the kernel, leaving the launch mechanism to MXNet. The |
| other one is to write a kernel from scratch and launch it using the <code class="highlighter-rouge"><<<...>>></code> method from CUDA. |
| Starting from MXNet 2.0, there is a third option - runtime compilation (RTC). This differs from the |
| previous methods (which use kernels compiled ahead of time), as it compiles the needed kernels |
| during runtime of the user script.</p> |
| |
| <p>In this tutorial we will cover the reasons for using RTC instead of the other methods, show how to |
| do it, as well as tips on what to keep in mind when doing it.</p> |
| |
| <h2 id="why-rtc">Why RTC?</h2> |
| |
| <h3 id="problems-with-kernels-compiled-ahead-of-time">Problems with kernels compiled ahead of time</h3> |
| |
| <p>The use of kernels compiled ahead of time in MXNet leads to a few problems, which unfortunately |
| are mostly invisible in any single PR, but grow over the course of many contributions and result in |
| serious issues.</p> |
| |
| <p>In order to understand them, let us look at the typical way kernels are launched in MXNet. This |
| example shows a launch of the simple kernel, taking a single input of type <code class="highlighter-rouge">DType</code> and producing |
| single output of type <code class="highlighter-rouge">OType</code>:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span> |
| <span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">OType</span><span class="p">,</span> <span class="p">{</span> |
| <span class="n">Kernel</span><span class="o"><</span><span class="p">...</span><span class="o">>::</span><span class="n">Launch</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr</span><span class="o"><</span><span class="n">DType</span><span class="o">></span><span class="p">(),</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr</span><span class="o"><</span><span class="n">OType</span><span class="o">></span><span class="p">());</span> |
| <span class="p">});</span> |
| <span class="p">});</span> |
| </code></pre></div></div> |
| |
| <p>This launch mechanism uses the <code class="highlighter-rouge">MSHADOW_TYPE_SWITCH</code> macro, which produces a version of the kernel |
| for every possible type. In the case of nested usage (as is the case in the example shown) it |
| produces a version of the kernel for every combination of types. This results in a large number of |
| kernels being generated.</p> |
| |
| <p>Another factor that multiplies the number of kernels is that different GPU architectures require |
| different compiled binaries. Therefore for MXNet to support all of them with a single binary, that |
| binary needs to contain copies of those kernels for each architecure.</p> |
| |
| <p>This proliferation of CUDA kernels in the binary leads to multiple issues. The first problem is the |
| size of the MXNet library - each compiled version of the kernel takes some space in the binary, |
| which is small but multiplied by the number of all versions (which could reach thousands per |
| GPU architecture) and GPU architectures. This increase in size led to multiple issues reported with |
| distribution of the MXNet package, |
| <a href="https://github.com/apache/incubator-mxnet/issues/17045">building the library</a> as well as |
| <a href="https://github.com/apache/incubator-mxnet/pull/18205">limiting the number of architectures natively |
| supported</a>.</p> |
| |
| <p>The second issue is the “idle” memory consumption of the MXNet library. In order to efficiently |
| launch kernels when they are called, CUDA driver needs to transfer them to the GPU memory ahead of |
| time. Since it cannot anticipate which kernels will actually be used, all of the kernels are |
| transferred when the CUDA context is created on a GPU. This means that, even if a user never uses |
| e.g. kernel which adds <code class="highlighter-rouge">int8</code> and <code class="highlighter-rouge">float16</code> tensors, that kernel still occupies memory on their GPU, |
| reducing the amount of memory available for useful work.</p> |
| |
| <p>The third issue, mostly affecting MXNet developers, is the compilation time of the MXNet library. |
| The more kernels versions need to be compiled, the more time and hardware resources is needed.</p> |
| |
| <h3 id="rtc-to-the-rescue">RTC to the rescue!</h3> |
| |
| <p>All of the issues mentioned in the previous paragraph are solved when using runtime compilation. |
| Using this paradigm, only the kernels actually invoked in the user script are compiled. They do not |
| occupy space in the MXNet binary and there is no unused kernels stored in users’ GPU memory.</p> |
| |
| <p>RTC also enables more features:</p> |
| |
| <ul> |
| <li>using more information about specific usage of the kernel when compiling it (e.g. using shape |
| information of the inputs) to optimize it better</li> |
| <li>writing kernels accepting any combinations of input and output types</li> |
| <li>(in the future) fusing more operations into the generated kernels.</li> |
| </ul> |
| |
| <h2 id="rtc-for-kernel-developers">RTC for kernel developers</h2> |
| |
| <h3 id="example-unary-operators">Example: unary operators</h3> |
| |
| <p>Let us start with an example of the simple kernel written using RTC: a kernel which performs unary |
| operation (with a concrete example of sigmoid) on its input. It is not a toy example though: it is |
| a fully generic kernel, capable of operating on any combination of input and output types, as well |
| as applying any unary operator:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">struct</span> <span class="n">UnaryRTCCompute</span> <span class="p">{</span> |
| <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">OP</span><span class="p">;</span> |
| |
| <span class="kt">void</span> <span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">);</span> |
| <span class="p">};</span> |
| |
| <span class="k">const</span> <span class="kt">char</span> <span class="n">unary_kernel_fwd</span><span class="p">[]</span> <span class="o">=</span> <span class="s">R"code( |
| |
| __launch_bounds__(kRTCMaxThreadsPerBlock) |
| __global__ void unary_kernel(const InputType* input, |
| const OutputType* output, |
| const index_t N) { |
| using IType = AccType<InputType>; |
| using OType = AccType<OutputType>; |
| |
| for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; |
| tid < N; |
| tid += gridDim.x * blockDim.x) { |
| const auto input = IType::from(input[i]); |
| const auto temp = OP(input); // enables returning different type |
| |
| if (req == OpReqType::kAddTo) { |
| // temp2 may have a wider type than either temp |
| // or OType |
| const auto temp2 = op::add(temp, OType::from(output[i])); |
| output[i] = OType::to(temp2); |
| } else { |
| output[i] = OType::to(temp); |
| } |
| } |
| } |
| |
| )code"</span><span class="p">;</span> |
| |
| <span class="kt">void</span> <span class="n">UnaryRTCCompute</span><span class="o">::</span><span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span> |
| <span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">common</span><span class="o">::</span><span class="n">cuda</span><span class="o">::</span><span class="n">rtc</span><span class="p">;</span> |
| <span class="k">if</span> <span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">kNullOp</span><span class="p">)</span> <span class="k">return</span><span class="p">;</span> |
| <span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">gpu</span><span class="o">>*</span> <span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">gpu</span><span class="o">></span><span class="p">();</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span> |
| <span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"#define OP op::"</span> <span class="o">+</span> |
| <span class="n">OP</span> <span class="o">+</span> |
| <span class="s">"</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"using InputType = "</span> <span class="o">+</span> |
| <span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"using OutputType = "</span> <span class="o">+</span> |
| <span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span> |
| |
| <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="k">const</span> <span class="kt">void</span><span class="o">*></span> <span class="n">args</span><span class="p">;</span> |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">size</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">Size</span><span class="p">();</span> |
| <span class="n">args</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="o">&</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">));</span> |
| <span class="n">args</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="o">&</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">));</span> |
| <span class="n">args</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="o">&</span><span class="n">size</span><span class="p">);</span> |
| |
| <span class="k">auto</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">get_function</span><span class="p">(</span><span class="n">code</span><span class="p">,</span> <span class="s">"unary_kernel"</span><span class="p">,</span> <span class="n">unary_kernel_fwd</span><span class="p">,</span> |
| <span class="n">ctx</span><span class="p">.</span><span class="n">run_ctx</span><span class="p">.</span><span class="n">get_ctx</span><span class="p">().</span><span class="n">dev_id</span><span class="p">);</span> |
| |
| <span class="k">const</span> <span class="kt">int</span> <span class="n">n_threads</span> <span class="o">=</span> <span class="mi">512</span><span class="p">;</span> |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">n_blocks</span> <span class="o">=</span> <span class="p">(</span><span class="n">size</span> <span class="o">+</span> <span class="n">n_threads</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">n_threads</span><span class="p">;</span> |
| <span class="k">const</span> <span class="kt">int</span> <span class="n">shared_memory_size</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> |
| <span class="n">launch</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">{</span><span class="n">n_blocks</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span> <span class="p">{</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span> |
| <span class="n">shared_memory_size</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="o">&</span><span class="n">args</span><span class="p">);</span> |
| <span class="p">}</span> |
| |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">UnaryRTCCompute</span><span class="p">{</span><span class="s">"sigmoid"</span><span class="p">});</span> |
| </code></pre></div></div> |
| |
| <h3 id="kernels-are-text">Kernels are text…</h3> |
| |
| <p>The main difference when writing kernels using RTC is that the kernel code becomes the text string. |
| This means that it is possible to change or compose the code at runtime, as is done here:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span> |
| <span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"#define OP op::"</span> <span class="o">+</span> |
| <span class="n">OP</span> <span class="o">+</span> |
| <span class="s">"</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"using InputType = "</span> <span class="o">+</span> |
| <span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"using OutputType = "</span> <span class="o">+</span> |
| <span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span> |
| </code></pre></div></div> |
| |
| <p>where the operation <code class="highlighter-rouge">OP</code> is also provided as a string in the operator declaration:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">UnaryRTCCompute</span><span class="p">{</span><span class="s">"sigmoid"</span><span class="p">});</span> |
| </code></pre></div></div> |
| |
| <h3 id="and-do-not-know-mxnet-source-code">and do not know MXNet source code</h3> |
| |
| <p>How does the kernel know what operation it should perform? The kernel’s source code uses <code class="highlighter-rouge">OP</code>, |
| which shows up in the <code class="highlighter-rouge">code</code> variable and is equal to <code class="highlighter-rouge">op::sigmoid</code>. Let us compare this to how the |
| same operator is defined for CPU:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">MXNET_OPERATOR_REGISTER_UNARY</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<cpu>"</span><span class="p">,</span> <span class="n">UnaryOp</span><span class="o">::</span><span class="n">Compute</span><span class="o"><</span><span class="n">cpu</span><span class="p">,</span> <span class="n">mshadow_op</span><span class="o">::</span><span class="n">sigmoid</span><span class="o">></span><span class="p">)</span> |
| </code></pre></div></div> |
| |
| <p>Since the kernel is compiled at runtime, it does not have access to the rest of the MXNet source |
| code, including <code class="highlighter-rouge">mshadow_op.h</code>, which defined <code class="highlighter-rouge">mshadow_op::sigmoid</code>. This means that we need to |
| provide the kernel with definitions of those functions (again, in text string form). Every |
| RTC-compiled kernel is prepended with a common header, containing string found in |
| <code class="highlighter-rouge">src/common/cuda/rtc/</code> directory. The <code class="highlighter-rouge">src/common/cuda/rtc/forward_functions-inl.h</code> file contains |
| the definition of <code class="highlighter-rouge">op::sigmoid</code>:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">template</span> <span class="o"><</span><span class="k">typename</span> <span class="n">DType</span><span class="o">></span> |
| <span class="n">__device__</span> <span class="kr">inline</span> <span class="n">DType</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="k">const</span> <span class="n">DType</span> <span class="n">val</span><span class="p">)</span> <span class="p">{</span> |
| <span class="k">if</span> <span class="p">(</span><span class="n">type_util</span><span class="o">::</span><span class="n">has_double_or_integral</span><span class="o"><</span><span class="n">DType</span><span class="o">>::</span><span class="n">value</span><span class="p">)</span> <span class="p">{</span> |
| <span class="k">return</span> <span class="mf">1.</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="o">::</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">val</span><span class="p">));</span> |
| <span class="p">}</span> <span class="k">else</span> <span class="p">{</span> |
| <span class="k">return</span> <span class="mf">1.</span><span class="n">f</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">expf</span><span class="p">(</span><span class="o">-</span><span class="n">val</span><span class="p">));</span> |
| <span class="p">}</span> |
| <span class="p">}</span> |
| </code></pre></div></div> |
| |
| <h3 id="handling-of-data-types">Handling of data types</h3> |
| |
| <p>MXNet has support for many datatypes. Some of those datatypes, like <code class="highlighter-rouge">float16</code>, <code class="highlighter-rouge">int8</code> or <code class="highlighter-rouge">bool</code> are |
| useful when storing the results, but in many computations they are too limiting as they can easily |
| overflow in the intermediate stages. That is why in the example we use <code class="highlighter-rouge">AccType<T></code> class - it |
| provides an accumulation type, that is potentially larger than the storage type - for example, |
| <code class="highlighter-rouge">AccType<float16>::type</code> is <code class="highlighter-rouge">float32</code>. It also provides special loading and storing functions: |
| <code class="highlighter-rouge">AccType<T>::from()</code> and <code class="highlighter-rouge">AccType<T>::to()</code>.</p> |
| |
| <p>One of the features of RTC-enabled kernels is to be able to accommodate any combination of the |
| input and output datatypes. Using <code class="highlighter-rouge">auto</code> as the output type of the intermediate steps helps with, |
| especially since many binary operators return a mixed type:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">template</span> <span class="o"><</span><span class="k">typename</span> <span class="n">DType</span><span class="p">,</span> <span class="k">typename</span> <span class="n">DType2</span><span class="o">></span> |
| <span class="n">__device__</span> <span class="kr">inline</span> <span class="k">typename</span> <span class="n">type_util</span><span class="o">::</span><span class="n">mixed_type</span><span class="o"><</span><span class="n">DType</span><span class="p">,</span> <span class="n">DType2</span><span class="o">>::</span><span class="n">type</span> |
| <span class="nf">add</span><span class="p">(</span><span class="k">const</span> <span class="n">DType</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">DType2</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span> |
| <span class="k">return</span> <span class="n">a</span> <span class="o">+</span> <span class="n">b</span><span class="p">;</span> |
| <span class="p">}</span> |
| </code></pre></div></div> |
| |
| <p><code class="highlighter-rouge">mixed_type<T, U>::type</code> is a type capable of storing value of the operation between 2 types <code class="highlighter-rouge">T</code> and |
| <code class="highlighter-rouge">U</code> - e.g. <code class="highlighter-rouge">mixed_type<float64, float32>::type = float64</code> and <code class="highlighter-rouge">mixed_type<float32, int32>::type = |
| float32</code>.</p> |
| |
| <h3 id="compiling-and-launching-rtc-kernels">Compiling and launching RTC kernels</h3> |
| |
| <p>The kernel code stored in <code class="highlighter-rouge">unary_kernel_fwd</code> is generic and relies on multiple names to be defined, |
| like <code class="highlighter-rouge">req</code>, <code class="highlighter-rouge">OP</code> or <code class="highlighter-rouge">InputType</code>. This is handled in the specific operator using the kernel by |
| defining a set of parameters that will be concatenated to the code during compilation:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span> |
| <span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"#define OP op::"</span> <span class="o">+</span> |
| <span class="n">OP</span> <span class="o">+</span> |
| <span class="s">"</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"using InputType = "</span> <span class="o">+</span> |
| <span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"using OutputType = "</span> <span class="o">+</span> |
| <span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span> |
| </code></pre></div></div> |
| |
| <p>In order to compile the kernel, the <code class="highlighter-rouge">mxnet::common::cuda::rtc::get_function</code> method is used:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">auto</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">get_function</span><span class="p">(</span><span class="n">code</span><span class="p">,</span> <span class="s">"unary_kernel"</span><span class="p">,</span> <span class="n">unary_kernel_fwd</span><span class="p">,</span> |
| <span class="n">ctx</span><span class="p">.</span><span class="n">run_ctx</span><span class="p">.</span><span class="n">get_ctx</span><span class="p">().</span><span class="n">dev_id</span><span class="p">);</span> |
| </code></pre></div></div> |
| |
| <p>In order to eliminate overheads coming from the compilation, it uses cache of kernels, with a key |
| being the name of the kernel (<code class="highlighter-rouge">"unary_kernel"</code> in our case) and the set of parameters (<code class="highlighter-rouge">code</code> in our |
| case). If the kernel is already in cache, it is returned, otherwise compilation takes place. If it |
| fails, the full source code is saved to disk and the MXNet error with the compilation log is |
| generated.</p> |
| |
| <p>To launch the kernel, the <code class="highlighter-rouge">mxnet::common::cuda::rtc::launch</code> method is used:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">launch</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">{</span><span class="n">n_blocks</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span> <span class="p">{</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span> |
| <span class="n">shared_memory_size</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="o">&</span><span class="n">args</span><span class="p">);</span> |
| </code></pre></div></div> |
| |
| <p>It takes the kernel object, grid and block dimensions, size of dynamic shared memory, stream and |
| kernel parameters.</p> |
| |
| <h2 id="other-features-enabled-by-rtc">Other features enabled by RTC</h2> |
| |
| <h3 id="vectorization">Vectorization</h3> |
| |
| <p>The actual kernel used for application of unary operator in MXNet looks slightly different compared |
| to the simple example shown in the previous paragraph. Differences come from using vectorization. |
| This means, that instead of reading (or writing) 1 element at a time, kernel instead accesses |
| multiple array elements at once. This is beneficial, especially when dealing with smaller |
| types like <code class="highlighter-rouge">float16</code> or <code class="highlighter-rouge">int8</code>. Accessing those small types one by one is inefficient and does not |
| saturate the memory bandwidth of the GPU, so using vector accesses improves achieved memory |
| bandwidth.</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> |
| <span class="c1">// excerpt from src/operator/tensor/elemwise_unary_op.h</span> |
| <span class="k">struct</span> <span class="n">UnaryRTCCompute</span> <span class="p">{</span> |
| <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">OP</span><span class="p">;</span> |
| |
| <span class="kt">void</span> <span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">);</span> |
| <span class="p">};</span> |
| |
| <span class="c1">// excerpt from src/operator/tensor/elemwise_unary_op.cc</span> |
| <span class="k">struct</span> <span class="n">unary_kernel_params</span> <span class="p">{</span> |
| <span class="k">const</span> <span class="kt">void</span> <span class="o">*</span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span> |
| <span class="kt">void</span> <span class="o">*</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span> |
| <span class="p">};</span> |
| |
| <span class="k">const</span> <span class="kt">char</span> <span class="n">unary_kernel_fwd</span><span class="p">[]</span> <span class="o">=</span> <span class="s">R"code( |
| |
| struct unary_kernel_params { |
| const void *inputs[1]; |
| void *outputs[1]; |
| }; |
| |
| __launch_bounds__(kRTCMaxThreadsPerBlock) |
| __global__ void unary_kernel(const unary_kernel_params params, |
| const index_t lead_dim, |
| const index_t other_dim, |
| const index_t N, |
| const index_t num_aligned_elements) { |
| using namespace vector; |
| VectorizedLoader<InputType0, nvec, aligned> loader( |
| reinterpret_cast<const InputType0*>(params.inputs[0]), N); |
| VectorizedStorer<OutputType0, nvec, aligned> storer( |
| reinterpret_cast<OutputType0*>(params.outputs[0]), N); |
| |
| using IType = AccType<InputType0>; |
| using OType = AccType<OutputType0>; |
| |
| const index_t M = num_aligned_elements; |
| |
| for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; |
| tid < M; |
| tid += gridDim.x * blockDim.x) { |
| loader.load(tid, N); |
| if (req == OpReqType::kAddTo) { |
| storer.load(tid, N); |
| } |
| #pragma unroll |
| for (int i = 0; i < nvec; ++i) { |
| const auto input = IType::from(loader.separate()[i]); |
| const auto temp = OP(input); // enables returning different type |
| |
| if (req == OpReqType::kAddTo) { |
| // temp2 may have a wider type than either temp |
| // or OType |
| const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); |
| storer.separate()[i] = OType::to(temp2); |
| } else { |
| storer.separate()[i] = OType::to(temp); |
| } |
| } |
| storer.store(tid, N); |
| } |
| } |
| |
| )code"</span><span class="p">;</span> |
| |
| <span class="kt">void</span> <span class="n">UnaryRTCCompute</span><span class="o">::</span><span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&</span> <span class="n">attrs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">OpContext</span><span class="o">&</span> <span class="n">ctx</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">inputs</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">OpReqType</span><span class="o">>&</span> <span class="n">req</span><span class="p">,</span> |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">TBlob</span><span class="o">>&</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span> |
| <span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">common</span><span class="o">::</span><span class="n">cuda</span><span class="o">::</span><span class="n">rtc</span><span class="p">;</span> |
| <span class="k">if</span> <span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">kNullOp</span><span class="p">)</span> <span class="k">return</span><span class="p">;</span> |
| <span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o"><</span><span class="n">gpu</span><span class="o">>*</span> <span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o"><</span><span class="n">gpu</span><span class="o">></span><span class="p">();</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| <span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span> |
| |
| <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span> |
| <span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span> |
| <span class="s">";</span><span class="se">\n</span><span class="s">"</span> |
| <span class="s">"#define OP op::"</span> <span class="o">+</span> |
| <span class="n">OP</span> <span class="o">+</span> |
| <span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span> |
| <span class="k">const</span> <span class="kt">int</span> <span class="n">nvec</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span> <span class="o">==</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">kFloat64</span> <span class="o">?</span> <span class="mi">2</span> <span class="o">:</span> <span class="mi">4</span><span class="p">;</span> |
| |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">size</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">Size</span><span class="p">();</span> |
| <span class="n">unary_kernel_params</span> <span class="n">params</span> <span class="o">=</span> <span class="p">{</span> <span class="p">{</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">},</span> |
| <span class="p">{</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">}</span> <span class="p">};</span> |
| |
| <span class="n">VectorizedKernelRTCLauncher</span><span class="p">(</span><span class="n">code</span><span class="p">,</span> <span class="s">"unary_kernel"</span><span class="p">,</span> |
| <span class="n">unary_kernel_fwd</span><span class="p">,</span> <span class="n">nvec</span><span class="p">,</span> |
| <span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> |
| <span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> |
| <span class="n">ctx</span><span class="p">.</span><span class="n">run_ctx</span><span class="p">.</span><span class="n">get_ctx</span><span class="p">().</span><span class="n">dev_id</span><span class="p">);</span> |
| <span class="p">}</span> |
| |
| <span class="c1">// excerpt from src/operator/tensor/elemwise_unary_op_basic.cu</span> |
| <span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span> |
| <span class="p">.</span><span class="n">set_attr</span><span class="o"><</span><span class="n">FCompute</span><span class="o">></span><span class="p">(</span><span class="s">"FCompute<gpu>"</span><span class="p">,</span> <span class="n">UnaryRTCCompute</span><span class="p">{</span><span class="s">"sigmoid"</span><span class="p">});</span> |
| </code></pre></div></div> |
| |
| <p>RTC implementation in MXNet provides a few useful helper functions and classes, which simplify the |
| process of writing and launching kernels using vectorization. For accessing the memory using |
| vectorization, 2 classes are provided, used in this kernel to access input and output array:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">VectorizedLoader</span><span class="o"><</span><span class="n">InputType0</span><span class="p">,</span> <span class="n">nvec</span><span class="p">,</span> <span class="n">aligned</span><span class="o">></span> <span class="n">loader</span><span class="p">(</span> |
| <span class="k">reinterpret_cast</span><span class="o"><</span><span class="k">const</span> <span class="n">InputType0</span><span class="o">*></span><span class="p">(</span><span class="n">params</span><span class="p">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">N</span><span class="p">);</span> |
| <span class="n">VectorizedStorer</span><span class="o"><</span><span class="n">OutputType0</span><span class="p">,</span> <span class="n">nvec</span><span class="p">,</span> <span class="n">aligned</span><span class="o">></span> <span class="n">storer</span><span class="p">(</span> |
| <span class="k">reinterpret_cast</span><span class="o"><</span><span class="n">OutputType0</span><span class="o">*></span><span class="p">(</span><span class="n">params</span><span class="p">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">N</span><span class="p">);</span> |
| </code></pre></div></div> |
| |
| <p>The <code class="highlighter-rouge">loader</code> object accesses <code class="highlighter-rouge">params.inputs[0]</code> pointer to array of N elements having type |
| <code class="highlighter-rouge">InputType0</code> (which is the name assigned to the type of the first input by the |
| <code class="highlighter-rouge">VectorizedKernelRTCLauncher</code>, which is the helper launcher function). It loads <code class="highlighter-rouge">nvec</code> elements at |
| a time and has additional <code class="highlighter-rouge">aligned</code> option, which is also set by the <code class="highlighter-rouge">VectorizedKernelRTCLauncher</code>. |
| Similarly <code class="highlighter-rouge">storer</code> object is used to write data of type <code class="highlighter-rouge">OutputType0</code> to <code class="highlighter-rouge">params.outputs[0]</code>.</p> |
| |
| <p>The kernel using <code class="highlighter-rouge">VectorizedKernelRTCLauncher</code> needs to have specific parameters:</p> |
| |
| <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">unary_kernel</span><span class="p">(</span><span class="k">const</span> <span class="n">unary_kernel_params</span> <span class="n">params</span><span class="p">,</span> <span class="c1">// kernel-specific parameters</span> |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">lead_dim</span><span class="p">,</span> <span class="c1">// lead dimension of the tensor</span> |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">other_dim</span><span class="p">,</span> <span class="c1">// size of the other dimensions</span> |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">N</span><span class="p">,</span> <span class="c1">// total number of elements</span> |
| <span class="k">const</span> <span class="n">index_t</span> <span class="n">num_aligned_elements</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// number of vector elements in</span> |
| <span class="c1">// lead dimension</span> |
| </code></pre></div></div> |
| |
| </div> |
| </div> |
| |
| </div> |
| </div> |
| |
| </article> |
| |
| </main><footer class="site-footer h-card"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-4"> |
| <h4 class="footer-category-title">Resources</h4> |
| <ul class="contact-list"> |
| <li><a href="/versions/master/community#stay-connected">Mailing lists</a></li> |
| <li><a href="/versions/master/community#github-issues">Github Issues</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/projects">Projects</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a href="https://discuss.mxnet.io">Forum</a></li> |
| <li><a href="/versions/master/community">Contribute To MXNet</a></li> |
| </ul> |
| </div> |
| |
| <div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul> |
| </div> |
| |
| <div class="col-4 footer-text"> |
| <p>A flexible and efficient library for deep learning.</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| <footer class="site-footer2"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3"> |
| <img src="/versions/master/assets/img/apache_incubator_logo.png" class="footer-logo col-2"> |
| </div> |
| <div class="footer-bottom-warning col-9"> |
| <p>Apache MXNet is an effort undergoing incubation at <a href="http://www.apache.org/">The Apache Software Foundation</a> (ASF), <span |
| style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required |
| of all newly accepted projects until a further review indicates that the infrastructure, |
| communications, and decision making process have stabilized in a manner consistent with other |
| successful ASF projects. While incubation status is not necessarily a reflection of the completeness |
| or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p><p>"Copyright © 2017-2022, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache |
| feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the |
| Apache Software Foundation."</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| |
| |
| |
| |
| </body> |
| |
| </html> |