blob: 304a43b3d9496f25fb773e10a58d07d3bf696b9f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tools.find_mxnet
import mxnet as mx
import logging
import sys
import os
import importlib
import re
from dataset.iterator import DetRecordIter
from train.metric import MultiBoxMetric
from evaluate.eval_metric import MApMetric, VOC07MApMetric
from config.config import cfg
from symbol.symbol_factory import get_symbol_train
def convert_pretrained(name, args):
"""
Special operations need to be made due to name inconsistance, etc
Parameters:
---------
name : str
pretrained model name
args : dict
loaded arguments
Returns:
---------
processed arguments as dict
"""
return args
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
num_example, batch_size, begin_epoch):
"""
Compute learning rate and refactor scheduler
Parameters:
---------
learning_rate : float
original learning rate
lr_refactor_step : comma separated str
epochs to change learning rate
lr_refactor_ratio : float
lr *= ratio at certain steps
num_example : int
number of training images, used to estimate the iterations given epochs
batch_size : int
training batch size
begin_epoch : int
starting epoch
Returns:
---------
(learning_rate, mx.lr_scheduler) as tuple
"""
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
if lr_refactor_ratio >= 1:
return (learning_rate, None)
else:
lr = learning_rate
epoch_size = num_example // batch_size
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
if lr != learning_rate:
logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
if not steps:
return (lr, None)
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
return (lr, lr_scheduler)
def train_net(net, train_path, num_classes, batch_size,
data_shape, mean_pixels, resume, finetune, pretrained, epoch,
prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
freeze_layer_pattern='',
num_example=10000, label_pad_width=350,
nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
use_difficult=False, class_names=None,
voc07_metric=False, nms_topk=400, force_suppress=False,
train_list="", val_path="", val_list="", iter_monitor=0,
monitor_pattern=".*", log_file=None):
"""
Wrapper for training phase.
Parameters:
----------
net : str
symbol name for the network structure
train_path : str
record file path for training
num_classes : int
number of object classes, not including background
batch_size : int
training batch-size
data_shape : int or tuple
width/height as integer or (3, height, width) tuple
mean_pixels : tuple of floats
mean pixel values for red, green and blue
resume : int
resume from previous checkpoint if > 0
finetune : int
fine-tune from previous checkpoint if > 0
pretrained : str
prefix of pretrained model, including path
epoch : int
load epoch of either resume/finetune/pretrained model
prefix : str
prefix for saving checkpoints
ctx : [mx.cpu()] or [mx.gpu(x)]
list of mxnet contexts
begin_epoch : int
starting epoch for training, should be 0 if not otherwise specified
end_epoch : int
end epoch of training
frequent : int
frequency to print out training status
learning_rate : float
training learning rate
momentum : float
trainig momentum
weight_decay : float
training weight decay param
lr_refactor_ratio : float
multiplier for reducing learning rate
lr_refactor_step : comma separated integers
at which epoch to rescale learning rate, e.g. '30, 60, 90'
freeze_layer_pattern : str
regex pattern for layers need to be fixed
num_example : int
number of training images
label_pad_width : int
force padding training and validation labels to sync their label widths
nms_thresh : float
non-maximum suppression threshold for validation
force_nms : boolean
suppress overlaped objects from different classes
train_list : str
list file path for training, this will replace the embeded labels in record
val_path : str
record file path for validation
val_list : str
list file path for validation, this will replace the embeded labels in record
iter_monitor : int
monitor internal stats in networks if > 0, specified by monitor_pattern
monitor_pattern : str
regex pattern for monitoring network stats
log_file : str
log to file if enabled
"""
# set up logger
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if log_file:
fh = logging.FileHandler(log_file)
logger.addHandler(fh)
# check args
if isinstance(data_shape, int):
data_shape = (3, data_shape, data_shape)
assert len(data_shape) == 3 and data_shape[0] == 3
prefix += '_' + net + '_' + str(data_shape[1])
if isinstance(mean_pixels, (int, float)):
mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
assert len(mean_pixels) == 3, "must provide all RGB mean values"
train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)
if val_path:
val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
else:
val_iter = None
# load symbol
net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)
# define layers with fixed weight/bias
if freeze_layer_pattern.strip():
re_prog = re.compile(freeze_layer_pattern)
fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
else:
fixed_param_names = None
# load pretrained or resume from previous state
ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
if resume > 0:
logger.info("Resume training with {} from epoch {}"
.format(ctx_str, resume))
_, args, auxs = mx.model.load_checkpoint(prefix, resume)
begin_epoch = resume
elif finetune > 0:
logger.info("Start finetuning with {} from epoch {}"
.format(ctx_str, finetune))
_, args, auxs = mx.model.load_checkpoint(prefix, finetune)
begin_epoch = finetune
# the prediction convolution layers name starts with relu, so it's fine
fixed_param_names = [name for name in net.list_arguments() \
if name.startswith('conv')]
elif pretrained:
logger.info("Start training with {} from pretrained model {}"
.format(ctx_str, pretrained))
_, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
args = convert_pretrained(pretrained, args)
else:
logger.info("Experimental: start training from scratch with {}"
.format(ctx_str))
args = None
auxs = None
fixed_param_names = None
# helper information
if fixed_param_names:
logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')
# init training module
mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
fixed_param_names=fixed_param_names)
# fit parameters
batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
epoch_end_callback = mx.callback.do_checkpoint(prefix)
learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
lr_refactor_ratio, num_example, batch_size, begin_epoch)
optimizer_params={'learning_rate':learning_rate,
'momentum':momentum,
'wd':weight_decay,
'lr_scheduler':lr_scheduler,
'clip_gradient':None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
# run fit net, every n epochs we run evaluation network to get mAP
if voc07_metric:
valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
else:
valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
mod.fit(train_iter,
val_iter,
eval_metric=MultiBoxMetric(),
validation_metric=valid_metric,
batch_end_callback=batch_end_callback,
epoch_end_callback=epoch_end_callback,
optimizer='sgd',
optimizer_params=optimizer_params,
begin_epoch=begin_epoch,
num_epoch=end_epoch,
initializer=mx.init.Xavier(),
arg_params=args,
aux_params=auxs,
allow_missing=True,
monitor=monitor)