blob: f0f07f87f3fbab3bbc9e78fdd2e238b8c53c9ea5 [file] [log] [blame]
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Deploy the Pretrained Model on Adreno™ &mdash; tvm 0.17.dev0 documentation</title>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" />
<link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/>
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Deploy the Pretrained Model on Adreno™ with tvmc Interface" href="deploy_model_on_adreno_tvmc.html" />
<link rel="prev" title="Deploy Deep Learning Models" href="index.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<header class="header">
<div class="innercontainer">
<div class="headerInner d-flex justify-content-between align-items-center">
<div class="headerLogo">
<a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a>
</div>
<div id="headMenu" class="headerNav">
<button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button>
<ul class="nav">
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/community>Community</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/download>Download</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/vta>VTA</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/blog>Blog</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/docs>Docs</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvmconf.org>Conference</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://github.com/apache/tvm/>Github</a>
</li>
</ul>
<div class="responsivetlcdropdown">
<button type="button" class="btn-link">
ASF
</button>
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
<div class="responsiveMenuIcon">
<button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button>
</div>
<div class="tlcDropdown">
<div class="dropdown">
<button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">
ASF
</button>
<div class="dropdown-menu dropdown-menu-right">
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
</div>
</div>
</div>
</header>
<nav data-toggle="wy-nav-shift" class="wy-nav-side fixed">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html">
<img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/>
</a>
<input type="checkbox" class="version-toggle-box" hidden id="version-toggle">
<label for="version-toggle" class="version-toggle-label">
<div tabindex="0" class="version version-selector version-selector-show">
0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span>
</div>
</label>
<div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Versions</span></p>
<ol style="text-align: left">
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li>
</ol>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">User Guide</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li>
<li class="toctree-l2 current"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#build-the-tvm-runtime-library">Build the TVM runtime library</a></li>
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#cross-compile-the-tvm-runtime-for-other-architectures">Cross compile the TVM runtime for other architectures</a></li>
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#optimize-and-tune-models-for-target-devices">Optimize and tune models for target devices</a></li>
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#deploy-optimized-model-on-target-devices">Deploy optimized model on target devices</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../deploy/index.html#additional-deployment-how-tos">Additional Deployment How-Tos</a><ul class="current">
<li class="toctree-l4 current"><a class="reference internal" href="index.html">Deploy Deep Learning Models</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li>
<li class="toctree-l2"><a class="reference internal" href="../optimize_operators/index.html">Optimize Tensor Operators</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top">
<div class="togglemenu">
</div>
<div class="nav-content">
<!-- tvm -->
Table of Contents
</div>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li>
<li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li>
<li><a href="../deploy/index.html">Deploy Models and Integrate TVM</a> <span class="br-arrow">></span></li>
<li><a href="index.html">Deploy Deep Learning Models</a> <span class="br-arrow">></span></li>
<li>Deploy the Pretrained Model on Adreno™</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/apache/tvm/edit/main/docs/how_to/deploy_models/deploy_model_on_adreno.rst" class="fa fa-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>This tutorial can be used interactively with Google Colab! You can also click
<a class="reference internal" href="#sphx-glr-download-how-to-deploy-models-deploy-model-on-adreno-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p>
<a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/b9e7311d8c56eb6e6aca08f0be35ff03/deploy_model_on_adreno.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a>
</div>
<div class="sphx-glr-example-title section" id="deploy-the-pretrained-model-on-adreno">
<span id="tutorial-deploy-model-on-adreno"></span><span id="sphx-glr-how-to-deploy-models-deploy-model-on-adreno-py"></span><h1>Deploy the Pretrained Model on Adreno™<a class="headerlink" href="#deploy-the-pretrained-model-on-adreno" title="Permalink to this headline"></a></h1>
<p><strong>Author</strong>: Daniil Barinov, Siva Rama Krishna</p>
<p>This article is a step-by-step tutorial to deploy pretrained Pytorch ResNet-18 model on Adreno (on different precisions).</p>
<p>For us to begin with, PyTorch must be installed.
TorchVision is also required since we will be using it as our model zoo.</p>
<p>A quick solution is to install it via pip:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip<span class="w"> </span>install<span class="w"> </span>torch
pip<span class="w"> </span>install<span class="w"> </span>torchvision
</pre></div>
</div>
<p>Besides that, you should have TVM builded for Android.
See the following instructions on how to build it.</p>
<p><a class="reference external" href="https://tvm.apache.org/docs/how_to/deploy/adreno.html">Deploy to Adreno GPU</a></p>
<p>After the build section there should be two files in <em>build</em> directory «libtvm_runtime.so» and «tvm_rpc».
Let’s push them to the device and run TVM RPC Server.</p>
<div class="section" id="tvm-rpc-server">
<h2>TVM RPC Server<a class="headerlink" href="#tvm-rpc-server" title="Permalink to this headline"></a></h2>
<p>To get the hash of the device use:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>adb<span class="w"> </span>devices
</pre></div>
</div>
<p>Set the android device to use, if you have several devices connected to your computer.</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span><span class="w"> </span><span class="nv">ANDROID_SERIAL</span><span class="o">=</span>&lt;device-hash&gt;
</pre></div>
</div>
<p>Then to upload these two files to the device you should use:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>adb<span class="w"> </span>push<span class="w"> </span><span class="o">{</span>libtvm_runtime.so,tvm_rpc<span class="o">}</span><span class="w"> </span>/data/local/tmp
</pre></div>
</div>
<p>At this moment you will have «libtvm_runtime.so» and «tvm_rpc» on path /data/local/tmp on your device.
Sometimes cmake can’t find «libc++_shared.so». Use:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>find<span class="w"> </span><span class="si">${</span><span class="nv">ANDROID_NDK_HOME</span><span class="si">}</span><span class="w"> </span>-name<span class="w"> </span>libc++_shared.so
</pre></div>
</div>
<p>to find it and also push it with adb on the desired device:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>adb<span class="w"> </span>push<span class="w"> </span>libc++_shared.so<span class="w"> </span>/data/local/tmp
</pre></div>
</div>
<p>We are now ready to run the TVM RPC Server.
Launch rpc_tracker with following line in 1st console:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>python3<span class="w"> </span>-m<span class="w"> </span>tvm.exec.rpc_tracker<span class="w"> </span>--port<span class="w"> </span><span class="m">9190</span>
</pre></div>
</div>
<p>Then we need to run tvm_rpc server from under the desired device in 2nd console:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>adb<span class="w"> </span>reverse<span class="w"> </span>tcp:9190<span class="w"> </span>tcp:9190
adb<span class="w"> </span>forward<span class="w"> </span>tcp:5000<span class="w"> </span>tcp:5000
adb<span class="w"> </span>forward<span class="w"> </span>tcp:5002<span class="w"> </span>tcp:5001
adb<span class="w"> </span>forward<span class="w"> </span>tcp:5003<span class="w"> </span>tcp:5002
adb<span class="w"> </span>forward<span class="w"> </span>tcp:5004<span class="w"> </span>tcp:5003
adb<span class="w"> </span>shell<span class="w"> </span><span class="nv">LD_LIBRARY_PATH</span><span class="o">=</span>/data/local/tmp<span class="w"> </span>/data/local/tmp/tvm_rpc<span class="w"> </span>server<span class="w"> </span>--host<span class="o">=</span><span class="m">0</span>.0.0.0<span class="w"> </span>--port<span class="o">=</span><span class="m">5000</span><span class="w"> </span>--tracker<span class="o">=</span><span class="m">127</span>.0.0.1:9190<span class="w"> </span>--key<span class="o">=</span>android<span class="w"> </span>--port-end<span class="o">=</span><span class="m">5100</span>
</pre></div>
</div>
<p>Before proceeding to compile and infer model, specify TVM_TRACKER_HOST and TVM_TRACKER_PORT</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span><span class="w"> </span><span class="nv">TVM_TRACKER_HOST</span><span class="o">=</span><span class="m">0</span>.0.0.0
<span class="nb">export</span><span class="w"> </span><span class="nv">TVM_TRACKER_PORT</span><span class="o">=</span><span class="m">9190</span>
</pre></div>
</div>
<p>check that the tracker is running and the device is available</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>python<span class="w"> </span>-m<span class="w"> </span>tvm.exec.query_rpc_tracker<span class="w"> </span>--port<span class="w"> </span><span class="m">9190</span>
</pre></div>
</div>
<p>For example, if we have 1 Android device,
the output can be:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>Queue<span class="w"> </span>Status
----------------------------------
key<span class="w"> </span>total<span class="w"> </span>free<span class="w"> </span>pending
----------------------------------
android<span class="w"> </span><span class="m">1</span><span class="w"> </span><span class="m">1</span><span class="w"> </span><span class="m">0</span>
----------------------------------
</pre></div>
</div>
</div>
<div class="section" id="configuration">
<h2>Configuration<a class="headerlink" href="#configuration" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torchvision</span>
<span class="kn">import</span> <span class="nn">tvm</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">te</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">relay</span><span class="p">,</span> <span class="n">rpc</span>
<span class="kn">from</span> <span class="nn">tvm.contrib</span> <span class="kn">import</span> <span class="n">utils</span><span class="p">,</span> <span class="n">ndk</span>
<span class="kn">from</span> <span class="nn">tvm.contrib</span> <span class="kn">import</span> <span class="n">graph_executor</span>
<span class="kn">from</span> <span class="nn">tvm.relay.op.contrib</span> <span class="kn">import</span> <span class="n">clml</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">autotvm</span>
<span class="c1"># Below are set of configuration that controls the behaviour of this script like</span>
<span class="c1"># local run or device run, target definitions, dtype setting and auto tuning enablement.</span>
<span class="c1"># Change these settings as needed if required.</span>
<span class="c1"># Adreno devices are efficient with float16 compared to float32</span>
<span class="c1"># Given the expected output doesn&#39;t effect by lowering precision</span>
<span class="c1"># it&#39;s advisable to use lower precision.</span>
<span class="c1"># We have a helper API to make the precision conversion simple and</span>
<span class="c1"># it supports dtype with &quot;float16&quot; and &quot;float16_acc32&quot; modes.</span>
<span class="c1"># Let&#39;s choose &quot;float16&quot; for calculation and &quot;float32&quot; for accumulation.</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">calculation_dtype</span></a> <span class="o">=</span> <span class="s2">&quot;float16&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">acc_dtype</span></a> <span class="o">=</span> <span class="s2">&quot;float32&quot;</span>
<span class="c1"># Specify Adreno target before compiling to generate texture</span>
<span class="c1"># leveraging kernels and get all the benefits of textures</span>
<span class="c1"># Note: This generated example running on our x86 server for demonstration.</span>
<span class="c1"># If running it on the Android device, we need to</span>
<span class="c1"># specify its instruction set. Set :code:`local_demo` to False if you want</span>
<span class="c1"># to run this tutorial with a real device over rpc.</span>
<a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">local_demo</span></a> <span class="o">=</span> <span class="kc">True</span>
<span class="c1"># by default on CPU target will execute.</span>
<span class="c1"># select &#39;cpu&#39;, &#39;opencl&#39; and &#39;opencl -device=adreno&#39;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">test_target</span></a> <span class="o">=</span> <span class="s2">&quot;cpu&quot;</span>
<span class="c1"># Change target configuration.</span>
<span class="c1"># Run `adb shell cat /proc/cpuinfo` to find the arch.</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arch</span></a> <span class="o">=</span> <span class="s2">&quot;arm64&quot;</span>
<a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a> <span class="o">=</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class"><span class="n">tvm</span><span class="o">.</span><span class="n">target</span><span class="o">.</span><span class="n">Target</span></a><span class="p">(</span><span class="s2">&quot;llvm -mtriple=</span><span class="si">%s</span><span class="s2">-linux-android&quot;</span> <span class="o">%</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arch</span></a><span class="p">)</span>
<span class="c1"># Auto tuning is compute intensive and time taking task,</span>
<span class="c1"># hence disabling for default run. Please enable it if required.</span>
<a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">is_tuning</span></a> <span class="o">=</span> <span class="kc">False</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tune_log</span></a> <span class="o">=</span> <span class="s2">&quot;adreno-resnet18.log&quot;</span>
<span class="c1"># To enable OpenCLML accelerated operator library.</span>
<a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">enable_clml</span></a> <span class="o">=</span> <span class="kc">False</span>
</pre></div>
</div>
</div>
<div class="section" id="get-a-pytorch-model">
<h2>Get a PyTorch Model<a class="headerlink" href="#get-a-pytorch-model" title="Permalink to this headline"></a></h2>
<p>Get resnet18 from torchvision models</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">model_name</span></a> <span class="o">=</span> <span class="s2">&quot;resnet18&quot;</span>
<span class="n">model</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">torchvision</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">model_name</span></a><span class="p">)(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="c1"># We grab the TorchScripted model via tracing</span>
<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_shape</span></a> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">]</span>
<span class="n">input_data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_shape</span></a><span class="p">)</span>
<span class="n">scripted_model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">trace</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">input_data</span><span class="p">)</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter &#39;pretrained&#39; is deprecated since 0.13 and may be removed in the future, please use &#39;weights&#39; instead.
warnings.warn(
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for &#39;weights&#39; are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
</pre></div>
</div>
</div>
<div class="section" id="load-a-test-image">
<h2>Load a test image<a class="headerlink" href="#load-a-test-image" title="Permalink to this headline"></a></h2>
<p>As an example we would use classical cat image from ImageNet</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="kn">from</span> <span class="nn">tvm.contrib.download</span> <span class="kn">import</span> <span class="n">download_testdata</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_url</span></a> <span class="o">=</span> <span class="s2">&quot;https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_path</span></a> <span class="o">=</span> <span class="n">download_testdata</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_url</span></a><span class="p">,</span> <span class="s2">&quot;cat.png&quot;</span><span class="p">,</span> <span class="n">module</span><span class="o">=</span><span class="s2">&quot;data&quot;</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_path</span></a><span class="p">)</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="c1"># Preprocess the image and convert to tensor</span>
<span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span>
<span class="n">my_preprocess</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span>
<span class="p">[</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="mi">256</span><span class="p">),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="mi">224</span><span class="p">),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="p">[</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">],</span> <span class="n">std</span><span class="o">=</span><span class="p">[</span><span class="mf">0.229</span><span class="p">,</span> <span class="mf">0.224</span><span class="p">,</span> <span class="mf">0.225</span><span class="p">]),</span>
<span class="p">]</span>
<span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">my_preprocess</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</pre></div>
</div>
<img src="../../_images/sphx_glr_deploy_model_on_adreno_001.png" srcset="../../_images/sphx_glr_deploy_model_on_adreno_001.png" alt="deploy model on adreno" class = "sphx-glr-single-img"/></div>
<div class="section" id="convert-pytorch-model-to-relay-module">
<h2>Convert PyTorch model to Relay module<a class="headerlink" href="#convert-pytorch-model-to-relay-module" title="Permalink to this headline"></a></h2>
<p>TVM has frontend api for various frameworks under relay.frontend and now
for pytorch model import we have relay.frontend.from_pytorch api.
Input name can be arbitrary</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_name</span></a> <span class="o">=</span> <span class="s2">&quot;input0&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">shape_list</span></a> <span class="o">=</span> <span class="p">[(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_name</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img</span><span class="o">.</span><span class="n">shape</span></a><span class="p">)]</span>
<span class="n">mod</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a> <span class="o">=</span> <a href="../../reference/api/python/relay/frontend.html#tvm.relay.frontend.from_pytorch" title="tvm.relay.frontend.from_pytorch" class="sphx-glr-backref-module-tvm-relay-frontend sphx-glr-backref-type-py-function"><span class="n">relay</span><span class="o">.</span><span class="n">frontend</span><span class="o">.</span><span class="n">from_pytorch</span></a><span class="p">(</span><span class="n">scripted_model</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">shape_list</span></a><span class="p">)</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
return LooseVersion(torch_ver) &gt; ver
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
other = LooseVersion(other)
</pre></div>
</div>
</div>
<div class="section" id="precisions">
<h2>Precisions<a class="headerlink" href="#precisions" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Adreno devices are efficient with float16 compared to float32</span>
<span class="c1"># Given the expected output doesn&#39;t effect by lowering precision</span>
<span class="c1"># it&#39;s advisable to use lower precision.</span>
<span class="c1"># TVM support Mixed Precision through ToMixedPrecision transformation pass.</span>
<span class="c1"># We may need to register precision rules like precision type, accumultation</span>
<span class="c1"># datatype ...etc. for the required operators to override the default settings.</span>
<span class="c1"># The below helper api simplifies the precision conversions across the module.</span>
<span class="c1"># Calculation dtype is set to &quot;float16&quot; and accumulation dtype is set to &quot;float32&quot;</span>
<span class="c1"># in configuration section above.</span>
<span class="kn">from</span> <span class="nn">tvm.driver.tvmc.transform</span> <span class="kn">import</span> <span class="n">apply_graph_transforms</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">apply_graph_transforms</span><span class="p">(</span>
<span class="n">mod</span><span class="p">,</span>
<span class="p">{</span>
<span class="s2">&quot;mixed_precision&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
<span class="s2">&quot;mixed_precision_ops&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;nn.conv2d&quot;</span><span class="p">,</span> <span class="s2">&quot;nn.dense&quot;</span><span class="p">],</span>
<span class="s2">&quot;mixed_precision_calculation_type&quot;</span><span class="p">:</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">calculation_dtype</span></a><span class="p">,</span>
<span class="s2">&quot;mixed_precision_acc_type&quot;</span><span class="p">:</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">acc_dtype</span></a><span class="p">,</span>
<span class="p">},</span>
<span class="p">)</span>
</pre></div>
</div>
<p>As you can see in the IR, the architecture now contains cast operations, which are
needed to convert to FP16 precision.
You can also use “float16” or “float32” precisions as other dtype options.</p>
</div>
<div class="section" id="prepare-tvm-target">
<h2>Prepare TVM Target<a class="headerlink" href="#prepare-tvm-target" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># This generated example running on our x86 server for demonstration.</span>
<span class="c1"># To deply and tun on real target over RPC please set :code:`local_demo` to False in above configuration sestion.</span>
<span class="c1"># Also, :code:`test_target` is set to :code:`llvm` as this example to make compatible for x86 demonstration.</span>
<span class="c1"># Please change it to :code:`opencl` or :code:`opencl -device=adreno` for RPC target in configuration above.</span>
<span class="k">if</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">local_demo</span></a><span class="p">:</span>
<a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a> <span class="o">=</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class"><span class="n">tvm</span><span class="o">.</span><span class="n">target</span><span class="o">.</span><span class="n">Target</span></a><span class="p">(</span><span class="s2">&quot;llvm&quot;</span><span class="p">)</span>
<span class="k">elif</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">test_target</span></a><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;opencl&quot;</span><span class="p">):</span>
<a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a> <span class="o">=</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class"><span class="n">tvm</span><span class="o">.</span><span class="n">target</span><span class="o">.</span><span class="n">Target</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">test_target</span></a><span class="p">,</span> <span class="n">host</span><span class="o">=</span><a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="autotuning">
<h2>AutoTuning<a class="headerlink" href="#autotuning" title="Permalink to this headline"></a></h2>
<p>The below few instructions can auto tune the relay module with xgboost being the tuner algorithm.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Auto Tuning process involces stages of extracting the tasks, defining tuning congiguration and</span>
<span class="c1"># tuning each task for best performing kernel configuration.</span>
<span class="c1"># Get RPC related settings.</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rpc_tracker_host</span></a> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;TVM_TRACKER_HOST&quot;</span><span class="p">,</span> <span class="s2">&quot;127.0.0.1&quot;</span><span class="p">)</span>
<a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rpc_tracker_port</span></a> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;TVM_TRACKER_PORT&quot;</span><span class="p">,</span> <span class="mi">9190</span><span class="p">))</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">key</span></a> <span class="o">=</span> <span class="s2">&quot;android&quot;</span>
<span class="c1"># Auto tuning is compute intensive and time taking task.</span>
<span class="c1"># It is set to False in above configuration as this script runs in x86 for demonstration.</span>
<span class="c1"># Please to set :code:`is_tuning` to True to enable auto tuning.</span>
<span class="k">if</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">is_tuning</span></a><span class="p">:</span>
<span class="c1"># Auto Tuning Stage 1: Extract tunable tasks</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">task</span><span class="o">.</span><span class="n">extract_from_program</span><span class="p">(</span>
<span class="n">mod</span><span class="p">,</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">test_target</span></a><span class="p">,</span> <span class="n">target_host</span><span class="o">=</span><a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a>
<span class="p">)</span>
<span class="c1"># Auto Tuning Stage 2: Define tuning configuration</span>
<span class="n">tmp_log_file</span> <span class="o">=</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tune_log</span></a> <span class="o">+</span> <span class="s2">&quot;.tmp&quot;</span>
<span class="n">measure_option</span> <span class="o">=</span> <span class="n">autotvm</span><span class="o">.</span><span class="n">measure_option</span><span class="p">(</span>
<span class="n">builder</span><span class="o">=</span><span class="n">autotvm</span><span class="o">.</span><span class="n">LocalBuilder</span><span class="p">(</span>
<span class="n">build_func</span><span class="o">=</span><a href="../../reference/api/python/contrib.html#tvm.contrib.ndk.create_shared" title="tvm.contrib.ndk.create_shared" class="sphx-glr-backref-module-tvm-contrib-ndk sphx-glr-backref-type-py-function"><span class="n">ndk</span><span class="o">.</span><span class="n">create_shared</span></a><span class="p">,</span> <span class="n">timeout</span><span class="o">=</span><span class="mi">15</span>
<span class="p">),</span> <span class="c1"># Build the test kernel locally</span>
<span class="n">runner</span><span class="o">=</span><span class="n">autotvm</span><span class="o">.</span><span class="n">RPCRunner</span><span class="p">(</span> <span class="c1"># The runner would be on a remote device.</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">key</span></a><span class="p">,</span> <span class="c1"># RPC Key</span>
<span class="n">host</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rpc_tracker_host</span></a><span class="p">,</span> <span class="c1"># Tracker host</span>
<span class="n">port</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rpc_tracker_port</span></a><span class="p">),</span> <span class="c1"># Tracker port</span>
<span class="n">number</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="c1"># Number of runs before averaging</span>
<span class="n">timeout</span><span class="o">=</span><span class="mi">600</span><span class="p">,</span> <span class="c1"># RPC Timeout</span>
<span class="p">),</span>
<span class="p">)</span>
<span class="n">n_trial</span> <span class="o">=</span> <span class="mi">1024</span> <span class="c1"># Number of iteration of training before choosing the best kernel config</span>
<span class="n">early_stopping</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Can be enabled to stop tuning while the loss is not minimizing.</span>
<span class="c1"># Auto Tuning Stage 3: Iterate through the tasks and tune.</span>
<span class="kn">from</span> <span class="nn">tvm.autotvm.tuner</span> <span class="kn">import</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">tsk</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">reversed</span><span class="p">(</span><span class="n">tasks</span><span class="p">[:</span><span class="mi">3</span><span class="p">])):</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Task:&quot;</span><span class="p">,</span> <span class="n">tsk</span><span class="p">)</span>
<span class="n">prefix</span> <span class="o">=</span> <span class="s2">&quot;[Task </span><span class="si">%2d</span><span class="s2">/</span><span class="si">%2d</span><span class="s2">] &quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">tasks</span><span class="p">))</span>
<span class="c1"># choose tuner</span>
<span class="n">tuner</span> <span class="o">=</span> <span class="s2">&quot;xgb&quot;</span>
<span class="c1"># create tuner</span>
<span class="k">if</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;reg&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_knob&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;reg&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;knob&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_itervar&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;reg&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;itervar&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_curve&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;reg&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;curve&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_knob&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;knob&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_itervar&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;itervar&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_curve&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;curve&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_binary&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank-binary&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_binary_knob&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank-binary&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;knob&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_binary_itervar&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank-binary&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;itervar&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;xgb_rank_binary_curve&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.tuner.XGBTuner" title="tvm.autotvm.tuner.XGBTuner" class="sphx-glr-backref-module-tvm-autotvm-tuner sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">XGBTuner</span></a><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="s2">&quot;rank-binary&quot;</span><span class="p">,</span> <span class="n">feature_type</span><span class="o">=</span><span class="s2">&quot;curve&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;ga&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <span class="n">GATuner</span><span class="p">(</span><span class="n">tsk</span><span class="p">,</span> <span class="n">pop_size</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;random&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <span class="n">RandomTuner</span><span class="p">(</span><span class="n">tsk</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tuner</span> <span class="o">==</span> <span class="s2">&quot;gridsearch&quot;</span><span class="p">:</span>
<span class="n">tuner_obj</span> <span class="o">=</span> <span class="n">GridSearchTuner</span><span class="p">(</span><span class="n">tsk</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid tuner: &quot;</span> <span class="o">+</span> <span class="n">tuner</span><span class="p">)</span>
<span class="n">tsk_trial</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">n_trial</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">tsk</span><span class="o">.</span><span class="n">config_space</span><span class="p">))</span>
<span class="n">tuner_obj</span><span class="o">.</span><span class="n">tune</span><span class="p">(</span>
<span class="n">n_trial</span><span class="o">=</span><span class="n">tsk_trial</span><span class="p">,</span>
<span class="n">early_stopping</span><span class="o">=</span><span class="n">early_stopping</span><span class="p">,</span>
<span class="n">measure_option</span><span class="o">=</span><span class="n">measure_option</span><span class="p">,</span>
<span class="n">callbacks</span><span class="o">=</span><span class="p">[</span>
<span class="n">autotvm</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">progress_bar</span><span class="p">(</span><span class="n">tsk_trial</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="n">prefix</span><span class="p">),</span>
<span class="n">autotvm</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">log_to_file</span><span class="p">(</span><span class="n">tmp_log_file</span><span class="p">),</span>
<span class="p">],</span>
<span class="p">)</span>
<span class="c1"># Auto Tuning Stage 4: Pick the best performing configurations from the overall log.</span>
<a href="../../reference/api/python/autotvm.html#tvm.autotvm.record.pick_best" title="tvm.autotvm.record.pick_best" class="sphx-glr-backref-module-tvm-autotvm-record sphx-glr-backref-type-py-function"><span class="n">autotvm</span><span class="o">.</span><span class="n">record</span><span class="o">.</span><span class="n">pick_best</span></a><span class="p">(</span><span class="n">tmp_log_file</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tune_log</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="enable-openclml-offloading">
<h2>Enable OpenCLML Offloading<a class="headerlink" href="#enable-openclml-offloading" title="Permalink to this headline"></a></h2>
<p>OpenCLML offloading will try to accelerate supported operators
by using OpenCLML proprietory operator library.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># By default :code:`enable_clml` is set to False in above configuration section.</span>
<span class="k">if</span> <span class="ow">not</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">local_demo</span></a> <span class="ow">and</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">enable_clml</span></a><span class="p">:</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">clml</span><span class="o">.</span><span class="n">partition_for_clml</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="compilation">
<h2>Compilation<a class="headerlink" href="#compilation" title="Permalink to this headline"></a></h2>
<p>Use tuning cache if exists.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <a href="https://docs.python.org/3/library/os.path.html#os.path.exists" title="os.path.exists" class="sphx-glr-backref-module-os-path sphx-glr-backref-type-py-function"><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tune_log</span></a><span class="p">):</span>
<span class="k">with</span> <a href="../../reference/api/python/autotvm.html#tvm.autotvm.apply_history_best" title="tvm.autotvm.apply_history_best" class="sphx-glr-backref-module-tvm-autotvm sphx-glr-backref-type-py-function"><span class="n">autotvm</span><span class="o">.</span><span class="n">apply_history_best</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tune_log</span></a><span class="p">):</span>
<span class="k">with</span> <a href="../../reference/api/python/ir.html#tvm.transform.PassContext" title="tvm.transform.PassContext" class="sphx-glr-backref-module-tvm-transform sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tvm</span><span class="o">.</span><span class="n">transform</span><span class="o">.</span><span class="n">PassContext</span></a><span class="p">(</span><span class="n">opt_level</span><span class="o">=</span><span class="mi">3</span><span class="p">):</span>
<span class="n">lib</span> <span class="o">=</span> <span class="n">relay</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="o">=</span><a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">with</span> <a href="../../reference/api/python/ir.html#tvm.transform.PassContext" title="tvm.transform.PassContext" class="sphx-glr-backref-module-tvm-transform sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tvm</span><span class="o">.</span><span class="n">transform</span><span class="o">.</span><span class="n">PassContext</span></a><span class="p">(</span><span class="n">opt_level</span><span class="o">=</span><span class="mi">3</span><span class="p">):</span>
<span class="n">lib</span> <span class="o">=</span> <span class="n">relay</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="o">=</span><a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="deploy-the-model-remotely-by-rpc">
<h2>Deploy the Model Remotely by RPC<a class="headerlink" href="#deploy-the-model-remotely-by-rpc" title="Permalink to this headline"></a></h2>
<p>Using RPC you can deploy the model from host
machine to the remote Adreno device</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">local_demo</span></a><span class="p">:</span>
<a href="../../reference/api/python/rpc.html#tvm.rpc.LocalSession" title="tvm.rpc.LocalSession" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">remote</span></a> <span class="o">=</span> <a href="../../reference/api/python/rpc.html#tvm.rpc.LocalSession" title="tvm.rpc.LocalSession" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-class"><span class="n">rpc</span><span class="o">.</span><span class="n">LocalSession</span></a><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tracker</span> <span class="o">=</span> <a href="../../reference/api/python/rpc.html#tvm.rpc.connect_tracker" title="tvm.rpc.connect_tracker" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-function"><span class="n">rpc</span><span class="o">.</span><span class="n">connect_tracker</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rpc_tracker_host</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rpc_tracker_port</span></a><span class="p">)</span>
<span class="c1"># When running a heavy model, we should increase the `session_timeout`</span>
<a href="../../reference/api/python/rpc.html#tvm.rpc.LocalSession" title="tvm.rpc.LocalSession" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">remote</span></a> <span class="o">=</span> <span class="n">tracker</span><span class="o">.</span><span class="n">request</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">key</span></a><span class="p">,</span> <span class="n">priority</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">session_timeout</span><span class="o">=</span><span class="mi">60</span><span class="p">)</span>
<span class="k">if</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">local_demo</span></a><span class="p">:</span>
<span class="n">dev</span> <span class="o">=</span> <a href="../../reference/api/python/rpc.html#tvm.rpc.RPCSession.cpu" title="tvm.rpc.RPCSession.cpu" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-method"><span class="n">remote</span><span class="o">.</span><span class="n">cpu</span></a><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">elif</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">test_target</span></a><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;opencl&quot;</span><span class="p">):</span>
<span class="n">dev</span> <span class="o">=</span> <a href="../../reference/api/python/rpc.html#tvm.rpc.RPCSession.cl" title="tvm.rpc.RPCSession.cl" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-method"><span class="n">remote</span><span class="o">.</span><span class="n">cl</span></a><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">dev</span> <span class="o">=</span> <a href="../../reference/api/python/rpc.html#tvm.rpc.RPCSession.cpu" title="tvm.rpc.RPCSession.cpu" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-method"><span class="n">remote</span><span class="o">.</span><span class="n">cpu</span></a><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<a href="../../reference/api/python/contrib.html#tvm.contrib.utils.TempDirectory" title="tvm.contrib.utils.TempDirectory" class="sphx-glr-backref-module-tvm-contrib-utils sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">temp</span></a> <span class="o">=</span> <a href="../../reference/api/python/contrib.html#tvm.contrib.utils.tempdir" title="tvm.contrib.utils.tempdir" class="sphx-glr-backref-module-tvm-contrib-utils sphx-glr-backref-type-py-function"><span class="n">utils</span><span class="o">.</span><span class="n">tempdir</span></a><span class="p">()</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary</span></a> <span class="o">=</span> <span class="s2">&quot;dev_lib_cl.so&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary_path</span></a> <span class="o">=</span> <a href="../../reference/api/python/contrib.html#tvm.contrib.utils.TempDirectory.relpath" title="tvm.contrib.utils.TempDirectory.relpath" class="sphx-glr-backref-module-tvm-contrib-utils sphx-glr-backref-type-py-method"><span class="n">temp</span><span class="o">.</span><span class="n">relpath</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary</span></a><span class="p">)</span>
<span class="n">fcompile</span> <span class="o">=</span> <a href="../../reference/api/python/contrib.html#tvm.contrib.ndk.create_shared" title="tvm.contrib.ndk.create_shared" class="sphx-glr-backref-module-tvm-contrib-ndk sphx-glr-backref-type-py-function"><span class="n">ndk</span><span class="o">.</span><span class="n">create_shared</span></a> <span class="k">if</span> <span class="ow">not</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">local_demo</span></a> <span class="k">else</span> <span class="kc">None</span>
<span class="n">lib</span><span class="o">.</span><span class="n">export_library</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary_path</span></a><span class="p">,</span> <span class="n">fcompile</span><span class="o">=</span><span class="n">fcompile</span><span class="p">)</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">remote_path</span></a> <span class="o">=</span> <span class="s2">&quot;/data/local/tmp/&quot;</span> <span class="o">+</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary</span></a>
<a href="../../reference/api/python/rpc.html#tvm.rpc.RPCSession.upload" title="tvm.rpc.RPCSession.upload" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-method"><span class="n">remote</span><span class="o">.</span><span class="n">upload</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary_path</span></a><span class="p">)</span>
<a href="../../reference/api/python/runtime.html#tvm.runtime.Module" title="tvm.runtime.Module" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rlib</span></a> <span class="o">=</span> <a href="../../reference/api/python/rpc.html#tvm.rpc.RPCSession.load_module" title="tvm.rpc.RPCSession.load_module" class="sphx-glr-backref-module-tvm-rpc sphx-glr-backref-type-py-method"><span class="n">remote</span><span class="o">.</span><span class="n">load_module</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">dso_binary</span></a><span class="p">)</span>
<a href="../../reference/api/python/graph_executor.html#tvm.contrib.graph_executor.GraphModule" title="tvm.contrib.graph_executor.GraphModule" class="sphx-glr-backref-module-tvm-contrib-graph_executor sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">m</span></a> <span class="o">=</span> <a href="../../reference/api/python/graph_executor.html#tvm.contrib.graph_executor.GraphModule" title="tvm.contrib.graph_executor.GraphModule" class="sphx-glr-backref-module-tvm-contrib-graph_executor sphx-glr-backref-type-py-class"><span class="n">graph_executor</span><span class="o">.</span><span class="n">GraphModule</span></a><span class="p">(</span><a href="../../reference/api/python/runtime.html#tvm.runtime.Module" title="tvm.runtime.Module" class="sphx-glr-backref-module-tvm-runtime sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">rlib</span></a><span class="p">[</span><span class="s2">&quot;default&quot;</span><span class="p">](</span><span class="n">dev</span><span class="p">))</span>
</pre></div>
</div>
</div>
<div class="section" id="run-inference">
<h2>Run inference<a class="headerlink" href="#run-inference" title="Permalink to this headline"></a></h2>
<p>We now can set inputs, infer our model and get predictions as output</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="../../reference/api/python/graph_executor.html#tvm.contrib.graph_executor.GraphModule.set_input" title="tvm.contrib.graph_executor.GraphModule.set_input" class="sphx-glr-backref-module-tvm-contrib-graph_executor sphx-glr-backref-type-py-method"><span class="n">m</span><span class="o">.</span><span class="n">set_input</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_name</span></a><span class="p">,</span> <a href="../../reference/api/python/ndarray.html#tvm.nd.array" title="tvm.nd.array" class="sphx-glr-backref-module-tvm-nd sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span></a><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)))</span>
<a href="../../reference/api/python/graph_executor.html#tvm.contrib.graph_executor.GraphModule.run" title="tvm.contrib.graph_executor.GraphModule.run" class="sphx-glr-backref-module-tvm-contrib-graph_executor sphx-glr-backref-type-py-method"><span class="n">m</span><span class="o">.</span><span class="n">run</span></a><span class="p">()</span>
<span class="n">tvm_output</span> <span class="o">=</span> <a href="../../reference/api/python/graph_executor.html#tvm.contrib.graph_executor.GraphModule.get_output" title="tvm.contrib.graph_executor.GraphModule.get_output" class="sphx-glr-backref-module-tvm-contrib-graph_executor sphx-glr-backref-type-py-method"><span class="n">m</span><span class="o">.</span><span class="n">get_output</span></a><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="get-predictions-and-performance-statistic">
<h2>Get predictions and performance statistic<a class="headerlink" href="#get-predictions-and-performance-statistic" title="Permalink to this headline"></a></h2>
<p>This piece of code displays the top-1 and top-5 predictions, as
well as provides information about the model’s performance</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">os.path</span> <span class="kn">import</span> <a href="https://docs.python.org/3/library/os.path.html#os.path.join" title="os.path.join" class="sphx-glr-backref-module-os-path sphx-glr-backref-type-py-function"><span class="n">join</span></a><span class="p">,</span> <span class="n">isfile</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">tvm.contrib</span> <span class="kn">import</span> <span class="n">download</span>
<span class="c1"># Download ImageNet categories</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">categ_url</span></a> <span class="o">=</span> <span class="s2">&quot;https://github.com/uwsampl/web-data/raw/main/vta/models/&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">categ_fn</span></a> <span class="o">=</span> <span class="s2">&quot;synset.txt&quot;</span>
<span class="n">download</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><a href="https://docs.python.org/3/library/os.path.html#os.path.join" title="os.path.join" class="sphx-glr-backref-module-os-path sphx-glr-backref-type-py-function"><span class="n">join</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">categ_url</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">categ_fn</span></a><span class="p">),</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">categ_fn</span></a><span class="p">)</span>
<a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="nb">open</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">categ_fn</span></a><span class="p">)</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
<span class="n">top_categories</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">tvm_output</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">top5</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">top_categories</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)[:</span><span class="mi">5</span><span class="p">]</span>
<span class="c1"># Report top-1 classification result</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Top-1 id: </span><span class="si">{}</span><span class="s2">, class name: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">top5</span><span class="p">[</span><span class="mi">1</span> <span class="o">-</span> <span class="mi">1</span><span class="p">],</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">top5</span><span class="p">[</span><span class="mi">1</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]]))</span>
<span class="c1"># Report top-5 classification results</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Top5 predictions: </span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">#1:&quot;</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">top5</span><span class="p">[</span><span class="mi">1</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]])</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">#2:&quot;</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">top5</span><span class="p">[</span><span class="mi">2</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]])</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">#3:&quot;</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">top5</span><span class="p">[</span><span class="mi">3</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]])</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">#4:&quot;</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">top5</span><span class="p">[</span><span class="mi">4</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]])</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">#5:&quot;</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">top5</span><span class="p">[</span><span class="mi">5</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]])</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">top5</span><span class="p">)</span>
<a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ImageNetClassifier</span></a> <span class="o">=</span> <span class="kc">False</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">top_categories</span><span class="p">[</span><span class="o">-</span><span class="mi">5</span><span class="p">:]:</span>
<span class="k">if</span> <span class="s2">&quot;cat&quot;</span> <span class="ow">in</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">synset</span></a><span class="p">[</span><span class="n">k</span><span class="p">]:</span>
<a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ImageNetClassifier</span></a> <span class="o">=</span> <span class="kc">True</span>
<span class="k">assert</span> <a href="https://docs.python.org/3/library/functions.html#bool" title="builtins.bool" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ImageNetClassifier</span></a><span class="p">,</span> <span class="s2">&quot;Failed ImageNet classifier validation check&quot;</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Evaluate inference time cost...&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><a href="../../reference/api/python/graph_executor.html#tvm.contrib.graph_executor.GraphModule.benchmark" title="tvm.contrib.graph_executor.GraphModule.benchmark" class="sphx-glr-backref-module-tvm-contrib-graph_executor sphx-glr-backref-type-py-method"><span class="n">m</span><span class="o">.</span><span class="n">benchmark</span></a><span class="p">(</span><span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">repeat</span><span class="o">=</span><span class="mi">10</span><span class="p">))</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/workspace/python/tvm/runtime/ndarray.py:217: DeprecationWarning: NDArray.asnumpy() will be deprecated in TVM v0.8 release. Please use NDArray.numpy() instead.
warnings.warn(
Top-1 id: 281, class name: tabby, tabby cat
Top5 predictions:
#1: tabby, tabby cat
#2: tiger cat
#3: lynx, catamount
#4: red fox, Vulpes vulpes
#5: Egyptian cat
[281 282 287 277 285]
Evaluate inference time cost...
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
4127.8531 4127.8278 4130.5083 4124.9284 1.6534
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 22.850 seconds)</p>
<div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-deploy-models-deploy-model-on-adreno-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/2387d8448da213eb625e6b3d916327d4/deploy_model_on_adreno.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">deploy_model_on_adreno.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/b9e7311d8c56eb6e6aca08f0be35ff03/deploy_model_on_adreno.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">deploy_model_on_adreno.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="deploy_model_on_adreno_tvmc.html" class="btn btn-neutral float-right" title="Deploy the Pretrained Model on Adreno™ with tvmc Interface" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="index.html" class="btn btn-neutral float-left" title="Deploy Deep Learning Models" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div>
<section class="footerSec">
<div class="footerHeader">
<div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row">
<div class="copywrite d-flex align-items-center">
<h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5>
</div>
</div>
</div>
<div>
<div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div>
</div>
</section>
</footer>
</div>
</div>
</section>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
</body>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<!-- Theme Analytics -->
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-75982049-2', 'auto');
ga('send', 'pageview');
</script>
</body>
</html>