| |
| |
| |
| |
| |
| |
| <!DOCTYPE html> |
| <html class="writer-html5" lang="en" > |
| <head> |
| <meta charset="utf-8"> |
| |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| |
| <title>Optimizing Operators with Schedule Templates and AutoTVM — tvm 0.17.dev0 documentation</title> |
| |
| |
| |
| <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> |
| <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/sg_gallery.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/sg_gallery-binder.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/sg_gallery-dataframe.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/sg_gallery-rendered-html.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../_static/css/tlcpack_theme.css" type="text/css" /> |
| |
| |
| |
| <link rel="shortcut icon" href="../_static/tvm-logo-square.png"/> |
| |
| |
| |
| |
| |
| |
| |
| <script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script> |
| <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script> |
| <script src="../_static/jquery.js"></script> |
| <script src="../_static/underscore.js"></script> |
| <script src="../_static/doctools.js"></script> |
| |
| <script type="text/javascript" src="../_static/js/theme.js"></script> |
| |
| |
| <script type="text/javascript" src="../_static/js/tlcpack_theme.js"></script> |
| <link rel="index" title="Index" href="../genindex.html" /> |
| <link rel="search" title="Search" href="../search.html" /> |
| <link rel="next" title="Optimizing Operators with Auto-scheduling" href="auto_scheduler_matmul_x86.html" /> |
| <link rel="prev" title="Working with Operators Using Tensor Expression" href="tensor_expr_get_started.html" /> |
| </head> |
| |
| <body class="wy-body-for-nav"> |
| |
| |
| <div class="wy-grid-for-nav"> |
| |
| |
| <header class="header"> |
| <div class="innercontainer"> |
| <div class="headerInner d-flex justify-content-between align-items-center"> |
| <div class="headerLogo"> |
| <a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a> |
| </div> |
| |
| <div id="headMenu" class="headerNav"> |
| <button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../_static/img/close-icon.svg" alt="Close"></button> |
| <ul class="nav"> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/community>Community</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/download>Download</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/vta>VTA</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/blog>Blog</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/docs>Docs</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvmconf.org>Conference</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://github.com/apache/tvm/>Github</a> |
| </li> |
| </ul> |
| <div class="responsivetlcdropdown"> |
| <button type="button" class="btn-link"> |
| ASF |
| </button> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| <div class="responsiveMenuIcon"> |
| <button type="button" id="menuBtn" class="btn-menu"><img src="../_static/img/menu-icon.svg" alt="Menu Icon"></button> |
| </div> |
| |
| <div class="tlcDropdown"> |
| <div class="dropdown"> |
| <button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false"> |
| ASF |
| </button> |
| <div class="dropdown-menu dropdown-menu-right"> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| </header> |
| |
| <nav data-toggle="wy-nav-shift" class="wy-nav-side fixed"> |
| <div class="wy-side-scroll"> |
| <div class="wy-side-nav-search" > |
| |
| |
| |
| <a href="../index.html"> |
| |
| |
| |
| |
| <img src="../_static/tvm-logo-small.png" class="logo" alt="Logo"/> |
| |
| </a> |
| |
| |
| |
| |
| <input type="checkbox" class="version-toggle-box" hidden id="version-toggle"> |
| <label for="version-toggle" class="version-toggle-label"> |
| <div tabindex="0" class="version version-selector version-selector-show"> |
| 0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span> |
| </div> |
| </label> |
| <div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| <p class="caption" role="heading"><span class="caption-text">Versions</span></p> |
| <ol style="text-align: left"> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li> |
| |
| </ol> |
| </div> |
| |
| |
| |
| |
| <div role="search"> |
| <form id="rtd-search-form" class="wy-form" action="../search.html" method="get"> |
| <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" /> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| </form> |
| </div> |
| |
| |
| </div> |
| |
| |
| <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| |
| |
| |
| |
| |
| |
| <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../install/index.html">Installing TVM</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../contribute/index.html">Contributor Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">User Guide</span></p> |
| <ul class="current"> |
| <li class="toctree-l1 current"><a class="reference internal" href="index.html">User Tutorial</a><ul class="current"> |
| <li class="toctree-l2"><a class="reference internal" href="introduction.html">Introduction</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="introduction.html#an-overview-of-tvm-and-model-optimization">An Overview of TVM and Model Optimization</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="install.html">Installing TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="tvmc_command_line_driver.html">Compiling and Optimizing a Model with TVMC</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="tvmc_python.html">Getting Starting using TVMC Python: a high-level API for TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="autotvm_relay_x86.html">Compiling and Optimizing a Model with the Python Interface (AutoTVM)</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="tensor_expr_get_started.html">Working with Operators Using Tensor Expression</a></li> |
| <li class="toctree-l2 current"><a class="current reference internal" href="#">Optimizing Operators with Schedule Templates and AutoTVM</a><ul> |
| <li class="toctree-l3"><a class="reference internal" href="#install-dependencies">Install dependencies</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="#basic-matrix-multiplication-with-te">Basic Matrix Multiplication with TE</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="#matrix-multiplication-with-autotvm">Matrix Multiplication with AutoTVM</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="#a-basic-matrix-multiplication-template">A Basic Matrix Multiplication Template</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="#a-matrix-multiplication-template-with-the-advanced-parameter-api">A Matrix Multiplication Template with the Advanced Parameter API</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="#step-2-use-autotvm-to-optimize-the-matrix-multiplication">Step 2: Use AutoTVM to Optimize the Matrix Multiplication</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#auto-tuners-in-tvm">Auto-tuners in TVM</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#begin-tuning">Begin tuning</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="#final-notes-and-summary">Final Notes and Summary</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l2"><a class="reference internal" href="auto_scheduler_matmul_x86.html">Optimizing Operators with Auto-scheduling</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="tensor_ir_blitz_course.html">Blitz Course to TensorIR</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="cross_compilation_and_rpc.html">Cross Compilation and RPC</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="relay_quick_start.html">Quick Start Tutorial for Compiling Deep Learning Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="uma.html">Making your Hardware Accelerator TVM-ready with UMA</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="intro_topi.html">Introduction to TOPI</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l1"><a class="reference internal" href="../how_to/index.html">How To Guides</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../dev/tutorial/index.html">Developer Tutorial</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../dev/how_to/how_to.html">Developer How-To Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../arch/index.html">Design and Architecture</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../reference/langref/index.html">Language Reference</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../reference/api/python/index.html">Python API</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../reference/api/links.html">Other APIs</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../reference/publications.html">Publications</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../genindex.html">Index</a></li> |
| </ul> |
| |
| |
| |
| </div> |
| |
| </div> |
| </nav> |
| |
| <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> |
| |
| <nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top"> |
| |
| <div class="togglemenu"> |
| |
| </div> |
| <div class="nav-content"> |
| <!-- tvm --> |
| Table of Contents |
| </div> |
| |
| </nav> |
| |
| |
| <div class="wy-nav-content"> |
| |
| <div class="rst-content"> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <div role="navigation" aria-label="breadcrumbs navigation"> |
| |
| <ul class="wy-breadcrumbs"> |
| |
| <li><a href="../index.html">Docs</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="index.html">User Tutorial</a> <span class="br-arrow">></span></li> |
| |
| <li>Optimizing Operators with Schedule Templates and AutoTVM</li> |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="wy-breadcrumbs-aside"> |
| |
| |
| |
| <a href="https://github.com/apache/tvm/edit/main/gallery/tutorial/autotvm_matmul_x86.py" class="fa fa-github"> Edit on GitHub</a> |
| |
| |
| |
| </li> |
| |
| </ul> |
| |
| |
| <hr/> |
| </div> |
| <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> |
| <div itemprop="articleBody"> |
| |
| <div class="sphx-glr-download-link-note admonition note"> |
| <p class="admonition-title">Note</p> |
| <p>This tutorial can be used interactively with Google Colab! You can also click |
| <a class="reference internal" href="#sphx-glr-download-tutorial-autotvm-matmul-x86-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p> |
| <a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/37bbf9e2065ec8deeb64a8d9fa0755bc/autotvm_matmul_x86.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a> |
| </div> |
| <div class="sphx-glr-example-title section" id="optimizing-operators-with-schedule-templates-and-autotvm"> |
| <span id="tutorial-autotvm-matmul-x86"></span><span id="sphx-glr-tutorial-autotvm-matmul-x86-py"></span><h1>Optimizing Operators with Schedule Templates and AutoTVM<a class="headerlink" href="#optimizing-operators-with-schedule-templates-and-autotvm" title="Permalink to this headline">¶</a></h1> |
| <p><strong>Authors</strong>: |
| <a class="reference external" href="https://github.com/merrymercy">Lianmin Zheng</a>, |
| <a class="reference external" href="https://github.com/hogepodge">Chris Hoge</a></p> |
| <p>In this tutorial, we show how the TVM Tensor Expression (TE) language |
| can be used to write schedule templates that can be searched by AutoTVM to |
| find the optimal schedule. This process is called Auto-Tuning, which helps |
| automate the process of optimizing tensor computation.</p> |
| <p>This tutorial builds on the previous <a class="reference internal" href="tensor_expr_get_started.html"><span class="doc">tutorial on how to write a matrix |
| multiplication using TE</span></a>.</p> |
| <p>There are two steps in auto-tuning.</p> |
| <ul class="simple"> |
| <li><p>The first step is defining a search space.</p></li> |
| <li><p>The second step is running a search algorithm to explore through this space.</p></li> |
| </ul> |
| <p>In this tutorial, you can learn how to perform these two steps in TVM. The whole |
| workflow is illustrated by a matrix multiplication example.</p> |
| <div class="admonition note"> |
| <p class="admonition-title">Note</p> |
| <p>Note that this tutorial will not run on Windows or recent versions of macOS. |
| To get it to run, you will need to wrap the body of this tutorial in a |
| <code class="code docutils literal notranslate"><span class="pre">if</span> <span class="pre">__name__</span> <span class="pre">==</span> <span class="pre">"__main__":</span></code> block.</p> |
| </div> |
| <div class="section" id="install-dependencies"> |
| <h2>Install dependencies<a class="headerlink" href="#install-dependencies" title="Permalink to this headline">¶</a></h2> |
| <p>To use autotvm package in TVM, we need to install some extra dependencies.</p> |
| <div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip3<span class="w"> </span>install<span class="w"> </span>--user<span class="w"> </span>psutil<span class="w"> </span>xgboost<span class="w"> </span>cloudpickle |
| </pre></div> |
| </div> |
| <p>To make TVM run faster in tuning, it is recommended to use cython as FFI of |
| TVM. In the root directory of TVM, execute:</p> |
| <div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip3<span class="w"> </span>install<span class="w"> </span>--user<span class="w"> </span>cython |
| sudo<span class="w"> </span>make<span class="w"> </span>cython3 |
| </pre></div> |
| </div> |
| <p>Now return to python code. Begin by importing the required packages.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">logging</span> |
| <span class="kn">import</span> <span class="nn">sys</span> |
| |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> |
| <span class="kn">import</span> <span class="nn">tvm</span> |
| <span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span> |
| <span class="kn">import</span> <span class="nn">tvm.testing</span> |
| |
| <span class="c1"># the module is called `autotvm`</span> |
| <span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">autotvm</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="basic-matrix-multiplication-with-te"> |
| <h2>Basic Matrix Multiplication with TE<a class="headerlink" href="#basic-matrix-multiplication-with-te" title="Permalink to this headline">¶</a></h2> |
| <p>Recall the basic implementation of matrix multiplication using TE. We write |
| it down here with a few changes. We will wrap the multiplication in a python |
| function definition. For simplicity, we will focus our attention on a split |
| optimization, using a fixed value that defines the block size of the |
| reordering.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul_basic</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span> |
| |
| <span class="n">A</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> |
| <span class="n">B</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> |
| |
| <span class="n">k</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"k"</span><span class="p">)</span> |
| <span class="n">C</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">),</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <a href="../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">j</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="n">k</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"C"</span><span class="p">)</span> |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span> |
| |
| <span class="c1"># schedule</span> |
| <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">k</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| |
| <span class="n">yo</span><span class="p">,</span> <span class="n">yi</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span> |
| <span class="n">xo</span><span class="p">,</span> <span class="n">xi</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span> |
| |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">yo</span><span class="p">,</span> <span class="n">xo</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">xi</span><span class="p">)</span> |
| |
| <span class="k">return</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="matrix-multiplication-with-autotvm"> |
| <h2>Matrix Multiplication with AutoTVM<a class="headerlink" href="#matrix-multiplication-with-autotvm" title="Permalink to this headline">¶</a></h2> |
| <p>In the previous schedule code, we use a constant “8” as the tiling factor. |
| However, it might not be the best one because the best tiling factor depends |
| on real hardware environment and input shape.</p> |
| <p>If you want the schedule code to be portable across a wider range of input |
| shapes and target hardware, it is better to define a set of candidate values |
| and pick the best one according to the measurement results on target |
| hardware.</p> |
| <p>In autotvm, we can define a tunable parameter, or a “knob” for such kind of |
| value.</p> |
| </div> |
| <div class="section" id="a-basic-matrix-multiplication-template"> |
| <h2>A Basic Matrix Multiplication Template<a class="headerlink" href="#a-basic-matrix-multiplication-template" title="Permalink to this headline">¶</a></h2> |
| <p>We begin with an example of how to create a tunable parameter set for the |
| block size of the <cite>split</cite> scheduling operation.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Matmul V1: List candidate values</span> |
| <span class="nd">@autotvm</span><span class="o">.</span><span class="n">template</span><span class="p">(</span><span class="s2">"tutorial/matmul_v1"</span><span class="p">)</span> <span class="c1"># 1. use a decorator</span> |
| <span class="k">def</span> <span class="nf">matmul_v1</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span> |
| <span class="n">A</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> |
| <span class="n">B</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> |
| |
| <span class="n">k</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"k"</span><span class="p">)</span> |
| <span class="n">C</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">),</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <a href="../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">j</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="n">k</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"C"</span><span class="p">)</span> |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span> |
| |
| <span class="c1"># schedule</span> |
| <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">k</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| |
| <span class="c1"># 2. get the config object</span> |
| <span class="n">cfg</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> |
| |
| <span class="c1"># 3. define search space</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_knob</span><span class="p">(</span><span class="s2">"tile_y"</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">])</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_knob</span><span class="p">(</span><span class="s2">"tile_x"</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">])</span> |
| |
| <span class="c1"># 4. schedule according to config</span> |
| <span class="n">yo</span><span class="p">,</span> <span class="n">yi</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_y"</span><span class="p">]</span><span class="o">.</span><span class="n">val</span><span class="p">)</span> |
| <span class="n">xo</span><span class="p">,</span> <span class="n">xi</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_x"</span><span class="p">]</span><span class="o">.</span><span class="n">val</span><span class="p">)</span> |
| |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">yo</span><span class="p">,</span> <span class="n">xo</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">xi</span><span class="p">)</span> |
| |
| <span class="k">return</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| <p>Here we make four modifications to the previous schedule code and get a |
| tunable “template”. We can explain the modifications one by one.</p> |
| <ol class="arabic"> |
| <li><p>Use a decorator to mark this function as a simple template.</p></li> |
| <li><p>Get a config object: You can regard this <code class="code docutils literal notranslate"><span class="pre">cfg</span></code> as an argument of |
| this function but we obtain it in a different way. With this argument, this |
| function is no longer a deterministic schedule. Instead, we can pass |
| different configurations to this function and get different schedules. A |
| function that uses a configuration object like this is called a “template”.</p> |
| <p>To make the template function more compact, we can do two things to define |
| the parameter search space within a single function.</p> |
| <ol class="arabic simple"> |
| <li><p>Define a search space across a set values. This is done by making |
| <code class="code docutils literal notranslate"><span class="pre">cfg</span></code> a <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigSpace" title="tvm.autotvm.task.space.ConfigSpace"><code class="xref any py py-class docutils literal notranslate"><span class="pre">ConfigSpace</span></code></a> object. It will collect all of the |
| tunable knobs in this function and build a search space from it.</p></li> |
| <li><p>Schedule according to an entity in this space. This is done by making |
| <code class="code docutils literal notranslate"><span class="pre">cfg</span></code> a <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigEntity" title="tvm.autotvm.task.space.ConfigEntity"><code class="xref any py py-class docutils literal notranslate"><span class="pre">ConfigEntity</span></code></a> object. When it is a |
| <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigEntity" title="tvm.autotvm.task.space.ConfigEntity"><code class="xref any py py-class docutils literal notranslate"><span class="pre">ConfigEntity</span></code></a>, it will ignore all space definition API (namely, |
| <code class="code docutils literal notranslate"><span class="pre">cfg.define_XXXXX(...)</span></code>). Instead, it will store deterministic |
| values for all tunable knobs, and we schedule according to these values.</p></li> |
| </ol> |
| <p>During auto-tuning, we will first call this template with a |
| <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigSpace" title="tvm.autotvm.task.space.ConfigSpace"><code class="xref any py py-class docutils literal notranslate"><span class="pre">ConfigSpace</span></code></a> object to build the search space. Then we call this |
| template with different <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigEntity" title="tvm.autotvm.task.space.ConfigEntity"><code class="xref any py py-class docutils literal notranslate"><span class="pre">ConfigEntity</span></code></a> in the built space to get |
| different schedules. Finally we will measure the code generated by |
| different schedules and pick the best one.</p> |
| </li> |
| <li><p>Define two tunable knobs. The first one is <code class="code docutils literal notranslate"><span class="pre">tile_y</span></code> with 5 possible |
| values. The second one is <code class="code docutils literal notranslate"><span class="pre">tile_x</span></code> with a same list of possible values. |
| These two knobs are independent, so they span a search space with size 25 = |
| 5x5.</p></li> |
| <li><p>The configuration knobs are passed to the <code class="code docutils literal notranslate"><span class="pre">split</span></code> schedule |
| operation, allowing us to schedule according to the 5x5 deterministic values |
| we previously defined in <code class="code docutils literal notranslate"><span class="pre">cfg</span></code>.</p></li> |
| </ol> |
| </div> |
| <div class="section" id="a-matrix-multiplication-template-with-the-advanced-parameter-api"> |
| <h2>A Matrix Multiplication Template with the Advanced Parameter API<a class="headerlink" href="#a-matrix-multiplication-template-with-the-advanced-parameter-api" title="Permalink to this headline">¶</a></h2> |
| <p>In the previous template, we manually listed all of the possible values for a |
| knob. This is the lowest level API to define the space, and gives an explicit |
| enumeration of the parameter space to search. However, we also provide |
| another set of APIs that can make the definition of the search space easier |
| and smarter. Where possible, we recommend you use this higher-level API</p> |
| <p>In the following example, we use <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigSpace.define_split" title="tvm.autotvm.task.space.ConfigSpace.define_split"><code class="xref any py py-meth docutils literal notranslate"><span class="pre">ConfigSpace.define_split</span></code></a> to define a |
| split knob. It will enumerate all the possible ways to split an axis and |
| construct the space.</p> |
| <p>We also have <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigSpace.define_reorder" title="tvm.autotvm.task.space.ConfigSpace.define_reorder"><code class="xref any py py-meth docutils literal notranslate"><span class="pre">ConfigSpace.define_reorder</span></code></a> for reorder knob and |
| <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.task.space.ConfigSpace.define_annotate" title="tvm.autotvm.task.space.ConfigSpace.define_annotate"><code class="xref any py py-meth docutils literal notranslate"><span class="pre">ConfigSpace.define_annotate</span></code></a> for annotation like unroll, vectorization, |
| thread binding. When the high level API cannot meet your requirements, you |
| can always fall back to using the low level API.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@autotvm</span><span class="o">.</span><span class="n">template</span><span class="p">(</span><span class="s2">"tutorial/matmul"</span><span class="p">)</span> |
| <span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span> |
| <span class="n">A</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> |
| <span class="n">B</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> |
| |
| <span class="n">k</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.reduce_axis" title="tvm.te.reduce_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span></a><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"k"</span><span class="p">)</span> |
| <span class="n">C</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">),</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <a href="../reference/api/python/te.html#tvm.te.sum" title="tvm.te.sum" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">sum</span></a><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">j</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="n">k</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"C"</span><span class="p">)</span> |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span> |
| |
| <span class="c1"># schedule</span> |
| <span class="n">y</span><span class="p">,</span> <span class="n">x</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span> |
| <span class="n">k</span> <span class="o">=</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> |
| |
| <span class="c1">##### define space begin #####</span> |
| <span class="n">cfg</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_y"</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> |
| <span class="n">cfg</span><span class="o">.</span><span class="n">define_split</span><span class="p">(</span><span class="s2">"tile_x"</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">num_outputs</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> |
| <span class="c1">##### define space end #####</span> |
| |
| <span class="c1"># schedule according to config</span> |
| <span class="n">yo</span><span class="p">,</span> <span class="n">yi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_y"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> |
| <span class="n">xo</span><span class="p">,</span> <span class="n">xi</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"tile_x"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> |
| |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">yo</span><span class="p">,</span> <span class="n">xo</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">xi</span><span class="p">)</span> |
| |
| <span class="k">return</span> <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">]</span> |
| </pre></div> |
| </div> |
| <div class="admonition-more-explanation-on-code-cfg-define-split admonition"> |
| <p class="admonition-title">More Explanation on <code class="code docutils literal notranslate"><span class="pre">cfg.define_split</span></code></p> |
| <p>In this template, <code class="code docutils literal notranslate"><span class="pre">cfg.define_split("tile_y",</span> <span class="pre">y,</span> <span class="pre">num_outputs=2)</span></code> will |
| enumerate all possible combinations that can split axis y into two axes with |
| factors of the length of y. For example, if the length of y is 32 and we |
| want to split it into two axes using factors of 32, then there are 6 |
| possible values for (length of outer axis, length of inner axis) pair, |
| namely (32, 1), (16, 2), (8, 4), (4, 8), (2, 16) or (1, 32). These are all 6 |
| possible values of <cite>tile_y</cite>.</p> |
| <p>During scheduling, <code class="code docutils literal notranslate"><span class="pre">cfg["tile_y"]</span></code> is a <code class="code docutils literal notranslate"><span class="pre">SplitEntity</span></code> object. |
| We stores the lengths of outer axes and inner axes in |
| <code class="code docutils literal notranslate"><span class="pre">cfg['tile_y'].size</span></code> (a tuple with two elements). In this template, |
| we apply it by using <code class="code docutils literal notranslate"><span class="pre">yo,</span> <span class="pre">yi</span> <span class="pre">=</span> <span class="pre">cfg['tile_y'].apply(s,</span> <span class="pre">C,</span> <span class="pre">y)</span></code>. |
| Actually, this is equivalent to <code class="code docutils literal notranslate"><span class="pre">yo,</span> <span class="pre">yi</span> <span class="pre">=</span> <span class="pre">s[C].split(y,</span> |
| <span class="pre">cfg["tile_y"].size[1])</span></code> or <code class="code docutils literal notranslate"><span class="pre">yo,</span> <span class="pre">yi</span> <span class="pre">=</span> <span class="pre">s[C].split(y,</span> |
| <span class="pre">nparts=cfg['tile_y"].size[0])</span></code></p> |
| <p>The advantage of using cfg.apply API is that it makes multi-level splits |
| (that is, when num_outputs >= 3) easier.</p> |
| </div> |
| </div> |
| <div class="section" id="step-2-use-autotvm-to-optimize-the-matrix-multiplication"> |
| <h2>Step 2: Use AutoTVM to Optimize the Matrix Multiplication<a class="headerlink" href="#step-2-use-autotvm-to-optimize-the-matrix-multiplication" title="Permalink to this headline">¶</a></h2> |
| <p>In Step 1, we wrote a matrix multiplication template that allowed us to |
| parameterize the block size used in the <cite>split</cite> schedule. We can now conduct |
| a search over this parameter space. The next step is to pick a tuner to guide |
| the exploration of this space.</p> |
| <div class="section" id="auto-tuners-in-tvm"> |
| <h3>Auto-tuners in TVM<a class="headerlink" href="#auto-tuners-in-tvm" title="Permalink to this headline">¶</a></h3> |
| <p>The job for a tuner can be described by following pseudo code</p> |
| <blockquote> |
| <div><div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">ct</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span> |
| <span class="k">while</span><span class="w"> </span><span class="n">ct</span><span class="w"> </span><span class="o"><</span><span class="w"> </span><span class="n">max_number_of_trials</span><span class="o">:</span> |
| <span class="w"> </span><span class="n">propose</span><span class="w"> </span><span class="n">a</span><span class="w"> </span><span class="n">batch</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">configs</span> |
| <span class="w"> </span><span class="n">measure</span><span class="w"> </span><span class="n">this</span><span class="w"> </span><span class="n">batch</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">configs</span><span class="w"> </span><span class="n">on</span><span class="w"> </span><span class="n">real</span><span class="w"> </span><span class="n">hardware</span><span class="w"> </span><span class="n">and</span><span class="w"> </span><span class="n">get</span><span class="w"> </span><span class="n">results</span> |
| <span class="w"> </span><span class="n">ct</span><span class="w"> </span><span class="o">+=</span><span class="w"> </span><span class="n">batch_size</span> |
| </pre></div> |
| </div> |
| </div></blockquote> |
| <p>When proposing the next batch of configs, the tuner can take different |
| strategies. Some of the tuner strategies provided by TVM include:</p> |
| <ul class="simple"> |
| <li><p><a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.RandomTuner" title="tvm.autotvm.tuner.RandomTuner"><code class="xref any py py-class docutils literal notranslate"><span class="pre">tvm.autotvm.tuner.RandomTuner</span></code></a>: Enumerate the space in a random order</p></li> |
| <li><p><a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.GridSearchTuner" title="tvm.autotvm.tuner.GridSearchTuner"><code class="xref any py py-class docutils literal notranslate"><span class="pre">tvm.autotvm.tuner.GridSearchTuner</span></code></a>: Enumerate the space in a grid search order</p></li> |
| <li><p><a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.GATuner" title="tvm.autotvm.tuner.GATuner"><code class="xref any py py-class docutils literal notranslate"><span class="pre">tvm.autotvm.tuner.GATuner</span></code></a>: Using genetic algorithm to search through the space</p></li> |
| <li><p><a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner"><code class="xref any py py-class docutils literal notranslate"><span class="pre">tvm.autotvm.tuner.XGBTuner</span></code></a>: Uses a model based method. Train a XGBoost model to |
| predict the speed of lowered IR and pick the next batch according to the |
| prediction.</p></li> |
| </ul> |
| <p>You can choose the tuner according to the size of your space, your time |
| budget and other factors. For example, if your space is very small (less |
| than 1000), a grid-search tuner or a random tuner is good enough. If your |
| space is at the level of 10^9 (this is the space size of a conv2d operator on |
| CUDA GPU), XGBoostTuner can explore more efficiently and find better configs.</p> |
| </div> |
| <div class="section" id="begin-tuning"> |
| <h3>Begin tuning<a class="headerlink" href="#begin-tuning" title="Permalink to this headline">¶</a></h3> |
| <p>Here we continue our matrix multiplication example. First we create a tuning |
| task. We can also inspect the initialized search space. In this case, for a |
| 512x512 square matrix multiplication, the space size is 10x10=100 Note that |
| the task and search space are independent of the tuner picked.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a> <span class="o">=</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span> |
| <span class="n">task</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">task</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="s2">"tutorial/matmul"</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">,</span> <span class="s2">"float32"</span><span class="p">),</span> <span class="n">target</span><span class="o">=</span><span class="s2">"llvm"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="n">task</span><span class="o">.</span><span class="n">config_space</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>ConfigSpace (len=100, range_length=100, space_map= |
| 0 tile_y: Split(policy=factors, product=512, num_outputs=2) len=10 |
| 1 tile_x: Split(policy=factors, product=512, num_outputs=2) len=10 |
| ) |
| </pre></div> |
| </div> |
| <p>Then we need to define how to measure the generated code and pick a tuner. |
| Since our space is small, a random tuner is just okay.</p> |
| <p>We only make 10 trials in this tutorial for demonstration. In practice, you |
| can do more trials according to your time budget. We will log the tuning |
| results into a log file. This file can be used to choose the best |
| configuration discovered by the tuner later.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># logging config (for printing tuning log to the screen)</span> |
| <a href="https://docs.python.org/3/library/logging.html#logging.getLogger" title="logging.getLogger" class="sphx-glr-backref-module-logging sphx-glr-backref-type-py-function"><span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span></a><span class="p">(</span><span class="s2">"autotvm"</span><span class="p">)</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span></a><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/logging.html#logging.getLogger" title="logging.getLogger" class="sphx-glr-backref-module-logging sphx-glr-backref-type-py-function"><span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span></a><span class="p">(</span><span class="s2">"autotvm"</span><span class="p">)</span><span class="o">.</span><span class="n">addHandler</span><span class="p">(</span><a href="https://docs.python.org/3/library/logging.handlers.html#logging.StreamHandler" title="logging.StreamHandler" class="sphx-glr-backref-module-logging sphx-glr-backref-type-py-class"><span class="n">logging</span><span class="o">.</span><span class="n">StreamHandler</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/io.html#io.TextIOWrapper" title="io.TextIOWrapper" class="sphx-glr-backref-module-io sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">sys</span><span class="o">.</span><span class="n">stdout</span></a><span class="p">))</span> |
| </pre></div> |
| </div> |
| <p>There are two steps for measuring a config: build and run. By default, we use |
| all CPU cores to compile program. We then measure them sequentially. To help |
| reduce variance, we take 5 measurements and average them.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">measure_option</span></a> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">measure_option</span></a><span class="p">(</span><span class="n">builder</span><span class="o">=</span><span class="s2">"local"</span><span class="p">,</span> <span class="n">runner</span><span class="o">=</span><span class="n">autotvm</span><span class="o">.</span><span class="n">LocalRunner</span><span class="p">(</span><span class="n">number</span><span class="o">=</span><span class="mi">5</span><span class="p">))</span> |
| |
| <span class="c1"># Begin tuning with RandomTuner, log records to file `matmul.log`</span> |
| <span class="c1"># You can use alternatives like XGBTuner.</span> |
| <a href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.RandomTuner" title="tvm.autotvm.tuner.RandomTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tuner</span></a> <span class="o">=</span> <a href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.RandomTuner" title="tvm.autotvm.tuner.RandomTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class"><span class="n">autotvm</span><span class="o">.</span><span class="n">tuner</span><span class="o">.</span><span class="n">RandomTuner</span></a><span class="p">(</span><span class="n">task</span><span class="p">)</span> |
| <a href="../reference/api/python/autotvm.html#tvm.autotvm.tuner.RandomTuner.tune" title="tvm.autotvm.tuner.RandomTuner.tune" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-method"><span class="n">tuner</span><span class="o">.</span><span class="n">tune</span></a><span class="p">(</span> |
| <span class="n">n_trial</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> |
| <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">measure_option</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">measure_option</span></a><span class="p">,</span> |
| <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">autotvm</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">log_to_file</span><span class="p">(</span><span class="s2">"matmul.log"</span><span class="p">)],</span> |
| <span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>waiting for device... |
| device available |
| Get devices for measurement successfully! |
| No: 1 GFLOPS: 8.65/8.65 result: MeasureResult(costs=(0.031042645599999995,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.7603747844696045, timestamp=1714833350.1770363) [('tile_y', [-1, 512]), ('tile_x', [-1, 32])],None,59 |
| No: 2 GFLOPS: 11.79/11.79 result: MeasureResult(costs=(0.0227671392,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.6369888782501221, timestamp=1714833350.8047419) [('tile_y', [-1, 256]), ('tile_x', [-1, 512])],None,98 |
| No: 3 GFLOPS: 14.94/14.94 result: MeasureResult(costs=(0.0179658476,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.6332783699035645, timestamp=1714833351.3520458) [('tile_y', [-1, 64]), ('tile_x', [-1, 128])],None,76 |
| No: 4 GFLOPS: 10.24/14.94 result: MeasureResult(costs=(0.0262016466,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.6723027229309082, timestamp=1714833352.0362012) [('tile_y', [-1, 512]), ('tile_x', [-1, 512])],None,99 |
| No: 5 GFLOPS: 12.24/14.94 result: MeasureResult(costs=(0.02193508,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.6962182521820068, timestamp=1714833352.8437154) [('tile_y', [-1, 64]), ('tile_x', [-1, 32])],None,56 |
| No: 6 GFLOPS: 7.47/14.94 result: MeasureResult(costs=(0.0359470202,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.8791561126708984, timestamp=1714833353.6731987) [('tile_y', [-1, 64]), ('tile_x', [-1, 16])],None,46 |
| No: 7 GFLOPS: 1.93/14.94 result: MeasureResult(costs=(0.1393351036,), error_no=MeasureErrorNo.NO_ERROR, all_cost=2.4879252910614014, timestamp=1714833356.1601758) [('tile_y', [-1, 4]), ('tile_x', [-1, 1])],None,2 |
| No: 8 GFLOPS: 10.39/14.94 result: MeasureResult(costs=(0.025823850599999996,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.6973447799682617, timestamp=1714833356.8340662) [('tile_y', [-1, 4]), ('tile_x', [-1, 16])],None,42 |
| No: 9 GFLOPS: 0.47/14.94 result: MeasureResult(costs=(0.5686715914,), error_no=MeasureErrorNo.NO_ERROR, all_cost=9.319947481155396, timestamp=1714833366.2875707) [('tile_y', [-1, 512]), ('tile_x', [-1, 1])],None,9 |
| No: 10 GFLOPS: 12.08/14.94 result: MeasureResult(costs=(0.0222167708,), error_no=MeasureErrorNo.NO_ERROR, all_cost=0.5997726917266846, timestamp=1714833366.8989227) [('tile_y', [-1, 32]), ('tile_x', [-1, 32])],None,55 |
| </pre></div> |
| </div> |
| <p>With tuning completed, we can choose the configuration from the log file that |
| has the best measured performance and compile the schedule with the |
| corresponding parameters. We also do a quick verification that the schedule is |
| producing correct answers. We can call the function <code class="code docutils literal notranslate"><span class="pre">matmul</span></code> directly |
| under the <a class="reference internal" href="../reference/api/python/autotvm.html#tvm.autotvm.apply_history_best" title="tvm.autotvm.apply_history_best"><code class="xref any py py-func docutils literal notranslate"><span class="pre">autotvm.apply_history_best</span></code></a> context. When we call this |
| function, it will query the dispatch context with its argument and get the |
| best config with the same argument.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># apply history best from log file</span> |
| <span class="k">with</span> <a href="../reference/api/python/autotvm.html#tvm.autotvm.apply_history_best" title="tvm.autotvm.apply_history_best" class="sphx-glr-backref-module-tvm-autotvm sphx-glr-backref-type-py-function"><span class="n">autotvm</span><span class="o">.</span><span class="n">apply_history_best</span></a><span class="p">(</span><span class="s2">"matmul.log"</span><span class="p">):</span> |
| <span class="k">with</span> <a href="../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class"><span class="n">tvm</span><span class="o">.</span><span class="n">target</span><span class="o">.</span><span class="n">Target</span></a><span class="p">(</span><span class="s2">"llvm"</span><span class="p">):</span> |
| <a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arg_bufs</span></a> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span> |
| <span class="n">func</span> <span class="o">=</span> <a href="../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arg_bufs</span></a><span class="p">)</span> |
| |
| <span class="c1"># check correctness</span> |
| <span class="n">a_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">N</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> |
| <span class="n">b_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">L</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">M</span></a><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> |
| <span class="n">c_np</span> <span class="o">=</span> <span class="n">a_np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">b_np</span><span class="p">)</span> |
| |
| <span class="n">c_tvm</span> <span class="o">=</span> <a href="../reference/api/python/ndarray.html#tvm.nd.empty" title="tvm.nd.empty" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">empty</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">c_np</span><span class="o">.</span><span class="n">shape</span></a><span class="p">)</span> |
| <span class="n">func</span><span class="p">(</span><a href="../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">a_np</span><span class="p">),</span> <a href="../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">b_np</span><span class="p">),</span> <span class="n">c_tvm</span><span class="p">)</span> |
| |
| <span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c_np</span><span class="p">,</span> <span class="n">c_tvm</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Finish loading 10 records |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="final-notes-and-summary"> |
| <h2>Final Notes and Summary<a class="headerlink" href="#final-notes-and-summary" title="Permalink to this headline">¶</a></h2> |
| <p>In this tutorial, we have shown how to build operator templates that allow |
| TVM to search a parameter space and choose optimized schedule configurations. |
| To gain a deeper understanding of how this works, we recommend expanding on |
| this example by adding new search parameters to the schedule based on |
| schedule operations demonstrated in the :ref: <cite>Getting Started With Tensor |
| Expressions <tensor_expr_get_started>_</cite> tutorial. In the upcoming sections, we |
| will demonstrate the AutoScheduler, a method for TVM to optimize common |
| operators without the need for the user to provide a user-defined template.</p> |
| <div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-tutorial-autotvm-matmul-x86-py"> |
| <div class="sphx-glr-download sphx-glr-download-python docutils container"> |
| <p><a class="reference download internal" download="" href="../_downloads/8e7bbc9dbdda76ac573b24606b41c006/autotvm_matmul_x86.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">autotvm_matmul_x86.py</span></code></a></p> |
| </div> |
| <div class="sphx-glr-download sphx-glr-download-jupyter docutils container"> |
| <p><a class="reference download internal" download="" href="../_downloads/37bbf9e2065ec8deeb64a8d9fa0755bc/autotvm_matmul_x86.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">autotvm_matmul_x86.ipynb</span></code></a></p> |
| </div> |
| </div> |
| <p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p> |
| </div> |
| </div> |
| |
| |
| </div> |
| |
| </div> |
| |
| |
| <footer> |
| |
| <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> |
| |
| <a href="auto_scheduler_matmul_x86.html" class="btn btn-neutral float-right" title="Optimizing Operators with Auto-scheduling" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a> |
| |
| |
| <a href="tensor_expr_get_started.html" class="btn btn-neutral float-left" title="Working with Operators Using Tensor Expression" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a> |
| |
| </div> |
| |
| <div id="button" class="backtop"><img src="../_static/img/right.svg" alt="backtop"/> </div> |
| <section class="footerSec"> |
| <div class="footerHeader"> |
| <div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row"> |
| <div class="copywrite d-flex align-items-center"> |
| <h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5> |
| </div> |
| </div> |
| |
| </div> |
| |
| <div> |
| <div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div> |
| </div> |
| |
| </section> |
| </footer> |
| </div> |
| </div> |
| |
| </section> |
| |
| </div> |
| |
| |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script> |
| <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script> |
| |
| </body> |
| <script type="text/javascript"> |
| jQuery(function () { |
| SphinxRtdTheme.Navigation.enable(true); |
| }); |
| </script> |
| |
| |
| |
| |
| <!-- Theme Analytics --> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-75982049-2', 'auto'); |
| ga('send', 'pageview'); |
| </script> |
| |
| |
| |
| |
| </body> |
| </html> |