blob: bce424b3eac7494fcde8bac34d2816bec54c5026 [file] [log] [blame]
import tools.find_mxnet
import mxnet as mx
import logging
import sys
import os
import importlib
from initializer import ScaleInitializer
from metric import MultiBoxMetric
from dataset.iterator import DetIter
from dataset.pascal_voc import PascalVoc
from dataset.concat_db import ConcatDB
from config.config import cfg
def load_pascal(image_set, year, devkit_path, shuffle=False):
"""
wrapper function for loading pascal voc dataset
Parameters:
----------
image_set : str
train, trainval...
year : str
2007, 2012 or combinations splitted by comma
devkit_path : str
root directory of dataset
shuffle : bool
whether to shuffle initial list
Returns:
----------
Imdb
"""
image_set = [y.strip() for y in image_set.split(',')]
assert image_set, "No image_set specified"
year = [y.strip() for y in year.split(',')]
assert year, "No year specified"
# make sure (# sets == # years)
if len(image_set) > 1 and len(year) == 1:
year = year * len(image_set)
if len(image_set) == 1 and len(year) > 1:
image_set = image_set * len(year)
assert len(image_set) == len(year), "Number of sets and year mismatch"
imdbs = []
for s, y in zip(image_set, year):
imdbs.append(PascalVoc(s, y, devkit_path, shuffle, is_train=True))
if len(imdbs) > 1:
return ConcatDB(imdbs, shuffle)
else:
return imdbs[0]
def convert_pretrained(name, args):
"""
Special operations need to be made due to name inconsistance, etc
Parameters:
---------
args : dict
loaded arguments
Returns:
---------
processed arguments as dict
"""
if name == 'vgg16_reduced':
args['conv6_bias'] = args.pop('fc6_bias')
args['conv6_weight'] = args.pop('fc6_weight')
args['conv7_bias'] = args.pop('fc7_bias')
args['conv7_weight'] = args.pop('fc7_weight')
del args['fc8_weight']
del args['fc8_bias']
return args
def train_net(net, dataset, image_set, year, devkit_path, batch_size,
data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix,
ctx, begin_epoch, end_epoch, frequent, learning_rate,
momentum, weight_decay, val_set, val_year,
lr_refactor_epoch, lr_refactor_ratio,
iter_monitor=0, log_file=None):
"""
Wrapper for training module
Parameters:
---------
net : mx.Symbol
training network
dataset : str
pascal, imagenet...
image_set : str
train, trainval...
year : str
2007, 2012 or combinations splitted by comma
devkit_path : str
root directory of dataset
batch_size : int
training batch size
data_shape : int or (int, int)
resize image size
mean_pixels : tuple (float, float, float)
mean pixel values in (R, G, B)
resume : int
if > 0, will load trained epoch with name given by prefix
finetune : int
if > 0, will load trained epoch with name given by prefix, in this mode
all convolutional layers except the last(prediction layer) are fixed
pretrained : str
prefix of pretrained model name
epoch : int
epoch of pretrained model
prefix : str
prefix of new model
ctx : mx.gpu(?) or list of mx.gpu(?)
training context
begin_epoch : int
begin epoch, default should be 0
end_epoch : int
when to stop training
frequent : int
frequency to log out batch_end_callback
learning_rate : float
learning rate, will be divided by batch_size automatically
momentum : float
(0, 1), training momentum
weight_decay : float
decay weights regardless of gradient
val_set : str
similar to image_set, used for validation
val_year : str
similar to year, used for validation
lr_refactor_epoch : int
number of epoch to change learning rate
lr_refactor_ratio : float
new_lr = old_lr * lr_refactor_ratio
iter_monitor : int
if larger than 0, will print weights/gradients every iter_monitor iters
log_file : str
log to file if not None
Returns:
---------
None
"""
# set up logger
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if log_file:
fh = logging.FileHandler(log_file)
logger.addHandler(fh)
# kvstore
kv = mx.kvstore.create("device")
# check args
if isinstance(data_shape, int):
data_shape = (data_shape, data_shape)
assert len(data_shape) == 2, "data_shape must be (h, w) tuple or list or int"
prefix += '_' + str(data_shape[0])
if isinstance(mean_pixels, (int, float)):
mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
assert len(mean_pixels) == 3, "must provide all RGB mean values"
# load dataset
if dataset == 'pascal':
imdb = load_pascal(image_set, year, devkit_path, cfg.TRAIN.INIT_SHUFFLE)
if val_set and val_year:
val_imdb = load_pascal(val_set, val_year, devkit_path, False)
else:
val_imdb = None
else:
raise NotImplementedError("Dataset " + dataset + " not supported")
# init data iterator
train_iter = DetIter(imdb, batch_size, data_shape, mean_pixels,
cfg.TRAIN.RAND_SAMPLERS, cfg.TRAIN.RAND_MIRROR,
cfg.TRAIN.EPOCH_SHUFFLE, cfg.TRAIN.RAND_SEED,
is_train=True)
# save per N epoch, avoid saving too frequently
resize_epoch = int(cfg.TRAIN.RESIZE_EPOCH)
if resize_epoch > 1:
batches_per_epoch = ((imdb.num_images - 1) // batch_size + 1) * resize_epoch
train_iter = mx.io.ResizeIter(train_iter, batches_per_epoch)
train_iter = mx.io.PrefetchingIter(train_iter)
if val_imdb:
val_iter = DetIter(val_imdb, batch_size, data_shape, mean_pixels,
cfg.VALID.RAND_SAMPLERS, cfg.VALID.RAND_MIRROR,
cfg.VALID.EPOCH_SHUFFLE, cfg.VALID.RAND_SEED,
is_train=True)
val_iter = mx.io.PrefetchingIter(val_iter)
else:
val_iter = None
# load symbol
sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
net = importlib.import_module("symbol_" + net).get_symbol_train(imdb.num_classes)
# define layers with fixed weight/bias
fixed_param_names = [name for name in net.list_arguments() \
if name.startswith('conv1_') or name.startswith('conv2_')]
# load pretrained or resume from previous state
ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
if resume > 0:
logger.info("Resume training with {} from epoch {}"
.format(ctx_str, resume))
_, args, auxs = mx.model.load_checkpoint(prefix, resume)
begin_epoch = resume
elif finetune > 0:
logger.info("Start finetuning with {} from epoch {}"
.format(ctx_str, finetune))
_, args, auxs = mx.model.load_checkpoint(prefix, finetune)
begin_epoch = finetune
# the prediction convolution layers name starts with relu, so it's fine
fixed_param_names = [name for name in net.list_arguments() \
if name.startswith('conv')]
elif pretrained:
logger.info("Start training with {} from pretrained model {}"
.format(ctx_str, pretrained))
_, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
args = convert_pretrained(pretrained, args)
else:
logger.info("Experimental: start training from scratch with {}"
.format(ctx_str))
args = None
auxs = None
fixed_param_names = None
# helper information
if fixed_param_names:
logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')
# init training module
mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
fixed_param_names=fixed_param_names)
# fit
batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
epoch_end_callback = mx.callback.do_checkpoint(prefix)
iter_refactor = lr_refactor_epoch * imdb.num_images // train_iter.batch_size
lr_scheduler = mx.lr_scheduler.FactorScheduler(iter_refactor, lr_refactor_ratio)
optimizer_params={'learning_rate':learning_rate,
'momentum':momentum,
'wd':weight_decay,
'lr_scheduler':lr_scheduler,
'clip_gradient':None,
'rescale_grad': 1.0}
monitor = mx.mon.Monitor(iter_monitor, pattern=".*") if iter_monitor > 0 else None
initializer = mx.init.Mixed([".*scale", ".*"], \
[ScaleInitializer(), mx.init.Xavier(magnitude=1)])
mod.fit(train_iter,
eval_data=val_iter,
eval_metric=MultiBoxMetric(),
batch_end_callback=batch_end_callback,
epoch_end_callback=epoch_end_callback,
optimizer='sgd',
optimizer_params=optimizer_params,
kvstore = kv,
begin_epoch=begin_epoch,
num_epoch=end_epoch,
initializer=initializer,
arg_params=args,
aux_params=auxs,
allow_missing=True,
monitor=monitor)