| import argparse |
| import pprint |
| import mxnet as mx |
| import numpy as np |
| |
| from rcnn.logger import logger |
| from rcnn.config import config, default, generate_config |
| from rcnn.symbol import * |
| from rcnn.core import callback, metric |
| from rcnn.core.loader import AnchorLoader |
| from rcnn.core.module import MutableModule |
| from rcnn.utils.load_data import load_gt_roidb, merge_roidb, filter_roidb |
| from rcnn.utils.load_model import load_param |
| |
| |
| def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, |
| lr=0.001, lr_step='5'): |
| # setup config |
| config.TRAIN.BATCH_IMAGES = 1 |
| config.TRAIN.BATCH_ROIS = 128 |
| config.TRAIN.END2END = True |
| config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True |
| |
| # load symbol |
| sym = eval('get_' + args.network + '_train')(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS) |
| feat_sym = sym.get_internals()['rpn_cls_score_output'] |
| |
| # setup multi-gpu |
| batch_size = len(ctx) |
| input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size |
| |
| # print config |
| logger.info(pprint.pformat(config)) |
| |
| # load dataset and prepare imdb for training |
| image_sets = [iset for iset in args.image_set.split('+')] |
| roidbs = [load_gt_roidb(args.dataset, image_set, args.root_path, args.dataset_path, |
| flip=not args.no_flip) |
| for image_set in image_sets] |
| roidb = merge_roidb(roidbs) |
| roidb = filter_roidb(roidb) |
| |
| # load training data |
| train_data = AnchorLoader(feat_sym, roidb, batch_size=input_batch_size, shuffle=not args.no_shuffle, |
| ctx=ctx, work_load_list=args.work_load_list, |
| feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES, |
| anchor_ratios=config.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING) |
| |
| # infer max shape |
| max_data_shape = [('data', (input_batch_size, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))] |
| max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape) |
| max_data_shape.append(('gt_boxes', (input_batch_size, 100, 5))) |
| logger.info('providing maximum shape %s %s' % (max_data_shape, max_label_shape)) |
| |
| # infer shape |
| data_shape_dict = dict(train_data.provide_data + train_data.provide_label) |
| arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict) |
| arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape)) |
| out_shape_dict = dict(zip(sym.list_outputs(), out_shape)) |
| aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape)) |
| logger.info('output shape %s' % pprint.pformat(out_shape_dict)) |
| |
| # load and initialize params |
| if args.resume: |
| arg_params, aux_params = load_param(prefix, begin_epoch, convert=True) |
| else: |
| arg_params, aux_params = load_param(pretrained, epoch, convert=True) |
| arg_params['rpn_conv_3x3_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight']) |
| arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_conv_3x3_bias']) |
| arg_params['rpn_cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight']) |
| arg_params['rpn_cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_cls_score_bias']) |
| arg_params['rpn_bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight']) |
| arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_bbox_pred_bias']) |
| arg_params['cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['cls_score_weight']) |
| arg_params['cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['cls_score_bias']) |
| arg_params['bbox_pred_weight'] = mx.random.normal(0, 0.001, shape=arg_shape_dict['bbox_pred_weight']) |
| arg_params['bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['bbox_pred_bias']) |
| |
| # check parameter shapes |
| for k in sym.list_arguments(): |
| if k in data_shape_dict: |
| continue |
| assert k in arg_params, k + ' not initialized' |
| assert arg_params[k].shape == arg_shape_dict[k], \ |
| 'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape) |
| for k in sym.list_auxiliary_states(): |
| assert k in aux_params, k + ' not initialized' |
| assert aux_params[k].shape == aux_shape_dict[k], \ |
| 'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape) |
| |
| # create solver |
| fixed_param_prefix = config.FIXED_PARAMS |
| data_names = [k[0] for k in train_data.provide_data] |
| label_names = [k[0] for k in train_data.provide_label] |
| mod = MutableModule(sym, data_names=data_names, label_names=label_names, |
| logger=logger, context=ctx, work_load_list=args.work_load_list, |
| max_data_shapes=max_data_shape, max_label_shapes=max_label_shape, |
| fixed_param_prefix=fixed_param_prefix) |
| |
| # decide training params |
| # metric |
| rpn_eval_metric = metric.RPNAccMetric() |
| rpn_cls_metric = metric.RPNLogLossMetric() |
| rpn_bbox_metric = metric.RPNL1LossMetric() |
| eval_metric = metric.RCNNAccMetric() |
| cls_metric = metric.RCNNLogLossMetric() |
| bbox_metric = metric.RCNNL1LossMetric() |
| eval_metrics = mx.metric.CompositeEvalMetric() |
| for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]: |
| eval_metrics.add(child_metric) |
| # callback |
| batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent) |
| means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES) |
| stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES) |
| epoch_end_callback = callback.do_checkpoint(prefix, means, stds) |
| # decide learning rate |
| base_lr = lr |
| lr_factor = 0.1 |
| lr_epoch = [int(epoch) for epoch in lr_step.split(',')] |
| lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch] |
| lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff))) |
| lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff] |
| logger.info('lr %f lr_epoch_diff %s lr_iters %s' % (lr, lr_epoch_diff, lr_iters)) |
| lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor) |
| # optimizer |
| optimizer_params = {'momentum': 0.9, |
| 'wd': 0.0005, |
| 'learning_rate': lr, |
| 'lr_scheduler': lr_scheduler, |
| 'rescale_grad': (1.0 / batch_size), |
| 'clip_gradient': 5} |
| |
| # train |
| mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback, |
| batch_end_callback=batch_end_callback, kvstore=args.kvstore, |
| optimizer='sgd', optimizer_params=optimizer_params, |
| arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch) |
| |
| |
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train Faster R-CNN network') |
| # general |
| parser.add_argument('--network', help='network name', default=default.network, type=str) |
| parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str) |
| args, rest = parser.parse_known_args() |
| generate_config(args.network, args.dataset) |
| parser.add_argument('--image_set', help='image_set name', default=default.image_set, type=str) |
| parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str) |
| parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str) |
| # training |
| parser.add_argument('--frequent', help='frequency of logging', default=default.frequent, type=int) |
| parser.add_argument('--kvstore', help='the kv-store type', default=default.kvstore, type=str) |
| parser.add_argument('--work_load_list', help='work load for different devices', default=None, type=list) |
| parser.add_argument('--no_flip', help='disable flip images', action='store_true') |
| parser.add_argument('--no_shuffle', help='disable random shuffle', action='store_true') |
| parser.add_argument('--resume', help='continue training', action='store_true') |
| # e2e |
| parser.add_argument('--gpus', help='GPU device to train with', default='0', type=str) |
| parser.add_argument('--pretrained', help='pretrained model prefix', default=default.pretrained, type=str) |
| parser.add_argument('--pretrained_epoch', help='pretrained model epoch', default=default.pretrained_epoch, type=int) |
| parser.add_argument('--prefix', help='new model prefix', default=default.e2e_prefix, type=str) |
| parser.add_argument('--begin_epoch', help='begin epoch of training, use with resume', default=0, type=int) |
| parser.add_argument('--end_epoch', help='end epoch of training', default=default.e2e_epoch, type=int) |
| parser.add_argument('--lr', help='base learning rate', default=default.e2e_lr, type=float) |
| parser.add_argument('--lr_step', help='learning rate steps (in epoch)', default=default.e2e_lr_step, type=str) |
| args = parser.parse_args() |
| return args |
| |
| |
| def main(): |
| args = parse_args() |
| logger.info('Called with argument: %s' % args) |
| ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] |
| train_net(args, ctx, args.pretrained, args.pretrained_epoch, args.prefix, args.begin_epoch, args.end_epoch, |
| lr=args.lr, lr_step=args.lr_step) |
| |
| if __name__ == '__main__': |
| main() |