| |
| |
| |
| |
| |
| |
| <!DOCTYPE html> |
| <html class="writer-html5" lang="en" > |
| <head> |
| <meta charset="utf-8"> |
| |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| |
| <title>Tuning High Performance Convolution on NVIDIA GPUs — tvm 0.17.dev0 documentation</title> |
| |
| |
| |
| <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" /> |
| |
| |
| |
| <link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/> |
| |
| |
| |
| |
| |
| |
| |
| <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script> |
| <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> |
| <script src="../../_static/jquery.js"></script> |
| <script src="../../_static/underscore.js"></script> |
| <script src="../../_static/doctools.js"></script> |
| |
| <script type="text/javascript" src="../../_static/js/theme.js"></script> |
| |
| |
| <script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script> |
| <link rel="index" title="Index" href="../../genindex.html" /> |
| <link rel="search" title="Search" href="../../search.html" /> |
| <link rel="next" title="Auto-tuning a Convolutional Network for NVIDIA GPU" href="tune_relay_cuda.html" /> |
| <link rel="prev" title="Auto-Tune with Templates and AutoTVM" href="index.html" /> |
| </head> |
| |
| <body class="wy-body-for-nav"> |
| |
| |
| <div class="wy-grid-for-nav"> |
| |
| |
| <header class="header"> |
| <div class="innercontainer"> |
| <div class="headerInner d-flex justify-content-between align-items-center"> |
| <div class="headerLogo"> |
| <a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a> |
| </div> |
| |
| <div id="headMenu" class="headerNav"> |
| <button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button> |
| <ul class="nav"> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/community>Community</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/download>Download</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/vta>VTA</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/blog>Blog</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/docs>Docs</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvmconf.org>Conference</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://github.com/apache/tvm/>Github</a> |
| </li> |
| </ul> |
| <div class="responsivetlcdropdown"> |
| <button type="button" class="btn-link"> |
| ASF |
| </button> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| <div class="responsiveMenuIcon"> |
| <button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button> |
| </div> |
| |
| <div class="tlcDropdown"> |
| <div class="dropdown"> |
| <button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false"> |
| ASF |
| </button> |
| <div class="dropdown-menu dropdown-menu-right"> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| </header> |
| |
| <nav data-toggle="wy-nav-shift" class="wy-nav-side fixed"> |
| <div class="wy-side-scroll"> |
| <div class="wy-side-nav-search" > |
| |
| |
| |
| <a href="../../index.html"> |
| |
| |
| |
| |
| <img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/> |
| |
| </a> |
| |
| |
| |
| |
| <input type="checkbox" class="version-toggle-box" hidden id="version-toggle"> |
| <label for="version-toggle" class="version-toggle-label"> |
| <div tabindex="0" class="version version-selector version-selector-show"> |
| 0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span> |
| </div> |
| </label> |
| <div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| <p class="caption" role="heading"><span class="caption-text">Versions</span></p> |
| <ol style="text-align: left"> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li> |
| |
| </ol> |
| </div> |
| |
| |
| |
| |
| <div role="search"> |
| <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get"> |
| <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" /> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| </form> |
| </div> |
| |
| |
| </div> |
| |
| |
| <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| |
| |
| |
| |
| |
| |
| <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">User Guide</span></p> |
| <ul class="current"> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li> |
| <li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current"> |
| <li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../optimize_operators/index.html">Optimize Tensor Operators</a></li> |
| <li class="toctree-l2 current"><a class="reference internal" href="index.html">Auto-Tune with Templates and AutoTVM</a><ul class="current"> |
| <li class="toctree-l3 current"><a class="current reference internal" href="#">Tuning High Performance Convolution on NVIDIA GPUs</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#install-dependencies">Install dependencies</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#step-1-define-the-search-space">Step 1: Define the search space</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#step-2-search-through-the-space">Step 2: Search through the space</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="tune_relay_cuda.html">Auto-tuning a Convolutional Network for NVIDIA GPU</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="tune_relay_x86.html">Auto-tuning a Convolutional Network for x86 CPU</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="tune_relay_arm.html">Auto-tuning a Convolutional Network for ARM CPU</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="tune_relay_mobile_gpu.html">Auto-tuning a Convolutional Network for Mobile GPU</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li> |
| </ul> |
| </li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li> |
| </ul> |
| |
| |
| |
| </div> |
| |
| </div> |
| </nav> |
| |
| <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> |
| |
| <nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top"> |
| |
| <div class="togglemenu"> |
| |
| </div> |
| <div class="nav-content"> |
| <!-- tvm --> |
| Table of Contents |
| </div> |
| |
| </nav> |
| |
| |
| <div class="wy-nav-content"> |
| |
| <div class="rst-content"> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <div role="navigation" aria-label="breadcrumbs navigation"> |
| |
| <ul class="wy-breadcrumbs"> |
| |
| <li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="index.html">Auto-Tune with Templates and AutoTVM</a> <span class="br-arrow">></span></li> |
| |
| <li>Tuning High Performance Convolution on NVIDIA GPUs</li> |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="wy-breadcrumbs-aside"> |
| |
| |
| |
| <a href="https://github.com/apache/tvm/edit/main/docs/how_to/tune_with_autotvm/tune_conv2d_cuda.rst" class="fa fa-github"> Edit on GitHub</a> |
| |
| |
| |
| </li> |
| |
| </ul> |
| |
| |
| <hr/> |
| </div> |
| <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> |
| <div itemprop="articleBody"> |
| |
| <div class="sphx-glr-download-link-note admonition note"> |
| <p class="admonition-title">Note</p> |
| <p>This tutorial can be used interactively with Google Colab! You can also click |
| <a class="reference internal" href="#sphx-glr-download-how-to-tune-with-autotvm-tune-conv2d-cuda-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p> |
| <a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/732ed130cbc15432e737da8cc47e1734/tune_conv2d_cuda.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a> |
| </div> |
| <div class="sphx-glr-example-title section" id="tuning-high-performance-convolution-on-nvidia-gpus"> |
| <span id="sphx-glr-how-to-tune-with-autotvm-tune-conv2d-cuda-py"></span><h1>Tuning High Performance Convolution on NVIDIA GPUs<a class="headerlink" href="#tuning-high-performance-convolution-on-nvidia-gpus" title="Permalink to this headline">¶</a></h1> |
| <p><strong>Author</strong>: <a class="reference external" href="https://github.com/merrymercy">Lianmin Zheng</a></p> |
| <p>This is an advanced tutorial for writing high performance tunable template for |
| NVIDIA GPU. By running auto-tuner on this template, we can outperform the |
| vendor provided library CuDNN in many cases.</p> |
| <p>Note that this tutorial will not run on Windows or recent versions of macOS. To |
| get it to run, you will need to wrap the body of this tutorial in a <code class="code docutils literal notranslate"><span class="pre">if</span> |
| <span class="pre">__name__</span> <span class="pre">==</span> <span class="pre">"__main__":</span></code> block.</p> |
| <div class="section" id="install-dependencies"> |
| <h2>Install dependencies<a class="headerlink" href="#install-dependencies" title="Permalink to this headline">¶</a></h2> |
| <p>To use autotvm package in tvm, we need to install some extra dependencies. |
| (change “3” to “2” if you use python2):</p> |
| <div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip3<span class="w"> </span>install<span class="w"> </span>--user<span class="w"> </span>psutil<span class="w"> </span>xgboost<span class="w"> </span>tornado<span class="w"> </span>cloudpickle |
| </pre></div> |
| </div> |
| <p>To make TVM run faster in tuning, it is recommended to use cython |
| as FFI of tvm. In the root directory of tvm, execute</p> |
| <div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip3<span class="w"> </span>install<span class="w"> </span>--user<span class="w"> </span>cython |
| sudo<span class="w"> </span>make<span class="w"> </span>cython3 |
| </pre></div> |
| </div> |
| <p>Now return to python code. Import packages.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">logging</span> |
| <span class="kn">import</span> <span class="nn">sys</span> |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> |
| |
| <span class="kn">import</span> <span class="nn">tvm</span> |
| <span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span><span class="p">,</span> <span class="n">topi</span><span class="p">,</span> <span class="n">testing</span> |
| <span class="kn">from</span> <span class="nn">tvm.topi.testing</span> <span class="kn">import</span> <span class="n">conv2d_nchw_python</span> |
| <span class="kn">import</span> <span class="nn">tvm.testing</span> |
| |
| <span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">autotvm</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="step-1-define-the-search-space"> |
| <h2>Step 1: Define the search space<a class="headerlink" href="#step-1-define-the-search-space" title="Permalink to this headline">¶</a></h2> |
| <p>There are plenty of useful schedule primitives in tvm. You can also find |
| some tutorials that describe them in more details, such as |
| (1). <a class="reference internal" href="../optimize_operators/opt_conv_cuda.html#opt-conv-gpu"><span class="std std-ref">How to optimize convolution on GPU</span></a> |
| (2). <a class="reference external" href="https://tvm.apache.org/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example">Optimizing DepthwiseConv on NVIDIA GPU</a></p> |
| <p>However, their implementations are manually tuned for some special input |
| shapes. In this section, we build a large enough space to cover |
| the techniques used in these tutorials. Then we rely on the efficient auto-tuner |
| to search through this space and pick some good configurations.</p> |
| <p>If you are familiar with writing cuda schedule, you can find the following |
| template is very general. Actually this template can be easily modified |
| to tune other operators such as depthwise convolution and GEMM. |
| In order to fully understand this template, you should be familiar with |
| the schedule primitives and auto tuning API. You can refer to the above |
| tutorials and <a class="reference internal" href="../../tutorial/autotvm_matmul_x86.html#tutorial-autotvm-matmul-x86"><span class="std std-ref">autotvm tutorial</span></a></p> |
| <p>It is worth noting that the search space for a conv2d operator |
| can be very large (at the level of 10^9 for some input shapes)</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@autotvm</span><span class="o">.</span><span class="n">template</span><span class="p">(</span><span class="s2">"tutorial/conv2d_no_batching"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">conv2d_no_batching</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">H</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CO</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KH</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KW</span></a><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">padding</span></a><span class="p">):</span> |
| <span class="k">assert</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Only consider batch_size = 1 in this template"</span> |
| |
| <span class="n">data</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">H</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"data"</span><span class="p">)</span> |
| <span class="n">kernel</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CO</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KH</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KW</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"kernel"</span><span class="p">)</span> |
| <span class="n">conv</span> <span class="o">=</span> <a href="../../reference/api/python/topi.html#tvm.topi.nn.conv2d_nchw" title="tvm.topi.nn.conv2d_nchw" class="sphx-glr-backref-module-tvm-topi-nn sphx-glr-backref-type-py-function"><span class="n">topi</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">conv2d_nchw</span></a><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">kernel</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">padding</span></a><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">([</span><span class="n">conv</span><span class="o">.</span><span class="n">op</span><span class="p">])</span> |
| |
| <span class="c1">##### space definition begin #####</span> |
| <span class="n">n</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">conv</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">rc</span><span class="p">,</span> <span class="n">ry</span><span class="p">,</span> <span class="n">rx</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">conv</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span> |
| |
| <span class="n">cfg</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_f"</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_y"</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_x"</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_rc"</span><span class="p">,</span> <span class="n">rc</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_ry"</span><span class="p">,</span> <span class="n">ry</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_rx"</span><span class="p">,</span> <span class="n">rx</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_knob</span><span class="p">(</span><span class="s2">"auto_unroll_max_step"</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">1500</span><span class="p">])</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_knob</span><span class="p">(</span><span class="s2">"unroll_explicit"</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> |
| <span class="c1">##### space definition end #####</span> |
| |
| <span class="c1"># inline padding</span> |
| <span class="n">pad_data</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">conv</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">input_tensors</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">pad_data</span><span class="p">]</span><span class="o">.</span><span class="n">compute_inline</span><span class="p">()</span> |
| <span class="n">data</span><span class="p">,</span> <span class="n">raw_data</span> <span class="o">=</span> <span class="n">pad_data</span><span class="p">,</span> <span class="n">data</span> |
| |
| <span class="n">output</span> <span class="o">=</span> <span class="n">conv</span> |
| <span class="n">OL</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_write" title="tvm.te.Schedule.cache_write" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_write</span></a><span class="p">(</span><span class="n">conv</span><span class="p">,</span> <span class="s2">"local"</span><span class="p">)</span> |
| |
| <span class="c1"># create cache stage</span> |
| <span class="n">AA</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="s2">"shared"</span><span class="p">,</span> <span class="p">[</span><span class="n">OL</span><span class="p">])</span> |
| <span class="n">WW</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="s2">"shared"</span><span class="p">,</span> <span class="p">[</span><span class="n">OL</span><span class="p">])</span> |
| <span class="n">AL</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><span class="n">AA</span><span class="p">,</span> <span class="s2">"local"</span><span class="p">,</span> <span class="p">[</span><span class="n">OL</span><span class="p">])</span> |
| <span class="n">WL</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule.cache_read" title="tvm.te.Schedule.cache_read" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-method"><span class="n">s</span><span class="o">.</span><span class="n">cache_read</span></a><span class="p">(</span><span class="n">WW</span><span class="p">,</span> <span class="s2">"local"</span><span class="p">,</span> <span class="p">[</span><span class="n">OL</span><span class="p">])</span> |
| |
| <span class="c1"># tile and bind spatial axes</span> |
| <span class="n">n</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">bf</span><span class="p">,</span> <span class="n">vf</span><span class="p">,</span> <span class="n">tf</span><span class="p">,</span> <span class="n">fi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_f"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span> |
| <span class="n">by</span><span class="p">,</span> <span class="n">vy</span><span class="p">,</span> <span class="n">ty</span><span class="p">,</span> <span class="n">yi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_y"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> |
| <span class="n">bx</span><span class="p">,</span> <span class="n">vx</span><span class="p">,</span> <span class="n">tx</span><span class="p">,</span> <span class="n">xi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_x"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> |
| <span class="n">kernel_scope</span> <span class="o">=</span> <span class="n">n</span> <span class="c1"># this is the scope to attach global config inside this kernel</span> |
| |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">bf</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.z"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">by</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.y"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">bx</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.x"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">vf</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"vthread"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">vy</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"vthread"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">vx</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"vthread"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">tf</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.z"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ty</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.y"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.x"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">bf</span><span class="p">,</span> <span class="n">by</span><span class="p">,</span> <span class="n">bx</span><span class="p">,</span> <span class="n">vf</span><span class="p">,</span> <span class="n">vy</span><span class="p">,</span> <span class="n">vx</span><span class="p">,</span> <span class="n">tf</span><span class="p">,</span> <span class="n">ty</span><span class="p">,</span> <span class="n">tx</span><span class="p">,</span> <span class="n">fi</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">xi</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">],</span> <span class="n">tx</span><span class="p">)</span> |
| |
| <span class="c1"># tile reduction axes</span> |
| <span class="n">n</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">rc</span><span class="p">,</span> <span class="n">ry</span><span class="p">,</span> <span class="n">rx</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span> |
| <span class="n">rco</span><span class="p">,</span> <span class="n">rcm</span><span class="p">,</span> <span class="n">rci</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_rc"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">OL</span><span class="p">,</span> <span class="n">rc</span><span class="p">)</span> |
| <span class="n">ryo</span><span class="p">,</span> <span class="n">rym</span><span class="p">,</span> <span class="n">ryi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_rx"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">OL</span><span class="p">,</span> <span class="n">ry</span><span class="p">)</span> |
| <span class="n">rxo</span><span class="p">,</span> <span class="n">rxm</span><span class="p">,</span> <span class="n">rxi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_ry"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">OL</span><span class="p">,</span> <span class="n">rx</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">rco</span><span class="p">,</span> <span class="n">ryo</span><span class="p">,</span> <span class="n">rxo</span><span class="p">,</span> <span class="n">rcm</span><span class="p">,</span> <span class="n">rym</span><span class="p">,</span> <span class="n">rxm</span><span class="p">,</span> <span class="n">rci</span><span class="p">,</span> <span class="n">ryi</span><span class="p">,</span> <span class="n">rxi</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> |
| |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">AA</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">],</span> <span class="n">rxo</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">WW</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">],</span> <span class="n">rxo</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">AL</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">],</span> <span class="n">rxm</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">WL</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">OL</span><span class="p">],</span> <span class="n">rxm</span><span class="p">)</span> |
| |
| <span class="c1"># cooperative fetching</span> |
| <span class="k">for</span> <span class="n">load</span> <span class="ow">in</span> <span class="p">[</span><span class="n">AA</span><span class="p">,</span> <span class="n">WW</span><span class="p">]:</span> |
| <span class="n">n</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">fused</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">fuse</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> |
| <span class="n">tz</span><span class="p">,</span> <span class="n">fused</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">fused</span><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_f"</span><span class="p">]</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> |
| <span class="n">ty</span><span class="p">,</span> <span class="n">fused</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">fused</span><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_y"</span><span class="p">]</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> |
| <span class="n">tx</span><span class="p">,</span> <span class="n">fused</span> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">fused</span><span class="p">,</span> <span class="n">nparts</span><span class="o">=</span><span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_x"</span><span class="p">]</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">tz</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.z"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">ty</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.y"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">load</span><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">tx</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.x"</span><span class="p">))</span> |
| |
| <span class="c1"># tune unroll</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">pragma</span><span class="p">(</span><span class="n">kernel_scope</span><span class="p">,</span> <span class="s2">"auto_unroll_max_step"</span><span class="p">,</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"auto_unroll_max_step"</span><span class="p">]</span><span class="o">.</span><span class="n">val</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">output</span><span class="p">]</span><span class="o">.</span><span class="n">pragma</span><span class="p">(</span><span class="n">kernel_scope</span><span class="p">,</span> <span class="s2">"unroll_explicit"</span><span class="p">,</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"unroll_explicit"</span><span class="p">]</span><span class="o">.</span><span class="n">val</span><span class="p">)</span> |
| |
| <span class="k">return</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><span class="n">raw_data</span><span class="p">,</span> <span class="n">kernel</span><span class="p">,</span> <span class="n">conv</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="step-2-search-through-the-space"> |
| <h2>Step 2: Search through the space<a class="headerlink" href="#step-2-search-through-the-space" title="Permalink to this headline">¶</a></h2> |
| <p>We pick the last layer on resnet as test case. |
| Since our space is very large, <code class="code docutils literal notranslate"><span class="pre">XGBoostTuner</span></code> is most suitable |
| for our case. Here we only do 20 trials for demonstration. |
| In practice, making 1000 trials usually can find some good kernels |
| for this template</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># logging config (for printing tuning log to screen)</span> |
| <a href="https://docs.python.org/3/library/logging.html#logging.getLogger" title="logging.getLogger" class="sphx-glr-backref-module-logging sphx-glr-backref-type-py-function"><span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span></a><span class="p">(</span><span class="s2">"autotvm"</span><span class="p">)</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span></a><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/logging.html#logging.getLogger" title="logging.getLogger" class="sphx-glr-backref-module-logging sphx-glr-backref-type-py-function"><span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span></a><span class="p">(</span><span class="s2">"autotvm"</span><span class="p">)</span><span class="o">.</span><span class="n">addHandler</span><span class="p">(</span><a href="https://docs.python.org/3/library/logging.handlers.html#logging.StreamHandler" title="logging.StreamHandler" class="sphx-glr-backref-module-logging sphx-glr-backref-type-py-class"><span class="n">logging</span><span class="o">.</span><span class="n">StreamHandler</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/io.html#io.TextIOWrapper" title="io.TextIOWrapper" class="sphx-glr-backref-module-io sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">sys</span><span class="o">.</span><span class="n">stdout</span></a><span class="p">))</span> |
| |
| <span class="c1"># the last layer in resnet</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">H</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CO</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KH</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KW</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">strides</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">padding</span></a> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> |
| <span class="n">task</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">task</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> |
| <span class="s2">"tutorial/conv2d_no_batching"</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">H</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CO</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KH</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KW</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">strides</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">padding</span></a><span class="p">),</span> <span class="n">target</span><span class="o">=</span><span class="s2">"cuda"</span> |
| <span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="n">task</span><span class="o">.</span><span class="n">config_space</span><span class="p">)</span> |
| |
| <span class="c1"># Use local gpu, measure 10 times for every config to reduce variance</span> |
| <span class="c1"># The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds</span> |
| <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">measure_option</span></a> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">measure_option</span></a><span class="p">(</span> |
| <span class="n">builder</span><span class="o">=</span><span class="n">autotvm</span><span class="o">.</span><span class="n">LocalBuilder</span><span class="p">(),</span> |
| <span class="n">runner</span><span class="o">=</span><span class="n">autotvm</span><span class="o">.</span><span class="n">LocalRunner</span><span class="p">(</span><span class="n">repeat</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">min_repeat_ms</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">timeout</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span> |
| <span class="p">)</span> |
| |
| <span class="n">record_file</span> <span class="o">=</span> <span class="kc">None</span> |
| <span class="c1"># Begin tuning, log records to file `conv2d.log`</span> |
| <span class="c1"># During tuning we will also try many invalid configs, so you are expected to</span> |
| <span class="c1"># see many error reports. As long as you can see non-zero GFLOPS, it is okay.</span> |
| |
| <span class="c1"># We do not run the tuning in our webpage server since it takes too long.</span> |
| <span class="c1"># Uncomment the following lines to run it by yourself.</span> |
| |
| <span class="c1"># tuner = autotvm.tuner.XGBTuner(task)</span> |
| <span class="c1"># record_file = "conv2d.log"</span> |
| <span class="c1"># tuner.tune(</span> |
| <span class="c1"># n_trial=5,</span> |
| <span class="c1"># measure_option=measure_option,</span> |
| <span class="c1"># callbacks=[autotvm.callback.log_to_file(record_file)],</span> |
| <span class="c1"># )</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>ConfigSpace (len=10454400, range_length=10454400, space_map= |
| 0 tile_f: Split(policy=factors, product=512, num_outputs=4) len=220 |
| 1 tile_y: Split(policy=factors, product=7, num_outputs=4) len=4 |
| 2 tile_x: Split(policy=factors, product=7, num_outputs=4) len=4 |
| 3 tile_rc: Split(policy=factors, product=512, num_outputs=3) len=55 |
| 4 tile_ry: Split(policy=factors, product=3, num_outputs=3) len=3 |
| 5 tile_rx: Split(policy=factors, product=3, num_outputs=3) len=3 |
| 6 auto_unroll_max_step: OtherOption([0, 512, 1500]) len=3 |
| 7 unroll_explicit: OtherOption([0, 1]) len=2 |
| ) |
| </pre></div> |
| </div> |
| <p>Finally we can inspect the best config from log file, check correctness, |
| and measure running time.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># inspect the best config</span> |
| <span class="n">dispatch_context</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.apply_history_best" title="tvm.autotvm.apply_history_best" class="sphx-glr-backref-module-tvm-autotvm sphx-glr-backref-type-py-function"><span class="n">autotvm</span><span class="o">.</span><span class="n">apply_history_best</span></a><span class="p">(</span><span class="n">record_file</span><span class="p">)</span> |
| <a href="../../reference/api/python/autotvm.html#tvm.autotvm.task.space.FallbackConfigEntity" title="tvm.autotvm.task.space.FallbackConfigEntity" class="sphx-glr-backref-module-tvm-autotvm-task-space sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">best_config</span></a> <span class="o">=</span> <span class="n">dispatch_context</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">task</span><span class="o">.</span><span class="n">target</span></a><span class="p">,</span> <span class="n">task</span><span class="o">.</span><span class="n">workload</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Best config:"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/autotvm.html#tvm.autotvm.task.space.FallbackConfigEntity" title="tvm.autotvm.task.space.FallbackConfigEntity" class="sphx-glr-backref-module-tvm-autotvm-task-space sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">best_config</span></a><span class="p">)</span> |
| |
| <span class="c1"># apply history best from log file</span> |
| <span class="k">with</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.apply_history_best" title="tvm.autotvm.apply_history_best" class="sphx-glr-backref-module-tvm-autotvm sphx-glr-backref-type-py-function"><span class="n">autotvm</span><span class="o">.</span><span class="n">apply_history_best</span></a><span class="p">(</span><span class="n">record_file</span><span class="p">):</span> |
| <span class="k">with</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class"><span class="n">tvm</span><span class="o">.</span><span class="n">target</span><span class="o">.</span><span class="n">Target</span></a><span class="p">(</span><span class="s2">"cuda"</span><span class="p">):</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arg_bufs</span></a> <span class="o">=</span> <span class="n">conv2d_no_batching</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">H</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CO</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KH</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KW</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">strides</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">padding</span></a><span class="p">)</span> |
| <span class="n">func</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arg_bufs</span></a><span class="p">)</span> |
| |
| <span class="c1"># check correctness</span> |
| <span class="n">a_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">H</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">W</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> |
| <span class="n">w_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CO</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">CI</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KH</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">KW</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> |
| <span class="n">c_np</span> <span class="o">=</span> <span class="n">conv2d_nchw_python</span><span class="p">(</span><span class="n">a_np</span><span class="p">,</span> <span class="n">w_np</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">strides</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">padding</span></a><span class="p">)</span> |
| |
| <span class="n">dev</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span> |
| <span class="n">a_tvm</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">a_np</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">dev</span><span class="p">)</span> |
| <span class="n">w_tvm</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">w_np</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">dev</span><span class="p">)</span> |
| <span class="n">c_tvm</span> <span class="o">=</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.empty" title="tvm.nd.empty" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">empty</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">c_np</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">dev</span><span class="p">)</span> |
| <span class="n">func</span><span class="p">(</span><span class="n">a_tvm</span><span class="p">,</span> <span class="n">w_tvm</span><span class="p">,</span> <span class="n">c_tvm</span><span class="p">)</span> |
| |
| <span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c_np</span><span class="p">,</span> <span class="n">c_tvm</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">)</span> |
| |
| <span class="c1"># Evaluate running time. Here we choose a large repeat number (400) to reduce the noise</span> |
| <span class="c1"># and the overhead of kernel launch. You can also use nvprof to validate the result.</span> |
| <span class="n">evaluator</span> <span class="o">=</span> <a href="../../reference/api/python/runtime.html#tvm.runtime.Module.time_evaluator" title="tvm.runtime.Module.time_evaluator" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-method"><span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span></a><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">400</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Time cost of this operator: </span><span class="si">%f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a_tvm</span><span class="p">,</span> <span class="n">w_tvm</span><span class="p">,</span> <span class="n">c_tvm</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Cannot find config for target=cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32, workload=('tutorial/conv2d_no_batching', 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)). A fallback configuration is used, which may bring great performance regression. |
| |
| Best config: |
| ,None |
| Time cost of this operator: 0.037088 |
| </pre></div> |
| </div> |
| <div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-tune-with-autotvm-tune-conv2d-cuda-py"> |
| <div class="sphx-glr-download sphx-glr-download-python docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/6ad550da5092845382b1197f58a93816/tune_conv2d_cuda.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tune_conv2d_cuda.py</span></code></a></p> |
| </div> |
| <div class="sphx-glr-download sphx-glr-download-jupyter docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/732ed130cbc15432e737da8cc47e1734/tune_conv2d_cuda.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">tune_conv2d_cuda.ipynb</span></code></a></p> |
| </div> |
| </div> |
| <p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p> |
| </div> |
| </div> |
| |
| |
| </div> |
| |
| </div> |
| |
| |
| <footer> |
| |
| <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> |
| |
| <a href="tune_relay_cuda.html" class="btn btn-neutral float-right" title="Auto-tuning a Convolutional Network for NVIDIA GPU" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a> |
| |
| |
| <a href="index.html" class="btn btn-neutral float-left" title="Auto-Tune with Templates and AutoTVM" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a> |
| |
| </div> |
| |
| <div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div> |
| <section class="footerSec"> |
| <div class="footerHeader"> |
| <div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row"> |
| <div class="copywrite d-flex align-items-center"> |
| <h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5> |
| </div> |
| </div> |
| |
| </div> |
| |
| <div> |
| <div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div> |
| </div> |
| |
| </section> |
| </footer> |
| </div> |
| </div> |
| |
| </section> |
| |
| </div> |
| |
| |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script> |
| <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script> |
| |
| </body> |
| <script type="text/javascript"> |
| jQuery(function () { |
| SphinxRtdTheme.Navigation.enable(true); |
| }); |
| </script> |
| |
| |
| |
| |
| <!-- Theme Analytics --> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-75982049-2', 'auto'); |
| ga('send', 'pageview'); |
| </script> |
| |
| |
| |
| |
| </body> |
| </html> |