| |
| |
| |
| |
| |
| |
| <!DOCTYPE html> |
| <html class="writer-html5" lang="en" > |
| <head> |
| <meta charset="utf-8"> |
| |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| |
| <title>How to optimize convolution on GPU — tvm 0.17.dev0 documentation</title> |
| |
| |
| |
| <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" /> |
| |
| |
| |
| <link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/> |
| |
| |
| |
| |
| |
| |
| |
| <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script> |
| <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> |
| <script src="../../_static/jquery.js"></script> |
| <script src="../../_static/underscore.js"></script> |
| <script src="../../_static/doctools.js"></script> |
| |
| <script type="text/javascript" src="../../_static/js/theme.js"></script> |
| |
| |
| <script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script> |
| <link rel="index" title="Index" href="../../genindex.html" /> |
| <link rel="search" title="Search" href="../../search.html" /> |
| <link rel="next" title="How to optimize convolution using TensorCores" href="opt_conv_tensorcore.html" /> |
| <link rel="prev" title="How to optimize GEMM on CPU" href="opt_gemm.html" /> |
| </head> |
| |
| <body class="wy-body-for-nav"> |
| |
| |
| <div class="wy-grid-for-nav"> |
| |
| |
| <header class="header"> |
| <div class="innercontainer"> |
| <div class="headerInner d-flex justify-content-between align-items-center"> |
| <div class="headerLogo"> |
| <a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a> |
| </div> |
| |
| <div id="headMenu" class="headerNav"> |
| <button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button> |
| <ul class="nav"> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/community>Community</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/download>Download</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/vta>VTA</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/blog>Blog</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/docs>Docs</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvmconf.org>Conference</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://github.com/apache/tvm/>Github</a> |
| </li> |
| </ul> |
| <div class="responsivetlcdropdown"> |
| <button type="button" class="btn-link"> |
| ASF |
| </button> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| <div class="responsiveMenuIcon"> |
| <button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button> |
| </div> |
| |
| <div class="tlcDropdown"> |
| <div class="dropdown"> |
| <button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false"> |
| ASF |
| </button> |
| <div class="dropdown-menu dropdown-menu-right"> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| </header> |
| |
| <nav data-toggle="wy-nav-shift" class="wy-nav-side fixed"> |
| <div class="wy-side-scroll"> |
| <div class="wy-side-nav-search" > |
| |
| |
| |
| <a href="../../index.html"> |
| |
| |
| |
| |
| <img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/> |
| |
| </a> |
| |
| |
| |
| |
| <input type="checkbox" class="version-toggle-box" hidden id="version-toggle"> |
| <label for="version-toggle" class="version-toggle-label"> |
| <div tabindex="0" class="version version-selector version-selector-show"> |
| 0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span> |
| </div> |
| </label> |
| <div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| <p class="caption" role="heading"><span class="caption-text">Versions</span></p> |
| <ol style="text-align: left"> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li> |
| |
| </ol> |
| </div> |
| |
| |
| |
| |
| <div role="search"> |
| <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get"> |
| <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" /> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| </form> |
| </div> |
| |
| |
| </div> |
| |
| |
| <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| |
| |
| |
| |
| |
| |
| <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">User Guide</span></p> |
| <ul class="current"> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li> |
| <li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current"> |
| <li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li> |
| <li class="toctree-l2 current"><a class="reference internal" href="index.html">Optimize Tensor Operators</a><ul class="current"> |
| <li class="toctree-l3"><a class="reference internal" href="opt_gemm.html">How to optimize GEMM on CPU</a></li> |
| <li class="toctree-l3 current"><a class="current reference internal" href="#">How to optimize convolution on GPU</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#preparation-and-algorithm">Preparation and Algorithm</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#memory-hierarchy">Memory Hierarchy</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#blocking">Blocking</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#virtual-thread-split">Virtual Thread Split</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#cooperative-fetching">Cooperative Fetching</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#generate-cuda-kernel">Generate CUDA Kernel</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="opt_conv_tensorcore.html">How to optimize convolution using TensorCores</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li> |
| </ul> |
| </li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li> |
| </ul> |
| |
| |
| |
| </div> |
| |
| </div> |
| </nav> |
| |
| <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> |
| |
| <nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top"> |
| |
| <div class="togglemenu"> |
| |
| </div> |
| <div class="nav-content"> |
| <!-- tvm --> |
| Table of Contents |
| </div> |
| |
| </nav> |
| |
| |
| <div class="wy-nav-content"> |
| |
| <div class="rst-content"> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <div role="navigation" aria-label="breadcrumbs navigation"> |
| |
| <ul class="wy-breadcrumbs"> |
| |
| <li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="index.html">Optimize Tensor Operators</a> <span class="br-arrow">></span></li> |
| |
| <li>How to optimize convolution on GPU</li> |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="wy-breadcrumbs-aside"> |
| |
| |
| |
| <a href="https://github.com/apache/tvm/edit/main/docs/how_to/optimize_operators/opt_conv_cuda.rst" class="fa fa-github"> Edit on GitHub</a> |
| |
| |
| |
| </li> |
| |
| </ul> |
| |
| |
| <hr/> |
| </div> |
| <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> |
| <div itemprop="articleBody"> |
| |
| <div class="sphx-glr-download-link-note admonition note"> |
| <p class="admonition-title">Note</p> |
| <p>This tutorial can be used interactively with Google Colab! You can also click |
| <a class="reference internal" href="#sphx-glr-download-how-to-optimize-operators-opt-conv-cuda-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p> |
| <a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/854257a66df713b1f3f82eb3577f95e3/opt_conv_cuda.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a> |
| </div> |
| <div class="sphx-glr-example-title section" id="how-to-optimize-convolution-on-gpu"> |
| <span id="opt-conv-gpu"></span><span id="sphx-glr-how-to-optimize-operators-opt-conv-cuda-py"></span><h1>How to optimize convolution on GPU<a class="headerlink" href="#how-to-optimize-convolution-on-gpu" title="Permalink to this headline">¶</a></h1> |
| <p><strong>Author</strong>: <a class="reference external" href="https://homes.cs.washington.edu/~haichen/">Haichen Shen</a></p> |
| <p>In this tutorial, we will demonstrate how to write a high performance |
| convolution implementation in TVM. We use square size input tensors and filters |
| as an example, and assume the input to convolution has a large batch. In this |
| example, we use a different layout to store the data in order to achieve better |
| data locality. The buffer layout is HWCN, which stands for height, width, |
| channel, batch.</p> |
| <div class="section" id="preparation-and-algorithm"> |
| <h2>Preparation and Algorithm<a class="headerlink" href="#preparation-and-algorithm" title="Permalink to this headline">¶</a></h2> |
| <p>We use the fixed size for input tensors with 256 channels and 14 x 14 |
| dimensions. The batch size is 256. Convolution filters contain 512 filters |
| of size 3 x 3. We use stride size 1 and padding size 1 for the |
| convolution. The following code defines the convolution algorithm in TVM.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> |
| <span class="kn">import</span> <span class="nn">tvm</span> |
| <span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span> |
| |
| <span class="c1"># The sizes of inputs and filters</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a> <span class="o">=</span> <span class="mi">256</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a> <span class="o">=</span> <span class="mi">256</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a> <span class="o">=</span> <span class="mi">512</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">=</span> <span class="mi">14</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a> <span class="o">=</span> <span class="mi">3</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a> <span class="o">=</span> <span class="mi">1</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">=</span> <span class="mi">1</span> |
| |
| <span class="c1"># Algorithm</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"W"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a> <span class="o">=</span> <span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">)</span> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">+</span> <span class="mi">1</span> |
| <span class="c1"># Pad input</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span> |
| <span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span> |
| <span class="k">lambda</span> <span class="n">yy</span><span class="p">,</span> <span class="n">xx</span><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">nn</span><span class="p">:</span> <a href="../../reference/api/python/tir.html#tvm.tir.if_then_else" title="tvm.tir.if_then_else" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">if_then_else</span></a><span class="p">(</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.all" title="tvm.tir.all" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">all</span></a><span class="p">(</span><span class="n">yy</span> <span class="o">>=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">yy</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a> <span class="o"><</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">>=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a> <span class="o"><</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">),</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><span class="n">yy</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad</span></a><span class="p">,</span> <span class="n">cc</span><span class="p">,</span> <span class="n">nn</span><span class="p">],</span> |
| <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">const</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">),</span> |
| <span class="p">),</span> |
| <span class="n">name</span><span class="o">=</span><span class="s2">"Apad"</span><span class="p">,</span> |
| <span class="p">)</span> |
| <span class="c1"># Create reduction variables</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"rc"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"ry"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"rx"</span><span class="p">)</span> |
| <span class="c1"># Compute the convolution</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span> |
| <span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span> |
| <span class="k">lambda</span> <span class="n">yy</span><span class="p">,</span> <span class="n">xx</span><span class="p">,</span> <span class="n">ff</span><span class="p">,</span> <span class="n">nn</span><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">[</span><span class="n">yy</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">+</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <span class="n">xx</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride</span></a> <span class="o">+</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">,</span> <span class="n">nn</span><span class="p">]</span> <span class="o">*</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">,</span> <span class="n">ff</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">]</span> |
| <span class="p">),</span> |
| <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">,</span> |
| <span class="p">)</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="memory-hierarchy"> |
| <h2>Memory Hierarchy<a class="headerlink" href="#memory-hierarchy" title="Permalink to this headline">¶</a></h2> |
| <p>We first specify the memory hierarchy for buffers. The figure below shows the |
| GPU memory hierarchy. One important difference from CPU memory hierarchy is |
| that GPU provides a cache buffer called shared memory, which is managed by |
| programmers. Thus how to maximize the data reuse in the shared memory is |
| critical to achieve high performance in GPU kernels.</p> |
| <a class="reference internal image-reference" href="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/gpu_memory_hierarchy.png"><img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/gpu_memory_hierarchy.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/gpu_memory_hierarchy.png" style="width: 271px; height: 319px;" /></a> |
| <p>In this example, we load both Apad and W into buffer AA and WW, which are |
| stored in the shared memory. These buffers will be later shared by all |
| threads within the same thread block to compute the convolution. Each thread |
| then loads its own part from shared buffer into their local registers, AL and |
| WL. BL is a local cache of output B, which is also stored in the thread local |
| registers.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Designate the memory hierarchy</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_inline</span><span class="p">()</span> <span class="c1"># compute Apad inline</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">,</span> <span class="s2">"shared"</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <span class="s2">"shared"</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AL</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">,</span> <span class="s2">"local"</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WL</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">,</span> <span class="s2">"local"</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">])</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_write" title="tvm.te.Schedule.cache_write" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_write</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">,</span> <span class="s2">"local"</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="blocking"> |
| <h2>Blocking<a class="headerlink" href="#blocking" title="Permalink to this headline">¶</a></h2> |
| <p>The following code splits the workload into thread blocks and individual |
| threads. We follow the blocking scheme in the matrix multiply. As shown in the |
| figure below, given a pixel coordinate (y, x), a thread block is responsible |
| for computing a region of block_factor x block_factor (64 x 64) for output |
| channels and batch. Due to the limit of shared memory space, we only load step |
| x block_factor (8 x 64) data from Apad and B each time to buffers in the |
| shared memory.</p> |
| <a class="reference internal image-reference" href="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_blocking.png"><img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_blocking.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_blocking.png" style="width: 317px; height: 308px;" /></a> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># tile consts</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tile</span></a> <span class="o">=</span> <span class="mi">8</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a> <span class="o">=</span> <span class="mi">8</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_factor</span></a> <span class="o">=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tile</span></a> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">step</span></a> <span class="o">=</span> <span class="mi">8</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a> <span class="o">=</span> <span class="mi">2</span> |
| |
| <span class="c1"># Get the GPU thread indices</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_x</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.x"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_y</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.y"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_z</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.z"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">),</span> <span class="s2">"threadIdx.x"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">),</span> <span class="s2">"threadIdx.y"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_xz</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">),</span> <span class="s2">"vthread"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"vx"</span><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_yz</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">),</span> <span class="s2">"vthread"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"vy"</span><span class="p">)</span> |
| |
| <span class="c1"># Split the workloads</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">hi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">wi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bz</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">fuse</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">hi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">wi</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">by</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_factor</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_factor</span></a><span class="p">)</span> |
| |
| <span class="c1"># Bind the iteration variables to GPU thread indices</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_z</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">by</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_y</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_x</span></a><span class="p">)</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="virtual-thread-split"> |
| <h2>Virtual Thread Split<a class="headerlink" href="#virtual-thread-split" title="Permalink to this headline">¶</a></h2> |
| <p>We further split the workload from a thread block to individual threads. To |
| avoid <em>memory bank conflict</em>, we use virtual thread to split the area into 4 |
| parts, and then tile into 8x8 grids. Therefore, shown in the figure below, |
| each thread computes 4 strided grids, where size of each grid is 4 x 4.</p> |
| <a class="reference internal image-reference" href="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_vthread.png"><img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_vthread.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/conv_gpu_vthread.png" style="width: 268px; height: 188px;" /></a> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tyz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">)</span> <span class="c1"># virtual thread split</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">txz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">vthread</span></a><span class="p">)</span> <span class="c1"># virtual thread split</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">by</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tyz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">txz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span> |
| |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tyz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_yz</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">txz</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_xz</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="cooperative-fetching"> |
| <h2>Cooperative Fetching<a class="headerlink" href="#cooperative-fetching" title="Permalink to this headline">¶</a></h2> |
| <p>As mentioned before, each time step we need to transfer step x block_factor |
| data from GPU global memory to shared memory. In order to reduce the memory |
| transfer per thread, the following code lets threads in the same thread block |
| coopertively fetch dependent data from global memory.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Schedule BL local write</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rco</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rc</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">step</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rco</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ry</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span> |
| |
| <span class="c1"># Attach computation to iteration variables</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rx</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AL</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WL</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">BL</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rci</span></a><span class="p">)</span> |
| |
| <span class="c1"># Schedule for A's shared memory load</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">_</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AA</span></a><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ni</span></a><span class="p">)</span> <span class="c1"># vectorize memory load</span> |
| |
| <span class="c1"># Schedule for W's shared memory load</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">_</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xi</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WW</span></a><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">fi</span></a><span class="p">)</span> <span class="c1"># vectorize memory load</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="generate-cuda-kernel"> |
| <h2>Generate CUDA Kernel<a class="headerlink" href="#generate-cuda-kernel" title="Permalink to this headline">¶</a></h2> |
| <p>Finally we use TVM to generate and compile the CUDA kernel, and evaluate the |
| latency of convolution.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">func</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">"cuda"</span><span class="p">)</span> |
| <span class="n">dev</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">cuda</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> |
| <span class="n">a_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">)</span> |
| <span class="n">w_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">)</span> |
| <span class="n">a</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">a_np</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span> |
| <span class="n">w</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">w_np</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span> |
| <span class="n">b</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channel</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch</span></a><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">),</span> <span class="n">dev</span><span class="p">)</span> |
| <span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <a href="../../reference/api/python/runtime.html#tvm.runtime.Module.time_evaluator" title="tvm.runtime.Module.time_evaluator" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-method"><span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span></a><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Convolution: </span><span class="si">%f</span><span class="s2"> ms"</span> <span class="o">%</span> <span class="p">(</span><span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span> <span class="o">*</span> <span class="mf">1e3</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Convolution: 40.042495 ms |
| </pre></div> |
| </div> |
| <div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-optimize-operators-opt-conv-cuda-py"> |
| <div class="sphx-glr-download sphx-glr-download-python docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/3c5c85c3954f3110f16ca084e286f03a/opt_conv_cuda.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">opt_conv_cuda.py</span></code></a></p> |
| </div> |
| <div class="sphx-glr-download sphx-glr-download-jupyter docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/854257a66df713b1f3f82eb3577f95e3/opt_conv_cuda.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">opt_conv_cuda.ipynb</span></code></a></p> |
| </div> |
| </div> |
| <p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p> |
| </div> |
| </div> |
| |
| |
| </div> |
| |
| </div> |
| |
| |
| <footer> |
| |
| <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> |
| |
| <a href="opt_conv_tensorcore.html" class="btn btn-neutral float-right" title="How to optimize convolution using TensorCores" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a> |
| |
| |
| <a href="opt_gemm.html" class="btn btn-neutral float-left" title="How to optimize GEMM on CPU" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a> |
| |
| </div> |
| |
| <div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div> |
| <section class="footerSec"> |
| <div class="footerHeader"> |
| <div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row"> |
| <div class="copywrite d-flex align-items-center"> |
| <h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5> |
| </div> |
| </div> |
| |
| </div> |
| |
| <div> |
| <div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div> |
| </div> |
| |
| </section> |
| </footer> |
| </div> |
| </div> |
| |
| </section> |
| |
| </div> |
| |
| |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script> |
| <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script> |
| |
| </body> |
| <script type="text/javascript"> |
| jQuery(function () { |
| SphinxRtdTheme.Navigation.enable(true); |
| }); |
| </script> |
| |
| |
| |
| |
| <!-- Theme Analytics --> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-75982049-2', 'auto'); |
| ga('send', 'pageview'); |
| </script> |
| |
| |
| |
| |
| </body> |
| </html> |