tree: 5d78a518bedbb4121311df39be1382c00c9a46f5 [path history] [tgz]
  1. data.py
  2. get_cub200_data.sh
  3. model.py
  4. README.md
  5. train.py
example/gluon/embedding_learning/README.md

Image Embedding Learning

This example implements embedding learning based on a Margin-based Loss with distance weighted sampling (Wu et al, 2017). The model obtains a validation Recall@1 of ~64% on the Caltech-UCSD Birds-200-2011 dataset.

Usage

Download the data

Note: the dataset is from Caltech-UCSD Birds 200. These datasets are copyright Caltech Computational Vision Group and licensed CC BY 4.0 Attribution. See original dataset source for details

./get_cub200_data.sh

Example runs and the results:

python3 train.py --data-path=data/CUB_200_2011 --gpus=0,1 --use-pretrained

python train.py --help gives the following arguments:

optional arguments:
  -h, --help            show this help message and exit
  --data-path DATA_PATH
                        path of data.
  --embed-dim EMBED_DIM
                        dimensionality of image embedding. default is 128.
  --batch-size BATCH_SIZE
                        training batch size per device (CPU/GPU). default is
                        70.
  --batch-k BATCH_K     number of images per class in a batch. default is 5.
  --gpus GPUS           list of gpus to use, e.g. 0 or 0,2,5. empty means
                        using cpu.
  --epochs EPOCHS       number of training epochs. default is 20.
  --optimizer OPTIMIZER
                        optimizer. default is adam.
  --lr LR               learning rate. default is 0.0001.
  --lr-beta LR_BETA     learning rate for the beta in margin based loss.
                        default is 0.1.
  --margin MARGIN       margin for the margin based loss. default is 0.2.
  --beta BETA           initial value for beta. default is 1.2.
  --nu NU               regularization parameter for beta. default is 0.0.
  --factor FACTOR       learning rate schedule factor. default is 0.5.
  --steps STEPS         epochs to update learning rate. default is
                        12,14,16,18.
  --wd WD               weight decay rate. default is 0.0001.
  --seed SEED           random seed to use. default=123.
  --model MODEL         type of model to use. see vision_model for options.
  --save-model-prefix SAVE_MODEL_PREFIX
                        prefix of models to be saved.
  --use-pretrained      enable using pretrained model from gluon.
  --kvstore KVSTORE     kvstore to use for trainer.
  --log-interval LOG_INTERVAL
                        number of batches to wait before logging.

Learned embeddings

The following visualizes the learned embeddings with t-SNE.

alt text

Citation

Sampling Matters in Deep Embedding Learning [paper] [project]
Chao-Yuan Wu, R. Manmatha, Alexander J. Smola and Philipp Krähenbühl