blob: 77811635845844670e3d147a2bde2fa2ed1a3850 [file] [log] [blame]
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>How to optimize convolution on GPU &mdash; tvm 0.17.dev0 documentation</title>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" />
<link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/>
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="How to optimize convolution using TensorCores" href="opt_conv_tensorcore.html" />
<link rel="prev" title="How to optimize GEMM on CPU" href="opt_gemm.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<header class="header">
<div class="innercontainer">
<div class="headerInner d-flex justify-content-between align-items-center">
<div class="headerLogo">
<a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a>
</div>
<div id="headMenu" class="headerNav">
<button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button>
<ul class="nav">
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/community>Community</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/download>Download</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/vta>VTA</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/blog>Blog</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/docs>Docs</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvmconf.org>Conference</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://github.com/apache/tvm/>Github</a>
</li>
</ul>
<div class="responsivetlcdropdown">
<button type="button" class="btn-link">
ASF
</button>
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
<div class="responsiveMenuIcon">
<button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button>
</div>
<div class="tlcDropdown">
<div class="dropdown">
<button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">
ASF
</button>
<div class="dropdown-menu dropdown-menu-right">
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
</div>
</div>
</div>
</header>
<nav data-toggle="wy-nav-shift" class="wy-nav-side fixed">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html">
<img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/>
</a>
<input type="checkbox" class="version-toggle-box" hidden id="version-toggle">
<label for="version-toggle" class="version-toggle-label">
<div tabindex="0" class="version version-selector version-selector-show">
0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span>
</div>
</label>
<div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Versions</span></p>
<ol style="text-align: left">
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li>
</ol>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">User Guide</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li>
<li class="toctree-l2 current"><a class="reference internal" href="index.html">Optimize Tensor Operators</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="opt_gemm.html">How to optimize GEMM on CPU</a></li>
<li class="toctree-l3 current"><a class="current reference internal" href="#">How to optimize convolution on GPU</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#preparation-and-algorithm">Preparation and Algorithm</a></li>
<li class="toctree-l4"><a class="reference internal" href="#memory-hierarchy">Memory Hierarchy</a></li>
<li class="toctree-l4"><a class="reference internal" href="#blocking">Blocking</a></li>
<li class="toctree-l4"><a class="reference internal" href="#virtual-thread-split">Virtual Thread Split</a></li>
<li class="toctree-l4"><a class="reference internal" href="#cooperative-fetching">Cooperative Fetching</a></li>
<li class="toctree-l4"><a class="reference internal" href="#generate-cuda-kernel">Generate CUDA Kernel</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="opt_conv_tensorcore.html">How to optimize convolution using TensorCores</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top">
<div class="togglemenu">
</div>
<div class="nav-content">
<!-- tvm -->
Table of Contents
</div>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li>
<li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li>
<li><a href="index.html">Optimize Tensor Operators</a> <span class="br-arrow">></span></li>
<li>How to optimize convolution on GPU</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/apache/tvm/edit/main/docs/how_to/optimize_operators/opt_conv_cuda.rst" class="fa fa-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>This tutorial can be used interactively with Google Colab! You can also click
<a class="reference internal" href="#sphx-glr-download-how-to-optimize-operators-opt-conv-cuda-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p>
<a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/854257a66df713b1f3f82eb3577f95e3/opt_conv_cuda.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a>
</div>
<div class="sphx-glr-example-title section" id="how-to-optimize-convolution-on-gpu">
<span id="opt-conv-gpu"></span><span id="sphx-glr-how-to-optimize-operators-opt-conv-cuda-py"></span><h1>How to optimize convolution on GPU<a class="headerlink" href="#how-to-optimize-convolution-on-gpu" title="Permalink to this headline"></a></h1>
<p><strong>Author</strong>: <a class="reference external" href="https://homes.cs.washington.edu/~haichen/">Haichen Shen</a></p>
<p>In this tutorial, we will demonstrate how to write a high performance
convolution implementation in TVM. We use square size input tensors and filters
as an example, and assume the input to convolution has a large batch. In this
example, we use a different layout to store the data in order to achieve better
data locality. The buffer layout is HWCN, which stands for height, width,
channel, batch.</p>
<div class="section" id="preparation-and-algorithm">
<h2>Preparation and Algorithm<a class="headerlink" href="#preparation-and-algorithm" title="Permalink to this headline"></a></h2>
<p>We use the fixed size for input tensors with 256 channels and 14 x 14
dimensions. The batch size is 256. Convolution filters contain 512 filters
of size 3 x 3. We use stride size 1 and padding size 1 for the
convolution. The following code defines the convolution algorithm in TVM.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tvm</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span>
<span class="c1"># The sizes of inputs and filters</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a> <span class="o">=</span> <span class="mi">256</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a> <span class="o">=</span> <span class="mi">256</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a> <span class="o">=</span> <span class="mi">512</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">=</span> <span class="mi">14</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a> <span class="o">=</span> <span class="mi">3</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a> <span class="o">=</span> <span class="mi">1</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">=</span> <span class="mi">1</span>
<span class="c1"># Algorithm</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;A&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;W&quot;</span><span class="p">)</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a> <span class="o">=</span> <span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">)</span> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">+</span> <span class="mi">1</span>
<span class="c1"># Pad input</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span>
<span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span>
<span class="k">lambda</span> <span class="n">yy</span><span class="p">,</span> <span class="n">xx</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">nn</span><span class="p">:</span> <a href="../../reference/api/python/tir.html#tvm.tir.if_then_else" title="tvm.tir.if_then_else" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">if_then_else</span></a><span class="p">(</span>
<a href="../../reference/api/python/tir.html#tvm.tir.all" title="tvm.tir.all" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">all</span></a><span class="p">(</span><span class="n">yy</span> <span class="o">&gt;=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">yy</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a> <span class="o">&lt;</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">&gt;=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a> <span class="o">&lt;</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">),</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><span class="n">yy</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">nn</span><span class="p">],</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">const</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">),</span>
<span class="p">),</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;Apad&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Create reduction variables</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;rc&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;ry&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;rx&quot;</span><span class="p">)</span>
<span class="c1"># Compute the convolution</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span>
<span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span>
<span class="k">lambda</span> <span class="n">yy</span><span class="p">,</span> <span class="n">xx</span><span class="p">,</span> <span class="n">ff</span><span class="p">,</span> <span class="n">nn</span><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">[</span><span class="n">yy</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">+</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">+</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">,</span> <span class="n">nn</span><span class="p">]</span> <span class="o">*</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">,</span> <span class="n">ff</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">]</span>
<span class="p">),</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;B&quot;</span><span class="p">,</span>
<span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="memory-hierarchy">
<h2>Memory Hierarchy<a class="headerlink" href="#memory-hierarchy" title="Permalink to this headline"></a></h2>
<p>We first specify the memory hierarchy for buffers. The figure below shows the
GPU memory hierarchy. One important difference from CPU memory hierarchy is
that GPU provides a cache buffer called shared memory, which is managed by
programmers. Thus how to maximize the data reuse in the shared memory is
critical to achieve high performance in GPU kernels.</p>
<a class="reference internal image-reference" href="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/gpu_memory_hierarchy.png"><img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/gpu_memory_hierarchy.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/gpu_memory_hierarchy.png" style="width: 271px; height: 319px;" /></a>
<p>In this example, we load both Apad and W into buffer AA and WW, which are
stored in the shared memory. These buffers will be later shared by all
threads within the same thread block to compute the convolution. Each thread
then loads its own part from shared buffer into their local registers, AL and
WL. BL is a local cache of output B, which is also stored in the thread local
registers.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Designate the memory hierarchy</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_inline</span><span class="p">()</span> <span class="c1"># compute Apad inline</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">,</span> <span class="s2">&quot;shared&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <span class="s2">&quot;shared&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AL</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">,</span> <span class="s2">&quot;local&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WL</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">,</span> <span class="s2">&quot;local&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_write" title="tvm.te.Schedule.cache_write" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_write</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">,</span> <span class="s2">&quot;local&quot;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="blocking">
<h2>Blocking<a class="headerlink" href="#blocking" title="Permalink to this headline"></a></h2>
<p>The following code splits the workload into thread blocks and individual
threads. We follow the blocking scheme in the matrix multiply. As shown in the
figure below, given a pixel coordinate (y, x), a thread block is responsible
for computing a region of block_factor x block_factor (64 x 64) for output
channels and batch. Due to the limit of shared memory space, we only load step
x block_factor (8 x 64) data from Apad and B each time to buffers in the
shared memory.</p>
<a class="reference internal image-reference" href="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_blocking.png"><img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_blocking.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_blocking.png" style="width: 317px; height: 308px;" /></a>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># tile consts</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tile</span></a> <span class="o">=</span> <span class="mi">8</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a> <span class="o">=</span> <span class="mi">8</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_factor</span></a> <span class="o">=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tile</span></a> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">step</span></a> <span class="o">=</span> <span class="mi">8</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a> <span class="o">=</span> <span class="mi">2</span>
<span class="c1"># Get the GPU thread indices</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_x</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;blockIdx.x&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_y</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;blockIdx.y&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_z</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;blockIdx.z&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">),</span> <span class="s2">&quot;threadIdx.x&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">),</span> <span class="s2">&quot;threadIdx.y&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_xz</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">),</span> <span class="s2">&quot;vthread&quot;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;vx&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_yz</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">),</span> <span class="s2">&quot;vthread&quot;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;vy&quot;</span><span class="p">)</span>
<span class="c1"># Split the workloads</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">hi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">wi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bz</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">fuse</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">hi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">wi</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">by</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_factor</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_factor</span></a><span class="p">)</span>
<span class="c1"># Bind the iteration variables to GPU thread indices</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_z</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">by</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_x</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="virtual-thread-split">
<h2>Virtual Thread Split<a class="headerlink" href="#virtual-thread-split" title="Permalink to this headline"></a></h2>
<p>We further split the workload from a thread block to individual threads. To
avoid <em>memory bank conflict</em>, we use virtual thread to split the area into 4
parts, and then tile into 8x8 grids. Therefore, shown in the figure below,
each thread computes 4 strided grids, where size of each grid is 4 x 4.</p>
<a class="reference internal image-reference" href="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_vthread.png"><img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_vthread.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_vthread.png" style="width: 268px; height: 188px;" /></a>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tyz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">)</span> <span class="c1"># virtual thread split</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">txz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">)</span> <span class="c1"># virtual thread split</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">by</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tyz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">txz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tyz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_yz</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">txz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_xz</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="cooperative-fetching">
<h2>Cooperative Fetching<a class="headerlink" href="#cooperative-fetching" title="Permalink to this headline"></a></h2>
<p>As mentioned before, each time step we need to transfer step x block_factor
data from GPU global memory to shared memory. In order to reduce the memory
transfer per thread, the following code lets threads in the same thread block
coopertively fetch dependent data from global memory.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Schedule BL local write</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rco</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">step</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rco</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span>
<span class="c1"># Attach computation to iteration variables</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AL</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WL</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a><span class="p">)</span>
<span class="c1"># Schedule for A&#39;s shared memory load</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">_</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span> <span class="c1"># vectorize memory load</span>
<span class="c1"># Schedule for W&#39;s shared memory load</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">_</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">)</span> <span class="c1"># vectorize memory load</span>
</pre></div>
</div>
</div>
<div class="section" id="generate-cuda-kernel">
<h2>Generate CUDA Kernel<a class="headerlink" href="#generate-cuda-kernel" title="Permalink to this headline"></a></h2>
<p>Finally we use TVM to generate and compile the CUDA kernel, and evaluate the
latency of convolution.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">func</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
<span class="n">dev</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">cuda</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">a_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">)</span>
<span class="n">w_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">a_np</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">w_np</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="n">evaluator</span> <span class="o">=</span> <a href="../../reference/api/python/runtime.html#tvm.runtime.Module.time_evaluator" title="tvm.runtime.Module.time_evaluator" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-method"><span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span></a><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Convolution: </span><span class="si">%f</span><span class="s2"> ms&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span> <span class="o">*</span> <span class="mf">1e3</span><span class="p">))</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Convolution: 40.042495 ms
</pre></div>
</div>
<div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-optimize-operators-opt-conv-cuda-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/3c5c85c3954f3110f16ca084e286f03a/opt_conv_cuda.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">opt_conv_cuda.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/854257a66df713b1f3f82eb3577f95e3/opt_conv_cuda.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">opt_conv_cuda.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="opt_conv_tensorcore.html" class="btn btn-neutral float-right" title="How to optimize convolution using TensorCores" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="opt_gemm.html" class="btn btn-neutral float-left" title="How to optimize GEMM on CPU" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div>
<section class="footerSec">
<div class="footerHeader">
<div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row">
<div class="copywrite d-flex align-items-center">
<h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5>
</div>
</div>
</div>
<div>
<div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div>
</div>
</section>
</footer>
</div>
</div>
</section>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
</body>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<!-- Theme Analytics -->
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-75982049-2', 'auto');
ga('send', 'pageview');
</script>
</body>
</html>