blob: d9059d6b065a03e7efc297b3b11d25334528927f [file] [log] [blame]
import argparse
import logging
import os
import mxnet as mx
from rcnn.callback import Speedometer
from rcnn.config import config
from rcnn.loader import ROIIter
from rcnn.metric import AccuracyMetric, LogLossMetric, SmoothL1LossMetric
from rcnn.module import MutableModule
from rcnn.symbol import get_vgg_rcnn
from utils.load_data import load_ss_roidb, load_rpn_roidb
from utils.load_model import load_checkpoint, load_param
from utils.save_model import save_checkpoint
def train_rcnn(image_set, year, root_path, devkit_path, pretrained, epoch,
prefix, ctx, begin_epoch, end_epoch, frequent, kv_store,
work_load_list=None, resume=False, proposal='rpn'):
# set up logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# load symbol
sym = get_vgg_rcnn()
# setup multi-gpu
config.TRAIN.BATCH_IMAGES *= len(ctx)
config.TRAIN.BATCH_SIZE *= len(ctx)
# load training data
voc, roidb, means, stds = eval('load_' + proposal + '_roidb')(image_set, year, root_path, devkit_path, flip=True)
train_data = ROIIter(roidb, batch_size=config.TRAIN.BATCH_IMAGES, shuffle=True, mode='train',
ctx=ctx, work_load_list=work_load_list)
# infer max shape
max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, 1000, 1000))]
# load pretrained
args, auxs, _ = load_param(pretrained, epoch, convert=True)
# initialize params
if not resume:
input_shapes = {k: v for k, v in train_data.provide_data + train_data.provide_label}
arg_shape, _, _ = sym.infer_shape(**input_shapes)
arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
args['cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['cls_score_weight'])
args['cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['cls_score_bias'])
args['bbox_pred_weight'] = mx.random.normal(0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
args['bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['bbox_pred_bias'])
# prepare training
if config.TRAIN.FINETUNE:
fixed_param_prefix = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']
else:
fixed_param_prefix = ['conv1', 'conv2']
data_names = [k[0] for k in train_data.provide_data]
label_names = [k[0] for k in train_data.provide_label]
batch_end_callback = Speedometer(train_data.batch_size, frequent=frequent)
epoch_end_callback = mx.callback.do_checkpoint(prefix)
if config.TRAIN.HAS_RPN is True:
eval_metric = AccuracyMetric(use_ignore=True, ignore=-1)
cls_metric = LogLossMetric(use_ignore=True, ignore=-1)
else:
eval_metric = AccuracyMetric()
cls_metric = LogLossMetric()
bbox_metric = SmoothL1LossMetric()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metric, cls_metric, bbox_metric]:
eval_metrics.add(child_metric)
optimizer_params = {'momentum': 0.9,
'wd': 0.0005,
'learning_rate': 0.001,
'lr_scheduler': mx.lr_scheduler.FactorScheduler(30000, 0.1),
'rescale_grad': (1.0 / config.TRAIN.BATCH_SIZE)}
# train
mod = MutableModule(sym, data_names=data_names, label_names=label_names,
logger=logger, context=ctx, work_load_list=work_load_list,
max_data_shapes=max_data_shape, fixed_param_prefix=fixed_param_prefix)
mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback, kvstore=kv_store,
optimizer='sgd', optimizer_params=optimizer_params,
arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=end_epoch)
# edit params and save
for epoch in range(begin_epoch + 1, end_epoch + 1):
arg_params, aux_params = load_checkpoint(prefix, epoch)
arg_params['bbox_pred_weight'] = (arg_params['bbox_pred_weight'].T * mx.nd.array(stds)).T
arg_params['bbox_pred_bias'] = arg_params['bbox_pred_bias'] * mx.nd.array(stds) + \
mx.nd.array(means)
save_checkpoint(prefix, epoch, arg_params, aux_params)
def parse_args():
parser = argparse.ArgumentParser(description='Train a Fast R-CNN Network')
parser.add_argument('--image_set', dest='image_set', help='can be trainval or train',
default='trainval', type=str)
parser.add_argument('--year', dest='year', help='can be 2007, 2010, 2012',
default='2007', type=str)
parser.add_argument('--root_path', dest='root_path', help='output data folder',
default=os.path.join(os.getcwd(), 'data'), type=str)
parser.add_argument('--devkit_path', dest='devkit_path', help='VOCdevkit path',
default=os.path.join(os.getcwd(), 'data', 'VOCdevkit'), type=str)
parser.add_argument('--pretrained', dest='pretrained', help='pretrained model prefix',
default=os.path.join(os.getcwd(), 'model', 'vgg16'), type=str)
parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
default=1, type=int)
parser.add_argument('--prefix', dest='prefix', help='new model prefix',
default=os.path.join(os.getcwd(), 'model', 'rcnn'), type=str)
parser.add_argument('--gpus', dest='gpu_ids', help='GPU device to train with',
default='0', type=str)
parser.add_argument('--begin_epoch', dest='begin_epoch', help='begin epoch of training',
default=0, type=int)
parser.add_argument('--end_epoch', dest='end_epoch', help='end epoch of training',
default=8, type=int)
parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
default=20, type=int)
parser.add_argument('--kv_store', dest='kv_store', help='the kv-store type',
default='device', type=str)
parser.add_argument('--work_load_list', dest='work_load_list', help='work load for different devices',
default=None, type=list)
parser.add_argument('--finetune', dest='finetune', help='second round finetune', action='store_true')
parser.add_argument('--resume', dest='resume', help='continue training', action='store_true')
parser.add_argument('--proposal', dest='proposal', help='can be ss for selective search or rpn',
default='rpn', type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
ctx = [mx.gpu(int(i)) for i in args.gpu_ids.split(',')]
if args.finetune:
config.TRAIN.FINETUNE = True
train_rcnn(args.image_set, args.year, args.root_path, args.devkit_path, args.pretrained, args.epoch,
args.prefix, ctx, args.begin_epoch, args.end_epoch, args.frequent,
args.kv_store, args.work_load_list, args.resume, args.proposal)