blob: 484f377ce86bf133fde81eaa8dfcb5f6ffbe7c31 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Fine-tune with Pretrained Models — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../_static/underscore.js" type="text/javascript"></script>
<script src="../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../_static/doctools.js" type="text/javascript"></script>
<script src="../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../get_started/index.html">Get Started</a> </li>
<li> <a href="../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../packages/python/index.html">
Python
</a></li>
<li><a href="../packages/r/index.html">
R
</a></li>
<li><a href="../packages/julia/index.html">
Julia
</a></li>
<li><a href="../packages/c++/index.html">
C++
</a></li>
<li><a href="../packages/scala/index.html">
Scala
</a></li>
<li><a href="../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../get_started/install.html">Install</a>
<a class="main-nav-link" href="../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../get_started/install.html">Install</a></li>
<li><a href="../tutorials/index.html">Tutorials</a></li>
<li><a href="../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tutorials/index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="fine-tune-with-pretrained-models">
<span id="fine-tune-with-pretrained-models"></span><h1>Fine-tune with Pretrained Models<a class="headerlink" href="#fine-tune-with-pretrained-models" title="Permalink to this headline"></a></h1>
<p>Many of the exciting deep learning algorithms for computer vision require
massive datasets for training. The most popular benchmark dataset,
<a class="reference external" href="http://www.image-net.org/">ImageNet</a>, for example, contains one million images
from one thousand categories. But for any practical problem, we typically have
access to comparatively small datasets. In these cases, if we were to train a
neural network’s weights from scratch, starting from random initialized
parameters, we would overfit the training set badly.</p>
<p>One approach to get around this problem is to first pretrain a deep net on a
large-scale dataset, like ImageNet. Then, given a new dataset, we can start
with these pretrained weights when training on our new task. This process is
commonly called <em>fine-tuning</em>. There are a number of variations of fine-tuning.
Sometimes, the initial neural network is used only as a <em>feature extractor</em>.
That means that we freeze every layer prior to the output layer and simply learn
a new output layer. In <a class="reference external" href="https://github.com/dmlc/mxnet-notebooks/blob/master/python/how_to/predict.ipynb">another document</a>, we explained how to
do this kind of feature extraction. Another approach is to update all of
the network’s weights for the new task, and that’s the approach we demonstrate in
this document.</p>
<p>To fine-tune a network, we must first replace the last fully-connected layer
with a new one that outputs the desired number of classes. We initialize its
weights randomly. Then we continue training as normal. Sometimes it’s common to
use a smaller learning rate based on the intuition that we may already be close
to a good result.</p>
<p>In this demonstration, we’ll fine-tune a model pretrained on ImageNet to the
smaller caltech-256 dataset. Following this example, you can fine-tune to other
datasets, even for strikingly different applications such as face
identification.</p>
<p>We will show that, even with simple hyper-parameters setting, we can match and
even outperform state-of-the-art results on caltech-256.</p>
<table border="1" class="docutils">
<colgroup>
<col width="50%"/>
<col width="50%"/>
</colgroup>
<thead valign="bottom">
<tr class="row-odd"><th class="head">Network</th>
<th class="head">Accuracy</th>
</tr>
</thead>
<tbody valign="top">
<tr class="row-even"><td>Resnet-50</td>
<td>77.4%</td>
</tr>
<tr class="row-odd"><td>Resnet-152</td>
<td>86.4%</td>
</tr>
</tbody>
</table>
<div class="section" id="prepare-data">
<span id="prepare-data"></span><h2>Prepare data<a class="headerlink" href="#prepare-data" title="Permalink to this headline"></a></h2>
<p>We follow the standard protocol to sample 60 images from each class as the
training set, and the rest for the validation set. We resize images into 256x256
size and pack them into the rec file. The scripts to prepare the data is as
following.</p>
<div class="highlight-sh"><div class="highlight"><pre><span></span>wget http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar
tar -xf 256_ObjectCategories.tar
mkdir -p caltech_256_train_60
<span class="k">for</span> i in 256_ObjectCategories/*<span class="p">;</span> <span class="k">do</span>
<span class="nv">c</span><span class="o">=</span><span class="sb">`</span>basename <span class="nv">$i</span><span class="sb">`</span>
mkdir -p caltech_256_train_60/<span class="nv">$c</span>
<span class="k">for</span> j in <span class="sb">`</span>ls <span class="nv">$i</span>/*.jpg <span class="p">|</span> shuf <span class="p">|</span> head -n <span class="m">60</span><span class="sb">`</span><span class="p">;</span> <span class="k">do</span>
mv <span class="nv">$j</span> caltech_256_train_60/<span class="nv">$c</span>/
<span class="k">done</span>
<span class="k">done</span>
python ~/mxnet/tools/im2rec.py --list True --recursive True caltech-256-60-train caltech_256_train_60/
python ~/mxnet/tools/im2rec.py --list True --recursive True caltech-256-60-val 256_ObjectCategories/
python ~/mxnet/tools/im2rec.py --resize <span class="m">256</span> --quality <span class="m">90</span> --num-thread <span class="m">16</span> caltech-256-60-val 256_ObjectCategories/
python ~/mxnet/tools/im2rec.py --resize <span class="m">256</span> --quality <span class="m">90</span> --num-thread <span class="m">16</span> caltech-256-60-train caltech_256_train_60/
</pre></div>
</div>
<p>The following code downloads the pregenerated rec files. It may take a few minutes.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span><span class="o">,</span> <span class="nn">urllib</span>
<span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">url</span><span class="p">):</span>
<span class="n">filename</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span>
<span class="n">urllib</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="n">download</span><span class="p">(</span><span class="s1">'http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec'</span><span class="p">)</span>
<span class="n">download</span><span class="p">(</span><span class="s1">'http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec'</span><span class="p">)</span>
</pre></div>
</div>
<p>Next, we define the function which returns the data iterators:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="k">def</span> <span class="nf">get_iterators</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">data_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)):</span>
<span class="n">train</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">ImageRecordIter</span><span class="p">(</span>
<span class="n">path_imgrec</span> <span class="o">=</span> <span class="s1">'./caltech-256-60-train.rec'</span><span class="p">,</span>
<span class="n">data_name</span> <span class="o">=</span> <span class="s1">'data'</span><span class="p">,</span>
<span class="n">label_name</span> <span class="o">=</span> <span class="s1">'softmax_label'</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">data_shape</span> <span class="o">=</span> <span class="n">data_shape</span><span class="p">,</span>
<span class="n">shuffle</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span>
<span class="n">rand_crop</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span>
<span class="n">rand_mirror</span> <span class="o">=</span> <span class="bp">True</span><span class="p">)</span>
<span class="n">val</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">ImageRecordIter</span><span class="p">(</span>
<span class="n">path_imgrec</span> <span class="o">=</span> <span class="s1">'./caltech-256-60-val.rec'</span><span class="p">,</span>
<span class="n">data_name</span> <span class="o">=</span> <span class="s1">'data'</span><span class="p">,</span>
<span class="n">label_name</span> <span class="o">=</span> <span class="s1">'softmax_label'</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">data_shape</span> <span class="o">=</span> <span class="n">data_shape</span><span class="p">,</span>
<span class="n">rand_crop</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
<span class="n">rand_mirror</span> <span class="o">=</span> <span class="bp">False</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span>
</pre></div>
</div>
<p>We then download a pretrained 50-layer ResNet model and load it into memory. Note
that if <code class="docutils literal"><span class="pre">load_checkpoint</span></code> reports an error, we can remove the downloaded files
and try <code class="docutils literal"><span class="pre">get_model</span></code> again.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_model</span><span class="p">(</span><span class="n">prefix</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
<span class="n">download</span><span class="p">(</span><span class="n">prefix</span><span class="o">+</span><span class="s1">'-symbol.json'</span><span class="p">)</span>
<span class="n">download</span><span class="p">(</span><span class="n">prefix</span><span class="o">+</span><span class="s1">'-</span><span class="si">%04d</span><span class="s1">.params'</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span><span class="p">,))</span>
<span class="n">get_model</span><span class="p">(</span><span class="s1">'http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="s1">'resnet-50'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="train">
<span id="train"></span><h2>Train<a class="headerlink" href="#train" title="Permalink to this headline"></a></h2>
<p>We first define a function which replaces the last fully-connected layer for a given network.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_fine_tune_model</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">layer_name</span><span class="o">=</span><span class="s1">'flatten0'</span><span class="p">):</span>
<span class="sd">"""</span>
<span class="sd"> symbol: the pretrained network symbol</span>
<span class="sd"> arg_params: the argument parameters of the pretrained model</span>
<span class="sd"> num_classes: the number of classes for the fine-tune datasets</span>
<span class="sd"> layer_name: the layer name before the last fully-connected layer</span>
<span class="sd"> """</span>
<span class="n">all_layers</span> <span class="o">=</span> <span class="n">symbol</span><span class="o">.</span><span class="n">get_internals</span><span class="p">()</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">all_layers</span><span class="p">[</span><span class="n">layer_name</span><span class="o">+</span><span class="s1">'_output'</span><span class="p">]</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">net</span><span class="p">,</span> <span class="n">num_hidden</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'fc1'</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">SoftmaxOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">net</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">)</span>
<span class="n">new_args</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">({</span><span class="n">k</span><span class="p">:</span><span class="n">arg_params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">arg_params</span> <span class="k">if</span> <span class="s1">'fc1'</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span><span class="p">})</span>
<span class="k">return</span> <span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">new_args</span><span class="p">)</span>
</pre></div>
</div>
<p>Now we create a module. We first call <code class="docutils literal"><span class="pre">init_params</span></code> to randomly initialize parameters, next use <code class="docutils literal"><span class="pre">set_params</span></code> to replace all parameters except for the last fully-connected layer with pretrained model.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">logging</span>
<span class="n">head</span> <span class="o">=</span> <span class="s1">'</span><span class="si">%(asctime)-15s</span><span class="s1"> </span><span class="si">%(message)s</span><span class="s1">'</span>
<span class="n">logging</span><span class="o">.</span><span class="n">basicConfig</span><span class="p">(</span><span class="n">level</span><span class="o">=</span><span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span><span class="p">,</span> <span class="n">format</span><span class="o">=</span><span class="n">head</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">symbol</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">):</span>
<span class="n">devs</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_gpus</span><span class="p">)]</span>
<span class="n">mod</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">symbol</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">devs</span><span class="p">)</span>
<span class="n">mod</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span>
<span class="n">num_epoch</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="n">arg_params</span><span class="o">=</span><span class="n">arg_params</span><span class="p">,</span>
<span class="n">aux_params</span><span class="o">=</span><span class="n">aux_params</span><span class="p">,</span>
<span class="n">allow_missing</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">batch_end_callback</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">Speedometer</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span>
<span class="n">kvstore</span><span class="o">=</span><span class="s1">'device'</span><span class="p">,</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span><span class="s1">'learning_rate'</span><span class="p">:</span><span class="mf">0.01</span><span class="p">},</span>
<span class="n">initializer</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Xavier</span><span class="p">(</span><span class="n">rnd_type</span><span class="o">=</span><span class="s1">'gaussian'</span><span class="p">,</span> <span class="n">factor_type</span><span class="o">=</span><span class="s2">"in"</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="n">eval_metric</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">)</span>
<span class="n">metric</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">metric</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">()</span>
<span class="k">return</span> <span class="n">mod</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">metric</span><span class="p">)</span>
</pre></div>
</div>
<p>Then we can start training. We use AWS EC2 g2.8xlarge, which has 8 GPUs.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">num_classes</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">batch_per_gpu</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">num_gpus</span> <span class="o">=</span> <span class="mi">8</span>
<span class="p">(</span><span class="n">new_sym</span><span class="p">,</span> <span class="n">new_args</span><span class="p">)</span> <span class="o">=</span> <span class="n">get_fine_tune_model</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_per_gpu</span> <span class="o">*</span> <span class="n">num_gpus</span>
<span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span> <span class="o">=</span> <span class="n">get_iterators</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">mod_score</span> <span class="o">=</span> <span class="n">fit</span><span class="p">(</span><span class="n">new_sym</span><span class="p">,</span> <span class="n">new_args</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">mod_score</span> <span class="o">></span> <span class="mf">0.77</span><span class="p">,</span> <span class="s2">"Low training accuracy."</span>
</pre></div>
</div>
<p>You will see that, after only 8 epochs, we can get 78% validation accuracy. This
matches the state-of-the-art results training on caltech-256 alone,
e.g. <a class="reference external" href="http://www.robots.ox.ac.uk/~vgg/research/deep_eval/">VGG</a>.</p>
<p>Next, we try to use another pretrained model. This model was trained on the
complete Imagenet dataset, which is 10x larger than the Imagenet 1K classes
version, and uses a 3x deeper Resnet architecture.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">get_model</span><span class="p">(</span><span class="s1">'http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">aux_params</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="s1">'resnet-152'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="p">(</span><span class="n">new_sym</span><span class="p">,</span> <span class="n">new_args</span><span class="p">)</span> <span class="o">=</span> <span class="n">get_fine_tune_model</span><span class="p">(</span><span class="n">sym</span><span class="p">,</span> <span class="n">arg_params</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
<span class="n">mod_score</span> <span class="o">=</span> <span class="n">fit</span><span class="p">(</span><span class="n">new_sym</span><span class="p">,</span> <span class="n">new_args</span><span class="p">,</span> <span class="n">aux_params</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">val</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">mod_score</span> <span class="o">></span> <span class="mf">0.86</span><span class="p">,</span> <span class="s2">"Low training accuracy."</span>
</pre></div>
</div>
<p>As can be seen, even for a single data epoch, it reaches 83% validation
accuracy. After 8 epoches, the validation accuracy increases to 86.4%.</p>
<div class="btn-group" role="group">
<div class="download_btn"><a download="finetune_python.ipynb" href="finetune_python.ipynb"><span class="glyphicon glyphicon-download-alt"></span> finetune_python.ipynb</a></div></div></div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Fine-tune with Pretrained Models</a><ul>
<li><a class="reference internal" href="#prepare-data">Prepare data</a></li>
<li><a class="reference internal" href="#train">Train</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../_static/js/search.js" type="text/javascript"></script>
<script src="../_static/js/navbar.js" type="text/javascript"></script>
<script src="../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>