blob: 4547f1c1bdba6472343c80259850b98beaf0f962 [file] [log] [blame]
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>How to optimize convolution using TensorCores &mdash; tvm 0.17.dev0 documentation</title>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" />
<link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/>
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Auto-Tune with Templates and AutoTVM" href="../tune_with_autotvm/index.html" />
<link rel="prev" title="How to optimize convolution on GPU" href="opt_conv_cuda.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<header class="header">
<div class="innercontainer">
<div class="headerInner d-flex justify-content-between align-items-center">
<div class="headerLogo">
<a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a>
</div>
<div id="headMenu" class="headerNav">
<button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button>
<ul class="nav">
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/community>Community</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/download>Download</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/vta>VTA</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/blog>Blog</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/docs>Docs</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvmconf.org>Conference</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://github.com/apache/tvm/>Github</a>
</li>
</ul>
<div class="responsivetlcdropdown">
<button type="button" class="btn-link">
ASF
</button>
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
<div class="responsiveMenuIcon">
<button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button>
</div>
<div class="tlcDropdown">
<div class="dropdown">
<button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">
ASF
</button>
<div class="dropdown-menu dropdown-menu-right">
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
</div>
</div>
</div>
</header>
<nav data-toggle="wy-nav-shift" class="wy-nav-side fixed">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html">
<img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/>
</a>
<input type="checkbox" class="version-toggle-box" hidden id="version-toggle">
<label for="version-toggle" class="version-toggle-label">
<div tabindex="0" class="version version-selector version-selector-show">
0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span>
</div>
</label>
<div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Versions</span></p>
<ol style="text-align: left">
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li>
</ol>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">User Guide</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li>
<li class="toctree-l2 current"><a class="reference internal" href="index.html">Optimize Tensor Operators</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="opt_gemm.html">How to optimize GEMM on CPU</a></li>
<li class="toctree-l3"><a class="reference internal" href="opt_conv_cuda.html">How to optimize convolution on GPU</a></li>
<li class="toctree-l3 current"><a class="current reference internal" href="#">How to optimize convolution using TensorCores</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#tensorcore-introduction">TensorCore Introduction</a></li>
<li class="toctree-l4"><a class="reference internal" href="#preparation-and-algorithm">Preparation and Algorithm</a></li>
<li class="toctree-l4"><a class="reference internal" href="#memory-scope">Memory Scope</a></li>
<li class="toctree-l4"><a class="reference internal" href="#define-tensor-intrinsic">Define Tensor Intrinsic</a></li>
<li class="toctree-l4"><a class="reference internal" href="#scheduling-the-computation">Scheduling the Computation</a></li>
<li class="toctree-l4"><a class="reference internal" href="#lowering-computation-to-intrinsics">Lowering Computation to Intrinsics</a></li>
<li class="toctree-l4"><a class="reference internal" href="#generate-cuda-kernel">Generate CUDA Kernel</a></li>
<li class="toctree-l4"><a class="reference internal" href="#summary">Summary</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top">
<div class="togglemenu">
</div>
<div class="nav-content">
<!-- tvm -->
Table of Contents
</div>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li>
<li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li>
<li><a href="index.html">Optimize Tensor Operators</a> <span class="br-arrow">></span></li>
<li>How to optimize convolution using TensorCores</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/apache/tvm/edit/main/docs/how_to/optimize_operators/opt_conv_tensorcore.rst" class="fa fa-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>This tutorial can be used interactively with Google Colab! You can also click
<a class="reference internal" href="#sphx-glr-download-how-to-optimize-operators-opt-conv-tensorcore-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p>
<a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/7455981870c23c8c76482dedf33d8a42/opt_conv_tensorcore.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a>
</div>
<div class="sphx-glr-example-title section" id="how-to-optimize-convolution-using-tensorcores">
<span id="opt-conv-tensorcore"></span><span id="sphx-glr-how-to-optimize-operators-opt-conv-tensorcore-py"></span><h1>How to optimize convolution using TensorCores<a class="headerlink" href="#how-to-optimize-convolution-using-tensorcores" title="Permalink to this headline"></a></h1>
<p><strong>Author</strong>: <a class="reference external" href="https://github.com/Hzfengsy">Siyuan Feng</a></p>
<p>In this tutorial, we will demonstrate how to write a high performance convolution
schedule using TensorCores in TVM. In this example, we assume the input to
convolution has a large batch. We strongly recommend covering the <a class="reference internal" href="opt_conv_cuda.html#opt-conv-gpu"><span class="std std-ref">How to optimize convolution on GPU</span></a> tutorial first.</p>
<div class="section" id="tensorcore-introduction">
<h2>TensorCore Introduction<a class="headerlink" href="#tensorcore-introduction" title="Permalink to this headline"></a></h2>
<p>Each Tensor Core provides a 4x4x4 matrix processing array that operates
<code class="code docutils literal notranslate"><span class="pre">D</span> <span class="pre">=</span> <span class="pre">A</span> <span class="pre">*</span> <span class="pre">B</span> <span class="pre">+</span> <span class="pre">C</span></code>, where A, B, C and D are 4x4 matrices as Figure shows.
The matrix multiplication inputs A and B are FP16 matrices, while the accumulation
matrices C and D may be FP16 or FP32 matrices.</p>
<p>However, CUDA programmers can only use warp-level primitive
<code class="code docutils literal notranslate"><span class="pre">wmma::mma_sync(acc_frag,</span> <span class="pre">a_frag,</span> <span class="pre">b_frag,</span> <span class="pre">acc_frag)</span></code> to perform
16x16x16 half-precision matrix multiplication on tensor cores. Before invoking
the matrix multiplication, programmers must load data from memory into registers
with primitive <code class="code docutils literal notranslate"><span class="pre">wmma::load_matrix_sync</span></code>, explicitly. The NVCC compiler translates
that primitive into multiple memory load instructions. At run time, every thread loads
16 elements from matrix A and 16 elements from B.</p>
</div>
<div class="section" id="preparation-and-algorithm">
<h2>Preparation and Algorithm<a class="headerlink" href="#preparation-and-algorithm" title="Permalink to this headline"></a></h2>
<p>We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions.
The batch size is 256. Convolution filters contain 512 filters of size 3 x 3.
We use stride size 1 and padding size 1 for the convolution. In the example, we use
NHWCnc memory layout.The following code defines the convolution algorithm in TVM.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tvm</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">tvm.contrib</span> <span class="kn">import</span> <span class="n">nvcc</span>
<span class="c1"># The sizes of inputs and filters</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch_size</span></a> <span class="o">=</span> <span class="mi">256</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">height</span></a> <span class="o">=</span> <span class="mi">14</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">width</span></a> <span class="o">=</span> <span class="mi">14</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channels</span></a> <span class="o">=</span> <span class="mi">256</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channels</span></a> <span class="o">=</span> <span class="mi">512</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_h</span></a> <span class="o">=</span> <span class="mi">3</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_w</span></a> <span class="o">=</span> <span class="mi">3</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_h</span></a> <span class="o">=</span> <span class="mi">1</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_w</span></a> <span class="o">=</span> <span class="mi">1</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride_h</span></a> <span class="o">=</span> <span class="mi">1</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride_w</span></a> <span class="o">=</span> <span class="mi">1</span>
<span class="c1"># TensorCore shape</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a> <span class="o">=</span> <span class="mi">16</span>
<span class="k">assert</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch_size</span></a> <span class="o">%</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a> <span class="o">==</span> <span class="mi">0</span>
<span class="k">assert</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channels</span></a> <span class="o">%</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a> <span class="o">==</span> <span class="mi">0</span>
<span class="k">assert</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channels</span></a> <span class="o">%</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a> <span class="o">==</span> <span class="mi">0</span>
<span class="c1"># Input feature map: (N, H, W, IC, n, ic)</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">data_shape</span></a> <span class="o">=</span> <span class="p">(</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch_size</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">height</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">width</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channels</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Kernel: (H, W, IC, OC, ic, oc)</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_shape</span></a> <span class="o">=</span> <span class="p">(</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_h</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_w</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channels</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channels</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Output feature map: (N, H, W, OC, n, oc)</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">output_shape</span></a> <span class="o">=</span> <span class="p">(</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch_size</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">height</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">width</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out_channels</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Reduction axes</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_h</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;kh&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_w</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;kw&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ic</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channels</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;ic&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;ii&quot;</span><span class="p">)</span>
<span class="c1"># Algorithm</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">data_shape</span></a><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;A&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_shape</span></a><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;W&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span>
<span class="p">(</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">batch_size</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">height</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_h</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">width</span></a> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_w</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_channels</span></a> <span class="o">//</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_size</span></a><span class="p">,</span>
<span class="p">),</span>
<span class="k">lambda</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nn</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">:</span> <a href="../../reference/api/python/tir.html#tvm.tir.if_then_else" title="tvm.tir.if_then_else" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">if_then_else</span></a><span class="p">(</span>
<a href="../../reference/api/python/tir.html#tvm.tir.all" title="tvm.tir.all" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">all</span></a><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a> <span class="o">&gt;=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_h</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_h</span></a> <span class="o">&lt;</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">height</span></a><span class="p">,</span> <span class="n">w</span> <span class="o">&gt;=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_w</span></a><span class="p">,</span> <span class="n">w</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_w</span></a> <span class="o">&lt;</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">width</span></a><span class="p">),</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_h</span></a><span class="p">,</span> <span class="n">w</span> <span class="o">-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pad_w</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nn</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">],</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">const</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="s2">&quot;float16&quot;</span><span class="p">),</span>
<span class="p">),</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;Apad&quot;</span><span class="p">,</span>
<span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">output_shape</span></a><span class="p">,</span>
<span class="k">lambda</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">o</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nn</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oo</span></a><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride_h</span></a> <span class="o">+</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">,</span> <span class="n">w</span> <span class="o">*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">stride_w</span></a> <span class="o">+</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ic</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nn</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="o">*</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ic</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">o</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oo</span></a><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">),</span>
<span class="n">axis</span><span class="o">=</span><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ic</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">],</span>
<span class="p">),</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;Conv&quot;</span><span class="p">,</span>
<span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">Conv</span><span class="o">.</span><span class="n">op</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_inline</span><span class="p">()</span>
</pre></div>
</div>
</div>
<div class="section" id="memory-scope">
<h2>Memory Scope<a class="headerlink" href="#memory-scope" title="Permalink to this headline"></a></h2>
<p>In traditional GPU schedule, we have global, shared and local memory scope.
To support TensorCores, we add another three special memory scope: <code class="code docutils literal notranslate"><span class="pre">wmma.matrix_a</span></code>,
<code class="code docutils literal notranslate"><span class="pre">wmma.matrix_b</span></code> and <code class="code docutils literal notranslate"><span class="pre">wmma.accumulator</span></code>. On hardware, all fragments scope
stores at the on-chip registers level, the same place with local memory.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Designate the memory hierarchy</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Apad</span></a><span class="p">,</span> <span class="s2">&quot;shared&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <span class="s2">&quot;shared&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AF</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">,</span> <span class="s2">&quot;wmma.matrix_a&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WF</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">,</span> <span class="s2">&quot;wmma.matrix_b&quot;</span><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">])</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_write" title="tvm.te.Schedule.cache_write" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_write</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">,</span> <span class="s2">&quot;wmma.accumulator&quot;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="define-tensor-intrinsic">
<h2>Define Tensor Intrinsic<a class="headerlink" href="#define-tensor-intrinsic" title="Permalink to this headline"></a></h2>
<p>In fact, TensorCore is a special hardware operation. So, we can just use tensorize
to replace a unit of computation with the TensorCore instruction. The first thing is
that we need to define tensor intrinsic.</p>
<p>There are four basic operation in TensorCore: <code class="code docutils literal notranslate"><span class="pre">fill_fragment</span></code>, <code class="code docutils literal notranslate"><span class="pre">load_matrix</span></code>,
<code class="code docutils literal notranslate"><span class="pre">mma_sync</span></code> and <code class="code docutils literal notranslate"><span class="pre">store_matrix</span></code>. Since <code class="code docutils literal notranslate"><span class="pre">fill_fragment</span></code> and <code class="code docutils literal notranslate"><span class="pre">mma_sync</span></code>
are both used in matrix multiplication, so we can just write following three intrinsics.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">intrin_wmma_load_matrix</span><span class="p">(</span><span class="n">scope</span><span class="p">):</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a> <span class="o">=</span> <span class="mi">16</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;A&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span><span class="p">)</span>
<span class="n">BA</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.shape" title="tvm.te.Tensor.shape" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">A</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s2">&quot;shared&quot;</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="k">lambda</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <span class="n">j</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;C&quot;</span><span class="p">)</span>
<span class="n">BC</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">C</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="n">scope</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">intrin_func</span><span class="p">(</span><span class="n">ins</span><span class="p">,</span> <span class="n">outs</span><span class="p">):</span>
<span class="n">ib</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">ir_builder</span><span class="o">.</span><span class="n">create</span><span class="p">()</span>
<span class="n">BA</span> <span class="o">=</span> <span class="n">ins</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">BC</span> <span class="o">=</span> <span class="n">outs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">ib</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span>
<a href="../../reference/api/python/tir.html#tvm.tir.call_intrin" title="tvm.tir.call_intrin" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_intrin</span></a><span class="p">(</span>
<span class="s2">&quot;handle&quot;</span><span class="p">,</span>
<span class="s2">&quot;tir.tvm_load_matrix_sync&quot;</span><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">BA</span><span class="o">.</span><span class="n">access_ptr</span><span class="p">(</span><span class="s2">&quot;r&quot;</span><span class="p">),</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<span class="s2">&quot;row_major&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">ib</span><span class="o">.</span><span class="n">get</span><span class="p">()</span>
<span class="k">return</span> <a href="../../reference/api/python/te.html#tvm.te.decl_tensor_intrin" title="tvm.te.decl_tensor_intrin" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">decl_tensor_intrin</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">,</span> <span class="n">intrin_func</span><span class="p">,</span> <span class="n">binds</span><span class="o">=</span><span class="p">{</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">:</span> <span class="n">BA</span><span class="p">,</span> <span class="n">C</span><span class="p">:</span> <span class="n">BC</span><span class="p">})</span>
<span class="k">def</span> <span class="nf">intrin_wmma_gemm</span><span class="p">():</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a> <span class="o">=</span> <span class="mi">16</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;A&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;B&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span><span class="p">)</span>
<span class="n">k</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;k&quot;</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span>
<span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span>
<span class="k">lambda</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">,</span> <span class="n">jj</span><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">,</span> <span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float&quot;</span><span class="p">)</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">jj</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float&quot;</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="n">k</span><span class="p">),</span>
<span class="n">name</span><span class="o">=</span><span class="s2">&quot;C&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">BA</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor.shape" title="tvm.te.Tensor.shape" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">A</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;BA&quot;</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s2">&quot;wmma.matrix_a&quot;</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span>
<span class="p">)</span>
<span class="n">BB</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span>
<span class="n">B</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">B</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;BB&quot;</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s2">&quot;wmma.matrix_b&quot;</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span>
<span class="p">)</span>
<span class="n">BC</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span>
<span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">C</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;BC&quot;</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s2">&quot;wmma.accumulator&quot;</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">intrin_func</span><span class="p">(</span><span class="n">ins</span><span class="p">,</span> <span class="n">outs</span><span class="p">):</span>
<span class="n">BA</span><span class="p">,</span> <span class="n">BB</span> <span class="o">=</span> <span class="n">ins</span>
<span class="p">(</span><span class="n">BC</span><span class="p">,)</span> <span class="o">=</span> <span class="n">outs</span>
<span class="k">def</span> <span class="nf">init</span><span class="p">():</span>
<span class="n">ib</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">ir_builder</span><span class="o">.</span><span class="n">create</span><span class="p">()</span>
<span class="n">ib</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span>
<a href="../../reference/api/python/tir.html#tvm.tir.call_intrin" title="tvm.tir.call_intrin" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_intrin</span></a><span class="p">(</span>
<span class="s2">&quot;handle&quot;</span><span class="p">,</span> <span class="s2">&quot;tir.tvm_fill_fragment&quot;</span><span class="p">,</span> <span class="n">BC</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <span class="n">BC</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span> <span class="mf">0.0</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">ib</span><span class="o">.</span><span class="n">get</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">update</span><span class="p">():</span>
<span class="n">ib</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">ir_builder</span><span class="o">.</span><span class="n">create</span><span class="p">()</span>
<span class="n">ib</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span>
<a href="../../reference/api/python/tir.html#tvm.tir.call_intrin" title="tvm.tir.call_intrin" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_intrin</span></a><span class="p">(</span>
<span class="s2">&quot;handle&quot;</span><span class="p">,</span>
<span class="s2">&quot;tir.tvm_mma_sync&quot;</span><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">BA</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<span class="n">BA</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">BB</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<span class="n">BB</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">ib</span><span class="o">.</span><span class="n">get</span><span class="p">()</span>
<span class="k">return</span> <span class="n">update</span><span class="p">(),</span> <span class="n">init</span><span class="p">(),</span> <span class="n">update</span><span class="p">()</span>
<span class="k">return</span> <a href="../../reference/api/python/te.html#tvm.te.decl_tensor_intrin" title="tvm.te.decl_tensor_intrin" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">decl_tensor_intrin</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">,</span> <span class="n">intrin_func</span><span class="p">,</span> <span class="n">binds</span><span class="o">=</span><span class="p">{</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">:</span> <span class="n">BA</span><span class="p">,</span> <span class="n">B</span><span class="p">:</span> <span class="n">BB</span><span class="p">,</span> <span class="n">C</span><span class="p">:</span> <span class="n">BC</span><span class="p">})</span>
<span class="k">def</span> <span class="nf">intrin_wmma_store_matrix</span><span class="p">():</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a> <span class="o">=</span> <span class="mi">16</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;A&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="n">BA</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span>
<a href="../../reference/api/python/te.html#tvm.te.Tensor.shape" title="tvm.te.Tensor.shape" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">A</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s2">&quot;wmma.accumulator&quot;</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span>
<span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">),</span> <span class="k">lambda</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <span class="n">j</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;C&quot;</span><span class="p">)</span>
<span class="n">BC</span> <span class="o">=</span> <a href="../../reference/api/python/tir.html#tvm.tir.decl_buffer" title="tvm.tir.decl_buffer" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">decl_buffer</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">C</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s2">&quot;global&quot;</span><span class="p">,</span> <span class="n">data_alignment</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">offset_factor</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">intrin_func</span><span class="p">(</span><span class="n">ins</span><span class="p">,</span> <span class="n">outs</span><span class="p">):</span>
<span class="n">ib</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">ir_builder</span><span class="o">.</span><span class="n">create</span><span class="p">()</span>
<span class="n">BA</span> <span class="o">=</span> <span class="n">ins</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">BC</span> <span class="o">=</span> <span class="n">outs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">ib</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span>
<a href="../../reference/api/python/tir.html#tvm.tir.call_intrin" title="tvm.tir.call_intrin" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_intrin</span></a><span class="p">(</span>
<span class="s2">&quot;handle&quot;</span><span class="p">,</span>
<span class="s2">&quot;tir.tvm_store_matrix_sync&quot;</span><span class="p">,</span>
<span class="n">BA</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<span class="n">BA</span><span class="o">.</span><span class="n">elem_offset</span> <span class="o">//</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">BC</span><span class="o">.</span><span class="n">access_ptr</span><span class="p">(</span><span class="s2">&quot;w&quot;</span><span class="p">),</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span>
<span class="s2">&quot;row_major&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">ib</span><span class="o">.</span><span class="n">get</span><span class="p">()</span>
<span class="k">return</span> <a href="../../reference/api/python/te.html#tvm.te.decl_tensor_intrin" title="tvm.te.decl_tensor_intrin" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">decl_tensor_intrin</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">,</span> <span class="n">intrin_func</span><span class="p">,</span> <span class="n">binds</span><span class="o">=</span><span class="p">{</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">:</span> <span class="n">BA</span><span class="p">,</span> <span class="n">C</span><span class="p">:</span> <span class="n">BC</span><span class="p">})</span>
</pre></div>
</div>
</div>
<div class="section" id="scheduling-the-computation">
<h2>Scheduling the Computation<a class="headerlink" href="#scheduling-the-computation" title="Permalink to this headline"></a></h2>
<p>To use TensorCores in TVM, we must schedule the computation into specific structure
to match the tensor intrinsic. The same as traditional GPU programs, we can also use
shared memory to boost the speed. If you have any questions about blocking and shared
memory, please refer <a class="reference internal" href="opt_conv_cuda.html#opt-conv-gpu"><span class="std std-ref">How to optimize convolution on GPU</span></a>.</p>
<p>In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore
instructions. Thus, the output shape of each warp is 64x32 and each block outputs
128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles)
one time.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p><em>Warp-level Operation</em></p>
<p>Note that all TensorCore instructions are warp-level instructions, which means all 32 threads
in a warp should do this instruction simultaneously. Making threadIdx.x extent=32 is one of the
easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain
TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution.
The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time.</p>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Define tiling sizes</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_row_warps</span></a> <span class="o">=</span> <span class="mi">4</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_col_warps</span></a> <span class="o">=</span> <span class="mi">2</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_row_tiles</span></a> <span class="o">=</span> <span class="mi">2</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_col_tiles</span></a> <span class="o">=</span> <span class="mi">4</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_size</span></a> <span class="o">=</span> <span class="mi">32</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">chunk</span></a> <span class="o">=</span> <span class="mi">2</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_x</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;blockIdx.x&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_y</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;blockIdx.y&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_z</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;blockIdx.z&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;threadIdx.x&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;threadIdx.y&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_z</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">&quot;threadIdx.z&quot;</span><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">hc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">wc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nnc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ooc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">Conv</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_k</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">fuse</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">hc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">wc</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_k</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_z</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_row_tiles</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_i</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_row_warps</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oci</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_col_tiles</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_j</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_col_warps</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_k</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_i</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_j</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oci</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nnc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ooc</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_i</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_x</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_j</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_z</span></a><span class="p">)</span>
<span class="c1"># Schedule local computation</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oc</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">o</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nnf</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oof</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">ConvF</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ko</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ki</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ic</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">chunk</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ko</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ki</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">o</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nnf</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oof</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">)</span>
<span class="c1"># Move intermediate computation into each output compute tile</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AF</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WF</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">)</span>
<span class="c1"># Schedule for A&#39;s share memory</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">h</span></a><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nn</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">AS</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xo</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_row_warps</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yo</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xo</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_col_warps</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">t</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">fuse</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nn</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">to</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ti</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">t</span></a><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_size</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_z</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AS</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ti</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span>
<span class="c1"># Schedule for W&#39;s share memory</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">],</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kh</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kw</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ic</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">o</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oo</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">WS</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xo</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">o</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_row_warps</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">yo</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">xo</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">block_col_warps</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">t</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">fuse</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ii</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">oo</span></a><span class="p">)</span>
<a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">to</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ti</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">t</span></a><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">warp_size</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_y</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ty</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_z</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">to</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">thread_x</span></a><span class="p">)</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WS</span></a><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ti</span></a><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/driver.html#tvm.lower" title="tvm.lower" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span># from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 14, 14, 16, 16, 16), &quot;float16&quot;), W: T.Buffer((3, 3, 16, 32, 16, 16), &quot;float16&quot;), Conv: T.Buffer((16, 14, 14, 32, 16, 16), &quot;float32&quot;)):
T.func_attr({&quot;from_legacy_te_schedule&quot;: T.bool(True), &quot;tir.noalias&quot;: T.bool(True)})
blockIdx_z = T.launch_thread(&quot;blockIdx.z&quot;, 196)
Conv_wmma_accumulator = T.allocate([2048], &quot;float32&quot;, &quot;wmma.accumulator&quot;)
Apad_shared = T.allocate([12288], &quot;float16&quot;, &quot;shared&quot;)
W_shared = T.allocate([12288], &quot;float16&quot;, &quot;shared&quot;)
Apad_shared_wmma_matrix_a = T.allocate([512], &quot;float16&quot;, &quot;wmma.matrix_a&quot;)
W_shared_wmma_matrix_b = T.allocate([1024], &quot;float16&quot;, &quot;wmma.matrix_b&quot;)
blockIdx_x = T.launch_thread(&quot;blockIdx.x&quot;, 2)
blockIdx_y = T.launch_thread(&quot;blockIdx.y&quot;, 4)
threadIdx_y = T.launch_thread(&quot;threadIdx.y&quot;, 4)
threadIdx_z = T.launch_thread(&quot;threadIdx.z&quot;, 2)
Conv_wmma_accumulator_1 = T.Buffer((2048,), data=Conv_wmma_accumulator, scope=&quot;wmma.accumulator&quot;)
for n_c_init, o_c_init, nn_c_init, oo_c_init in T.grid(2, 4, 16, 16):
Conv_wmma_accumulator_1[n_c_init * 1024 + o_c_init * 256 + nn_c_init * 16 + oo_c_init] = T.float32(0)
for ic_outer, kh in T.grid(8, 3):
threadIdx_x = T.env_thread(&quot;threadIdx.x&quot;)
Apad_shared_1 = T.Buffer((12288,), &quot;float16&quot;, data=Apad_shared, scope=&quot;shared&quot;)
for ax2, ax3, ax4_ax5_fused_outer in T.grid(3, 2, 8):
cse_var_2: T.int32 = ax3 * 256
cse_var_1: T.int32 = ax4_ax5_fused_outer * 32
T.launch_thread(threadIdx_x, 32)
A_1 = T.Buffer((12845056,), &quot;float16&quot;, data=A.data)
Apad_shared_1[threadIdx_y * 3072 + threadIdx_z * 1536 + ax2 * 512 + cse_var_2 + cse_var_1 + threadIdx_x] = T.if_then_else(1 &lt;= blockIdx_z // 14 + kh and blockIdx_z // 14 + kh &lt; 15 and 1 &lt;= ax2 + blockIdx_z % 14 and ax2 + blockIdx_z % 14 &lt; 15, A_1[blockIdx_x * 6422528 + threadIdx_y * 1605632 + threadIdx_z * 802816 + kh * 57344 + blockIdx_z * 4096 + ax2 * 4096 + ic_outer * 512 + cse_var_2 + cse_var_1 + threadIdx_x - 61440], T.float16(0))
W_shared_1 = T.Buffer((12288,), &quot;float16&quot;, data=W_shared, scope=&quot;shared&quot;)
for ax1, ax2 in T.grid(3, 2):
T.launch_thread(threadIdx_x, 32)
W_1 = T.Buffer((1179648,), &quot;float16&quot;, data=W.data)
W_shared_1[ax1 * 4096 + ax2 * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8:ax1 * 4096 + ax2 * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8 + 8] = W_1[kh * 393216 + ax1 * 131072 + ic_outer * 16384 + ax2 * 8192 + blockIdx_y * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8:kh * 393216 + ax1 * 131072 + ic_outer * 16384 + ax2 * 8192 + blockIdx_y * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8 + 8]
for ic_inner, kw in T.grid(2, 3):
Apad_shared_wmma_matrix_a_1 = T.Buffer((512,), &quot;float16&quot;, data=Apad_shared_wmma_matrix_a, scope=&quot;wmma.matrix_a&quot;)
for ax0, ax4, ax5 in T.grid(2, 16, 16):
cse_var_3: T.int32 = ax4 * 16
Apad_shared_wmma_matrix_a_1[ax0 * 256 + cse_var_3 + ax5] = Apad_shared_1[threadIdx_y * 3072 + ax0 * 1536 + kw * 512 + ic_inner * 256 + cse_var_3 + ax5]
W_shared_wmma_matrix_b_1 = T.Buffer((1024,), &quot;float16&quot;, data=W_shared_wmma_matrix_b, scope=&quot;wmma.matrix_b&quot;)
for ax3, ax4, ax5 in T.grid(4, 16, 16):
cse_var_5: T.int32 = ax3 * 256
cse_var_4: T.int32 = ax4 * 16
W_shared_wmma_matrix_b_1[cse_var_5 + cse_var_4 + ax5] = W_shared_1[kw * 4096 + ic_inner * 2048 + threadIdx_z * 1024 + cse_var_5 + cse_var_4 + ax5]
for n_c, o_c, nn_c, oo_c, ii in T.grid(2, 4, 16, 16, 16):
cse_var_8: T.int32 = o_c * 256
cse_var_7: T.int32 = nn_c * 16
cse_var_6: T.int32 = n_c * 1024 + cse_var_8 + cse_var_7 + oo_c
Conv_wmma_accumulator_1[cse_var_6] = Conv_wmma_accumulator_1[cse_var_6] + T.Cast(&quot;float32&quot;, Apad_shared_wmma_matrix_a_1[n_c * 256 + cse_var_7 + ii]) * T.Cast(&quot;float32&quot;, W_shared_wmma_matrix_b_1[cse_var_8 + ii * 16 + oo_c])
for n_inner, o_inner, nn, oo in T.grid(2, 4, 16, 16):
cse_var_10: T.int32 = o_inner * 256
cse_var_9: T.int32 = nn * 16
Conv_1 = T.Buffer((25690112,), data=Conv.data)
Conv_1[blockIdx_x * 12845056 + threadIdx_y * 3211264 + n_inner * 1605632 + blockIdx_z * 8192 + blockIdx_y * 2048 + threadIdx_z * 1024 + cse_var_10 + cse_var_9 + oo] = Conv_wmma_accumulator_1[n_inner * 1024 + cse_var_10 + cse_var_9 + oo]
</pre></div>
</div>
</div>
<div class="section" id="lowering-computation-to-intrinsics">
<h2>Lowering Computation to Intrinsics<a class="headerlink" href="#lowering-computation-to-intrinsics" title="Permalink to this headline"></a></h2>
<p>The last phase is to lower the computation loops down to TensorCore hardware intrinsics
by mapping the 2D convolution to tensor intrinsics</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">AF</span></a><span class="p">]</span><span class="o">.</span><span class="n">tensorize</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">AF</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">intrin_wmma_load_matrix</span><span class="p">(</span><span class="s2">&quot;wmma.matrix_a&quot;</span><span class="p">))</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WF</span></a><span class="p">]</span><span class="o">.</span><span class="n">tensorize</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">WF</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">intrin_wmma_load_matrix</span><span class="p">(</span><span class="s2">&quot;wmma.matrix_b&quot;</span><span class="p">))</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">]</span><span class="o">.</span><span class="n">tensorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nnc</span></a><span class="p">,</span> <span class="n">intrin_wmma_store_matrix</span><span class="p">())</span>
<a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ConvF</span></a><span class="p">]</span><span class="o">.</span><span class="n">tensorize</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">nnf</span></a><span class="p">,</span> <span class="n">intrin_wmma_gemm</span><span class="p">())</span>
<span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/driver.html#tvm.lower" title="tvm.lower" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span># from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 14, 14, 16, 16, 16), &quot;float16&quot;), W: T.Buffer((3, 3, 16, 32, 16, 16), &quot;float16&quot;), Conv: T.Buffer((16, 14, 14, 32, 16, 16), &quot;float32&quot;)):
T.func_attr({&quot;from_legacy_te_schedule&quot;: T.bool(True), &quot;tir.noalias&quot;: T.bool(True)})
blockIdx_z = T.launch_thread(&quot;blockIdx.z&quot;, 196)
Conv_wmma_accumulator = T.allocate([2048], &quot;float32&quot;, &quot;wmma.accumulator&quot;)
Apad_shared = T.allocate([12288], &quot;float16&quot;, &quot;shared&quot;)
W_shared = T.allocate([12288], &quot;float16&quot;, &quot;shared&quot;)
Apad_shared_wmma_matrix_a = T.allocate([512], &quot;float16&quot;, &quot;wmma.matrix_a&quot;)
W_shared_wmma_matrix_b = T.allocate([1024], &quot;float16&quot;, &quot;wmma.matrix_b&quot;)
blockIdx_x = T.launch_thread(&quot;blockIdx.x&quot;, 2)
blockIdx_y = T.launch_thread(&quot;blockIdx.y&quot;, 4)
threadIdx_y = T.launch_thread(&quot;threadIdx.y&quot;, 4)
threadIdx_z = T.launch_thread(&quot;threadIdx.z&quot;, 2)
for n_c_init, o_c_init in T.grid(2, 4):
T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, n_c_init * 4 + o_c_init, T.float32(0))
for ic_outer, kh in T.grid(8, 3):
threadIdx_x = T.env_thread(&quot;threadIdx.x&quot;)
for ax2, ax3, ax4_ax5_fused_outer in T.grid(3, 2, 8):
cse_var_2: T.int32 = ax3 * 256
cse_var_1: T.int32 = ax4_ax5_fused_outer * 32
T.launch_thread(threadIdx_x, 32)
Apad_shared_1 = T.Buffer((12288,), &quot;float16&quot;, data=Apad_shared, scope=&quot;shared&quot;)
A_1 = T.Buffer((12845056,), &quot;float16&quot;, data=A.data)
Apad_shared_1[threadIdx_y * 3072 + threadIdx_z * 1536 + ax2 * 512 + cse_var_2 + cse_var_1 + threadIdx_x] = T.if_then_else(1 &lt;= blockIdx_z // 14 + kh and blockIdx_z // 14 + kh &lt; 15 and 1 &lt;= ax2 + blockIdx_z % 14 and ax2 + blockIdx_z % 14 &lt; 15, A_1[blockIdx_x * 6422528 + threadIdx_y * 1605632 + threadIdx_z * 802816 + kh * 57344 + blockIdx_z * 4096 + ax2 * 4096 + ic_outer * 512 + cse_var_2 + cse_var_1 + threadIdx_x - 61440], T.float16(0))
for ax1, ax2 in T.grid(3, 2):
T.launch_thread(threadIdx_x, 32)
W_shared_1 = T.Buffer((12288,), &quot;float16&quot;, data=W_shared, scope=&quot;shared&quot;)
W_1 = T.Buffer((1179648,), &quot;float16&quot;, data=W.data)
W_shared_1[ax1 * 4096 + ax2 * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8:ax1 * 4096 + ax2 * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8 + 8] = W_1[kh * 393216 + ax1 * 131072 + ic_outer * 16384 + ax2 * 8192 + blockIdx_y * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8:kh * 393216 + ax1 * 131072 + ic_outer * 16384 + ax2 * 8192 + blockIdx_y * 2048 + threadIdx_y * 512 + threadIdx_z * 256 + threadIdx_x * 8 + 8]
for ic_inner, kw in T.grid(2, 3):
for ax0 in range(2):
T.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, ax0, T.tvm_access_ptr(T.type_annotation(&quot;float16&quot;), Apad_shared, threadIdx_y * 3072 + ax0 * 1536 + kw * 512 + ic_inner * 256, 256, 1), 16, &quot;row_major&quot;)
for ax3 in range(4):
T.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, ax3, T.tvm_access_ptr(T.type_annotation(&quot;float16&quot;), W_shared, kw * 4096 + ic_inner * 2048 + threadIdx_z * 1024 + ax3 * 256, 256, 1), 16, &quot;row_major&quot;)
for n_c, o_c in T.grid(2, 4):
cse_var_3: T.int32 = n_c * 4 + o_c
T.tvm_mma_sync(Conv_wmma_accumulator, cse_var_3, Apad_shared_wmma_matrix_a, n_c, W_shared_wmma_matrix_b, o_c, Conv_wmma_accumulator, cse_var_3)
for n_inner, o_inner in T.grid(2, 4):
T.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, n_inner * 4 + o_inner, T.tvm_access_ptr(T.type_annotation(&quot;float32&quot;), Conv.data, blockIdx_x * 12845056 + threadIdx_y * 3211264 + n_inner * 1605632 + blockIdx_z * 8192 + blockIdx_y * 2048 + threadIdx_z * 1024 + o_inner * 256, 256, 2), 16, &quot;row_major&quot;)
</pre></div>
</div>
</div>
<div class="section" id="generate-cuda-kernel">
<h2>Generate CUDA Kernel<a class="headerlink" href="#generate-cuda-kernel" title="Permalink to this headline"></a></h2>
<p>Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution.
Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not
be able to run on our build server</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">dev</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">cuda</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <a href="../../reference/api/python/contrib.html#tvm.contrib.nvcc.have_tensorcore" title="tvm.contrib.nvcc.have_tensorcore" class="sphx-glr-backref-module-tvm-contrib-nvcc sphx-glr-backref-type-py-function"><span class="n">nvcc</span><span class="o">.</span><span class="n">have_tensorcore</span></a><span class="p">(</span><span class="n">dev</span><span class="o">.</span><span class="n">compute_version</span><span class="p">):</span>
<span class="k">with</span> <a href="../../reference/api/python/ir.html#tvm.transform.PassContext" title="tvm.transform.PassContext" class="sphx-glr-backref-module-tvm-transform sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tvm</span><span class="o">.</span><span class="n">transform</span><span class="o">.</span><span class="n">PassContext</span></a><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;tir.UnrollLoop&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;auto_max_step&quot;</span><span class="p">:</span> <span class="mi">16</span><span class="p">}}):</span>
<span class="n">func</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span></a><span class="p">],</span> <span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
<span class="n">a_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">data_shape</span></a><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">)</span>
<span class="n">w_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">kernel_shape</span></a><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">a_np</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">w_np</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">c</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">output_shape</span></a><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">Conv</span><span class="o">.</span><span class="n">dtype</span></a><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">evaluator</span> <span class="o">=</span> <a href="../../reference/api/python/runtime.html#tvm.runtime.Module.time_evaluator" title="tvm.runtime.Module.time_evaluator" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-method"><span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span></a><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;conv2d with tensor core: </span><span class="si">%f</span><span class="s2"> ms&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span> <span class="o">*</span> <span class="mf">1e3</span><span class="p">))</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>conv2d with tensor core: 12.214825 ms
</pre></div>
</div>
</div>
<div class="section" id="summary">
<h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline"></a></h2>
<p>This tutorial demonstrates how TVM scheduling primitives can be used to
call TensorCores on specific GPUs.</p>
<div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-optimize-operators-opt-conv-tensorcore-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/7372db5919b5619bc34fde3434862bca/opt_conv_tensorcore.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">opt_conv_tensorcore.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/7455981870c23c8c76482dedf33d8a42/opt_conv_tensorcore.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">opt_conv_tensorcore.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../tune_with_autotvm/index.html" class="btn btn-neutral float-right" title="Auto-Tune with Templates and AutoTVM" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="opt_conv_cuda.html" class="btn btn-neutral float-left" title="How to optimize convolution on GPU" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div>
<section class="footerSec">
<div class="footerHeader">
<div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row">
<div class="copywrite d-flex align-items-center">
<h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5>
</div>
</div>
</div>
<div>
<div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div>
</div>
</section>
</footer>
</div>
</div>
</section>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
</body>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<!-- Theme Analytics -->
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-75982049-2', 'auto');
ga('send', 'pageview');
</script>
</body>
</html>