blob: 04244d15b0b98a499732ac04d57c252b8b6e4d52 [file] [log] [blame]
# Fine-tune with Pretrained Models
Many of the exciting deep learning algorithms for computer vision require
massive datasets for training. The most popular benchmark dataset,
[ImageNet](http://www.image-net.org/), for example, contains one million images
from one thousand categories. But for any practical problem, we typically have
access to comparatively small datasets. In these cases, if we were to train a
neural network's weights from scratch, starting from random initialized
parameters, we would overfit the training set badly.
One approach to get around this problem is to first pretrain a deep net on a
large-scale dataset, like ImageNet. Then, given a new dataset, we can start
with these pretrained weights when training on our new task. This process is
commonly called _fine-tuning_. There are a number of variations of fine-tuning.
Sometimes, the initial neural network is used only as a _feature extractor_.
That means that we freeze every layer prior to the output layer and simply learn
a new output layer. In [another document](https://github.com/dmlc/mxnet-notebooks/blob/master/python/how_to/predict.ipynb), we explained how to
do this kind of feature extraction. Another approach is to update all of
the network's weights for the new task, and that's the approach we demonstrate in
this document.
To fine-tune a network, we must first replace the last fully-connected layer
with a new one that outputs the desired number of classes. We initialize its
weights randomly. Then we continue training as normal. Sometimes it's common to
use a smaller learning rate based on the intuition that we may already be close
to a good result.
In this demonstration, we'll fine-tune a model pretrained on ImageNet to the
smaller caltech-256 dataset. Following this example, you can fine-tune to other
datasets, even for strikingly different applications such as face
identification.
We will show that, even with simple hyper-parameters setting, we can match and
even outperform state-of-the-art results on caltech-256.
| Network | Accuracy |
| --- | --- |
| Resnet-50 | 77.4% |
| Resnet-152 | 86.4% |
## Prepare data
We follow the standard protocol to sample 60 images from each class as the
training set, and the rest for the validation set. We resize images into 256x256
size and pack them into the rec file. The scripts to prepare the data is as
following.
> In order to successfully run the following bash script on Windows please use https://cygwin.com/install.html .
```sh
wget http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar
tar -xf 256_ObjectCategories.tar
mkdir -p caltech_256_train_60
for i in 256_ObjectCategories/*; do
c=`basename $i`
mkdir -p caltech_256_train_60/$c
for j in `ls $i/*.jpg | shuf | head -n 60`; do
mv $j caltech_256_train_60/$c/
done
done
python ~/mxnet/tools/im2rec.py --list --recursive caltech-256-60-train caltech_256_train_60/
python ~/mxnet/tools/im2rec.py --list --recursive caltech-256-60-val 256_ObjectCategories/
python ~/mxnet/tools/im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-val 256_ObjectCategories/
python ~/mxnet/tools/im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-train caltech_256_train_60/
```
The following code downloads the pregenerated rec files. It may take a few minutes.
```python
import os, sys
if sys.version_info[0] >= 3:
from urllib.request import urlretrieve
else:
from urllib import urlretrieve
def download(url):
filename = url.split("/")[-1]
if not os.path.exists(filename):
urlretrieve(url, filename)
download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')
download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')
```
Next, we define the function which returns the data iterators:
```python
import mxnet as mx
def get_iterators(batch_size, data_shape=(3, 224, 224)):
train = mx.io.ImageRecordIter(
path_imgrec = './caltech-256-60-train.rec',
data_name = 'data',
label_name = 'softmax_label',
batch_size = batch_size,
data_shape = data_shape,
shuffle = True,
rand_crop = True,
rand_mirror = True)
val = mx.io.ImageRecordIter(
path_imgrec = './caltech-256-60-val.rec',
data_name = 'data',
label_name = 'softmax_label',
batch_size = batch_size,
data_shape = data_shape,
rand_crop = False,
rand_mirror = False)
return (train, val)
```
We then download a pretrained 50-layer ResNet model and load it into memory. Note
that if `load_checkpoint` reports an error, we can remove the downloaded files
and try `get_model` again.
```python
def get_model(prefix, epoch):
download(prefix+'-symbol.json')
download(prefix+'-%04d.params' % (epoch,))
get_model('http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50', 0)
```
## Train
We first define a function which replaces the last fully-connected layer for a given network.
```python
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
"""
symbol: the pretrained network symbol
arg_params: the argument parameters of the pretrained model
num_classes: the number of classes for the fine-tune datasets
layer_name: the layer name before the last fully-connected layer
"""
all_layers = symbol.get_internals()
net = all_layers[layer_name+'_output']
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
return (net, new_args)
```
Now we create a module. Note we pass the existing parameters from the loaded model via the `arg_params` argument.
The parameters of the last fully-connected layer will be randomly initialized by the `initializer`.
```python
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
devs = [mx.gpu(i) for i in range(num_gpus)]
mod = mx.mod.Module(symbol=symbol, context=devs)
mod.fit(train, val,
num_epoch=8,
arg_params=arg_params,
aux_params=aux_params,
allow_missing=True,
batch_end_callback = mx.callback.Speedometer(batch_size, 10),
kvstore='device',
optimizer='sgd',
optimizer_params={'learning_rate':0.01},
initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
eval_metric='acc')
metric = mx.metric.Accuracy()
return mod.score(val, metric)
```
Then we can start training. We use AWS EC2 g2.8xlarge, which has 8 GPUs.
```python
num_classes = 256
batch_per_gpu = 16
num_gpus = 8
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)
batch_size = batch_per_gpu * num_gpus
(train, val) = get_iterators(batch_size)
mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)
assert mod_score > 0.77, "Low training accuracy."
```
You will see that, after only 8 epochs, we can get 78% validation accuracy. This
matches the state-of-the-art results training on caltech-256 alone,
e.g. [VGG](http://www.robots.ox.ac.uk/~vgg/research/deep_eval/).
Next, we try to use another pretrained model. This model was trained on the
complete Imagenet dataset, which is 10x larger than the Imagenet 1K classes
version, and uses a 3x deeper Resnet architecture.
```python
get_model('http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)
mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)
assert mod_score > 0.86, "Low training accuracy."
```
As can be seen, even for a single data epoch, it reaches 83% validation
accuracy. After 8 epoches, the validation accuracy increases to 86.4%.
<!-- INSERT SOURCE DOWNLOAD BUTTONS -->