blob: d300543503d900d28c7446bef7171b0e6085e579 [file] [log] [blame]
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Compile PyTorch Object Detection Models &mdash; tvm 0.17.dev0 documentation</title>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" />
<link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/>
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Deploy a Framework-prequantized Model with TVM" href="deploy_prequantized.html" />
<link rel="prev" title="Deploy the Pretrained Model on Raspberry Pi" href="deploy_model_on_rasp.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<header class="header">
<div class="innercontainer">
<div class="headerInner d-flex justify-content-between align-items-center">
<div class="headerLogo">
<a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a>
</div>
<div id="headMenu" class="headerNav">
<button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button>
<ul class="nav">
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/community>Community</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/download>Download</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/vta>VTA</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/blog>Blog</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvm.apache.org/docs>Docs</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://tvmconf.org>Conference</a>
</li>
<li class="nav-item">
<a class="nav-link" href=https://github.com/apache/tvm/>Github</a>
</li>
</ul>
<div class="responsivetlcdropdown">
<button type="button" class="btn-link">
ASF
</button>
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
<div class="responsiveMenuIcon">
<button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button>
</div>
<div class="tlcDropdown">
<div class="dropdown">
<button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">
ASF
</button>
<div class="dropdown-menu dropdown-menu-right">
<ul>
<li>
<a href=https://apache.org/>Apache Homepage</a>
</li>
<li>
<a href=https://www.apache.org/licenses/>License</a>
</li>
<li>
<a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
</li>
<li>
<a href=https://www.apache.org/security/>Security</a>
</li>
<li>
<a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
</li>
<li>
<a href=https://www.apache.org/events/current-event>Events</a>
</li>
</ul>
</div>
</div>
</div>
</div>
</div>
</header>
<nav data-toggle="wy-nav-shift" class="wy-nav-side fixed">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html">
<img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/>
</a>
<input type="checkbox" class="version-toggle-box" hidden id="version-toggle">
<label for="version-toggle" class="version-toggle-label">
<div tabindex="0" class="version version-selector version-selector-show">
0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span>
</div>
</label>
<div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Versions</span></p>
<ol style="text-align: left">
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li>
<li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li>
</ol>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">User Guide</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li>
<li class="toctree-l2 current"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#build-the-tvm-runtime-library">Build the TVM runtime library</a></li>
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#cross-compile-the-tvm-runtime-for-other-architectures">Cross compile the TVM runtime for other architectures</a></li>
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#optimize-and-tune-models-for-target-devices">Optimize and tune models for target devices</a></li>
<li class="toctree-l3"><a class="reference internal" href="../deploy/index.html#deploy-optimized-model-on-target-devices">Deploy optimized model on target devices</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../deploy/index.html#additional-deployment-how-tos">Additional Deployment How-Tos</a><ul class="current">
<li class="toctree-l4 current"><a class="reference internal" href="index.html">Deploy Deep Learning Models</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li>
<li class="toctree-l2"><a class="reference internal" href="../optimize_operators/index.html">Optimize Tensor Operators</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top">
<div class="togglemenu">
</div>
<div class="nav-content">
<!-- tvm -->
Table of Contents
</div>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li>
<li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li>
<li><a href="../deploy/index.html">Deploy Models and Integrate TVM</a> <span class="br-arrow">></span></li>
<li><a href="index.html">Deploy Deep Learning Models</a> <span class="br-arrow">></span></li>
<li>Compile PyTorch Object Detection Models</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/apache/tvm/edit/main/docs/how_to/deploy_models/deploy_object_detection_pytorch.rst" class="fa fa-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>This tutorial can be used interactively with Google Colab! You can also click
<a class="reference internal" href="#sphx-glr-download-how-to-deploy-models-deploy-object-detection-pytorch-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p>
<a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/399e1d7889ca66b69d51655784827503/deploy_object_detection_pytorch.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a>
</div>
<div class="sphx-glr-example-title section" id="compile-pytorch-object-detection-models">
<span id="sphx-glr-how-to-deploy-models-deploy-object-detection-pytorch-py"></span><h1>Compile PyTorch Object Detection Models<a class="headerlink" href="#compile-pytorch-object-detection-models" title="Permalink to this headline"></a></h1>
<p>This article is an introductory tutorial to deploy PyTorch object
detection models with Relay VM.</p>
<p>For us to begin with, PyTorch should be installed.
TorchVision is also required since we will be using it as our model zoo.</p>
<p>A quick solution is to install via pip</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip<span class="w"> </span>install<span class="w"> </span>torch
pip<span class="w"> </span>install<span class="w"> </span>torchvision
</pre></div>
</div>
<p>or please refer to official site
<a class="reference external" href="https://pytorch.org/get-started/locally/">https://pytorch.org/get-started/locally/</a></p>
<p>PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.</p>
<p>Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
be unstable.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tvm</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">relay</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="kn">import</span> <span class="n">relay</span>
<span class="kn">from</span> <span class="nn">tvm.runtime.vm</span> <span class="kn">import</span> <span class="n">VirtualMachine</span>
<span class="kn">from</span> <span class="nn">tvm.contrib.download</span> <span class="kn">import</span> <span class="n">download_testdata</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">cv2</span>
<span class="c1"># PyTorch imports</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torchvision</span>
</pre></div>
</div>
<div class="section" id="load-pre-trained-maskrcnn-from-torchvision-and-do-tracing">
<h2>Load pre-trained maskrcnn from torchvision and do tracing<a class="headerlink" href="#load-pre-trained-maskrcnn-from-torchvision-and-do-tracing" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a> <span class="o">=</span> <span class="mi">300</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_shape</span></a> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">)</span>
<span class="k">def</span> <span class="nf">do_trace</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">inp</span><span class="p">):</span>
<span class="n">model_trace</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">trace</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">inp</span><span class="p">)</span>
<span class="n">model_trace</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="k">return</span> <span class="n">model_trace</span>
<span class="k">def</span> <span class="nf">dict_to_tuple</span><span class="p">(</span><span class="n">out_dict</span><span class="p">):</span>
<span class="k">if</span> <span class="s2">&quot;masks&quot;</span> <span class="ow">in</span> <span class="n">out_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">return</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;boxes&quot;</span><span class="p">],</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;scores&quot;</span><span class="p">],</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;labels&quot;</span><span class="p">],</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;masks&quot;</span><span class="p">]</span>
<span class="k">return</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;boxes&quot;</span><span class="p">],</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;scores&quot;</span><span class="p">],</span> <span class="n">out_dict</span><span class="p">[</span><span class="s2">&quot;labels&quot;</span><span class="p">]</span>
<span class="k">class</span> <span class="nc">TraceWrapper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inp</span><span class="p">):</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out</span></a> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dict_to_tuple</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out</span></a><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">model_func</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">detection</span><span class="o">.</span><span class="n">maskrcnn_resnet50_fpn</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">TraceWrapper</span><span class="p">(</span><span class="n">model_func</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="n">inp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">250.0</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">)))</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">out</span></a> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span>
<span class="n">script_module</span> <span class="o">=</span> <span class="n">do_trace</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">inp</span><span class="p">)</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter &#39;pretrained&#39; is deprecated since 0.13 and may be removed in the future, please use &#39;weights&#39; instead.
warnings.warn(
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for &#39;weights&#39; are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: &quot;https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth&quot; to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
0%| | 0.00/170M [00:00&lt;?, ?B/s]
4%|3 | 6.30M/170M [00:00&lt;00:03, 49.0MB/s]
6%|6 | 11.0M/170M [00:00&lt;00:03, 47.0MB/s]
9%|9 | 16.0M/170M [00:00&lt;00:03, 44.3MB/s]
13%|#3 | 22.3M/170M [00:00&lt;00:03, 48.7MB/s]
16%|#5 | 27.0M/170M [00:00&lt;00:03, 48.2MB/s]
19%|#8 | 32.0M/170M [00:00&lt;00:03, 41.3MB/s]
24%|##3 | 40.1M/170M [00:00&lt;00:02, 52.9MB/s]
27%|##7 | 46.3M/170M [00:00&lt;00:02, 55.9MB/s]
31%|### | 51.9M/170M [00:01&lt;00:02, 55.3MB/s]
34%|###3 | 57.6M/170M [00:01&lt;00:02, 56.6MB/s]
38%|###7 | 64.0M/170M [00:01&lt;00:02, 46.3MB/s]
42%|####2 | 72.0M/170M [00:01&lt;00:02, 44.6MB/s]
46%|####6 | 78.3M/170M [00:02&lt;00:03, 27.2MB/s]
48%|####8 | 81.9M/170M [00:02&lt;00:03, 25.2MB/s]
52%|#####1 | 88.0M/170M [00:02&lt;00:03, 27.8MB/s]
57%|#####6 | 96.0M/170M [00:02&lt;00:02, 33.1MB/s]
60%|###### | 102M/170M [00:02&lt;00:02, 34.8MB/s]
62%|######2 | 106M/170M [00:02&lt;00:01, 34.1MB/s]
66%|######5 | 112M/170M [00:03&lt;00:01, 37.1MB/s]
70%|######9 | 118M/170M [00:03&lt;00:01, 35.8MB/s]
72%|#######1 | 122M/170M [00:03&lt;00:01, 35.2MB/s]
74%|#######4 | 126M/170M [00:03&lt;00:01, 35.5MB/s]
76%|#######6 | 130M/170M [00:03&lt;00:01, 33.1MB/s]
79%|#######9 | 134M/170M [00:03&lt;00:01, 33.9MB/s]
81%|######## | 138M/170M [00:03&lt;00:01, 32.3MB/s]
85%|########4 | 144M/170M [00:04&lt;00:00, 34.7MB/s]
89%|########9 | 152M/170M [00:04&lt;00:00, 38.6MB/s]
94%|#########4| 160M/170M [00:04&lt;00:00, 40.2MB/s]
99%|#########8| 168M/170M [00:04&lt;00:00, 46.5MB/s]
100%|##########| 170M/170M [00:04&lt;00:00, 39.3MB/s]
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/nn/functional.py:3912: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:157: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/ops/boxes.py:159: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/__init__.py:1209: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can&#39;t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert condition, message
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:298: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
torch.tensor(s, dtype=torch.float32, device=boxes.device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/transform.py:299: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/ torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py:389: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
</pre></div>
</div>
</div>
<div class="section" id="download-a-test-image-and-pre-process">
<h2>Download a test image and pre-process<a class="headerlink" href="#download-a-test-image-and-pre-process" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_url</span></a> <span class="o">=</span> <span class="p">(</span>
<span class="s2">&quot;https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg&quot;</span>
<span class="p">)</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_path</span></a> <span class="o">=</span> <span class="n">download_testdata</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_url</span></a><span class="p">,</span> <span class="s2">&quot;test_street_small.jpg&quot;</span><span class="p">,</span> <span class="n">module</span><span class="o">=</span><span class="s2">&quot;data&quot;</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">img_path</span></a><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">in_size</span></a><span class="p">))</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span></a><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">img</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="import-the-graph-to-relay">
<h2>Import the graph to Relay<a class="headerlink" href="#import-the-graph-to-relay" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_name</span></a> <span class="o">=</span> <span class="s2">&quot;input0&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">shape_list</span></a> <span class="o">=</span> <span class="p">[(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_name</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_shape</span></a><span class="p">)]</span>
<span class="n">mod</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a> <span class="o">=</span> <a href="../../reference/api/python/relay/frontend.html#tvm.relay.frontend.from_pytorch" title="tvm.relay.frontend.from_pytorch" class="sphx-glr-backref-module-tvm-relay-frontend sphx-glr-backref-type-py-function"><span class="n">relay</span><span class="o">.</span><span class="n">frontend</span><span class="o">.</span><span class="n">from_pytorch</span></a><span class="p">(</span><span class="n">script_module</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">shape_list</span></a><span class="p">)</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
return LooseVersion(torch_ver) &gt; ver
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
other = LooseVersion(other)
/workspace/python/tvm/relay/build_module.py:345: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)
warnings.warn(
</pre></div>
</div>
</div>
<div class="section" id="compile-with-relay-vm">
<h2>Compile with Relay VM<a class="headerlink" href="#compile-with-relay-vm" title="Permalink to this headline"></a></h2>
<p>Note: Currently only CPU target is supported. For x86 target, it is
highly recommended to build TVM with Intel MKL and Intel OpenMP to get
best performance, due to the existence of large dense operator in
torchvision rcnn models.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Add &quot;-libs=mkl&quot; to get best performance on x86 target.</span>
<span class="c1"># For x86 machine supports AVX512, the complete target is</span>
<span class="c1"># &quot;llvm -mcpu=skylake-avx512 -libs=mkl&quot;</span>
<a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a> <span class="o">=</span> <span class="s2">&quot;llvm&quot;</span>
<span class="k">with</span> <a href="../../reference/api/python/ir.html#tvm.transform.PassContext" title="tvm.transform.PassContext" class="sphx-glr-backref-module-tvm-transform sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tvm</span><span class="o">.</span><span class="n">transform</span><span class="o">.</span><span class="n">PassContext</span></a><span class="p">(</span><span class="n">opt_level</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">disabled_pass</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;FoldScaleAxis&quot;</span><span class="p">]):</span>
<span class="n">vm_exec</span> <span class="o">=</span> <span class="n">relay</span><span class="o">.</span><span class="n">vm</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="inference-with-relay-vm">
<h2>Inference with Relay VM<a class="headerlink" href="#inference-with-relay-vm" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">dev</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="n">vm</span> <span class="o">=</span> <span class="n">VirtualMachine</span><span class="p">(</span><span class="n">vm_exec</span><span class="p">,</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">vm</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="s2">&quot;main&quot;</span><span class="p">,</span> <span class="o">**</span><span class="p">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">input_name</span></a><span class="p">:</span> <span class="n">img</span><span class="p">})</span>
<span class="n">tvm_res</span> <span class="o">=</span> <span class="n">vm</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
</pre></div>
</div>
</div>
<div class="section" id="get-boxes-with-score-larger-than-0-9">
<h2>Get boxes with score larger than 0.9<a class="headerlink" href="#get-boxes-with-score-larger-than-0-9" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">score_threshold</span></a> <span class="o">=</span> <span class="mf">0.9</span>
<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">boxes</span></a> <span class="o">=</span> <span class="n">tvm_res</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">valid_boxes</span></a> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">score</span></a> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tvm_res</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()):</span>
<span class="k">if</span> <a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">score</span></a> <span class="o">&gt;</span> <a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">score_threshold</span></a><span class="p">:</span>
<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">valid_boxes</span></a><span class="o">.</span><span class="n">append</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">boxes</span></a><span class="p">[</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">break</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Get </span><span class="si">{}</span><span class="s2"> valid boxes&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">valid_boxes</span></a><span class="p">)))</span>
</pre></div>
</div>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Get 9 valid boxes
</pre></div>
</div>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 4 minutes 4.847 seconds)</p>
<div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-deploy-models-deploy-object-detection-pytorch-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/7795da4b258c8feff986668b95ef57ad/deploy_object_detection_pytorch.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">deploy_object_detection_pytorch.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/399e1d7889ca66b69d51655784827503/deploy_object_detection_pytorch.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">deploy_object_detection_pytorch.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="deploy_prequantized.html" class="btn btn-neutral float-right" title="Deploy a Framework-prequantized Model with TVM" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
<a href="deploy_model_on_rasp.html" class="btn btn-neutral float-left" title="Deploy the Pretrained Model on Raspberry Pi" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
</div>
<div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div>
<section class="footerSec">
<div class="footerHeader">
<div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row">
<div class="copywrite d-flex align-items-center">
<h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5>
</div>
</div>
</div>
<div>
<div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div>
</div>
</section>
</footer>
</div>
</div>
</section>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
</body>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<!-- Theme Analytics -->
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-75982049-2', 'auto');
ga('send', 'pageview');
</script>
</body>
</html>