blob: 7e882d11dbbd71f1d60f11934578fd610164d48c [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/master/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Using runtime compilation (RTC) to write CUDA kernels in MXNet | Apache MXNet</title>
<meta name="generator" content="Jekyll v4.0.0" />
<meta property="og:title" content="Using runtime compilation (RTC) to write CUDA kernels in MXNet" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/master/api/faq/using_rtc" />
<meta property="og:url" content="https://mxnet.apache.org/versions/master/api/faq/using_rtc" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/master/api/faq/using_rtc","headline":"Using runtime compilation (RTC) to write CUDA kernels in MXNet","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/master/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/master/assets/js/docsearch.min.js"></script><script src="/versions/master/assets/js/globalSearch.js" defer></script>
<script src="/versions/master/assets/js/clipboard.js" defer></script>
<script src="/versions/master/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/master/"><img
src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">master</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/master/get_started">Get Started</a>
<a class="page-link" href="/versions/master/features">Features</a>
<a class="page-link" href="/versions/master/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/master/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/master/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/master/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">master
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option-active" href="/">master</a>
<a href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Using runtime compilation (RTC) to write CUDA kernels in MXNet</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<h3 style="text-transform: capitalize; padding-left:10px">faq</h3>
<ul>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/add_op_in_backend">A Beginner's Guide to Implementing Operators in MXNet Backend</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/cloud">MXNet on the Cloud</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/distributed_training">Distributed Training in MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/env_var">Environment Variables</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/float16">Float16</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/large_tensor_support">Using MXNet with Large Tensor Support</a></li>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/model_parallel_lstm">Model Parallel</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/new_op">Create New Operators</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/perf">Some Tips for Improving MXNet Performance</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/recordio">Create a Dataset Using RecordIO</a></li>
<!-- page-category -->
<li><a href="/versions/master/api/faq/s3_integration">Use data from S3 for training</a></li>
<!-- page-category -->
<li><a href="/versions/master/api/faq/security">MXNet Security Best Practices</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/tensor_inspector_tutorial">Use TensorInspector to Help Debug Operators</a></li>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/using_rtc">Using runtime compilation (RTC) to write CUDA kernels in MXNet</a></li>
<!-- page-category -->
<!-- page-category -->
<li><a href="/versions/master/api/faq/why_mxnet">Why MXNet came to be?</a></li>
<!-- page-category -->
<!-- page-category -->
<!-- page-category -->
<!-- resource-p -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="using-runtime-compilation-rtc-to-write-cuda-kernels-in-mxnet">Using runtime compilation (RTC) to write CUDA kernels in MXNet</h1>
<h2 id="introduction">Introduction</h2>
<p>CUDA kernel is a function running on the GPU to perform computation. This tutorial assumes the
reader has a basic knowledge about how to write such kernels.</p>
<p>There are currently 2 typical ways of writing and launching CUDA kernels in MXNet. The first one is
to use the <code class="highlighter-rouge">Kernel&lt;...&gt;::Launch()</code> API, which is suitable for simple elementwise operations and
enables writing only portion of the kernel, leaving the launch mechanism to MXNet. The
other one is to write a kernel from scratch and launch it using the <code class="highlighter-rouge">&lt;&lt;&lt;...&gt;&gt;&gt;</code> method from CUDA.
Starting from MXNet 2.0, there is a third option - runtime compilation (RTC). This differs from the
previous methods (which use kernels compiled ahead of time), as it compiles the needed kernels
during runtime of the user script.</p>
<p>In this tutorial we will cover the reasons for using RTC instead of the other methods, show how to
do it, as well as tips on what to keep in mind when doing it.</p>
<h2 id="why-rtc">Why RTC?</h2>
<h3 id="problems-with-kernels-compiled-ahead-of-time">Problems with kernels compiled ahead of time</h3>
<p>The use of kernels compiled ahead of time in MXNet leads to a few problems, which unfortunately
are mostly invisible in any single PR, but grow over the course of many contributions and result in
serious issues.</p>
<p>In order to understand them, let us look at the typical way kernels are launched in MXNet. This
example shows a launch of the simple kernel, taking a single input of type <code class="highlighter-rouge">DType</code> and producing
single output of type <code class="highlighter-rouge">OType</code>:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">DType</span><span class="p">,</span> <span class="p">{</span>
<span class="n">MSHADOW_TYPE_SWITCH</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">,</span> <span class="n">OType</span><span class="p">,</span> <span class="p">{</span>
<span class="n">Kernel</span><span class="o">&lt;</span><span class="p">...</span><span class="o">&gt;::</span><span class="n">Launch</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr</span><span class="o">&lt;</span><span class="n">DType</span><span class="o">&gt;</span><span class="p">(),</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr</span><span class="o">&lt;</span><span class="n">OType</span><span class="o">&gt;</span><span class="p">());</span>
<span class="p">});</span>
<span class="p">});</span>
</code></pre></div></div>
<p>This launch mechanism uses the <code class="highlighter-rouge">MSHADOW_TYPE_SWITCH</code> macro, which produces a version of the kernel
for every possible type. In the case of nested usage (as is the case in the example shown) it
produces a version of the kernel for every combination of types. This results in a large number of
kernels being generated.</p>
<p>Another factor that multiplies the number of kernels is that different GPU architectures require
different compiled binaries. Therefore for MXNet to support all of them with a single binary, that
binary needs to contain copies of those kernels for each architecure.</p>
<p>This proliferation of CUDA kernels in the binary leads to multiple issues. The first problem is the
size of the MXNet library - each compiled version of the kernel takes some space in the binary,
which is small but multiplied by the number of all versions (which could reach thousands per
GPU architecture) and GPU architectures. This increase in size led to multiple issues reported with
distribution of the MXNet package,
<a href="https://github.com/apache/incubator-mxnet/issues/17045">building the library</a> as well as
<a href="https://github.com/apache/incubator-mxnet/pull/18205">limiting the number of architectures natively
supported</a>.</p>
<p>The second issue is the “idle” memory consumption of the MXNet library. In order to efficiently
launch kernels when they are called, CUDA driver needs to transfer them to the GPU memory ahead of
time. Since it cannot anticipate which kernels will actually be used, all of the kernels are
transferred when the CUDA context is created on a GPU. This means that, even if a user never uses
e.g. kernel which adds <code class="highlighter-rouge">int8</code> and <code class="highlighter-rouge">float16</code> tensors, that kernel still occupies memory on their GPU,
reducing the amount of memory available for useful work.</p>
<p>The third issue, mostly affecting MXNet developers, is the compilation time of the MXNet library.
The more kernels versions need to be compiled, the more time and hardware resources is needed.</p>
<h3 id="rtc-to-the-rescue">RTC to the rescue!</h3>
<p>All of the issues mentioned in the previous paragraph are solved when using runtime compilation.
Using this paradigm, only the kernels actually invoked in the user script are compiled. They do not
occupy space in the MXNet binary and there is no unused kernels stored in users’ GPU memory.</p>
<p>RTC also enables more features:</p>
<ul>
<li>using more information about specific usage of the kernel when compiling it (e.g. using shape
information of the inputs) to optimize it better</li>
<li>writing kernels accepting any combinations of input and output types</li>
<li>(in the future) fusing more operations into the generated kernels.</li>
</ul>
<h2 id="rtc-for-kernel-developers">RTC for kernel developers</h2>
<h3 id="example-unary-operators">Example: unary operators</h3>
<p>Let us start with an example of the simple kernel written using RTC: a kernel which performs unary
operation (with a concrete example of sigmoid) on its input. It is not a toy example though: it is
a fully generic kernel, capable of operating on any combination of input and output types, as well
as applying any unary operator:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">struct</span> <span class="n">UnaryRTCCompute</span> <span class="p">{</span>
<span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">OP</span><span class="p">;</span>
<span class="kt">void</span> <span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">inputs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;&amp;</span> <span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">outputs</span><span class="p">);</span>
<span class="p">};</span>
<span class="k">const</span> <span class="kt">char</span> <span class="n">unary_kernel_fwd</span><span class="p">[]</span> <span class="o">=</span> <span class="s">R"code(
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void unary_kernel(const InputType* input,
const OutputType* output,
const index_t N) {
using IType = AccType&lt;InputType&gt;;
using OType = AccType&lt;OutputType&gt;;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid &lt; N;
tid += gridDim.x * blockDim.x) {
const auto input = IType::from(input[i]);
const auto temp = OP(input); // enables returning different type
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType::from(output[i]));
output[i] = OType::to(temp2);
} else {
output[i] = OType::to(temp);
}
}
}
)code"</span><span class="p">;</span>
<span class="kt">void</span> <span class="n">UnaryRTCCompute</span><span class="o">::</span><span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">inputs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;&amp;</span> <span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">common</span><span class="o">::</span><span class="n">cuda</span><span class="o">::</span><span class="n">rtc</span><span class="p">;</span>
<span class="k">if</span> <span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">kNullOp</span><span class="p">)</span> <span class="k">return</span><span class="p">;</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o">&lt;</span><span class="n">gpu</span><span class="o">&gt;*</span> <span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o">&lt;</span><span class="n">gpu</span><span class="o">&gt;</span><span class="p">();</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span>
<span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"#define OP op::"</span> <span class="o">+</span>
<span class="n">OP</span> <span class="o">+</span>
<span class="s">"</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"using InputType = "</span> <span class="o">+</span>
<span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"using OutputType = "</span> <span class="o">+</span>
<span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="k">const</span> <span class="kt">void</span><span class="o">*&gt;</span> <span class="n">args</span><span class="p">;</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">size</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">Size</span><span class="p">();</span>
<span class="n">args</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">));</span>
<span class="n">args</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">));</span>
<span class="n">args</span><span class="p">.</span><span class="n">emplace_back</span><span class="p">(</span><span class="o">&amp;</span><span class="n">size</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">get_function</span><span class="p">(</span><span class="n">code</span><span class="p">,</span> <span class="s">"unary_kernel"</span><span class="p">,</span> <span class="n">unary_kernel_fwd</span><span class="p">,</span>
<span class="n">ctx</span><span class="p">.</span><span class="n">run_ctx</span><span class="p">.</span><span class="n">get_ctx</span><span class="p">().</span><span class="n">dev_id</span><span class="p">);</span>
<span class="k">const</span> <span class="kt">int</span> <span class="n">n_threads</span> <span class="o">=</span> <span class="mi">512</span><span class="p">;</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">n_blocks</span> <span class="o">=</span> <span class="p">(</span><span class="n">size</span> <span class="o">+</span> <span class="n">n_threads</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">n_threads</span><span class="p">;</span>
<span class="k">const</span> <span class="kt">int</span> <span class="n">shared_memory_size</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="n">launch</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">{</span><span class="n">n_blocks</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span> <span class="p">{</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span>
<span class="n">shared_memory_size</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">args</span><span class="p">);</span>
<span class="p">}</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">FCompute</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FCompute&lt;gpu&gt;"</span><span class="p">,</span> <span class="n">UnaryRTCCompute</span><span class="p">{</span><span class="s">"sigmoid"</span><span class="p">});</span>
</code></pre></div></div>
<h3 id="kernels-are-text">Kernels are text…</h3>
<p>The main difference when writing kernels using RTC is that the kernel code becomes the text string.
This means that it is possible to change or compose the code at runtime, as is done here:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span>
<span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"#define OP op::"</span> <span class="o">+</span>
<span class="n">OP</span> <span class="o">+</span>
<span class="s">"</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"using InputType = "</span> <span class="o">+</span>
<span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"using OutputType = "</span> <span class="o">+</span>
<span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span>
</code></pre></div></div>
<p>where the operation <code class="highlighter-rouge">OP</code> is also provided as a string in the operator declaration:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">FCompute</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FCompute&lt;gpu&gt;"</span><span class="p">,</span> <span class="n">UnaryRTCCompute</span><span class="p">{</span><span class="s">"sigmoid"</span><span class="p">});</span>
</code></pre></div></div>
<h3 id="and-do-not-know-mxnet-source-code">and do not know MXNet source code</h3>
<p>How does the kernel know what operation it should perform? The kernel’s source code uses <code class="highlighter-rouge">OP</code>,
which shows up in the <code class="highlighter-rouge">code</code> variable and is equal to <code class="highlighter-rouge">op::sigmoid</code>. Let us compare this to how the
same operator is defined for CPU:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">MXNET_OPERATOR_REGISTER_UNARY</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">FCompute</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FCompute&lt;cpu&gt;"</span><span class="p">,</span> <span class="n">UnaryOp</span><span class="o">::</span><span class="n">Compute</span><span class="o">&lt;</span><span class="n">cpu</span><span class="p">,</span> <span class="n">mshadow_op</span><span class="o">::</span><span class="n">sigmoid</span><span class="o">&gt;</span><span class="p">)</span>
</code></pre></div></div>
<p>Since the kernel is compiled at runtime, it does not have access to the rest of the MXNet source
code, including <code class="highlighter-rouge">mshadow_op.h</code>, which defined <code class="highlighter-rouge">mshadow_op::sigmoid</code>. This means that we need to
provide the kernel with definitions of those functions (again, in text string form). Every
RTC-compiled kernel is prepended with a common header, containing string found in
<code class="highlighter-rouge">src/common/cuda/rtc/</code> directory. The <code class="highlighter-rouge">src/common/cuda/rtc/forward_functions-inl.h</code> file contains
the definition of <code class="highlighter-rouge">op::sigmoid</code>:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">DType</span><span class="o">&gt;</span>
<span class="n">__device__</span> <span class="kr">inline</span> <span class="n">DType</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="k">const</span> <span class="n">DType</span> <span class="n">val</span><span class="p">)</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="n">type_util</span><span class="o">::</span><span class="n">has_double_or_integral</span><span class="o">&lt;</span><span class="n">DType</span><span class="o">&gt;::</span><span class="n">value</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="mf">1.</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="o">::</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">val</span><span class="p">));</span>
<span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
<span class="k">return</span> <span class="mf">1.</span><span class="n">f</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">expf</span><span class="p">(</span><span class="o">-</span><span class="n">val</span><span class="p">));</span>
<span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<h3 id="handling-of-data-types">Handling of data types</h3>
<p>MXNet has support for many datatypes. Some of those datatypes, like <code class="highlighter-rouge">float16</code>, <code class="highlighter-rouge">int8</code> or <code class="highlighter-rouge">bool</code> are
useful when storing the results, but in many computations they are too limiting as they can easily
overflow in the intermediate stages. That is why in the example we use <code class="highlighter-rouge">AccType&lt;T&gt;</code> class - it
provides an accumulation type, that is potentially larger than the storage type - for example,
<code class="highlighter-rouge">AccType&lt;float16&gt;::type</code> is <code class="highlighter-rouge">float32</code>. It also provides special loading and storing functions:
<code class="highlighter-rouge">AccType&lt;T&gt;::from()</code> and <code class="highlighter-rouge">AccType&lt;T&gt;::to()</code>.</p>
<p>One of the features of RTC-enabled kernels is to be able to accommodate any combination of the
input and output datatypes. Using <code class="highlighter-rouge">auto</code> as the output type of the intermediate steps helps with,
especially since many binary operators return a mixed type:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">template</span> <span class="o">&lt;</span><span class="k">typename</span> <span class="n">DType</span><span class="p">,</span> <span class="k">typename</span> <span class="n">DType2</span><span class="o">&gt;</span>
<span class="n">__device__</span> <span class="kr">inline</span> <span class="k">typename</span> <span class="n">type_util</span><span class="o">::</span><span class="n">mixed_type</span><span class="o">&lt;</span><span class="n">DType</span><span class="p">,</span> <span class="n">DType2</span><span class="o">&gt;::</span><span class="n">type</span>
<span class="nf">add</span><span class="p">(</span><span class="k">const</span> <span class="n">DType</span> <span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">DType2</span> <span class="n">b</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="n">a</span> <span class="o">+</span> <span class="n">b</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>
<p><code class="highlighter-rouge">mixed_type&lt;T, U&gt;::type</code> is a type capable of storing value of the operation between 2 types <code class="highlighter-rouge">T</code> and
<code class="highlighter-rouge">U</code> - e.g. <code class="highlighter-rouge">mixed_type&lt;float64, float32&gt;::type = float64</code> and <code class="highlighter-rouge">mixed_type&lt;float32, int32&gt;::type =
float32</code>.</p>
<h3 id="compiling-and-launching-rtc-kernels">Compiling and launching RTC kernels</h3>
<p>The kernel code stored in <code class="highlighter-rouge">unary_kernel_fwd</code> is generic and relies on multiple names to be defined,
like <code class="highlighter-rouge">req</code>, <code class="highlighter-rouge">OP</code> or <code class="highlighter-rouge">InputType</code>. This is handled in the specific operator using the kernel by
defining a set of parameters that will be concatenated to the code during compilation:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span>
<span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"#define OP op::"</span> <span class="o">+</span>
<span class="n">OP</span> <span class="o">+</span>
<span class="s">"</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"using InputType = "</span> <span class="o">+</span>
<span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"using OutputType = "</span> <span class="o">+</span>
<span class="n">common</span><span class="o">::</span><span class="n">mshadow_type_info</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span><span class="p">).</span><span class="n">name</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span>
</code></pre></div></div>
<p>In order to compile the kernel, the <code class="highlighter-rouge">mxnet::common::cuda::rtc::get_function</code> method is used:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="k">auto</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">get_function</span><span class="p">(</span><span class="n">code</span><span class="p">,</span> <span class="s">"unary_kernel"</span><span class="p">,</span> <span class="n">unary_kernel_fwd</span><span class="p">,</span>
<span class="n">ctx</span><span class="p">.</span><span class="n">run_ctx</span><span class="p">.</span><span class="n">get_ctx</span><span class="p">().</span><span class="n">dev_id</span><span class="p">);</span>
</code></pre></div></div>
<p>In order to eliminate overheads coming from the compilation, it uses cache of kernels, with a key
being the name of the kernel (<code class="highlighter-rouge">"unary_kernel"</code> in our case) and the set of parameters (<code class="highlighter-rouge">code</code> in our
case). If the kernel is already in cache, it is returned, otherwise compilation takes place. If it
fails, the full source code is saved to disk and the MXNet error with the compilation log is
generated.</p>
<p>To launch the kernel, the <code class="highlighter-rouge">mxnet::common::cuda::rtc::launch</code> method is used:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">launch</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">{</span><span class="n">n_blocks</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span> <span class="p">{</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span>
<span class="n">shared_memory_size</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="o">&amp;</span><span class="n">args</span><span class="p">);</span>
</code></pre></div></div>
<p>It takes the kernel object, grid and block dimensions, size of dynamic shared memory, stream and
kernel parameters.</p>
<h2 id="other-features-enabled-by-rtc">Other features enabled by RTC</h2>
<h3 id="vectorization">Vectorization</h3>
<p>The actual kernel used for application of unary operator in MXNet looks slightly different compared
to the simple example shown in the previous paragraph. Differences come from using vectorization.
This means, that instead of reading (or writing) 1 element at a time, kernel instead accesses
multiple array elements at once. This is beneficial, especially when dealing with smaller
types like <code class="highlighter-rouge">float16</code> or <code class="highlighter-rouge">int8</code>. Accessing those small types one by one is inefficient and does not
saturate the memory bandwidth of the GPU, so using vector accesses improves achieved memory
bandwidth.</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="c1">// excerpt from src/operator/tensor/elemwise_unary_op.h</span>
<span class="k">struct</span> <span class="n">UnaryRTCCompute</span> <span class="p">{</span>
<span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">OP</span><span class="p">;</span>
<span class="kt">void</span> <span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">inputs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;&amp;</span> <span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">outputs</span><span class="p">);</span>
<span class="p">};</span>
<span class="c1">// excerpt from src/operator/tensor/elemwise_unary_op.cc</span>
<span class="k">struct</span> <span class="n">unary_kernel_params</span> <span class="p">{</span>
<span class="k">const</span> <span class="kt">void</span> <span class="o">*</span><span class="n">inputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="kt">void</span> <span class="o">*</span><span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">];</span>
<span class="p">};</span>
<span class="k">const</span> <span class="kt">char</span> <span class="n">unary_kernel_fwd</span><span class="p">[]</span> <span class="o">=</span> <span class="s">R"code(
struct unary_kernel_params {
const void *inputs[1];
void *outputs[1];
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void unary_kernel(const unary_kernel_params params,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
VectorizedLoader&lt;InputType0, nvec, aligned&gt; loader(
reinterpret_cast&lt;const InputType0*&gt;(params.inputs[0]), N);
VectorizedStorer&lt;OutputType0, nvec, aligned&gt; storer(
reinterpret_cast&lt;OutputType0*&gt;(params.outputs[0]), N);
using IType = AccType&lt;InputType0&gt;;
using OType = AccType&lt;OutputType0&gt;;
const index_t M = num_aligned_elements;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid &lt; M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
if (req == OpReqType::kAddTo) {
storer.load(tid, N);
}
#pragma unroll
for (int i = 0; i &lt; nvec; ++i) {
const auto input = IType::from(loader.separate()[i]);
const auto temp = OP(input); // enables returning different type
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
storer.separate()[i] = OType::to(temp2);
} else {
storer.separate()[i] = OType::to(temp);
}
}
storer.store(tid, N);
}
}
)code"</span><span class="p">;</span>
<span class="kt">void</span> <span class="n">UnaryRTCCompute</span><span class="o">::</span><span class="k">operator</span><span class="p">()(</span><span class="k">const</span> <span class="n">nnvm</span><span class="o">::</span><span class="n">NodeAttrs</span><span class="o">&amp;</span> <span class="n">attrs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">OpContext</span><span class="o">&amp;</span> <span class="n">ctx</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">inputs</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">OpReqType</span><span class="o">&gt;&amp;</span> <span class="n">req</span><span class="p">,</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="n">TBlob</span><span class="o">&gt;&amp;</span> <span class="n">outputs</span><span class="p">)</span> <span class="p">{</span>
<span class="k">using</span> <span class="k">namespace</span> <span class="n">mxnet</span><span class="o">::</span><span class="n">common</span><span class="o">::</span><span class="n">cuda</span><span class="o">::</span><span class="n">rtc</span><span class="p">;</span>
<span class="k">if</span> <span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">kNullOp</span><span class="p">)</span> <span class="k">return</span><span class="p">;</span>
<span class="n">mshadow</span><span class="o">::</span><span class="n">Stream</span><span class="o">&lt;</span><span class="n">gpu</span><span class="o">&gt;*</span> <span class="n">s</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">get_stream</span><span class="o">&lt;</span><span class="n">gpu</span><span class="o">&gt;</span><span class="p">();</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">inputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="n">CHECK_EQ</span><span class="p">(</span><span class="n">outputs</span><span class="p">.</span><span class="n">size</span><span class="p">(),</span> <span class="mi">1U</span><span class="p">);</span>
<span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span> <span class="n">code</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="p">(</span><span class="s">"const OpReqType req = "</span><span class="p">)</span> <span class="o">+</span>
<span class="n">util</span><span class="o">::</span><span class="n">to_string</span><span class="p">(</span><span class="n">req</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">+</span>
<span class="s">";</span><span class="se">\n</span><span class="s">"</span>
<span class="s">"#define OP op::"</span> <span class="o">+</span>
<span class="n">OP</span> <span class="o">+</span>
<span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">;</span>
<span class="k">const</span> <span class="kt">int</span> <span class="n">nvec</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">type_flag_</span> <span class="o">==</span> <span class="n">mshadow</span><span class="o">::</span><span class="n">kFloat64</span> <span class="o">?</span> <span class="mi">2</span> <span class="o">:</span> <span class="mi">4</span><span class="p">;</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">size</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">Size</span><span class="p">();</span>
<span class="n">unary_kernel_params</span> <span class="n">params</span> <span class="o">=</span> <span class="p">{</span> <span class="p">{</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">},</span>
<span class="p">{</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">dptr_</span><span class="p">}</span> <span class="p">};</span>
<span class="n">VectorizedKernelRTCLauncher</span><span class="p">(</span><span class="n">code</span><span class="p">,</span> <span class="s">"unary_kernel"</span><span class="p">,</span>
<span class="n">unary_kernel_fwd</span><span class="p">,</span> <span class="n">nvec</span><span class="p">,</span>
<span class="n">size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">s</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span>
<span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span>
<span class="n">ctx</span><span class="p">.</span><span class="n">run_ctx</span><span class="p">.</span><span class="n">get_ctx</span><span class="p">().</span><span class="n">dev_id</span><span class="p">);</span>
<span class="p">}</span>
<span class="c1">// excerpt from src/operator/tensor/elemwise_unary_op_basic.cu</span>
<span class="n">NNVM_REGISTER_OP</span><span class="p">(</span><span class="n">sigmoid</span><span class="p">)</span>
<span class="p">.</span><span class="n">set_attr</span><span class="o">&lt;</span><span class="n">FCompute</span><span class="o">&gt;</span><span class="p">(</span><span class="s">"FCompute&lt;gpu&gt;"</span><span class="p">,</span> <span class="n">UnaryRTCCompute</span><span class="p">{</span><span class="s">"sigmoid"</span><span class="p">});</span>
</code></pre></div></div>
<p>RTC implementation in MXNet provides a few useful helper functions and classes, which simplify the
process of writing and launching kernels using vectorization. For accessing the memory using
vectorization, 2 classes are provided, used in this kernel to access input and output array:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">VectorizedLoader</span><span class="o">&lt;</span><span class="n">InputType0</span><span class="p">,</span> <span class="n">nvec</span><span class="p">,</span> <span class="n">aligned</span><span class="o">&gt;</span> <span class="n">loader</span><span class="p">(</span>
<span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="k">const</span> <span class="n">InputType0</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">params</span><span class="p">.</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">N</span><span class="p">);</span>
<span class="n">VectorizedStorer</span><span class="o">&lt;</span><span class="n">OutputType0</span><span class="p">,</span> <span class="n">nvec</span><span class="p">,</span> <span class="n">aligned</span><span class="o">&gt;</span> <span class="n">storer</span><span class="p">(</span>
<span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">OutputType0</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">params</span><span class="p">.</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">N</span><span class="p">);</span>
</code></pre></div></div>
<p>The <code class="highlighter-rouge">loader</code> object accesses <code class="highlighter-rouge">params.inputs[0]</code> pointer to array of N elements having type
<code class="highlighter-rouge">InputType0</code> (which is the name assigned to the type of the first input by the
<code class="highlighter-rouge">VectorizedKernelRTCLauncher</code>, which is the helper launcher function). It loads <code class="highlighter-rouge">nvec</code> elements at
a time and has additional <code class="highlighter-rouge">aligned</code> option, which is also set by the <code class="highlighter-rouge">VectorizedKernelRTCLauncher</code>.
Similarly <code class="highlighter-rouge">storer</code> object is used to write data of type <code class="highlighter-rouge">OutputType0</code> to <code class="highlighter-rouge">params.outputs[0]</code>.</p>
<p>The kernel using <code class="highlighter-rouge">VectorizedKernelRTCLauncher</code> needs to have specific parameters:</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">unary_kernel</span><span class="p">(</span><span class="k">const</span> <span class="n">unary_kernel_params</span> <span class="n">params</span><span class="p">,</span> <span class="c1">// kernel-specific parameters</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">lead_dim</span><span class="p">,</span> <span class="c1">// lead dimension of the tensor</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">other_dim</span><span class="p">,</span> <span class="c1">// size of the other dimensions</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">N</span><span class="p">,</span> <span class="c1">// total number of elements</span>
<span class="k">const</span> <span class="n">index_t</span> <span class="n">num_aligned_elements</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// number of vector elements in</span>
<span class="c1">// lead dimension</span>
</code></pre></div></div>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/master/community#stay-connected">Mailing lists</a></li>
<li><a href="/versions/master/community#github-issues">Github Issues</a></li>
<li><a href="https://github.com/apache/mxnet/projects">Projects</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/master/community">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/master/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>