blob: bf174c762cbf9ca03bbfa7b38027c150df05495f [file] [log] [blame]
import sys
sys.path.insert(0, "../../python")
import os.path
import mxnet as mx
from config_util import get_checkpoint_path, parse_contexts
from stt_metric import STTMetric
class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
"""A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr`
dynamically based on performance on the validation set.
"""
def __init__(self, dynamic_lr, effective_sample_count=1, momentum=0.9, optimizer="nag"):
super(SimpleLRScheduler, self).__init__()
self.dynamic_lr = dynamic_lr
self.effective_sample_count = effective_sample_count
self.momentum = momentum
self.optimizer = optimizer
def __call__(self, num_update):
return self.dynamic_lr / self.effective_sample_count
def score_with_state_forwarding(module, eval_data, eval_metric):
eval_data.reset()
eval_metric.reset()
for eval_batch in eval_data:
module.forward(eval_batch, is_train=False)
module.update_metric(eval_metric, eval_batch.label)
# copy over states
outputs = module.get_outputs()
# outputs[0] is softmax, 1:end are states
for i in range(1, len(outputs)):
outputs[i].copyto(eval_data.init_state_arrays[i - 1])
def get_initializer(args):
init_type = getattr(mx.initializer, args.config.get('train', 'initializer'))
init_scale = args.config.getfloat('train', 'init_scale')
if init_type is mx.initializer.Xavier:
return mx.initializer.Xavier(magnitude=init_scale, factor_type=args.config.get('train', 'factor_type'))
return init_type(init_scale)
def do_training(args, module, data_train, data_val, begin_epoch=0):
from distutils.dir_util import mkpath
from log_util import LogUtil
log = LogUtil().getlogger()
mkpath(os.path.dirname(get_checkpoint_path(args)))
seq_len = args.config.get('arch', 'max_t_count')
batch_size = args.config.getint('common', 'batch_size')
save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
contexts = parse_contexts(args)
num_gpu = len(contexts)
eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len)
optimizer = args.config.get('train', 'optimizer')
momentum = args.config.getfloat('train', 'momentum')
learning_rate = args.config.getfloat('train', 'learning_rate')
lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)
n_epoch = begin_epoch
num_epoch = args.config.getint('train', 'num_epoch')
clip_gradient = args.config.getfloat('train', 'clip_gradient')
weight_decay = args.config.getfloat('train', 'weight_decay')
save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
show_every = args.config.getint('train', 'show_every')
if clip_gradient == 0:
clip_gradient = None
module.bind(data_shapes=data_train.provide_data,
label_shapes=data_train.provide_label,
for_training=True)
if begin_epoch == 0:
module.init_params(initializer=get_initializer(args))
def reset_optimizer():
module.init_optimizer(kvstore='device',
optimizer=args.config.get('train', 'optimizer'),
optimizer_params={'clip_gradient': clip_gradient,
'wd': weight_decay},
force_init=True)
reset_optimizer()
while True:
if n_epoch >= num_epoch:
break
eval_metric.reset()
log.info('---------train---------')
for nbatch, data_batch in enumerate(data_train):
if data_batch.effective_sample_count is not None:
lr_scheduler.effective_sample_count = data_batch.effective_sample_count
module.forward_backward(data_batch)
module.update()
if (nbatch+1) % show_every == 0:
module.update_metric(eval_metric, data_batch.label)
# commented for Libri_sample data set to see only train cer
log.info('---------validation---------')
for nbatch, data_batch in enumerate(data_val):
module.update_metric(eval_metric, data_batch.label)
#module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)
data_train.reset()
# save checkpoints
if n_epoch % save_checkpoint_every_n_epoch == 0:
log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)
n_epoch += 1
log.info('FINISH')