blob: 99eaf4d02a4f5b01864e1fca596f7a3c0b1c9868 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="Predict with pre-trained models" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="Predict with pre-trained models" property="og:description"/>
<title>Predict with pre-trained models — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/0.11.0/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../genindex.html" rel="index" title="Index">
<link href="../../search.html" rel="search" title="Search"/>
<link href="../index.html" rel="up" title="Tutorials"/>
<link href="../vision/large_scale_classification.html" rel="next" title="Large Scale Image Classification"/>
<link href="mnist.html" rel="prev" title="Handwritten Digit Recognition"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/0.11.0/get_started/install.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.11.0/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.11.0/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/0.11.0/how_to/faq.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/v0.11.0/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/v0.11.0">Github</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/community/contribute.html">Contribute</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">0.11.0<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7/">1.7</a></li><li><a href=/versions/1.6/>1.6</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/0.11.0/get_started/install.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.11.0/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/0.11.0/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/0.11.0/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/0.11.0/how_to/faq.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/0.11.0/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/v0.11.0/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/0.11.0/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/0.11.0/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/v0.11.0" tabindex="-1">Github</a></li>
<li><a href="/versions/0.11.0/community/contribute.html" tabindex="-1">Contribute</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">0.11.0</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6/>1.6</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="reference internal" href="../index.html#python">Python</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../index.html#basic">Basic</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html#training-and-inference">Training and Inference</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="linear-regression.html">Linear Regression</a></li>
<li class="toctree-l4"><a class="reference internal" href="mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l4 current"><a class="current reference internal" href="#">Predict with pre-trained models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../vision/large_scale_classification.html">Large Scale Image Classification</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../index.html#contributing-tutorials">Contributing Tutorials</a></li>
</ul>
</li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="predict-with-pre-trained-models">
<span id="predict-with-pre-trained-models"></span><h1>Predict with pre-trained models<a class="headerlink" href="#predict-with-pre-trained-models" title="Permalink to this headline"></a></h1>
<p>This tutorial explains how to recognize objects in an image with a
pre-trained model, and how to perform feature extraction.</p>
<div class="section" id="prerequisites">
<span id="prerequisites"></span><h2>Prerequisites<a class="headerlink" href="#prerequisites" title="Permalink to this headline"></a></h2>
<p>To complete this tutorial, we need:</p>
<ul class="simple">
<li>MXNet. See the instructions for your operating system in <a class="reference external" href="/versions/0.11.0/get_started/install.html">Setup and Installation</a></li>
<li><a class="reference external" href="http://docs.python-requests.org/en/master/">Python Requests</a>, <a class="reference external" href="https://matplotlib.org/">Matplotlib</a> and <a class="reference external" href="http://jupyter.org/index.html">Jupyter Notebook</a>.</li>
</ul>
<div class="highlight-default"><div class="highlight"><pre><span></span>$ pip install requests matplotlib jupyter
</pre></div>
</div>
</div>
<div class="section" id="loading">
<span id="loading"></span><h2>Loading<a class="headerlink" href="#loading" title="Permalink to this headline"></a></h2>
<p>We first download a pre-trained ResNet 152 layer that is trained on the full
ImageNet dataset with over 10 million images and 10 thousand classes. A
pre-trained model contains two parts, a json file containing the model
definition and a binary file containing the parameters. In addition, there may be
a text file for the labels.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="n">path</span><span class="o">=</span><span class="s1">'http://data.mxnet.io/models/imagenet-11k/'</span>
<span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">path</span><span class="o">+</span><span class="s1">'resnet-152/resnet-152-symbol.json'</span><span class="p">),</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">path</span><span class="o">+</span><span class="s1">'resnet-152/resnet-152-0000.params'</span><span class="p">),</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">path</span><span class="o">+</span><span class="s1">'synset.txt'</span><span class="p">)]</span>
</pre></div>
</div>
<p>Next, we load the downloaded model. <em>Note:</em> If GPU is available, we can replace all
occurrences of <code class="docutils literal"><span class="pre">mx.cpu()</span></code> with <code class="docutils literal"><span class="pre">mx.gpu()</span></code> to accelerate the computation.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="s1">'resnet-152'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">label_names</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">for_training</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">data_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'data'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">224</span><span class="p">,</span><span class="mi">224</span><span class="p">))],</span>
<span class="n">label_shapes</span><span class="o">=</span><span class="n">mod</span><span class="o">.</span><span class="n">_label_shapes</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">,</span> <span class="n">allow_missing</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s1">'synset.txt'</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">l</span><span class="o">.</span><span class="n">rstrip</span><span class="p">()</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="n">f</span><span class="p">]</span>
</pre></div>
</div>
</div>
<div class="section" id="predicting">
<span id="predicting"></span><h2>Predicting<a class="headerlink" href="#predicting" title="Permalink to this headline"></a></h2>
<p>We first define helper functions for downloading an image and performing the
prediction:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span>
<span class="kn">import</span> <span class="nn">cv2</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="c1"># define a simple data batch</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">namedtuple</span>
<span class="n">Batch</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span><span class="s1">'Batch'</span><span class="p">,</span> <span class="p">[</span><span class="s1">'data'</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">get_image</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">show</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="c1"># download and show the image</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">fname</span><span class="p">),</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span>
<span class="k">if</span> <span class="n">img</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">None</span>
<span class="k">if</span> <span class="n">show</span><span class="p">:</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">'off'</span><span class="p">)</span>
<span class="c1"># convert into format (batch, RGB, width, height)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">))</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:]</span>
<span class="k">return</span> <span class="n">img</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">url</span><span class="p">):</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">get_image</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">show</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># compute the predict probabilities</span>
<span class="n">mod</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">Batch</span><span class="p">([</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img</span><span class="p">)]))</span>
<span class="n">prob</span> <span class="o">=</span> <span class="n">mod</span><span class="o">.</span><span class="n">get_outputs</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="c1"># print the top-5</span>
<span class="n">prob</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">prob</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">prob</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">a</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">5</span><span class="p">]:</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'probability=</span><span class="si">%f</span><span class="s1">, class=</span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span><span class="p">(</span><span class="n">prob</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">labels</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
</pre></div>
</div>
<p>Now, we can perform prediction with any downloadable URL:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s1">'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg'</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s1">'http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg'</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="feature-extraction">
<span id="feature-extraction"></span><h2>Feature extraction<a class="headerlink" href="#feature-extraction" title="Permalink to this headline"></a></h2>
<p>By feature extraction, we mean presenting the input images by the output of an
internal layer rather than the last softmax layer. These outputs, which can be
viewed as the feature of the raw input image, can then be used by other
applications such as object detection.</p>
<p>We can use the <code class="docutils literal"><span class="pre">get_internals</span></code> method to get all internal layers from a
Symbol.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># list the last 10 layers</span>
<span class="n">all_layers</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">get_internals</span><span class="p">()</span>
<span class="n">all_layers</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">()[</span><span class="o">-</span><span class="mi">10</span><span class="p">:]</span>
</pre></div>
</div>
<p>An often used layer for feature extraction is the one before the last fully
connected layer. For ResNet, and also Inception, it is the flattened layer with
name <code class="docutils literal"><span class="pre">flatten0</span></code> which reshapes the 4-D convolutional layer output into 2-D for
the fully connected layer. The following source code extracts a new Symbol which
outputs the flattened layer and creates a model.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">fe_sym</span> <span class="o">=</span> <span class="n">all_layers</span><span class="p">[</span><span class="s1">'flatten0_output'</span><span class="p">]</span>
<span class="n">fe_mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">fe_sym</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">label_names</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">fe_mod</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">for_training</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">data_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'data'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">224</span><span class="p">,</span><span class="mi">224</span><span class="p">))])</span>
<span class="n">fe_mod</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span>
</pre></div>
</div>
<p>We can now invoke <code class="docutils literal"><span class="pre">forward</span></code> to obtain the features:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">img</span> <span class="o">=</span> <span class="n">get_image</span><span class="p">(</span><span class="s1">'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg'</span><span class="p">)</span>
<span class="n">fe_mod</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">Batch</span><span class="p">([</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img</span><span class="p">)]))</span>
<span class="n">features</span> <span class="o">=</span> <span class="n">fe_mod</span><span class="o">.</span><span class="n">get_outputs</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">features</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2048</span><span class="p">)</span>
</pre></div>
</div>
<div class="btn-group" role="group">
<div class="download-btn"><a download="predict_image.ipynb" href="predict_image.ipynb"><span class="glyphicon glyphicon-download-alt"></span> predict_image.ipynb</a></div></div></div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Predict with pre-trained models</a><ul>
<li><a class="reference internal" href="#prerequisites">Prerequisites</a></li>
<li><a class="reference internal" href="#loading">Loading</a></li>
<li><a class="reference internal" href="#predicting">Predicting</a></li>
<li><a class="reference internal" href="#feature-extraction">Feature extraction</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script src="../../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>