| |
| |
| |
| |
| |
| |
| <!DOCTYPE html> |
| <html class="writer-html5" lang="en" > |
| <head> |
| <meta charset="utf-8"> |
| |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| |
| <title>5. Training Vision Models for microTVM on Arduino — tvm 0.17.dev0 documentation</title> |
| |
| |
| |
| <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-binder.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-dataframe.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/sg_gallery-rendered-html.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> |
| <link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" /> |
| |
| |
| |
| <link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/> |
| |
| |
| |
| |
| |
| |
| |
| <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script> |
| <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> |
| <script src="../../_static/jquery.js"></script> |
| <script src="../../_static/underscore.js"></script> |
| <script src="../../_static/doctools.js"></script> |
| |
| <script type="text/javascript" src="../../_static/js/theme.js"></script> |
| |
| |
| <script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script> |
| <link rel="index" title="Index" href="../../genindex.html" /> |
| <link rel="search" title="Search" href="../../search.html" /> |
| <link rel="next" title="6. Model Tuning with microTVM" href="micro_autotune.html" /> |
| <link rel="prev" title="4. microTVM PyTorch Tutorial" href="micro_pytorch.html" /> |
| </head> |
| |
| <body class="wy-body-for-nav"> |
| |
| |
| <div class="wy-grid-for-nav"> |
| |
| |
| <header class="header"> |
| <div class="innercontainer"> |
| <div class="headerInner d-flex justify-content-between align-items-center"> |
| <div class="headerLogo"> |
| <a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a> |
| </div> |
| |
| <div id="headMenu" class="headerNav"> |
| <button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button> |
| <ul class="nav"> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/community>Community</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/download>Download</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/vta>VTA</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/blog>Blog</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvm.apache.org/docs>Docs</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://tvmconf.org>Conference</a> |
| </li> |
| <li class="nav-item"> |
| <a class="nav-link" href=https://github.com/apache/tvm/>Github</a> |
| </li> |
| </ul> |
| <div class="responsivetlcdropdown"> |
| <button type="button" class="btn-link"> |
| ASF |
| </button> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| <div class="responsiveMenuIcon"> |
| <button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button> |
| </div> |
| |
| <div class="tlcDropdown"> |
| <div class="dropdown"> |
| <button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false"> |
| ASF |
| </button> |
| <div class="dropdown-menu dropdown-menu-right"> |
| <ul> |
| <li> |
| <a href=https://apache.org/>Apache Homepage</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/licenses/>License</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/security/>Security</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/foundation/thanks.html>Thanks</a> |
| </li> |
| <li> |
| <a href=https://www.apache.org/events/current-event>Events</a> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| </header> |
| |
| <nav data-toggle="wy-nav-shift" class="wy-nav-side fixed"> |
| <div class="wy-side-scroll"> |
| <div class="wy-side-nav-search" > |
| |
| |
| |
| <a href="../../index.html"> |
| |
| |
| |
| |
| <img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/> |
| |
| </a> |
| |
| |
| |
| |
| <input type="checkbox" class="version-toggle-box" hidden id="version-toggle"> |
| <label for="version-toggle" class="version-toggle-label"> |
| <div tabindex="0" class="version version-selector version-selector-show"> |
| 0.17.dev0 <span class="chevron versions-hidden"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m8 4 8 8-8 8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span><span class="chevron versions-shown"><svg fill="none" height="24" viewBox="0 0 24 24" width="24" xmlns="http://www.w3.org/2000/svg"><path d="m4 8 8 8 8-8" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="2"/></svg></span> |
| </div> |
| </label> |
| <div class="version-details wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| <p class="caption" role="heading"><span class="caption-text">Versions</span></p> |
| <ol style="text-align: left"> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="/">0.17.dev0 (main)</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.8.0/">v0.8.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.9.0/">v0.9.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.10.0/">v0.10.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.11.0/">v0.11.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.12.0/">v0.12.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.13.0/">v0.13.0</a></div></li> |
| |
| |
| |
| |
| <li><div class="version"><a style="font-size: 0.8em; padding: 4px" href="v0.14.0/">v0.14.0</a></div></li> |
| |
| </ol> |
| </div> |
| |
| |
| |
| |
| <div role="search"> |
| <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get"> |
| <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" /> |
| <input type="hidden" name="check_keywords" value="yes" /> |
| <input type="hidden" name="area" value="default" /> |
| </form> |
| </div> |
| |
| |
| </div> |
| |
| |
| <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation"> |
| |
| |
| |
| |
| |
| |
| <p class="caption" role="heading"><span class="caption-text">Getting Started</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing TVM</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">Contributor Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">User Guide</span></p> |
| <ul class="current"> |
| <li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li> |
| <li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current"> |
| <li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">Compile Deep Learning Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">Deploy Models and Integrate TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../optimize_operators/index.html">Optimize Tensor Operators</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li> |
| <li class="toctree-l2 current"><a class="reference internal" href="index.html">Work With microTVM</a><ul class="current"> |
| <li class="toctree-l3"><a class="reference internal" href="micro_tvmc.html">1. microTVM CLI Tool</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_tflite.html">2. microTVM TFLite Tutorial</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_aot.html">3. microTVM Ahead-of-Time (AOT) Compilation</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_pytorch.html">4. microTVM PyTorch Tutorial</a></li> |
| <li class="toctree-l3 current"><a class="current reference internal" href="#">5. Training Vision Models for microTVM on Arduino</a><ul> |
| <li class="toctree-l4"><a class="reference internal" href="#motivation">Motivation</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#downloading-the-data">Downloading the Data</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#loading-the-data">Loading the Data</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#id1">Loading the Data</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#quantization">Quantization</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#compiling-with-tvm-for-arduino">Compiling With TVM For Arduino</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#testing-our-arduino-project">Testing our Arduino Project</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#writing-our-arduino-script">Writing our Arduino Script</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#uploading-to-our-device">Uploading to Our Device</a></li> |
| <li class="toctree-l4"><a class="reference internal" href="#summary">Summary</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_autotune.html">6. Model Tuning with microTVM</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_ethosu.html">7. Running TVM on bare metal Arm(R) Cortex(R)-M55 CPU and Ethos(TM)-U55 NPU with CMSIS-NN</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_mlperftiny.html">8. Creating Your MLPerfTiny Submission with microTVM</a></li> |
| <li class="toctree-l3"><a class="reference internal" href="micro_custom_ide.html">9. Bring microTVM to your own development environment</a></li> |
| </ul> |
| </li> |
| <li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li> |
| <li class="toctree-l2"><a class="reference internal" href="../../faq.html">Frequently Asked Questions</a></li> |
| </ul> |
| </li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Developer Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">Developer How-To Guide</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Architecture Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Topic Guides</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM: TVM on bare-metal</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li> |
| </ul> |
| <p class="caption" role="heading"><span class="caption-text">Reference Guide</span></p> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">Language Reference</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../genindex.html">Index</a></li> |
| </ul> |
| |
| |
| |
| </div> |
| |
| </div> |
| </nav> |
| |
| <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> |
| |
| <nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top"> |
| |
| <div class="togglemenu"> |
| |
| </div> |
| <div class="nav-content"> |
| <!-- tvm --> |
| Table of Contents |
| </div> |
| |
| </nav> |
| |
| |
| <div class="wy-nav-content"> |
| |
| <div class="rst-content"> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <div role="navigation" aria-label="breadcrumbs navigation"> |
| |
| <ul class="wy-breadcrumbs"> |
| |
| <li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li> |
| |
| <li><a href="index.html">Work With microTVM</a> <span class="br-arrow">></span></li> |
| |
| <li>5. Training Vision Models for microTVM on Arduino</li> |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="wy-breadcrumbs-aside"> |
| |
| |
| |
| <a href="https://github.com/apache/tvm/edit/main/docs/how_to/work_with_microtvm/micro_train.rst" class="fa fa-github"> Edit on GitHub</a> |
| |
| |
| |
| </li> |
| |
| </ul> |
| |
| |
| <hr/> |
| </div> |
| <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> |
| <div itemprop="articleBody"> |
| |
| <div class="sphx-glr-download-link-note admonition note"> |
| <p class="admonition-title">Note</p> |
| <p>This tutorial can be used interactively with Google Colab! You can also click |
| <a class="reference internal" href="#sphx-glr-download-how-to-work-with-microtvm-micro-train-py"><span class="std std-ref">here</span></a> to run the Jupyter notebook locally.</p> |
| <a class="reference external image-reference" href="https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/a7c7ea4b5017ae70db1f51dd8e6dcd82/micro_train.ipynb"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg" width="300px" /></a> |
| </div> |
| <div class="sphx-glr-example-title section" id="training-vision-models-for-microtvm-on-arduino"> |
| <span id="tutorial-micro-train-arduino"></span><span id="sphx-glr-how-to-work-with-microtvm-micro-train-py"></span><h1>5. Training Vision Models for microTVM on Arduino<a class="headerlink" href="#training-vision-models-for-microtvm-on-arduino" title="Permalink to this headline">¶</a></h1> |
| <p><strong>Author</strong>: <a class="reference external" href="https://github.com/guberti">Gavin Uberti</a></p> |
| <p>This tutorial shows how MobileNetV1 models can be trained |
| to fit on embedded devices, and how those models can be |
| deployed to Arduino using TVM.</p> |
| <div class="section" id="motivation"> |
| <h2>Motivation<a class="headerlink" href="#motivation" title="Permalink to this headline">¶</a></h2> |
| <p>When building IOT devices, we often want them to <strong>see and understand</strong> the world around them. |
| This can take many forms, but often times a device will want to know if a certain <strong>kind of |
| object</strong> is in its field of vision.</p> |
| <p>For example, a security camera might look for <strong>people</strong>, so it can decide whether to save a video |
| to memory. A traffic light might look for <strong>cars</strong>, so it can judge which lights should change |
| first. Or a forest camera might look for a <strong>kind of animal</strong>, so they can estimate how large |
| the animal population is.</p> |
| <p>To make these devices affordable, we would like them to need only a low-cost processor like the |
| <a class="reference external" href="https://www.nordicsemi.com/Products/nRF52840">nRF52840</a> (costing five dollars each on Mouser) or the <a class="reference external" href="https://www.raspberrypi.com/products/rp2040/">RP2040</a> (just $1.45 each!).</p> |
| <p>These devices have very little memory (~250 KB RAM), meaning that no conventional edge AI |
| vision model (like MobileNet or EfficientNet) will be able to run. In this tutorial, we will |
| show how these models can be modified to work around this requirement. Then, we will use TVM |
| to compile and deploy it for an Arduino that uses one of these processors.</p> |
| <div class="section" id="installing-the-prerequisites"> |
| <h3>Installing the Prerequisites<a class="headerlink" href="#installing-the-prerequisites" title="Permalink to this headline">¶</a></h3> |
| <p>This tutorial will use TensorFlow to train the model - a widely used machine learning library |
| created by Google. TensorFlow is a very low-level library, however, so we will the Keras |
| interface to talk to TensorFlow. We will also use TensorFlow Lite to perform quantization on |
| our model, as TensorFlow by itself does not support this.</p> |
| <p>Once we have our generated model, we will use TVM to compile and test it. To avoid having to |
| build from source, we’ll install <code class="docutils literal notranslate"><span class="pre">tlcpack</span></code> - a community build of TVM. Lastly, we’ll also |
| install <code class="docutils literal notranslate"><span class="pre">imagemagick</span></code> and <code class="docutils literal notranslate"><span class="pre">curl</span></code> to preprocess data:</p> |
| <blockquote> |
| <div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>pip<span class="w"> </span>install<span class="w"> </span>-q<span class="w"> </span>tensorflow<span class="w"> </span>tflite |
| pip<span class="w"> </span>install<span class="w"> </span>-q<span class="w"> </span>tlcpack-nightly<span class="w"> </span>-f<span class="w"> </span>https://tlcpack.ai/wheels |
| apt-get<span class="w"> </span>-qq<span class="w"> </span>install<span class="w"> </span>imagemagick<span class="w"> </span>curl |
| |
| <span class="c1"># Install Arduino CLI and library for Nano 33 BLE</span> |
| curl<span class="w"> </span>-fsSL<span class="w"> </span>https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh<span class="w"> </span><span class="p">|</span><span class="w"> </span>sh |
| /content/bin/arduino-cli<span class="w"> </span>core<span class="w"> </span>update-index |
| /content/bin/arduino-cli<span class="w"> </span>core<span class="w"> </span>install<span class="w"> </span>arduino:mbed_nano |
| </pre></div> |
| </div> |
| </div></blockquote> |
| </div> |
| <div class="section" id="using-the-gpu"> |
| <h3>Using the GPU<a class="headerlink" href="#using-the-gpu" title="Permalink to this headline">¶</a></h3> |
| <p>This tutorial demonstrates training a neural network, which is requires a lot of computing power |
| and will go much faster if you have a GPU. If you are viewing this tutorial on Google Colab, you |
| can enable a GPU by going to <strong>Runtime->Change runtime type</strong> and selecting “GPU” as our hardware |
| accelerator. If you are running locally, you can <a class="reference external" href="https://www.tensorflow.org/guide/gpu">follow TensorFlow’s guide</a> instead.</p> |
| <p>We can test our GPU installation with the following code:</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> |
| |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">tf</span><span class="o">.</span><span class="n">test</span><span class="o">.</span><span class="n">gpu_device_name</span><span class="p">():</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"No GPU was detected!"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Model training will take much longer (~30 minutes instead of ~5)"</span><span class="p">)</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"GPU detected - you're good to go."</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>No GPU was detected! |
| Model training will take much longer (~30 minutes instead of ~5) |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="choosing-our-work-dir"> |
| <h3>Choosing Our Work Dir<a class="headerlink" href="#choosing-our-work-dir" title="Permalink to this headline">¶</a></h3> |
| <p>We need to pick a directory where our image datasets, trained model, and eventual Arduino sketch |
| will all live. If running on Google Colab, we’ll save everything in <code class="docutils literal notranslate"><span class="pre">/root</span></code> (aka <code class="docutils literal notranslate"><span class="pre">~</span></code>) but you’ll |
| probably want to store it elsewhere if running locally. Note that this variable only affects Python |
| scripts - you’ll have to adjust the Bash commands too.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span> |
| |
| <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a> <span class="o">=</span> <span class="s2">"/root"</span> |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="downloading-the-data"> |
| <h2>Downloading the Data<a class="headerlink" href="#downloading-the-data" title="Permalink to this headline">¶</a></h2> |
| <p>Convolutional neural networks usually learn by looking at many images, along with labels telling |
| the network what those images are. To get these images, we’ll need a publicly available dataset |
| with thousands of images of all sorts of objects and labels of what’s in each image. We’ll also |
| need a bunch of images that <strong>aren’t</strong> of cars, as we’re trying to distinguish these two classes.</p> |
| <p>In this tutorial, we’ll create a model to detect if an image contains a <strong>car</strong>, but you can use |
| whatever category you like! Just change the source URL below to one containing images of another |
| type of object.</p> |
| <p>To get our car images, we’ll be downloading the <a class="reference external" href="http://ai.stanford.edu/~jkrause/cars/car_dataset.html">Stanford Cars dataset</a>, |
| which contains 16,185 full color images of cars. We’ll also need images of random things that |
| aren’t cars, so we’ll use the <a class="reference external" href="https://cocodataset.org/#home">COCO 2017</a> validation set (it’s |
| smaller, and thus faster to download than the full training set. Training on the full data set |
| would yield better results). Note that there are some cars in the COCO 2017 data set, but it’s |
| a small enough fraction not to matter - just keep in mind that this will drive down our percieved |
| accuracy slightly.</p> |
| <p>We could use the TensorFlow dataloader utilities, but we’ll instead do it manually to make sure |
| it’s easy to change the datasets being used. We’ll end up with the following file hierarchy:</p> |
| <blockquote> |
| <div><div class="highlight-default notranslate"><div class="highlight"><pre><span></span>/root |
| ├── images |
| │ ├── object |
| │ │ ├── 000001.jpg |
| │ │ │ ... |
| │ │ └── 016185.jpg |
| │ ├── object.tgz |
| │ ├── random |
| │ │ ├── 000000000139.jpg |
| │ │ │ ... |
| │ │ └── 000000581781.jpg |
| │ └── random.zip |
| </pre></div> |
| </div> |
| </div></blockquote> |
| <p>We should also note that Stanford cars has 8k images, while the COCO 2017 validation set is 5k |
| images - it is not a 50/50 split! If we wanted to, we could weight these classes differently |
| during training to correct for this, but training will still work if we ignore it. It should |
| take about <strong>2 minutes</strong> to download the Stanford Cars, while COCO 2017 validation will take |
| <strong>1 minute</strong>.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span> |
| <span class="kn">import</span> <span class="nn">shutil</span> |
| <span class="kn">import</span> <span class="nn">urllib.request</span> |
| |
| <span class="c1"># Download datasets</span> |
| <a href="https://docs.python.org/3/library/os.html#os.makedirs" title="os.makedirs" class="sphx-glr-backref-module-os sphx-glr-backref-type-py-function"><span class="n">os</span><span class="o">.</span><span class="n">makedirs</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/os.html#os.makedirs" title="os.makedirs" class="sphx-glr-backref-module-os sphx-glr-backref-type-py-function"><span class="n">os</span><span class="o">.</span><span class="n">makedirs</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/urllib.request.html#urllib.request.urlretrieve" title="urllib.request.urlretrieve" class="sphx-glr-backref-module-urllib-request sphx-glr-backref-type-py-function"><span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span></a><span class="p">(</span> |
| <span class="s2">"https://data.deepai.org/stanfordcars.zip"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads/target.zip"</span> |
| <span class="p">)</span> |
| <a href="https://docs.python.org/3/library/urllib.request.html#urllib.request.urlretrieve" title="urllib.request.urlretrieve" class="sphx-glr-backref-module-urllib-request sphx-glr-backref-type-py-function"><span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span></a><span class="p">(</span> |
| <span class="s2">"http://images.cocodataset.org/zips/val2017.zip"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads/random.zip"</span> |
| <span class="p">)</span> |
| |
| <span class="c1"># Extract them and rename their folders</span> |
| <a href="https://docs.python.org/3/library/shutil.html#shutil.unpack_archive" title="shutil.unpack_archive" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">unpack_archive</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads/target.zip"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/shutil.html#shutil.unpack_archive" title="shutil.unpack_archive" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">unpack_archive</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads/random.zip"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/shutil.html#shutil.move" title="shutil.move" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">move</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads/cars_train/cars_train"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images/target"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/shutil.html#shutil.move" title="shutil.move" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">move</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/downloads/val2017"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images/random"</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>'/tmp/tmpsxx5kk25/images/random' |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="loading-the-data"> |
| <h2>Loading the Data<a class="headerlink" href="#loading-the-data" title="Permalink to this headline">¶</a></h2> |
| <p>Currently, our data is stored on-disk as JPG files of various sizes. To train with it, we’ll have |
| to load the images into memory, resize them to be 64x64, and convert them to raw, uncompressed |
| data. Keras’s <code class="docutils literal notranslate"><span class="pre">image_dataset_from_directory</span></code> will take care of most of this, though it loads |
| images such that each pixel value is a float from 0 to 255.</p> |
| <p>We’ll also need to load labels, though Keras will help with this. From our subdirectory structure, |
| it knows the images in <code class="docutils literal notranslate"><span class="pre">/objects</span></code> are one class, and those in <code class="docutils literal notranslate"><span class="pre">/random</span></code> another. Setting |
| <code class="docutils literal notranslate"><span class="pre">label_mode='categorical'</span></code> tells Keras to convert these into <strong>categorical labels</strong> - a 2x1 vector |
| that’s either <code class="docutils literal notranslate"><span class="pre">[1,</span> <span class="pre">0]</span></code> for an object of our target class, or <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1]</span></code> vector for anything else. |
| We’ll also set <code class="docutils literal notranslate"><span class="pre">shuffle=True</span></code> to randomize the order of our examples.</p> |
| <p>We will also <strong>batch</strong> the data - grouping samples into clumps to make our training go faster. |
| Setting <code class="docutils literal notranslate"><span class="pre">batch_size</span> <span class="pre">=</span> <span class="pre">32</span></code> is a decent number.</p> |
| <p>Lastly, in machine learning we generally want our inputs to be small numbers. We’ll thus use a |
| <code class="docutils literal notranslate"><span class="pre">Rescaling</span></code> layer to change our images such that each pixel is a float between <code class="docutils literal notranslate"><span class="pre">0.0</span></code> and <code class="docutils literal notranslate"><span class="pre">1.0</span></code>, |
| instead of <code class="docutils literal notranslate"><span class="pre">0</span></code> to <code class="docutils literal notranslate"><span class="pre">255</span></code>. We need to be careful not to rescale our categorical labels though, so |
| we’ll use a <code class="docutils literal notranslate"><span class="pre">lambda</span></code> function.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">IMAGE_SIZE</span></a> <span class="o">=</span> <span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> |
| <span class="n">unscaled_dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">image_dataset_from_directory</span><span class="p">(</span> |
| <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images"</span><span class="p">,</span> |
| <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> |
| <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> |
| <span class="n">label_mode</span><span class="o">=</span><span class="s2">"categorical"</span><span class="p">,</span> |
| <span class="n">image_size</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">IMAGE_SIZE</span></a><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">],</span> |
| <span class="p">)</span> |
| <span class="n">rescale</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)</span> |
| <span class="n">full_dataset</span> <span class="o">=</span> <span class="n">unscaled_dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">im</span><span class="p">,</span> <span class="n">lbl</span><span class="p">:</span> <span class="p">(</span><span class="n">rescale</span><span class="p">(</span><span class="n">im</span><span class="p">),</span> <span class="n">lbl</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Found 13144 files belonging to 2 classes. |
| </pre></div> |
| </div> |
| <div class="section" id="what-s-inside-our-dataset"> |
| <h3>What’s Inside Our Dataset?<a class="headerlink" href="#what-s-inside-our-dataset" title="Permalink to this headline">¶</a></h3> |
| <p>Before giving this data set to our neural network, we ought to give it a quick visual inspection. |
| Does the data look properly transformed? Do the labels seem appropriate? And what’s our ratio of |
| objects to other stuff? We can display some examples from our datasets using <code class="docutils literal notranslate"><span class="pre">matplotlib</span></code>:</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> |
| |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_target_class</span></a> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><a href="https://docs.python.org/3/library/os.html#os.listdir" title="os.listdir" class="sphx-glr-backref-module-os sphx-glr-backref-type-py-function"><span class="n">os</span><span class="o">.</span><span class="n">listdir</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images/target/"</span><span class="p">))</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_random_class</span></a> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><a href="https://docs.python.org/3/library/os.html#os.listdir" title="os.listdir" class="sphx-glr-backref-module-os sphx-glr-backref-type-py-function"><span class="n">os</span><span class="o">.</span><span class="n">listdir</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images/random/"</span><span class="p">))</span> |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images/target contains </span><span class="si">{</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_target_class</span></a><span class="si">}</span><span class="s2"> images"</span><span class="p">)</span> |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/images/random contains </span><span class="si">{</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_random_class</span></a><span class="si">}</span><span class="s2"> images"</span><span class="p">)</span> |
| |
| <span class="c1"># Show some samples and their labels</span> |
| <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">SAMPLES_TO_SHOW</span></a> <span class="o">=</span> <span class="mi">10</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> |
| <span class="k">for</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a><span class="p">,</span> <span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">unscaled_dataset</span><span class="o">.</span><span class="n">unbatch</span><span class="p">()):</span> |
| <span class="k">if</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a> <span class="o">>=</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">SAMPLES_TO_SHOW</span></a><span class="p">:</span> |
| <span class="k">break</span> |
| <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">SAMPLES_TO_SHOW</span></a><span class="p">,</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">i</span></a> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">label</span><span class="o">.</span><span class="n">numpy</span><span class="p">()))</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <img src="../../_images/sphx_glr_micro_train_001.png" srcset="../../_images/sphx_glr_micro_train_001.png" alt="[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]" class = "sphx-glr-single-img"/><div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/tmp/tmpsxx5kk25/images/target contains 8144 images |
| /tmp/tmpsxx5kk25/images/random contains 5000 images |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="validating-our-accuracy"> |
| <h3>Validating our Accuracy<a class="headerlink" href="#validating-our-accuracy" title="Permalink to this headline">¶</a></h3> |
| <p>While developing our model, we’ll often want to check how accurate it is (e.g. to see if it |
| improves during training). How do we do this? We could just train it on <em>all</em> of the data, and |
| then ask it to classify that same data. However, our model could cheat by just memorizing all of |
| the samples, which would make it <em>appear</em> to have very high accuracy, but perform very badly in |
| reality. In practice, this “memorizing” is called <strong>overfitting</strong>.</p> |
| <p>To prevent this, we will set aside some of the data (we’ll use 20%) as a <strong>validation set</strong>. Our |
| model will never be trained on validation data - we’ll only use it to check our model’s accuracy.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_batches</span></a> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">full_dataset</span><span class="p">)</span> |
| <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">full_dataset</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">num_batches</span></a> <span class="o">*</span> <span class="mf">0.8</span><span class="p">))</span> |
| <span class="n">validation_dataset</span> <span class="o">=</span> <span class="n">full_dataset</span><span class="o">.</span><span class="n">skip</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="id1"> |
| <h2>Loading the Data<a class="headerlink" href="#id1" title="Permalink to this headline">¶</a></h2> |
| <p>In the past decade, <a class="reference external" href="https://en.wikipedia.org/wiki/Convolutional_neural_network">convolutional neural networks</a> have been widely |
| adopted for image classification tasks. State-of-the-art models like <a class="reference external" href="https://arxiv.org/abs/2104.00298">EfficientNet V2</a> are able |
| to perform image classification better than even humans! Unfortunately, these models have tens of |
| millions of parameters, and thus won’t fit on cheap security camera computers.</p> |
| <p>Our applications generally don’t need perfect accuracy - 90% is good enough. We can thus use the |
| older and smaller MobileNet V1 architecture. But this <em>still</em> won’t be small enough - by default, |
| MobileNet V1 with 224x224 inputs and alpha 1.0 takes ~50 MB to just <strong>store</strong>. To reduce the size |
| of the model, there are three knobs we can turn. First, we can reduce the size of the input images |
| from 224x224 to 96x96 or 64x64, and Keras makes it easy to do this. We can also reduce the <strong>alpha</strong> |
| of the model, from 1.0 to 0.25, which downscales the width of the network (and the number of |
| filters) by a factor of four. And if we were really strapped for space, we could reduce the |
| number of <strong>channels</strong> by making our model take grayscale images instead of RGB ones.</p> |
| <p>In this tutorial, we will use an RGB 64x64 input image and alpha 0.25. This is not quite |
| ideal, but it allows the finished model to fit in 192 KB of RAM, while still letting us perform |
| transfer learning using the official TensorFlow source models (if we used alpha <0.25 or a |
| grayscale input, we wouldn’t be able to do this).</p> |
| <div class="section" id="what-is-transfer-learning"> |
| <h3>What is Transfer Learning?<a class="headerlink" href="#what-is-transfer-learning" title="Permalink to this headline">¶</a></h3> |
| <p>Deep learning has <a class="reference external" href="https://paperswithcode.com/sota/image-classification-on-imagenet">dominated image classification</a> for a long time, |
| but training neural networks takes a lot of time. When a neural network is trained “from scratch”, |
| its parameters start out randomly initialized, forcing it to learn very slowly how to tell images |
| apart.</p> |
| <p>With transfer learning, we instead start with a neural network that’s <strong>already</strong> good at a |
| specific task. In this example, that task is classifying images from <a class="reference external" href="https://www.image-net.org/">the ImageNet database</a>. This |
| means the network already has some object detection capabilities, and is likely closer to what you |
| want then a random model would be.</p> |
| <p>This works especially well with image processing neural networks like MobileNet. In practice, it |
| turns out the convolutional layers of the model (i.e. the first 90% of the layers) are used for |
| identifying low-level features like lines and shapes - only the last few fully connected layers |
| are used to determine how those shapes make up the objects the network is trying to detect.</p> |
| <p>We can take advantage of this by starting training with a MobileNet model that was trained on |
| ImageNet, and already knows how to identify those lines and shapes. We can then just remove the |
| last few layers from this pretrained model, and add our own final layers. We’ll then train this |
| conglomerate model for a few epochs on our cars vs non-cars dataset, to adjust the first layers |
| and train from scratch the last layers. This process of training an already-partially-trained |
| model is called <em>fine-tuning</em>.</p> |
| <p>Source MobileNets for transfer learning have been <a class="reference external" href="https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md">pretrained by the TensorFlow folks</a>, so we |
| can just download the one closest to what we want (the 128x128 input model with 0.25 depth scale).</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/os.html#os.makedirs" title="os.makedirs" class="sphx-glr-backref-module-os sphx-glr-backref-type-py-function"><span class="n">os</span><span class="o">.</span><span class="n">makedirs</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models"</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WEIGHTS_PATH</span></a> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models/mobilenet_2_5_128_tf.h5"</span> |
| <a href="https://docs.python.org/3/library/urllib.request.html#urllib.request.urlretrieve" title="urllib.request.urlretrieve" class="sphx-glr-backref-module-urllib-request sphx-glr-backref-type-py-function"><span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span></a><span class="p">(</span> |
| <span class="s2">"https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_2_5_128_tf.h5"</span><span class="p">,</span> |
| <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WEIGHTS_PATH</span></a><span class="p">,</span> |
| <span class="p">)</span> |
| |
| <span class="n">pretrained</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">MobileNet</span><span class="p">(</span> |
| <span class="n">input_shape</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">IMAGE_SIZE</span></a><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">WEIGHTS_PATH</span></a><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.25</span> |
| <span class="p">)</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="modifying-our-network"> |
| <h3>Modifying Our Network<a class="headerlink" href="#modifying-our-network" title="Permalink to this headline">¶</a></h3> |
| <p>As mentioned above, our pretrained model is designed to classify the 1,000 ImageNet categories, |
| but we want to convert it to classify cars. Since only the bottom few layers are task-specific, |
| we’ll <strong>cut off the last five layers</strong> of our original model. In their place we’ll build our own |
| “tail” to the model by performing respape, dropout, flatten, and softmax operations.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span> |
| |
| <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">InputLayer</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">IMAGE_SIZE</span></a><span class="p">))</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">pretrained</span><span class="o">.</span><span class="n">inputs</span></a><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">pretrained</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="o">-</span><span class="mi">5</span><span class="p">]</span><span class="o">.</span><span class="n">output</span><span class="p">))</span> |
| |
| <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="o">-</span><span class="mi">1</span><span class="p">,)))</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.1</span><span class="p">))</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">())</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">))</span> |
| </pre></div> |
| </div> |
| </div> |
| <div class="section" id="fine-tuning-our-network"> |
| <h3>Fine Tuning Our Network<a class="headerlink" href="#fine-tuning-our-network" title="Permalink to this headline">¶</a></h3> |
| <p>When training neural networks, we must set a parameter called the <strong>learning rate</strong> that controls |
| how fast our network learns. It must be set carefully - too slow, and our network will take |
| forever to train; too fast, and our network won’t be able to learn some fine details. Generally |
| for Adam (the optimizer we’re using), <code class="docutils literal notranslate"><span class="pre">0.001</span></code> is a pretty good learning rate (and is what’s |
| recommended in the <a class="reference external" href="https://arxiv.org/abs/1412.6980">original paper</a>). However, in this case |
| <code class="docutils literal notranslate"><span class="pre">0.0005</span></code> seems to work a little better.</p> |
| <p>We’ll also pass the validation set from earlier to <code class="docutils literal notranslate"><span class="pre">model.fit</span></code>. This will evaluate how good our |
| model is each time we train it, and let us track how our model is improving. Once training is |
| finished, the model should have a validation accuracy around <code class="docutils literal notranslate"><span class="pre">0.98</span></code> (meaning it was right 98% of |
| the time on our validation set).</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> |
| <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.0005</span><span class="p">),</span> |
| <span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> |
| <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> |
| <span class="p">)</span> |
| <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Epoch 1/3 |
| 328/328 - 48s - loss: 0.2301 - accuracy: 0.9223 - val_loss: 0.1238 - val_accuracy: 0.9535 - 48s/epoch - 146ms/step |
| Epoch 2/3 |
| 328/328 - 45s - loss: 0.1032 - accuracy: 0.9618 - val_loss: 0.1061 - val_accuracy: 0.9607 - 45s/epoch - 137ms/step |
| Epoch 3/3 |
| 328/328 - 45s - loss: 0.0665 - accuracy: 0.9757 - val_loss: 0.1202 - val_accuracy: 0.9603 - 45s/epoch - 136ms/step |
| |
| <keras.callbacks.History object at 0x7f07227a3520> |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="quantization"> |
| <h2>Quantization<a class="headerlink" href="#quantization" title="Permalink to this headline">¶</a></h2> |
| <p>We’ve done a decent job of reducing our model’s size so far - changing the input dimension, |
| along with removing the bottom layers reduced the model to just 219k parameters. However, each of |
| these parameters is a <code class="docutils literal notranslate"><span class="pre">float32</span></code> that takes four bytes, so our model will take up almost one MB!</p> |
| <p>Additionally, it might be the case that our hardware doesn’t have built-in support for floating |
| point numbers. While most high-memory Arduinos (like the Nano 33 BLE) do have hardware support, |
| some others (like the Arduino Due) do not. On any boards <em>without</em> dedicated hardware support, |
| floating point multiplication will be extremely slow.</p> |
| <p>To address both issues we will <strong>quantize</strong> the model - representing the weights as eight bit |
| integers. It’s more complex than just rounding, though - to get the best performance, TensorFlow |
| tracks how each neuron in our model activates, so we can figure out how most accurately simulate |
| the neuron’s original activations with integer operations.</p> |
| <p>We will help TensorFlow do this by creating a representative dataset - a subset of the original |
| that is used for tracking how those neurons activate. We’ll then pass this into a <code class="docutils literal notranslate"><span class="pre">TFLiteConverter</span></code> |
| (Keras itself does not have quantization support) with an <code class="docutils literal notranslate"><span class="pre">Optimize</span></code> flag to tell TFLite to perform |
| the conversion. By default, TFLite keeps the inputs and outputs of our model as floats, so we must |
| explicitly tell it to avoid this behavior.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">representative_dataset</span><span class="p">():</span> |
| <span class="k">for</span> <span class="n">image_batch</span><span class="p">,</span> <span class="n">label_batch</span> <span class="ow">in</span> <span class="n">full_dataset</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> |
| <span class="k">yield</span> <span class="p">[</span><span class="n">image_batch</span><span class="p">]</span> |
| |
| |
| <span class="n">converter</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">lite</span><span class="o">.</span><span class="n">TFLiteConverter</span><span class="o">.</span><span class="n">from_keras_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">converter</span><span class="o">.</span><span class="n">optimizations</span></a> <span class="o">=</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">lite</span><span class="o">.</span><span class="n">Optimize</span><span class="o">.</span><span class="n">DEFAULT</span><span class="p">]</span> |
| <span class="n">converter</span><span class="o">.</span><span class="n">representative_dataset</span> <span class="o">=</span> <span class="n">representative_dataset</span> |
| <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">converter</span><span class="o">.</span><span class="n">target_spec</span><span class="o">.</span><span class="n">supported_ops</span></a> <span class="o">=</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">lite</span><span class="o">.</span><span class="n">OpsSet</span><span class="o">.</span><span class="n">TFLITE_BUILTINS_INT8</span><span class="p">]</span> |
| <span class="n">converter</span><span class="o">.</span><span class="n">inference_input_type</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">uint8</span> |
| <span class="n">converter</span><span class="o">.</span><span class="n">inference_output_type</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">uint8</span> |
| |
| <a href="https://docs.python.org/3/library/stdtypes.html#bytes" title="builtins.bytes" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">quantized_model</span></a> <span class="o">=</span> <span class="n">converter</span><span class="o">.</span><span class="n">convert</span><span class="p">()</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>/venv/apache-tvm-py3.8/lib/python3.8/site-packages/tensorflow/lite/python/convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. |
| warnings.warn("Statistics for quantized inputs were expected, but not " |
| </pre></div> |
| </div> |
| <div class="section" id="download-the-model-if-desired"> |
| <h3>Download the Model if Desired<a class="headerlink" href="#download-the-model-if-desired" title="Permalink to this headline">¶</a></h3> |
| <p>We’ve now got a finished model that you can use locally or in other tutorials (try autotuning |
| this model or viewing it on <a class="reference external" href="https://netron.app/">https://netron.app/</a>). But before we do |
| those things, we’ll have to write it to a file (<code class="docutils literal notranslate"><span class="pre">quantized.tflite</span></code>). If you’re running this |
| tutorial on Google Colab, you’ll have to uncomment the last two lines to download the file |
| after writing it.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">QUANTIZED_MODEL_PATH</span></a> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models/quantized.tflite"</span> |
| <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">QUANTIZED_MODEL_PATH</span></a><span class="p">,</span> <span class="s2">"wb"</span><span class="p">)</span> <span class="k">as</span> <a href="https://docs.python.org/3/library/io.html#io.BufferedWriter" title="io.BufferedWriter" class="sphx-glr-backref-module-io sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">f</span></a><span class="p">:</span> |
| <a href="https://docs.python.org/3/library/io.html#io.BufferedWriter" title="io.BufferedWriter" class="sphx-glr-backref-module-io sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">f</span></a><span class="o">.</span><span class="n">write</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#bytes" title="builtins.bytes" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">quantized_model</span></a><span class="p">)</span> |
| <span class="c1"># from google.colab import files</span> |
| <span class="c1"># files.download(QUANTIZED_MODEL_PATH)</span> |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="compiling-with-tvm-for-arduino"> |
| <h2>Compiling With TVM For Arduino<a class="headerlink" href="#compiling-with-tvm-for-arduino" title="Permalink to this headline">¶</a></h2> |
| <p>TensorFlow has a built-in framework for deploying to microcontrollers - <a class="reference external" href="https://www.tensorflow.org/lite/microcontrollers">TFLite Micro</a>. However, |
| it’s poorly supported by development boards and does not support autotuning. We will use Apache |
| TVM instead.</p> |
| <p>TVM can be used either with its command line interface (<code class="docutils literal notranslate"><span class="pre">tvmc</span></code>) or with its Python interface. The |
| Python interface is fully-featured and more stable, so we’ll use it here.</p> |
| <p>TVM is an optimizing compiler, and optimizations to our model are performed in stages via |
| <strong>intermediate representations</strong>. The first of these is <a class="reference external" href="https://arxiv.org/abs/1810.00952">Relay</a> a high-level intermediate |
| representation emphasizing portability. The conversion from <code class="docutils literal notranslate"><span class="pre">.tflite</span></code> to Relay is done without any |
| knowledge of our “end goal” - the fact we intend to run this model on an Arduino.</p> |
| <div class="section" id="choosing-an-arduino-board"> |
| <h3>Choosing an Arduino Board<a class="headerlink" href="#choosing-an-arduino-board" title="Permalink to this headline">¶</a></h3> |
| <p>Next, we’ll have to decide exactly which Arduino board to use. The Arduino sketch that we |
| ultimately generate should be compatible with any board, but knowing which board we are using in |
| advance allows TVM to adjust its compilation strategy to get better performance.</p> |
| <p>There is one catch - we need enough <strong>memory</strong> (flash and RAM) to be able to run our model. We |
| won’t ever be able to run a complex vision model like a MobileNet on an Arduino Uno - that board |
| only has 2 kB of RAM and 32 kB of flash! Our model has ~200,000 parameters, so there is just no |
| way it could fit.</p> |
| <p>For this tutorial, we will use the Nano 33 BLE, which has 1 MB of flash memory and 256 KB of RAM. |
| However, any other Arduino with those specs or better should also work.</p> |
| </div> |
| <div class="section" id="generating-our-project"> |
| <h3>Generating our project<a class="headerlink" href="#generating-our-project" title="Permalink to this headline">¶</a></h3> |
| <p>Next, we’ll compile the model to TVM’s MLF (model library format) intermediate representation, |
| which consists of C/C++ code and is designed for autotuning. To improve performance, we’ll tell |
| TVM that we’re compiling for the <code class="docutils literal notranslate"><span class="pre">nrf52840</span></code> microprocessor (the one the Nano 33 BLE uses). We’ll |
| also tell it to use the C runtime (abbreviated <code class="docutils literal notranslate"><span class="pre">crt</span></code>) and to use ahead-of-time memory allocation |
| (abbreviated <code class="docutils literal notranslate"><span class="pre">aot</span></code>, which helps reduce the model’s memory footprint). Lastly, we will disable |
| vectorization with <code class="docutils literal notranslate"><span class="pre">"tir.disable_vectorize":</span> <span class="pre">True</span></code>, as C has no native vectorized types.</p> |
| <p>Once we have set these configuration parameters, we will call <code class="docutils literal notranslate"><span class="pre">tvm.relay.build</span></code> to compile our |
| Relay model into the MLF intermediate representation. From here, we just need to call |
| <code class="docutils literal notranslate"><span class="pre">tvm.micro.generate_project</span></code> and pass in the Arduino template project to finish compilation.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">shutil</span> |
| <span class="kn">import</span> <span class="nn">tvm</span> |
| <span class="kn">import</span> <span class="nn">tvm.micro.testing</span> |
| |
| <span class="c1"># Method to load model is different in TFLite 1 vs 2</span> |
| <span class="k">try</span><span class="p">:</span> <span class="c1"># TFLite 2.1 and above</span> |
| <span class="kn">import</span> <span class="nn">tflite</span> |
| |
| <span class="n">tflite_model</span> <span class="o">=</span> <span class="n">tflite</span><span class="o">.</span><span class="n">Model</span><span class="o">.</span><span class="n">GetRootAsModel</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#bytes" title="builtins.bytes" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">quantized_model</span></a><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> |
| <span class="k">except</span> <span class="ne">AttributeError</span><span class="p">:</span> <span class="c1"># Fall back to TFLite 1.14 method</span> |
| <span class="kn">import</span> <span class="nn">tflite.Model</span> |
| |
| <span class="n">tflite_model</span> <span class="o">=</span> <span class="n">tflite</span><span class="o">.</span><span class="n">Model</span><span class="o">.</span><span class="n">Model</span><span class="o">.</span><span class="n">GetRootAsModel</span><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#bytes" title="builtins.bytes" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">quantized_model</span></a><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> |
| |
| <span class="c1"># Convert to the Relay intermediate representation</span> |
| <span class="n">mod</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a> <span class="o">=</span> <a href="../../reference/api/python/relay/frontend.html#tvm.relay.frontend.from_tflite" title="tvm.relay.frontend.from_tflite" class="sphx-glr-backref-module-tvm-relay-frontend sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">relay</span><span class="o">.</span><span class="n">frontend</span><span class="o">.</span><span class="n">from_tflite</span></a><span class="p">(</span><span class="n">tflite_model</span><span class="p">)</span> |
| |
| <span class="c1"># Set configuration flags to improve performance</span> |
| <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">micro</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">get_target</span><span class="p">(</span><span class="s2">"zephyr"</span><span class="p">,</span> <span class="s2">"nrf5340dk_nrf5340_cpuapp"</span><span class="p">)</span> |
| <span class="n">runtime</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">relay</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="s2">"crt"</span><span class="p">)</span> |
| <span class="n">executor</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">relay</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">Executor</span><span class="p">(</span><span class="s2">"aot"</span><span class="p">,</span> <span class="p">{</span><span class="s2">"unpacked-api"</span><span class="p">:</span> <span class="kc">True</span><span class="p">})</span> |
| |
| <span class="c1"># Convert to the MLF intermediate representation</span> |
| <span class="k">with</span> <a href="../../reference/api/python/ir.html#tvm.transform.PassContext" title="tvm.transform.PassContext" class="sphx-glr-backref-module-tvm-transform sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">tvm</span><span class="o">.</span><span class="n">transform</span><span class="o">.</span><span class="n">PassContext</span></a><span class="p">(</span><span class="n">opt_level</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">config</span><span class="o">=</span><span class="p">{</span><span class="s2">"tir.disable_vectorize"</span><span class="p">:</span> <span class="kc">True</span><span class="p">}):</span> |
| <span class="n">mod</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">relay</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <a href="../../reference/api/python/target.html#tvm.target.Target" title="tvm.target.Target" class="sphx-glr-backref-module-tvm-target sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">target</span></a><span class="p">,</span> <span class="n">runtime</span><span class="o">=</span><span class="n">runtime</span><span class="p">,</span> <span class="n">executor</span><span class="o">=</span><span class="n">executor</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="o">=</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">params</span></a><span class="p">)</span> |
| |
| <span class="c1"># Generate an Arduino project from the MLF intermediate representation</span> |
| <a href="https://docs.python.org/3/library/shutil.html#shutil.rmtree" title="shutil.rmtree" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">rmtree</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models/project"</span><span class="p">,</span> <span class="n">ignore_errors</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/unittest.mock.html#unittest.mock.MagicMock" title="unittest.mock.MagicMock" class="sphx-glr-backref-module-unittest-mock sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arduino_project</span></a> <span class="o">=</span> <a href="../../reference/api/python/micro.html#tvm.micro.generate_project" title="tvm.micro.generate_project" class="sphx-glr-backref-module-tvm-micro sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">micro</span><span class="o">.</span><span class="n">generate_project</span></a><span class="p">(</span> |
| <a href="../../reference/api/python/micro.html#tvm.micro.get_microtvm_template_projects" title="tvm.micro.get_microtvm_template_projects" class="sphx-glr-backref-module-tvm-micro sphx-glr-backref-type-py-function"><span class="n">tvm</span><span class="o">.</span><span class="n">micro</span><span class="o">.</span><span class="n">get_microtvm_template_projects</span></a><span class="p">(</span><span class="s2">"arduino"</span><span class="p">),</span> |
| <span class="n">mod</span><span class="p">,</span> |
| <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models/project"</span><span class="p">,</span> |
| <span class="p">{</span> |
| <span class="s2">"board"</span><span class="p">:</span> <span class="s2">"nano33ble"</span><span class="p">,</span> |
| <span class="s2">"arduino_cli_cmd"</span><span class="p">:</span> <span class="s2">"/content/bin/arduino-cli"</span><span class="p">,</span> |
| <span class="s2">"project_type"</span><span class="p">:</span> <span class="s2">"example_project"</span><span class="p">,</span> |
| <span class="p">},</span> |
| <span class="p">)</span> |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="testing-our-arduino-project"> |
| <h2>Testing our Arduino Project<a class="headerlink" href="#testing-our-arduino-project" title="Permalink to this headline">¶</a></h2> |
| <p>Consider the following two 224x224 images from the author’s camera roll - one of a car, one not. |
| We will test our Arduino project by loading both of these images and executing the compiled model |
| on them.</p> |
| <a class="reference internal image-reference" href="https://raw.githubusercontent.com/tlc-pack/web-data/main/testdata/microTVM/data/model_train_images_combined.png"><img alt="https://raw.githubusercontent.com/tlc-pack/web-data/main/testdata/microTVM/data/model_train_images_combined.png" class="align-center" src="https://raw.githubusercontent.com/tlc-pack/web-data/main/testdata/microTVM/data/model_train_images_combined.png" style="width: 600px; height: 200px;" /></a> |
| <p>Currently, these are 224x224 PNG images we can download from Imgur. Before we can feed in these |
| images, we’ll need to resize and convert them to raw data, which can be done with <code class="docutils literal notranslate"><span class="pre">imagemagick</span></code>.</p> |
| <p>It’s also challenging to load raw data onto an Arduino, as only C/CPP files (and similar) are |
| compiled. We can work around this by embedding our raw data in a hard-coded C array with the |
| built-in utility <code class="docutils literal notranslate"><span class="pre">bin2c</span></code> that will output a file like below:</p> |
| <blockquote> |
| <div><div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="k">static</span><span class="w"> </span><span class="k">const</span><span class="w"> </span><span class="kt">unsigned</span><span class="w"> </span><span class="kt">char</span><span class="w"> </span><span class="n">CAR_IMAGE</span><span class="p">[]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">{</span> |
| <span class="w"> </span><span class="mh">0x22</span><span class="p">,</span><span class="mh">0x23</span><span class="p">,</span><span class="mh">0x14</span><span class="p">,</span><span class="mh">0x22</span><span class="p">,</span> |
| <span class="w"> </span><span class="p">...</span> |
| <span class="w"> </span><span class="mh">0x07</span><span class="p">,</span><span class="mh">0x0e</span><span class="p">,</span><span class="mh">0x08</span><span class="p">,</span><span class="mh">0x08</span> |
| <span class="p">};</span> |
| </pre></div> |
| </div> |
| </div></blockquote> |
| <p>We can do both of these things with a few lines of Bash code:</p> |
| <blockquote> |
| <div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>mkdir<span class="w"> </span>-p<span class="w"> </span>~/tests |
| curl<span class="w"> </span><span class="s2">"https://i.imgur.com/JBbEhxN.png"</span><span class="w"> </span>-o<span class="w"> </span>~/tests/car_224.png |
| convert<span class="w"> </span>~/tests/car_224.png<span class="w"> </span>-resize<span class="w"> </span><span class="m">64</span><span class="w"> </span>~/tests/car_64.png |
| stream<span class="w"> </span>~/tests/car_64.png<span class="w"> </span>~/tests/car.raw |
| bin2c<span class="w"> </span>-c<span class="w"> </span>-st<span class="w"> </span>~/tests/car.raw<span class="w"> </span>--name<span class="w"> </span>CAR_IMAGE<span class="w"> </span>><span class="w"> </span>~/models/project/car.c |
| |
| curl<span class="w"> </span><span class="s2">"https://i.imgur.com/wkh7Dx2.png"</span><span class="w"> </span>-o<span class="w"> </span>~/tests/catan_224.png |
| convert<span class="w"> </span>~/tests/catan_224.png<span class="w"> </span>-resize<span class="w"> </span><span class="m">64</span><span class="w"> </span>~/tests/catan_64.png |
| stream<span class="w"> </span>~/tests/catan_64.png<span class="w"> </span>~/tests/catan.raw |
| bin2c<span class="w"> </span>-c<span class="w"> </span>-st<span class="w"> </span>~/tests/catan.raw<span class="w"> </span>--name<span class="w"> </span>CATAN_IMAGE<span class="w"> </span>><span class="w"> </span>~/models/project/catan.c |
| </pre></div> |
| </div> |
| </div></blockquote> |
| </div> |
| <div class="section" id="writing-our-arduino-script"> |
| <h2>Writing our Arduino Script<a class="headerlink" href="#writing-our-arduino-script" title="Permalink to this headline">¶</a></h2> |
| <p>We now need a little bit of Arduino code to read the two binary arrays we just generated, run the |
| model on them, and log the output to the serial monitor. This file will replace <code class="docutils literal notranslate"><span class="pre">arduino_sketch.ino</span></code> |
| as the main file of our sketch. You’ll have to copy this code in manually..</p> |
| <blockquote> |
| <div><div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span><span class="w"> </span><span class="cpf">"src/model.h"</span> |
| <span class="cp">#include</span><span class="w"> </span><span class="cpf">"car.c"</span> |
| <span class="cp">#include</span><span class="w"> </span><span class="cpf">"catan.c"</span> |
| |
| <span class="kt">void</span><span class="w"> </span><span class="nf">setup</span><span class="p">()</span><span class="w"> </span><span class="p">{</span> |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">begin</span><span class="p">(</span><span class="mi">9600</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">TVMInitialize</span><span class="p">();</span> |
| <span class="p">}</span> |
| |
| <span class="kt">void</span><span class="w"> </span><span class="nf">loop</span><span class="p">()</span><span class="w"> </span><span class="p">{</span> |
| <span class="w"> </span><span class="kt">uint8_t</span><span class="w"> </span><span class="n">result_data</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span> |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">println</span><span class="p">(</span><span class="s">"Car results:"</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">TVMExecute</span><span class="p">(</span><span class="n">const_cast</span><span class="o"><</span><span class="kt">uint8_t</span><span class="o">*></span><span class="p">(</span><span class="n">CAR_IMAGE</span><span class="p">),</span><span class="w"> </span><span class="n">result_data</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">print</span><span class="p">(</span><span class="n">result_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span><span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">print</span><span class="p">(</span><span class="s">", "</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">print</span><span class="p">(</span><span class="n">result_data</span><span class="p">[</span><span class="mi">1</span><span class="p">]);</span><span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">println</span><span class="p">();</span> |
| |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">println</span><span class="p">(</span><span class="s">"Other object results:"</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">TVMExecute</span><span class="p">(</span><span class="n">const_cast</span><span class="o"><</span><span class="kt">uint8_t</span><span class="o">*></span><span class="p">(</span><span class="n">CATAN_IMAGE</span><span class="p">),</span><span class="w"> </span><span class="n">result_data</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">print</span><span class="p">(</span><span class="n">result_data</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span><span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">print</span><span class="p">(</span><span class="s">", "</span><span class="p">);</span> |
| <span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">print</span><span class="p">(</span><span class="n">result_data</span><span class="p">[</span><span class="mi">1</span><span class="p">]);</span><span class="w"> </span><span class="n">Serial</span><span class="p">.</span><span class="n">println</span><span class="p">();</span> |
| |
| <span class="w"> </span><span class="n">delay</span><span class="p">(</span><span class="mi">1000</span><span class="p">);</span> |
| <span class="p">}</span> |
| </pre></div> |
| </div> |
| </div></blockquote> |
| <div class="section" id="compiling-our-code"> |
| <h3>Compiling Our Code<a class="headerlink" href="#compiling-our-code" title="Permalink to this headline">¶</a></h3> |
| <p>Now that our project has been generated, TVM’s job is mostly done! We can still call |
| <code class="docutils literal notranslate"><span class="pre">arduino_project.build()</span></code> and <code class="docutils literal notranslate"><span class="pre">arduino_project.upload()</span></code>, but these just use <code class="docutils literal notranslate"><span class="pre">arduino-cli</span></code>’s |
| compile and flash commands underneath. We could also begin autotuning our model, but that’s a |
| subject for a different tutorial. To finish up, we’ll verify no compiler errors are thrown |
| by our project:</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/shutil.html#shutil.rmtree" title="shutil.rmtree" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">rmtree</span></a><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models/project/build"</span><span class="p">,</span> <span class="n">ignore_errors</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> |
| <a href="https://docs.python.org/3/library/unittest.mock.html#unittest.mock.MagicMock" title="unittest.mock.MagicMock" class="sphx-glr-backref-module-unittest-mock sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">arduino_project</span><span class="o">.</span><span class="n">build</span></a><span class="p">()</span> |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Compilation succeeded!"</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Compilation succeeded! |
| </pre></div> |
| </div> |
| </div> |
| </div> |
| <div class="section" id="uploading-to-our-device"> |
| <h2>Uploading to Our Device<a class="headerlink" href="#uploading-to-our-device" title="Permalink to this headline">¶</a></h2> |
| <p>The very last step is uploading our sketch to an Arduino to make sure our code works properly. |
| Unfortunately, we can’t do that from Google Colab, so we’ll have to download our sketch. This is |
| simple enough to do - we’ll just turn our project into a <cite>.zip</cite> archive, and call <cite>files.download</cite>. |
| If you’re running on Google Colab, you’ll have to uncomment the last two lines to download the file |
| after writing it.</p> |
| <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ZIP_FOLDER</span></a> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">FOLDER</span></a><span class="si">}</span><span class="s2">/models/project"</span> |
| <a href="https://docs.python.org/3/library/shutil.html#shutil.make_archive" title="shutil.make_archive" class="sphx-glr-backref-module-shutil sphx-glr-backref-type-py-function"><span class="n">shutil</span><span class="o">.</span><span class="n">make_archive</span></a><span class="p">(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ZIP_FOLDER</span></a><span class="p">,</span> <span class="s2">"zip"</span><span class="p">,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str" class="sphx-glr-backref-module-builtins sphx-glr-backref-type-py-class sphx-glr-backref-instance"><span class="n">ZIP_FOLDER</span></a><span class="p">)</span> |
| <span class="c1"># from google.colab import files</span> |
| <span class="c1"># files.download(f"{FOLDER}/models/project.zip")</span> |
| </pre></div> |
| </div> |
| <p>From here, we’ll need to open it in the Arduino IDE. You’ll have to download the IDE as well as |
| the SDK for whichever board you are using. For certain boards like the Sony SPRESENSE, you may |
| have to change settings to control how much memory you want the board to use.</p> |
| <div class="section" id="expected-results"> |
| <h3>Expected Results<a class="headerlink" href="#expected-results" title="Permalink to this headline">¶</a></h3> |
| <p>If all works as expected, you should see the following output on a Serial monitor:</p> |
| <blockquote> |
| <div><div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Car</span> <span class="n">results</span><span class="p">:</span> |
| <span class="mi">255</span><span class="p">,</span> <span class="mi">0</span> |
| <span class="n">Other</span> <span class="nb">object</span> <span class="n">results</span><span class="p">:</span> |
| <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span> |
| </pre></div> |
| </div> |
| </div></blockquote> |
| <p>The first number represents the model’s confidence that the object <strong>is</strong> a car and ranges from |
| 0-255. The second number represents the model’s confidence that the object <strong>is not</strong> a car and |
| is also 0-255. These results mean the model is very sure that the first image is a car, and the |
| second image is not (which is correct). Hence, our model is working!</p> |
| </div> |
| </div> |
| <div class="section" id="summary"> |
| <h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline">¶</a></h2> |
| <p>In this tutorial, we used transfer learning to quickly train an image recognition model to |
| identify cars. We modified its input dimensions and last few layers to make it better at this, |
| and to make it faster and smaller. We then quantified the model and compiled it using TVM to |
| create an Arduino sketch. Lastly, we tested the model using two static images to prove it works |
| as intended.</p> |
| <div class="section" id="next-steps"> |
| <h3>Next Steps<a class="headerlink" href="#next-steps" title="Permalink to this headline">¶</a></h3> |
| <p>From here, we could modify the model to read live images from the camera - we have another |
| Arduino tutorial for how to do that <a class="reference external" href="https://github.com/guberti/tvm-arduino-demos/tree/master/examples/person_detection">on GitHub</a>. Alternatively, we could also |
| <a class="reference external" href="https://tvm.apache.org/docs/how_to/work_with_microtvm/micro_autotune.html">use TVM’s autotuning capabilities</a> to dramatically improve the model’s performance.</p> |
| <p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 5 minutes 9.477 seconds)</p> |
| <div class="sphx-glr-footer sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-work-with-microtvm-micro-train-py"> |
| <div class="sphx-glr-download sphx-glr-download-python docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/b52cec46baf4f78d6bcd94cbe269c8a6/micro_train.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">micro_train.py</span></code></a></p> |
| </div> |
| <div class="sphx-glr-download sphx-glr-download-jupyter docutils container"> |
| <p><a class="reference download internal" download="" href="../../_downloads/a7c7ea4b5017ae70db1f51dd8e6dcd82/micro_train.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">micro_train.ipynb</span></code></a></p> |
| </div> |
| </div> |
| <p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p> |
| </div> |
| </div> |
| </div> |
| |
| |
| </div> |
| |
| </div> |
| |
| |
| <footer> |
| |
| <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation"> |
| |
| <a href="micro_autotune.html" class="btn btn-neutral float-right" title="6. Model Tuning with microTVM" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a> |
| |
| |
| <a href="micro_pytorch.html" class="btn btn-neutral float-left" title="4. microTVM PyTorch Tutorial" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a> |
| |
| </div> |
| |
| <div id="button" class="backtop"><img src="../../_static/img/right.svg" alt="backtop"/> </div> |
| <section class="footerSec"> |
| <div class="footerHeader"> |
| <div class="d-flex align-md-items-center justify-content-between flex-column flex-md-row"> |
| <div class="copywrite d-flex align-items-center"> |
| <h5 id="copy-right-info">© 2023 Apache Software Foundation | All rights reserved</h5> |
| </div> |
| </div> |
| |
| </div> |
| |
| <div> |
| <div class="footernote">Copyright © 2023 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</div> |
| </div> |
| |
| </section> |
| </footer> |
| </div> |
| </div> |
| |
| </section> |
| |
| </div> |
| |
| |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script> |
| <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script> |
| |
| </body> |
| <script type="text/javascript"> |
| jQuery(function () { |
| SphinxRtdTheme.Navigation.enable(true); |
| }); |
| </script> |
| |
| |
| |
| |
| <!-- Theme Analytics --> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-75982049-2', 'auto'); |
| ga('send', 'pageview'); |
| </script> |
| |
| |
| |
| |
| </body> |
| </html> |