| |
| |
| |
| |
| |
| |
| <!DOCTYPE html> |
| <html class="writer-html5" lang="en" > |
| <head> |
| <meta charset="utf-8"> |
| |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| |
| <title>Intrinsics and Math Functions — tvm 0.17.dev0 documentation</title> |
| |
| |
| |
| <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" /> |
| |
| |
| |
| <link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/> |
| |
| |
| |
| |
| |
| |
| |
| <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script> |
| <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> |
| <script src="../../_static/jquery.js"></script> |
| <script src="../../_static/underscore.js"></script> |
| <script src="../../_static/doctools.js"></script> |
| |
| <script type="text/javascript" src="../../_static/js/theme.js"></script> |
| |
| |
| <script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script> |
| <link rel="index" title="Index" href="../../genindex.html" /> |
| <link rel="search" title="Search" href="../../search.html" /> |
| <link rel="next" title="Scan and Recurrent Kernel" href="scan.html" /> |
| <link rel="prev" title="Reduction" href="reduction.html" /> |
| </head> |
| |
| <body class="wy-body-for-nav"> |
| |
| |
| <div class="wy-grid-for-nav"> |
| |
| |
| <header class="header"> |
| <div class="innercontainer"> |
| <div class="headerInner d-flex justify-content-between align-items-center"> |
| <div class="headerLogo"> |
| <a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a> |
| </div> |
| |
| <div id="headMenu" class="headerNav"> |
| <button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button> |
| <ul class="nav"> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/community>Community</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/download>Download</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/vta>VTA</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/blog>Blog</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/docs>Docs</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvmconf.org>Conference</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://github.com/apache/tvm/>Github</a> |
| </li> |
| </ul> |
| <div class="responsivetlcdropdown"> |
| <button type="button" class="btn-link"> |
| ASF |
| </button> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| <div class="responsiveMenuIcon"> |
| <button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button> |
| </div> |
| |
| <div class="tlcDropdown"> |
| <div class="dropdown"> |
| <button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false"> |
| ASF |
| </button> |
| <div class="dropdown-menu dropdown-menu-right"> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| </header> |
| |
| <nav data-toggle="wy-nav-shift" class="wy-nav-side fixed"> |
| <div class="wy-side-scroll"> |
| <div class="wy-side-nav-search" > |
| |
| |
| |
| <a href="../../index.html"> |
| |
| |
| |
| |
| <img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/> |
| |
| </a> |
| |
| |
| |
| |
| <input type="checkbox" class="version-toggle-box" hidden id="version-toggle"> |
| <label for="version-toggle" class="version-toggle-label"> |
| <div tabindex="0" class="version version-selector version-selector-show"> |
| 0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span> |
| </div> |
| </label> |
| <div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| <p class="caption" role="heading"><span class="caption-text">Versions</span></p> |
| <ol style="text-align: left"> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li> |
| |
| </ol> |
| </div> |
| |
| |
| |
| |
| <div role="search"> |
| <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get"> |
| <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" /> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| </form> |
| </div> |
| |
| |
| </div> |
| |
| |
| <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| |
| |
| |
| |
| |
| |
| <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">User Guide</span></p> |
| <ul class="current"> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li> |
| <li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current"> |
| <li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li> |
| <li class="toctree-l2 current"><a class="reference internal" href="index.html">Work With Tensor Expression and Schedules</a><ul class="current"> |
| <li class="toctree-l3"><a class="reference internal" href="schedule_primitives.html">Schedule Primitives in TVM</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="reduction.html">Reduction</a></li> |
| <li class="toctree-l3 current"><a class="current reference internal" href="#">Intrinsics and Math Functions</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#direct-declare-extern-math-call">Direct Declare Extern Math Call</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#unified-intrinsic-call">Unified Intrinsic Call</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#intrinsic-lowering-rule">Intrinsic Lowering Rule</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#add-your-own-intrinsic">Add Your Own Intrinsic</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#summary">Summary</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="scan.html">Scan and Recurrent Kernel</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="extern_op.html">External Tensor Functions</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="tensorize.html">Use Tensorize to Leverage Hardware Intrinsics</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="tuple_inputs.html">Compute and Reduce with Tuple Inputs</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="tedd.html">Use Tensor Expression Debug Display (TEDD) for Visualization</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l2"><a class="reference internal" href="../optimize_operators/index.html">Optimize Tensor Operators</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li> |
| </ul> |
| </li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li> |
| </ul> |
| |
| |
| |
| </div> |
| |
| </div> |
| </nav> |
| |
| <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> |
| |
| <nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top"> |
| |
| <div class="togglemenu"> |
| |
| </div> |
| <div class="nav-content"> |
| <!-- tvm --> |
| Table of Contents |
| </div> |
| |
| </nav> |
| |
| |
| <div class="wy-nav-content"> |
| |
| <div class="rst-content"> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <div role="navigation" aria-label="breadcrumbs navigation"> |
| |
| <ul class="wy-breadcrumbs"> |
| |
| <li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="index.html">Work With Tensor Expression and Schedules</a> <span class="br-arrow">></span></li> |
| |
| <li>Intrinsics and Math Functions</li> |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="wy-breadcrumbs-aside"> |
| |
| |
| |
| <a href="https://github.com/apache/tvm/edit/main/docs/how_to/work_with_schedules/intrin_math.rst" class="fa fa-github"> Edit on GitHub</a> |
| |
| |
| |
| </li> |
| |
| </ul> |
| |
| |
| <hr/> |
| </div> |
| <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> |
| <div itemprop="articleBody"> |
| |
| <div class="sphx-glr-download-link-note admonition note"> |
| <p class="admonition-title">Note</p> |
| <p>This tutorial can be used interactively with Google Colab! You can also click |
| <a class="reference internal" href="#sphx-glr-download-how-to-work-with-schedules-intrin-math-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p> |
| <a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/1e482ba1190961191e3a0bdbd0585faa/intrin_math.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a> |
| </div> |
| <div class="sphx-glr-example-title section" id="intrinsics-and-math-functions"> |
| <span id="sphx-glr-how-to-work-with-schedules-intrin-math-py"></span><h1>Intrinsics and Math Functions<a class="headerlink" href="#intrinsics-and-math-functions" title="Permalink to this headline">¶</a></h1> |
| <p><strong>Author</strong>: <a class="reference external" href="https://tqchen.github.io">Tianqi Chen</a></p> |
| <p>While TVM supports basic arithmetic operations. In many cases |
| usually we will need more complicated builtin functions. |
| For example <code class="code docutils literal notranslate"><span class="pre">exp</span></code> to take the exponential of the function.</p> |
| <p>These functions are target system dependent and may have different |
| names of different target platforms. In this tutorial, we will learn |
| how we can invoke these target specific functions, and how we can unify |
| the interface via TVM’s intrinsic API.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">absolute_import</span><span class="p">,</span> <span class="n">print_function</span> |
| |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> |
| |
| <span class="kn">import</span> <span class="nn">tvm</span> |
| <span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span> |
| <span class="kn">from</span> <span class="nn">tvm.ir</span> <span class="kn">import</span> <a href="../../reference/api/python/ir.html#tvm.ir.register_op_attr" title="tvm.ir.register_op_attr" class="sphx-glr-backref-module-tvm-ir sphx-glr-backref-type-py-function"><span class="n">register_op_attr</span></a><span class="p">,</span> <a href="../../reference/api/python/ir.html#tvm.ir.register_intrin_lowering" title="tvm.ir.register_intrin_lowering" class="sphx-glr-backref-module-tvm-ir sphx-glr-backref-type-py-function"><span class="n">register_intrin_lowering</span></a> |
| </pre></div> |
| </div> |
| <div class="section" id="direct-declare-extern-math-call"> |
| <h2>Direct Declare Extern Math Call<a class="headerlink" href="#direct-declare-extern-math-call" title="Permalink to this headline">¶</a></h2> |
| <p>The most straight-forward way to call target specific function is via |
| extern function call construct in tvm. |
| In the following example, we use <a class="reference internal" href="../../reference/api/python/tir.html#tvm.tir.call_pure_extern" title="tvm.tir.call_pure_extern"><code class="xref any py py-func docutils literal notranslate"><span class="pre">tvm.tir.call_pure_extern</span></code></a> to call |
| <code class="code docutils literal notranslate"><span class="pre">__expf</span></code> function, which is only available under CUDA.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="../../reference/api/python/tir.html#tvm.tir.Var" title="tvm.tir.Var" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.var" title="tvm.te.var" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">var</span></a><span class="p">(</span><span class="s2">"n"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.Var" title="tvm.tir.Var" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.shape" title="tvm.te.Tensor.shape" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">A</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <a href="../../reference/api/python/tir.html#tvm.tir.call_pure_extern" title="tvm.tir.call_pure_extern" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_pure_extern</span></a><span class="p">(</span><span class="s2">"float32"</span><span class="p">,</span> <span class="s2">"__expf"</span><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><span class="n">i</span><span class="p">]),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span></a><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a> <span class="o">=</span> <span class="mi">64</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.x"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.x"</span><span class="p">))</span> |
| <span class="n">f</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">"cuda"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"myexp"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/runtime.html#tvm.runtime.Module.imported_modules" title="tvm.runtime.Module.imported_modules" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-property"><span class="n">f</span><span class="o">.</span><span class="n">imported_modules</span></a><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">get_source</span><span class="p">())</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="unified-intrinsic-call"> |
| <h2>Unified Intrinsic Call<a class="headerlink" href="#unified-intrinsic-call" title="Permalink to this headline">¶</a></h2> |
| <p>The above code verifies that direct external call can be used to |
| call into device specific functions. |
| However, the above way only works for CUDA target with float type. |
| Ideally, we want to write same code for any device and any data type.</p> |
| <p>TVM intrinsic provides the user a mechanism to achieve this, and this |
| is the recommended way to solve the problem. |
| The following code use te.exp instead, which create an intrinsic call |
| :py:<a class="reference internal" href="../../reference/api/python/te.html#tvm.te.exp" title="tvm.te.exp"><code class="xref py py-func docutils literal notranslate"><span class="pre">tvm.te.exp()</span></code></a> to do the exponential.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="../../reference/api/python/tir.html#tvm.tir.Var" title="tvm.tir.Var" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.var" title="tvm.te.var" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">var</span></a><span class="p">(</span><span class="s2">"n"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.Var" title="tvm.tir.Var" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.shape" title="tvm.te.Tensor.shape" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">A</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <a href="../../reference/api/python/te.html#tvm.te.exp" title="tvm.te.exp" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">exp</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><span class="n">i</span><span class="p">]),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span></a><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a> <span class="o">=</span> <span class="mi">64</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.x"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.x"</span><span class="p">))</span> |
| <span class="n">fcuda</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">"cuda"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"myexp"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/runtime.html#tvm.runtime.Module.imported_modules" title="tvm.runtime.Module.imported_modules" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-property"><span class="n">fcuda</span><span class="o">.</span><span class="n">imported_modules</span></a><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">get_source</span><span class="p">())</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| </pre></div> |
| </div> |
| <p>We can find that the code works for both CUDA and opencl. |
| The same te.exp can also be used for float64 data types.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">fopencl</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">"opencl"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"myexp"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/runtime.html#tvm.runtime.Module.imported_modules" title="tvm.runtime.Module.imported_modules" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-property"><span class="n">fopencl</span><span class="o">.</span><span class="n">imported_modules</span></a><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">get_source</span><span class="p">())</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>// Function: myexp_kernel |
| __kernel void myexp_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1); |
| __kernel void myexp_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1) { |
| if ((convert_int(get_group_id(0))) < (n >> 6)) { |
| B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = exp(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]); |
| } else { |
| if ((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) < n) { |
| B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = exp(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]); |
| } |
| } |
| } |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="intrinsic-lowering-rule"> |
| <h2>Intrinsic Lowering Rule<a class="headerlink" href="#intrinsic-lowering-rule" title="Permalink to this headline">¶</a></h2> |
| <p>When <a class="reference internal" href="../../reference/api/python/te.html#tvm.te.exp" title="tvm.te.exp"><code class="xref py py-func docutils literal notranslate"><span class="pre">tvm.te.exp()</span></code></a> is called, TVM creates an intrinsic Call Expr. |
| TVM uses transformation rules to transform the intrinsic |
| call to device specific extern calls.</p> |
| <p>TVM also allows user to customize the rules during runtime. |
| The following example customizes CUDA lowering rule for <code class="code docutils literal notranslate"><span class="pre">exp</span></code>.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">my_cuda_math_rule</span><span class="p">(</span><span class="n">op</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""Customized CUDA intrinsic lowering rule"""</span> |
| <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.Call" title="tvm.tir.Call" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">Call</span></a><span class="p">)</span> |
| <span class="n">name</span> <span class="o">=</span> <span class="n">op</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">name</span> |
| <span class="k">assert</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"tir."</span><span class="p">)</span> |
| <span class="n">dispatch_name</span> <span class="o">=</span> <span class="n">name</span><span class="p">[</span><span class="mi">4</span><span class="p">:]</span> |
| <span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s2">"float32"</span><span class="p">:</span> |
| <span class="c1"># call float function</span> |
| <span class="k">return</span> <a href="../../reference/api/python/tir.html#tvm.tir.call_pure_extern" title="tvm.tir.call_pure_extern" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_pure_extern</span></a><span class="p">(</span><span class="s2">"float32"</span><span class="p">,</span> <span class="s2">"</span><span class="si">%s</span><span class="s2">f"</span> <span class="o">%</span> <span class="n">dispatch_name</span><span class="p">,</span> <span class="n">op</span><span class="o">.</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> |
| <span class="k">elif</span> <span class="n">op</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s2">"float64"</span><span class="p">:</span> |
| <span class="c1"># call double function</span> |
| <span class="k">return</span> <a href="../../reference/api/python/tir.html#tvm.tir.call_pure_extern" title="tvm.tir.call_pure_extern" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_pure_extern</span></a><span class="p">(</span><span class="s2">"float32"</span><span class="p">,</span> <span class="n">dispatch_name</span><span class="p">,</span> <span class="n">op</span><span class="o">.</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="c1"># cannot do translation, return self.</span> |
| <span class="k">return</span> <span class="n">op</span> |
| |
| |
| <a href="../../reference/api/python/ir.html#tvm.ir.register_intrin_lowering" title="tvm.ir.register_intrin_lowering" class="sphx-glr-backref-module-tvm-ir sphx-glr-backref-type-py-function"><span class="n">register_intrin_lowering</span></a><span class="p">(</span><span class="s2">"tir.exp"</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">f</span><span class="o">=</span><span class="n">my_cuda_math_rule</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="mi">99</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span><function my_cuda_math_rule at 0x7f074862b3a0> |
| </pre></div> |
| </div> |
| <p>Register the rule to TVM with override option to override existing rule. |
| Notice the difference between the printed code from previous one: |
| our new rule uses math function <code class="code docutils literal notranslate"><span class="pre">expf</span></code> instead of |
| fast math version <code class="code docutils literal notranslate"><span class="pre">__expf</span></code>.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">fcuda</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">"cuda"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"myexp"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/runtime.html#tvm.runtime.Module.imported_modules" title="tvm.runtime.Module.imported_modules" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-property"><span class="n">fcuda</span><span class="o">.</span><span class="n">imported_modules</span></a><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">get_source</span><span class="p">())</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) myexp_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="add-your-own-intrinsic"> |
| <h2>Add Your Own Intrinsic<a class="headerlink" href="#add-your-own-intrinsic" title="Permalink to this headline">¶</a></h2> |
| <p>If there is an intrinsic that is not provided by TVM. |
| User can easily add new intrinsic by using the intrinsic rule system. |
| The following example add an intrinsic <code class="code docutils literal notranslate"><span class="pre">mylog</span></code> to the system.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">mylog</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""customized log intrinsic function"""</span> |
| <span class="k">return</span> <a href="../../reference/api/python/tir.html#tvm.tir.call_intrin" title="tvm.tir.call_intrin" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_intrin</span></a><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="s2">"tir.mylog"</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> |
| |
| |
| <span class="k">def</span> <span class="nf">my_cuda_mylog_rule</span><span class="p">(</span><span class="n">op</span><span class="p">):</span> |
| <span class="w"> </span><span class="sd">"""CUDA lowering rule for log"""</span> |
| <span class="k">if</span> <span class="n">op</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s2">"float32"</span><span class="p">:</span> |
| <span class="k">return</span> <a href="../../reference/api/python/tir.html#tvm.tir.call_pure_extern" title="tvm.tir.call_pure_extern" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_pure_extern</span></a><span class="p">(</span><span class="s2">"float32"</span><span class="p">,</span> <span class="s2">"logf"</span><span class="p">,</span> <span class="n">op</span><span class="o">.</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> |
| <span class="k">elif</span> <span class="n">op</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s2">"float64"</span><span class="p">:</span> |
| <span class="k">return</span> <a href="../../reference/api/python/tir.html#tvm.tir.call_pure_extern" title="tvm.tir.call_pure_extern" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">call_pure_extern</span></a><span class="p">(</span><span class="s2">"float64"</span><span class="p">,</span> <span class="s2">"log"</span><span class="p">,</span> <span class="n">op</span><span class="o">.</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">op</span> |
| |
| |
| <span class="c1"># new op registration is triggered by registering an attribute of the op</span> |
| <a href="../../reference/api/python/ir.html#tvm.ir.register_op_attr" title="tvm.ir.register_op_attr" class="sphx-glr-backref-module-tvm-ir sphx-glr-backref-type-py-function"><span class="n">register_op_attr</span></a><span class="p">(</span><span class="s2">"tir.mylog"</span><span class="p">,</span> <span class="s2">"TCallEffectKind"</span><span class="p">,</span> <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">CallEffectKind</span><span class="o">.</span><span class="n">Pure</span><span class="p">)</span> |
| <a href="../../reference/api/python/ir.html#tvm.ir.register_intrin_lowering" title="tvm.ir.register_intrin_lowering" class="sphx-glr-backref-module-tvm-ir sphx-glr-backref-type-py-function"><span class="n">register_intrin_lowering</span></a><span class="p">(</span><span class="s2">"tir.mylog"</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">f</span><span class="o">=</span><span class="n">my_cuda_mylog_rule</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="mi">99</span><span class="p">)</span> |
| |
| <a href="../../reference/api/python/tir.html#tvm.tir.Var" title="tvm.tir.Var" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.var" title="tvm.te.var" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">var</span></a><span class="p">(</span><span class="s2">"n"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.placeholder" title="tvm.te.placeholder" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">placeholder</span></a><span class="p">((</span><a href="../../reference/api/python/tir.html#tvm.tir.Var" title="tvm.tir.Var" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">n</span></a><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"A"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.compute" title="tvm.te.compute" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">compute</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.shape" title="tvm.te.Tensor.shape" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">A</span><span class="o">.</span><span class="n">shape</span></a><span class="p">,</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <span class="n">mylog</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">[</span><span class="n">i</span><span class="p">]),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"B"</span><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.create_schedule" title="tvm.te.create_schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span></a><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a> <span class="o">=</span> <span class="mi">64</span> |
| <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a> <span class="o">=</span> <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Tensor.op" title="tvm.te.Tensor.op" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-property"><span class="n">B</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span></a><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">factor</span><span class="o">=</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_thread</span></a><span class="p">)</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">bx</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"blockIdx.x"</span><span class="p">))</span> |
| <a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">]</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><a href="../../reference/api/python/tir.html#tvm.tir.IterVar" title="tvm.tir.IterVar" class="sphx-glr-backref-module-tvm-tir sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tx</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.thread_axis" title="tvm.te.thread_axis" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-function"><span class="n">te</span><span class="o">.</span><span class="n">thread_axis</span></a><span class="p">(</span><span class="s2">"threadIdx.x"</span><span class="p">))</span> |
| <span class="n">fcuda</span> <span class="o">=</span> <a href="../../reference/api/python/driver.html#tvm.build" title="tvm.build" class="sphx-glr-backref-module-tvm sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">build</span></a><span class="p">(</span><a href="../../reference/api/python/te.html#tvm.te.Schedule" title="tvm.te.Schedule" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">s</span></a><span class="p">,</span> <span class="p">[</span><a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">A</span></a><span class="p">,</span> <a href="../../reference/api/python/te.html#tvm.te.Tensor" title="tvm.te.Tensor" class="sphx-glr-backref-module-tvm-te sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">B</span></a><span class="p">],</span> <span class="s2">"cuda"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"mylog"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/runtime.html#tvm.runtime.Module.imported_modules" title="tvm.runtime.Module.imported_modules" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-property"><span class="n">fcuda</span><span class="o">.</span><span class="n">imported_modules</span></a><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">get_source</span><span class="p">())</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ |
| (__CUDACC_VER_MAJOR__ > 11)) |
| #define TVM_ENABLE_L2_PREFETCH 1 |
| #else |
| #define TVM_ENABLE_L2_PREFETCH 0 |
| #endif |
| |
| #ifdef _WIN32 |
| using uint = unsigned int; |
| using uchar = unsigned char; |
| using ushort = unsigned short; |
| using int64_t = long long; |
| using uint64_t = unsigned long long; |
| #else |
| #define uint unsigned int |
| #define uchar unsigned char |
| #define ushort unsigned short |
| #define int64_t long long |
| #define uint64_t unsigned long long |
| #endif |
| extern "C" __global__ void __launch_bounds__(64) mylog_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1); |
| extern "C" __global__ void __launch_bounds__(64) mylog_kernel(float* __restrict__ A, float* __restrict__ B, int n, int stride, int stride_1) { |
| if (((int)blockIdx.x) < (n >> 6)) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = logf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } else { |
| if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) { |
| B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = logf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]); |
| } |
| } |
| } |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="summary"> |
| <h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline">¶</a></h2> |
| <ul class="simple"> |
| <li><p>TVM can call extern target dependent math function.</p></li> |
| <li><p>Use intrinsic to defined a unified interface for the functions.</p></li> |
| <li><p>For more intrinsics available in tvm, take a look at <a class="reference internal" href="../../reference/api/python/tir.html#module-tvm.tir" title="tvm.tir"><code class="xref any py py-mod docutils literal notranslate"><span class="pre">tvm.tir</span></code></a></p></li> |
| <li><p>You can customize the intrinsic behavior by defining your own rules.</p></li> |
| </ul> |
| <div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-work-with-schedules-intrin-math-py"> |
| <div class="sphx-glr-download sphx-glr-download-python docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/d9089082842c138d4c81335f88c60c82/intrin_math.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">intrin_math.py</span></code></a></p> |
| </div> |
| <div class="sphx-glr-download sphx-glr-download-jupyter docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/1e482ba1190961191e3a0bdbd0585faa/intrin_math.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">intrin_math.ipynb</span></code></a></p> |
| </div> |
| </div> |
| <p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p> |
| </div> |
| </div> |
| |
| |
| </div> |
| |
| </div> |
| |
| |
| <footer> |
| |
| <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> |
| |
| <a href="scan.html" class="btn btn-neutral float-right" title="Scan and Recurrent Kernel" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a> |
| |
| |
| <a href="reduction.html" class="btn btn-neutral float-left" title="Reduction" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a> |
| |
| </div> |
| |
| <div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div> |
| <section class="footerSec"> |
| <div class="footerHeader"> |
| <div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row"> |
| <div class="copywrite d-flex align-items-center"> |
| <h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5> |
| </div> |
| </div> |
| |
| </div> |
| |
| <div> |
| <div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div> |
| </div> |
| |
| </section> |
| </footer> |
| </div> |
| </div> |
| |
| </section> |
| |
| </div> |
| |
| |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script> |
| <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script> |
| |
| </body> |
| <script type="text/javascript"> |
| jQuery(function () { |
| SphinxRtdTheme.Navigation.enable(true); |
| }); |
| </script> |
| |
| |
| |
| |
| <!-- Theme Analytics --> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-75982049-2', 'auto'); |
| ga('send', 'pageview'); |
| </script> |
| |
| |
| |
| |
| </body> |
| </html> |