blob: 743e49630c81dd65a4bec4ac2a7cb9ffa060797f [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Predict with pre-trained models — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../index.html" rel="up" title="Tutorials">
<link href="../vision/large_scale_classification.html" rel="next" title="Large Scale Image Classification"/>
<link href="mnist.html" rel="prev" title="Handwritten Digit Recognition"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="reference internal" href="../index.html#python">Python</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../index.html#basics">Basics</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html#training-and-inference">Training and Inference</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="linear-regression.html">Linear Regression</a></li>
<li class="toctree-l4"><a class="reference internal" href="mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l4 current"><a class="current reference internal" href="">Predict with pre-trained models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../vision/large_scale_classification.html">Large Scale Image Classification</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../index.html#contributing-tutorials">Contributing Tutorials</a></li>
</ul>
</li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="predict-with-pre-trained-models">
<span id="predict-with-pre-trained-models"></span><h1>Predict with pre-trained models<a class="headerlink" href="#predict-with-pre-trained-models" title="Permalink to this headline"></a></h1>
<p>This tutorial explains how to recognize objects in an image with a
pre-trained model, and how to perform feature extraction.</p>
<div class="section" id="prerequisites">
<span id="prerequisites"></span><h2>Prerequisites<a class="headerlink" href="#prerequisites" title="Permalink to this headline"></a></h2>
<p>To complete this tutorial, we need:</p>
<ul class="simple">
<li>MXNet. See the instructions for your operating system in <a class="reference external" href="http://mxnet.io/get_started/install.html">Setup and Installation</a></li>
<li><a class="reference external" href="http://docs.python-requests.org/en/master/">Python Requests</a>, <a class="reference external" href="https://matplotlib.org/">Matplotlib</a> and <a class="reference external" href="http://jupyter.org/index.html">Jupyter Notebook</a>.</li>
</ul>
<div class="highlight-python"><div class="highlight"><pre><span></span>$ pip install requests matplotlib jupyter
</pre></div>
</div>
</div>
<div class="section" id="loading">
<span id="loading"></span><h2>Loading<a class="headerlink" href="#loading" title="Permalink to this headline"></a></h2>
<p>We first download a pre-trained ResNet 152 layer that is trained on the full
ImageNet dataset with over 10 million images and 10 thousand classes. A
pre-trained model contains two parts, a json file containing the model
definition and a binary file containing the parameters. In addition, there may be
a text file for the labels.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="n">path</span><span class="o">=</span><span class="s1">'http://data.mxnet.io/models/imagenet-11k/'</span>
<span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">path</span><span class="o">+</span><span class="s1">'resnet-152/resnet-152-symbol.json'</span><span class="p">),</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">path</span><span class="o">+</span><span class="s1">'resnet-152/resnet-152-0000.params'</span><span class="p">),</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">path</span><span class="o">+</span><span class="s1">'synset.txt'</span><span class="p">)]</span>
</pre></div>
</div>
<p>Next, we load the downloaded model. <em>Note:</em> If GPU is available, we can replace all
occurrences of <code class="docutils literal"><span class="pre">mx.cpu()</span></code> with <code class="docutils literal"><span class="pre">mx.gpu()</span></code> to accelerate the computation.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="s1">'resnet-152'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">sym</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">label_names</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">for_training</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">data_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'data'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">224</span><span class="p">,</span><span class="mi">224</span><span class="p">))],</span>
<span class="n">label_shapes</span><span class="o">=</span><span class="n">mod</span><span class="o">.</span><span class="n">_label_shapes</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">,</span> <span class="n">allow_missing</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s1">'synset.txt'</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">l</span><span class="o">.</span><span class="n">rstrip</span><span class="p">()</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="n">f</span><span class="p">]</span>
</pre></div>
</div>
</div>
<div class="section" id="predicting">
<span id="predicting"></span><h2>Predicting<a class="headerlink" href="#predicting" title="Permalink to this headline"></a></h2>
<p>We first define helper functions for downloading an image and performing the
prediction:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span>
<span class="kn">import</span> <span class="nn">cv2</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="c1"># define a simple data batch</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">namedtuple</span>
<span class="n">Batch</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span><span class="s1">'Batch'</span><span class="p">,</span> <span class="p">[</span><span class="s1">'data'</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">get_image</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">show</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="c1"># download and show the image</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">fname</span><span class="p">),</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span>
<span class="k">if</span> <span class="n">img</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">None</span>
<span class="k">if</span> <span class="n">show</span><span class="p">:</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">'off'</span><span class="p">)</span>
<span class="c1"># convert into format (batch, RGB, width, height)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">))</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:]</span>
<span class="k">return</span> <span class="n">img</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">url</span><span class="p">):</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">get_image</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">show</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1"># compute the predict probabilities</span>
<span class="n">mod</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">Batch</span><span class="p">([</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img</span><span class="p">)]))</span>
<span class="n">prob</span> <span class="o">=</span> <span class="n">mod</span><span class="o">.</span><span class="n">get_outputs</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="c1"># print the top-5</span>
<span class="n">prob</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">prob</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">prob</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">a</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">5</span><span class="p">]:</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'probability=</span><span class="si">%f</span><span class="s1">, class=</span><span class="si">%s</span><span class="s1">'</span> <span class="o">%</span><span class="p">(</span><span class="n">prob</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">labels</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
</pre></div>
</div>
<p>Now, we can perform prediction with any downloadable URL:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s1">'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg'</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s1">'http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg'</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="feature-extraction">
<span id="feature-extraction"></span><h2>Feature extraction<a class="headerlink" href="#feature-extraction" title="Permalink to this headline"></a></h2>
<p>By feature extraction, we mean presenting the input images by the output of an
internal layer rather than the last softmax layer. These outputs, which can be
viewed as the feature of the raw input image, can then be used by other
applications such as object detection.</p>
<p>We can use the <code class="docutils literal"><span class="pre">get_internals</span></code> method to get all internal layers from a
Symbol.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># list the last 10 layers</span>
<span class="n">all_layers</span> <span class="o">=</span> <span class="n">sym</span><span class="o">.</span><span class="n">get_internals</span><span class="p">()</span>
<span class="n">all_layers</span><span class="o">.</span><span class="n">list_outputs</span><span class="p">()[</span><span class="o">-</span><span class="mi">10</span><span class="p">:]</span>
</pre></div>
</div>
<p>An often used layer for feature extraction is the one before the last fully
connected layer. For ResNet, and also Inception, it is the flattened layer with
name <code class="docutils literal"><span class="pre">flatten0</span></code> which reshapes the 4-D convolutional layer output into 2-D for
the fully connected layer. The following source code extracts a new Symbol which
outputs the flattened layer and creates a model.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">fe_sym</span> <span class="o">=</span> <span class="n">all_layers</span><span class="p">[</span><span class="s1">'flatten0_output'</span><span class="p">]</span>
<span class="n">fe_mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">fe_sym</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">label_names</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">fe_mod</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">for_training</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">data_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'data'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">224</span><span class="p">,</span><span class="mi">224</span><span class="p">))])</span>
<span class="n">fe_mod</span><span class="o">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">)</span>
</pre></div>
</div>
<p>We can now invoke <code class="docutils literal"><span class="pre">forward</span></code> to obtain the features:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">img</span> <span class="o">=</span> <span class="n">get_image</span><span class="p">(</span><span class="s1">'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg'</span><span class="p">)</span>
<span class="n">fe_mod</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">Batch</span><span class="p">([</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img</span><span class="p">)]))</span>
<span class="n">features</span> <span class="o">=</span> <span class="n">fe_mod</span><span class="o">.</span><span class="n">get_outputs</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">features</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2048</span><span class="p">)</span>
</pre></div>
</div>
<div class="btn-group" role="group">
<div class="download_btn"><a download="predict_image_python.ipynb" href="predict_image_python.ipynb"><span class="glyphicon glyphicon-download-alt"></span> predict_image_python.ipynb</a></div></div></div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Predict with pre-trained models</a><ul>
<li><a class="reference internal" href="#prerequisites">Prerequisites</a></li>
<li><a class="reference internal" href="#loading">Loading</a></li>
<li><a class="reference internal" href="#predicting">Predicting</a></li>
<li><a class="reference internal" href="#feature-extraction">Feature extraction</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>