This example implements embedding learning based on a Margin-based Loss with distance weighted sampling (Wu et al, 2017). The model obtains a validation Recall@1 of ~64% on the Caltech-UCSD Birds-200-2011 dataset.
Download the data
Note: the dataset is from Caltech-UCSD Birds 200. These datasets are copyright Caltech Computational Vision Group and licensed CC BY 4.0 Attribution. See original dataset source for details
./get_cub200_data.sh
Example runs and the results:
python3 train.py --data-path=data/CUB_200_2011 --gpus=0,1 --use-pretrained
python train.py --help
gives the following arguments:
optional arguments: -h, --help show this help message and exit --data-path DATA_PATH path of data. --embed-dim EMBED_DIM dimensionality of image embedding. default is 128. --batch-size BATCH_SIZE training batch size per device (CPU/GPU). default is 70. --batch-k BATCH_K number of images per class in a batch. default is 5. --gpus GPUS list of gpus to use, e.g. 0 or 0,2,5. empty means using cpu. --epochs EPOCHS number of training epochs. default is 20. --optimizer OPTIMIZER optimizer. default is adam. --lr LR learning rate. default is 0.0001. --lr-beta LR_BETA learning rate for the beta in margin based loss. default is 0.1. --margin MARGIN margin for the margin based loss. default is 0.2. --beta BETA initial value for beta. default is 1.2. --nu NU regularization parameter for beta. default is 0.0. --factor FACTOR learning rate schedule factor. default is 0.5. --steps STEPS epochs to update learning rate. default is 12,14,16,18. --wd WD weight decay rate. default is 0.0001. --seed SEED random seed to use. default=123. --model MODEL type of model to use. see vision_model for options. --save-model-prefix SAVE_MODEL_PREFIX prefix of models to be saved. --use-pretrained enable using pretrained model from gluon. --kvstore KVSTORE kvstore to use for trainer. --log-interval LOG_INTERVAL number of batches to wait before logging.
The following visualizes the learned embeddings with t-SNE.
Sampling Matters in Deep Embedding Learning [paper] [project]
Chao-Yuan Wu, R. Manmatha, Alexander J. Smola and Philipp Krähenbühl