blob: 24c658aeca4db876da82e42af4a34f0addc4b15f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import pprint
import mxnet as mx
import numpy as np
from rcnn.logger import logger
from rcnn.config import config, default, generate_config
from rcnn.symbol import *
from rcnn.core import callback, metric
from rcnn.core.loader import AnchorLoader
from rcnn.core.module import MutableModule
from rcnn.utils.load_data import load_gt_roidb, merge_roidb, filter_roidb
from rcnn.utils.load_model import load_param
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
lr=0.001, lr_step='5'):
# setup config
config.TRAIN.BATCH_IMAGES = 1
config.TRAIN.BATCH_ROIS = 128
config.TRAIN.END2END = True
config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True
# load symbol
sym = eval('get_' + args.network + '_train')(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
feat_sym = sym.get_internals()['rpn_cls_score_output']
# setup multi-gpu
batch_size = len(ctx)
input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size
# print config
logger.info(pprint.pformat(config))
# load dataset and prepare imdb for training
image_sets = [iset for iset in args.image_set.split('+')]
roidbs = [load_gt_roidb(args.dataset, image_set, args.root_path, args.dataset_path,
flip=not args.no_flip)
for image_set in image_sets]
roidb = merge_roidb(roidbs)
roidb = filter_roidb(roidb)
# load training data
train_data = AnchorLoader(feat_sym, roidb, batch_size=input_batch_size, shuffle=not args.no_shuffle,
ctx=ctx, work_load_list=args.work_load_list,
feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES,
anchor_ratios=config.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)
# infer max shape
max_data_shape = [('data', (input_batch_size, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
max_data_shape.append(('gt_boxes', (input_batch_size, 100, 5)))
logger.info('providing maximum shape %s %s' % (max_data_shape, max_label_shape))
# infer shape
data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
out_shape_dict = dict(zip(sym.list_outputs(), out_shape))
aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
logger.info('output shape %s' % pprint.pformat(out_shape_dict))
# load and initialize params
if args.resume:
arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
else:
arg_params, aux_params = load_param(pretrained, epoch, convert=True)
arg_params['rpn_conv_3x3_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_conv_3x3_bias'])
arg_params['rpn_cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
arg_params['rpn_cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_cls_score_bias'])
arg_params['rpn_bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_bbox_pred_bias'])
arg_params['cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['cls_score_weight'])
arg_params['cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['cls_score_bias'])
arg_params['bbox_pred_weight'] = mx.random.normal(0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
arg_params['bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['bbox_pred_bias'])
# check parameter shapes
for k in sym.list_arguments():
if k in data_shape_dict:
continue
assert k in arg_params, k + ' not initialized'
assert arg_params[k].shape == arg_shape_dict[k], \
'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
for k in sym.list_auxiliary_states():
assert k in aux_params, k + ' not initialized'
assert aux_params[k].shape == aux_shape_dict[k], \
'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)
# create solver
fixed_param_prefix = config.FIXED_PARAMS
data_names = [k[0] for k in train_data.provide_data]
label_names = [k[0] for k in train_data.provide_label]
mod = MutableModule(sym, data_names=data_names, label_names=label_names,
logger=logger, context=ctx, work_load_list=args.work_load_list,
max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
fixed_param_prefix=fixed_param_prefix)
# decide training params
# metric
rpn_eval_metric = metric.RPNAccMetric()
rpn_cls_metric = metric.RPNLogLossMetric()
rpn_bbox_metric = metric.RPNL1LossMetric()
eval_metric = metric.RCNNAccMetric()
cls_metric = metric.RCNNLogLossMetric()
bbox_metric = metric.RCNNL1LossMetric()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
eval_metrics.add(child_metric)
# callback
batch_end_callback = mx.callback.Speedometer(train_data.batch_size, frequent=args.frequent, auto_reset=False)
means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
epoch_end_callback = callback.do_checkpoint(prefix, means, stds)
# decide learning rate
base_lr = lr
lr_factor = 0.1
lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
logger.info('lr %f lr_epoch_diff %s lr_iters %s' % (lr, lr_epoch_diff, lr_iters))
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
# optimizer
optimizer_params = {'momentum': 0.9,
'wd': 0.0005,
'learning_rate': lr,
'lr_scheduler': lr_scheduler,
'rescale_grad': (1.0 / batch_size),
'clip_gradient': 5}
# train
mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback, kvstore=args.kvstore,
optimizer='sgd', optimizer_params=optimizer_params,
arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
def parse_args():
parser = argparse.ArgumentParser(description='Train Faster R-CNN network')
# general
parser.add_argument('--network', help='network name', default=default.network, type=str)
parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
args, rest = parser.parse_known_args()
generate_config(args.network, args.dataset)
parser.add_argument('--image_set', help='image_set name', default=default.image_set, type=str)
parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
# training
parser.add_argument('--frequent', help='frequency of logging', default=default.frequent, type=int)
parser.add_argument('--kvstore', help='the kv-store type', default=default.kvstore, type=str)
parser.add_argument('--work_load_list', help='work load for different devices', default=None, type=list)
parser.add_argument('--no_flip', help='disable flip images', action='store_true')
parser.add_argument('--no_shuffle', help='disable random shuffle', action='store_true')
parser.add_argument('--resume', help='continue training', action='store_true')
# e2e
parser.add_argument('--gpus', help='GPU device to train with', default='0', type=str)
parser.add_argument('--pretrained', help='pretrained model prefix', default=default.pretrained, type=str)
parser.add_argument('--pretrained_epoch', help='pretrained model epoch', default=default.pretrained_epoch, type=int)
parser.add_argument('--prefix', help='new model prefix', default=default.e2e_prefix, type=str)
parser.add_argument('--begin_epoch', help='begin epoch of training, use with resume', default=0, type=int)
parser.add_argument('--end_epoch', help='end epoch of training', default=default.e2e_epoch, type=int)
parser.add_argument('--lr', help='base learning rate', default=default.e2e_lr, type=float)
parser.add_argument('--lr_step', help='learning rate steps (in epoch)', default=default.e2e_lr_step, type=str)
args = parser.parse_args()
return args
def main():
args = parse_args()
logger.info('Called with argument: %s' % args)
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
train_net(args, ctx, args.pretrained, args.pretrained_epoch, args.prefix, args.begin_epoch, args.end_epoch,
lr=args.lr, lr_step=args.lr_step)
if __name__ == '__main__':
main()