blob: 214394c3f4b6346714d494294c3ba9ef03883dd8 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="Visualizing Decisions of Convolutional Neural Networks" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="Visualizing Decisions of Convolutional Neural Networks" property="og:description"/>
<title>Visualizing Decisions of Convolutional Neural Networks — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/1.3.1/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../genindex.html" rel="index" title="Index">
<link href="../../search.html" rel="search" title="Search"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/1.3.1/install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/1.3.1/faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.3.1/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.3.1">Github</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/community/ecosystem.html">Ecosystem</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">1.3.1<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/1.3.1/install/index.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/1.3.1/faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/1.3.1/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.3.1/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/1.3.1/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/1.3.1/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.3.1" tabindex="-1">Github</a></li>
<li><a href="/versions/1.3.1/community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="/versions/1.3.1/community/ecosystem.html" tabindex="-1">Ecosystem</a></li>
<li><a href="/versions/1.3.1/community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">1.3.1</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../community/contribute.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="visualizing-decisions-of-convolutional-neural-networks">
<span id="visualizing-decisions-of-convolutional-neural-networks"></span><h1>Visualizing Decisions of Convolutional Neural Networks<a class="headerlink" href="#visualizing-decisions-of-convolutional-neural-networks" title="Permalink to this headline"></a></h1>
<p>Convolutional Neural Networks have made a lot of progress in Computer Vision. Their accuracy is as good as humans in some tasks. However it remains hard to explain the predictions of convolutional neural networks, as they lack the interpretability offered by other models, for example decision trees.</p>
<p>It is often helpful to be able to explain why a model made the prediction it made. For example when a model misclassifies an image, it is hard to say why without visualizing the network’s decision.</p>
<p><img align="right" alt="Explaining the misclassification of volcano as spider" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/volcano_barn_spider.png" width="500px/"/></p>
<p>Visualizations also help build confidence about the predictions of a model. For example, even if a model correctly predicts birds as birds, we would want to confirm that the model bases its decision on the features of bird and not on the features of some other object that might occur together with birds in the dataset (like leaves).</p>
<p>In this tutorial, we show how to visualize the predictions made by convolutional neural networks using <a class="reference external" href="https://arxiv.org/abs/1610.02391">Gradient-weighted Class Activation Mapping</a>. Unlike many other visualization methods, Grad-CAM can be used on a wide variety of CNN model families - CNNs with fully connected layers, CNNs used for structural outputs (e.g. captioning), CNNs used in tasks with multi-model input (e.g. VQA) or reinforcement learning without architectural changes or re-training.</p>
<p>In the rest of this notebook, we will explain how to visualize predictions made by <a class="reference external" href="https://arxiv.org/abs/1409.1556">VGG-16</a>. We begin by importing the required dependencies. <code class="docutils literal"><span class="pre">gradcam</span></code> module contains the implementation of visualization techniques used in this notebook.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mxnet</span> <span class="kn">import</span> <span class="n">gluon</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="n">gradcam_file</span> <span class="o">=</span> <span class="s2">"gradcam.py"</span>
<span class="n">base_url</span> <span class="o">=</span> <span class="s2">"https://raw.githubusercontent.com/indhub/mxnet/cnnviz/example/cnn_visualization/{}?raw=true"</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">base_url</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">gradcam_file</span><span class="p">),</span> <span class="n">fname</span><span class="o">=</span><span class="n">gradcam_file</span><span class="p">)</span>
<span class="kn">import</span> <span class="nn">gradcam</span>
</pre></div>
</div>
<div class="section" id="building-the-network-to-visualize">
<span id="building-the-network-to-visualize"></span><h2>Building the network to visualize<a class="headerlink" href="#building-the-network-to-visualize" title="Permalink to this headline"></a></h2>
<p>Next, we build the network we want to visualize. For this example, we will use the <a class="reference external" href="https://arxiv.org/abs/1409.1556">VGG-16</a> network. This code was taken from the Gluon <a class="reference external" href="https://github.com/apache/incubator-mxnet/blob/1.3.1/python/mxnet/gluon/model_zoo/vision/alexnet.py">model zoo</a> and refactored to make it easy to switch between <code class="docutils literal"><span class="pre">gradcam</span></code>‘s and Gluon’s implementation of ReLU and Conv2D. Same code can be used for both training and visualization with a minor (one line) change.</p>
<p>Notice that we import ReLU and Conv2D from <code class="docutils literal"><span class="pre">gradcam</span></code> module instead of mxnet.gluon.nn.</p>
<ul class="simple">
<li>We use a modified ReLU because we use guided backpropagation for visualization and guided backprop requires ReLU layer to block the backward flow of negative gradients corresponding to the neurons which decrease the activation of the higher layer unit we aim to visualize. Check <a class="reference external" href="https://arxiv.org/abs/1412.6806">this</a> paper to learn more about guided backprop.</li>
<li>We use a modified Conv2D (a wrapper on top of Gluon’s Conv2D) because we want to capture the output of a given convolutional layer and its gradients. This is needed to implement Grad-CAM. Check <a class="reference external" href="https://arxiv.org/abs/1610.02391">this</a> paper to learn more about Grad-CAM.</li>
</ul>
<p>When you train the network, you could just import <code class="docutils literal"><span class="pre">Activation</span></code> and <code class="docutils literal"><span class="pre">Conv2D</span></code> from <code class="docutils literal"><span class="pre">gluon.nn</span></code> instead. No other part of the code needs any change to switch between training and visualization.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon.model_zoo</span> <span class="kn">import</span> <span class="n">model_store</span>
<span class="kn">from</span> <span class="nn">mxnet.initializer</span> <span class="kn">import</span> <span class="n">Xavier</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon.nn</span> <span class="kn">import</span> <span class="n">MaxPool2D</span><span class="p">,</span> <span class="n">Flatten</span><span class="p">,</span> <span class="n">Dense</span><span class="p">,</span> <span class="n">Dropout</span><span class="p">,</span> <span class="n">BatchNorm</span>
<span class="kn">from</span> <span class="nn">gradcam</span> <span class="kn">import</span> <span class="n">Activation</span><span class="p">,</span> <span class="n">Conv2D</span>
<span class="k">class</span> <span class="nc">VGG</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">gluon</span><span class="o">.</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">VGG</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">layers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">filters</span><span class="p">)</span>
<span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">name_scope</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_features</span><span class="p">(</span><span class="n">layers</span><span class="p">,</span> <span class="n">filters</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4096</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">,</span>
<span class="n">weight_initializer</span><span class="o">=</span><span class="s1">'normal'</span><span class="p">,</span>
<span class="n">bias_initializer</span><span class="o">=</span><span class="s1">'zeros'</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="mf">0.5</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4096</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">,</span>
<span class="n">weight_initializer</span><span class="o">=</span><span class="s1">'normal'</span><span class="p">,</span>
<span class="n">bias_initializer</span><span class="o">=</span><span class="s1">'zeros'</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="mf">0.5</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">Dense</span><span class="p">(</span><span class="n">classes</span><span class="p">,</span>
<span class="n">weight_initializer</span><span class="o">=</span><span class="s1">'normal'</span><span class="p">,</span>
<span class="n">bias_initializer</span><span class="o">=</span><span class="s1">'zeros'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_make_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">filters</span><span class="p">):</span>
<span class="n">featurizer</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gluon</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">(</span><span class="n">prefix</span><span class="o">=</span><span class="s1">''</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">num</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">layers</span><span class="p">):</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num</span><span class="p">):</span>
<span class="n">featurizer</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">weight_initializer</span><span class="o">=</span><span class="n">Xavier</span><span class="p">(</span><span class="n">rnd_type</span><span class="o">=</span><span class="s1">'gaussian'</span><span class="p">,</span>
<span class="n">factor_type</span><span class="o">=</span><span class="s1">'out'</span><span class="p">,</span>
<span class="n">magnitude</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">bias_initializer</span><span class="o">=</span><span class="s1">'zeros'</span><span class="p">))</span>
<span class="n">featurizer</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Activation</span><span class="p">(</span><span class="s1">'relu'</span><span class="p">))</span>
<span class="n">featurizer</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">MaxPool2D</span><span class="p">(</span><span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">))</span>
<span class="k">return</span> <span class="n">featurizer</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</div>
<div class="section" id="loading-pretrained-weights">
<span id="loading-pretrained-weights"></span><h2>Loading pretrained weights<a class="headerlink" href="#loading-pretrained-weights" title="Permalink to this headline"></a></h2>
<p>We’ll use pre-trained weights (trained on ImageNet) from model zoo instead of training the model from scratch.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># Number of convolution layers and number of filters for each VGG configuration.</span>
<span class="c1"># Check the VGG [paper](https://arxiv.org/abs/1409.1556) for more details on the different architectures.</span>
<span class="n">vgg_spec</span> <span class="o">=</span> <span class="p">{</span><span class="mi">11</span><span class="p">:</span> <span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">]),</span>
<span class="mi">13</span><span class="p">:</span> <span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">]),</span>
<span class="mi">16</span><span class="p">:</span> <span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">]),</span>
<span class="mi">19</span><span class="p">:</span> <span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">])}</span>
<span class="k">def</span> <span class="nf">get_vgg</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">root</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s1">'~'</span><span class="p">,</span> <span class="s1">'.mxnet'</span><span class="p">,</span> <span class="s1">'models'</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="c1"># Get the number of convolution layers and filters</span>
<span class="n">layers</span><span class="p">,</span> <span class="n">filters</span> <span class="o">=</span> <span class="n">vgg_spec</span><span class="p">[</span><span class="n">num_layers</span><span class="p">]</span>
<span class="c1"># Build the VGG network</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">VGG</span><span class="p">(</span><span class="n">layers</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="c1"># Load pretrained weights from model zoo</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon.model_zoo.model_store</span> <span class="kn">import</span> <span class="n">get_model_file</span>
<span class="n">net</span><span class="o">.</span><span class="n">load_params</span><span class="p">(</span><span class="n">get_model_file</span><span class="p">(</span><span class="s1">'vgg</span><span class="si">%d</span><span class="s1">'</span> <span class="o">%</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">net</span>
<span class="k">def</span> <span class="nf">vgg16</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="k">return</span> <span class="n">get_vgg</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="preprocessing-and-other-helpers">
<span id="preprocessing-and-other-helpers"></span><h2>Preprocessing and other helpers<a class="headerlink" href="#preprocessing-and-other-helpers" title="Permalink to this headline"></a></h2>
<p>We’ll resize the input image to 224x224 before feeding it to the network. We normalize the images using the same parameters ImageNet dataset was normalised using to create the pretrained model. These parameters are published <a class="reference external" href="/api/python/gluon/model_zoo.html">here</a>. We use <code class="docutils literal"><span class="pre">transpose</span></code> to convert the image to channel-last format.</p>
<p>Note that we do not hybridize the network. This is because we want <code class="docutils literal"><span class="pre">gradcam.Activation</span></code> and <code class="docutils literal"><span class="pre">gradcam.Conv2D</span></code> to behave differently at different times during the execution. For example, <code class="docutils literal"><span class="pre">gradcam.Activation</span></code> will do the regular backpropagation while computing the gradient of the topmost convolutional layer but will do guided backpropagation when computing the gradient of the image.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">image_sz</span> <span class="o">=</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">preprocess</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">imresize</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">image_sz</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">image_sz</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">/</span><span class="mi">255</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">color_normalize</span><span class="p">(</span><span class="n">data</span><span class="p">,</span>
<span class="n">mean</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">]),</span>
<span class="n">std</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.229</span><span class="p">,</span> <span class="mf">0.224</span><span class="p">,</span> <span class="mf">0.225</span><span class="p">]))</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="n">data</span>
<span class="n">network</span> <span class="o">=</span> <span class="n">vgg16</span><span class="p">(</span><span class="n">ctx</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span>
</pre></div>
</div>
<p>We define a helper to display multiple images in a row in Jupyter notebook.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">show_images</span><span class="p">(</span><span class="n">pred_str</span><span class="p">,</span> <span class="n">images</span><span class="p">):</span>
<span class="n">titles</span> <span class="o">=</span> <span class="p">[</span><span class="n">pred_str</span><span class="p">,</span> <span class="s1">'Grad-CAM'</span><span class="p">,</span> <span class="s1">'Guided Grad-CAM'</span><span class="p">,</span> <span class="s1">'Saliency Map'</span><span class="p">]</span>
<span class="n">num_images</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
<span class="n">fig</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span><span class="mi">15</span><span class="p">))</span>
<span class="n">rows</span><span class="p">,</span> <span class="n">cols</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_images</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_images</span><span class="p">):</span>
<span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">cols</span><span class="p">,</span> <span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="n">titles</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">'gray'</span> <span class="k">if</span> <span class="n">i</span><span class="o">==</span><span class="n">num_images</span><span class="o">-</span><span class="mi">1</span> <span class="k">else</span> <span class="bp">None</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</pre></div>
</div>
<p>Given an image, the network predicts a probability distribution over all categories. The most probable category can be found by applying the <code class="docutils literal"><span class="pre">argmax</span></code> operation. This gives an integer corresponding to the category. We still need to convert this to a human readable category name to know what category the network predicted. <a class="reference external" href="http://data.mxnet.io/models/imagenet/synset.txt">Synset</a> file contains the mapping between Imagenet category index and category name. We’ll download the synset file, load it in a list to convert category index to human readable category names.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">synset_url</span> <span class="o">=</span> <span class="s2">"http://data.mxnet.io/models/imagenet/synset.txt"</span>
<span class="n">synset_file_name</span> <span class="o">=</span> <span class="s2">"synset.txt"</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">synset_url</span><span class="p">,</span> <span class="n">fname</span><span class="o">=</span><span class="n">synset_file_name</span><span class="p">)</span>
<span class="n">synset</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s1">'synset.txt'</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">synset</span> <span class="o">=</span> <span class="p">[</span><span class="n">l</span><span class="o">.</span><span class="n">rstrip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">' '</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">','</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="n">f</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">get_class_name</span><span class="p">(</span><span class="n">cls_id</span><span class="p">):</span>
<span class="k">return</span> <span class="s2">"</span><span class="si">%s</span><span class="s2"> (</span><span class="si">%d</span><span class="s2">)"</span> <span class="o">%</span> <span class="p">(</span><span class="n">synset</span><span class="p">[</span><span class="n">cls_id</span><span class="p">],</span> <span class="n">cls_id</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">run_inference</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="visualizing-cnn-decisions">
<span id="visualizing-cnn-decisions"></span><h2>Visualizing CNN decisions<a class="headerlink" href="#visualizing-cnn-decisions" title="Permalink to this headline"></a></h2>
<p>Next, we’ll write a method to get an image, preprocess it, predict category and visualize the prediction. We’ll use <code class="docutils literal"><span class="pre">gradcam.visualize()</span></code> to create the visualizations. <code class="docutils literal"><span class="pre">gradcam.visualize</span></code> returns a tuple with the following visualizations:</p>
<ol class="simple">
<li><strong>Grad-CAM:</strong> This is a heatmap superimposed on the input image showing which part(s) of the image contributed most to the CNN’s decision.</li>
<li><strong>Guided Grad-CAM:</strong> Guided Grad-CAM shows which exact pixels contributed the most to the CNN’s decision.</li>
<li><strong>Saliency map:</strong> Saliency map is a monochrome image showing which pixels contributed the most to the CNN’s decision. Sometimes, it is easier to see the areas in the image that most influence the output in a monochrome image than in a color image.</li>
</ol>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">visualize</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">img_path</span><span class="p">,</span> <span class="n">conv_layer_name</span><span class="p">):</span>
<span class="n">orig_img</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">img</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span>
<span class="n">preprocessed_img</span> <span class="o">=</span> <span class="n">preprocess</span><span class="p">(</span><span class="n">orig_img</span><span class="p">)</span>
<span class="n">preprocessed_img</span> <span class="o">=</span> <span class="n">preprocessed_img</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">pred_str</span> <span class="o">=</span> <span class="n">get_class_name</span><span class="p">(</span><span class="n">run_inference</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">preprocessed_img</span><span class="p">))</span>
<span class="n">orig_img</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">imresize</span><span class="p">(</span><span class="n">orig_img</span><span class="p">,</span> <span class="n">image_sz</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">image_sz</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
<span class="n">vizs</span> <span class="o">=</span> <span class="n">gradcam</span><span class="o">.</span><span class="n">visualize</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">preprocessed_img</span><span class="p">,</span> <span class="n">orig_img</span><span class="p">,</span> <span class="n">conv_layer_name</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">pred_str</span><span class="p">,</span> <span class="p">(</span><span class="n">orig_img</span><span class="p">,</span> <span class="o">*</span><span class="n">vizs</span><span class="p">))</span>
</pre></div>
</div>
<p>Next, we need to get the name of the last convolutional layer that extracts features from the image. We use the gradient information flowing into the last convolutional layer of the CNN to understand the importance of each neuron for a decision of interest. We are interested in the last convolutional layer because convolutional features naturally retain spatial information which is lost in fully connected layers. So, we expect the last convolutional layer to have the best compromise between high level semantics and detailed spacial information. The neurons in this layer look for semantic class specific information in the image (like object parts).</p>
<p>In our network, feature extractors are added to a HybridSequential block named features. You can list the layers in that block by just printing <code class="docutils literal"><span class="pre">network.features</span></code>. You can see that the topmost convolutional layer is at index 28. <code class="docutils literal"><span class="pre">network.features[28]._name</span></code> will give the name of the layer.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">last_conv_layer_name</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">features</span><span class="p">[</span><span class="mi">28</span><span class="p">]</span><span class="o">.</span><span class="n">_name</span>
<span class="k">print</span><span class="p">(</span><span class="n">last_conv_layer_name</span><span class="p">)</span>
</pre></div>
</div>
<p>vgg0_conv2d12<!--notebook-skip-line--></p>
<p>Let’s download some images we can use for visualization.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">images</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"hummingbird.jpg"</span><span class="p">,</span> <span class="s2">"jellyfish.jpg"</span><span class="p">,</span> <span class="s2">"snow_leopard.jpg"</span><span class="p">,</span> <span class="s2">"volcano.jpg"</span><span class="p">]</span>
<span class="n">base_url</span> <span class="o">=</span> <span class="s2">"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/{}?raw=true"</span>
<span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">images</span><span class="p">:</span>
<span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">base_url</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">image</span><span class="p">),</span> <span class="n">fname</span><span class="o">=</span><span class="n">image</span><span class="p">)</span>
</pre></div>
</div>
<p>We now have everything we need to start visualizing. Let’s visualize the CNN decision for the images we downloaded.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">show_images</span><span class="p">(</span><span class="o">*</span><span class="n">visualize</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="s2">"hummingbird.jpg"</span><span class="p">,</span> <span class="n">last_conv_layer_name</span><span class="p">))</span>
</pre></div>
</div>
<p><img alt="Visualizing CNN decision" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/hummingbird.png"><!--notebook-skip-line--></img></p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">show_images</span><span class="p">(</span><span class="o">*</span><span class="n">visualize</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="s2">"jellyfish.jpg"</span><span class="p">,</span> <span class="n">last_conv_layer_name</span><span class="p">))</span>
</pre></div>
</div>
<p><img alt="Visualizing CNN decision" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/jellyfish.png"><!--notebook-skip-line--></img></p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">show_images</span><span class="p">(</span><span class="o">*</span><span class="n">visualize</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="s2">"snow_leopard.jpg"</span><span class="p">,</span> <span class="n">last_conv_layer_name</span><span class="p">))</span>
</pre></div>
</div>
<p><img alt="Visualizing CNN decision" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/snow_leopard.png"/><!--notebook-skip-line--></p>
<p>Shown above are some images the network was able to predict correctly. We can see that the network is basing its decision on the appropriate features. Now, let’s look at an example that the network gets the prediction wrong and visualize why it gets the prediction wrong.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">show_images</span><span class="p">(</span><span class="o">*</span><span class="n">visualize</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="s2">"volcano.jpg"</span><span class="p">,</span> <span class="n">last_conv_layer_name</span><span class="p">))</span>
</pre></div>
</div>
<p><img alt="Visualizing CNN decision" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/cnn_visualization/volcano.png"/><!--notebook-skip-line--></p>
<p>While it is not immediately evident why the network thinks this volcano is a spider, after looking at the Grad-CAM visualization, it is hard to look at the volcano and not see the spider!</p>
<p>Being able to visualize why a CNN predicts specific classes is a powerful tool to diagnose prediction failures. Even when the network is making correct predictions, visualizing activations is an important step to verify that the network is making its decisions based on the right features and not some correlation which happens to exist in the training data.</p>
<p>The visualization method demonstrated in this tutorial applies to a wide variety of network architectures and a wide variety of tasks beyond classification - like VQA and image captioning. Any type of differentiable output can be used to create the visualizations shown above. Visualization techniques like these solve (at least partially) the long standing problem of interpretability of neural networks.</p>
<div class="btn-group" role="group">
<div class="download-btn"><a download="cnn_visualization.ipynb" href="cnn_visualization.ipynb"><span class="glyphicon glyphicon-download-alt"></span> cnn_visualization.ipynb</a></div></div></div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Visualizing Decisions of Convolutional Neural Networks</a><ul>
<li><a class="reference internal" href="#building-the-network-to-visualize">Building the network to visualize</a></li>
<li><a class="reference internal" href="#loading-pretrained-weights">Loading pretrained weights</a></li>
<li><a class="reference internal" href="#preprocessing-and-other-helpers">Preprocessing and other helpers</a></li>
<li><a class="reference internal" href="#visualizing-cnn-decisions">Visualizing CNN decisions</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script src="../../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>