blob: 37f00fc4dd9030788910bcfced24cc7fb8aeb62b [file] [log] [blame]
import sys
sys.path.insert(0, "../../python")
import os.path
import mxnet as mx
from config_util import get_checkpoint_path, parse_contexts
from stt_metric import STTMetric
#tensorboard setting
from tensorboard import SummaryWriter
import numpy as np
def get_initializer(args):
init_type = getattr(mx.initializer, args.config.get('train', 'initializer'))
init_scale = args.config.getfloat('train', 'init_scale')
if init_type is mx.initializer.Xavier:
return mx.initializer.Xavier(magnitude=init_scale, factor_type=args.config.get('train', 'factor_type'))
return init_type(init_scale)
class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
"""A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr`
dynamically based on performance on the validation set.
"""
def __init__(self, learning_rate=0.001):
super(SimpleLRScheduler, self).__init__()
self.learning_rate = learning_rate
def __call__(self, num_update):
return self.learning_rate
def do_training(args, module, data_train, data_val, begin_epoch=0):
from distutils.dir_util import mkpath
from log_util import LogUtil
log = LogUtil().getlogger()
mkpath(os.path.dirname(get_checkpoint_path(args)))
seq_len = args.config.get('arch', 'max_t_count')
batch_size = args.config.getint('common', 'batch_size')
save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')
contexts = parse_contexts(args)
num_gpu = len(contexts)
eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True)
# tensorboard setting
loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False)
optimizer = args.config.get('train', 'optimizer')
momentum = args.config.getfloat('train', 'momentum')
learning_rate = args.config.getfloat('train', 'learning_rate')
learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')
mode = args.config.get('common', 'mode')
num_epoch = args.config.getint('train', 'num_epoch')
clip_gradient = args.config.getfloat('train', 'clip_gradient')
weight_decay = args.config.getfloat('train', 'weight_decay')
save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
show_every = args.config.getint('train', 'show_every')
n_epoch=begin_epoch
if clip_gradient == 0:
clip_gradient = None
module.bind(data_shapes=data_train.provide_data,
label_shapes=data_train.provide_label,
for_training=True)
if begin_epoch == 0 and mode == 'train':
module.init_params(initializer=get_initializer(args))
lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)
def reset_optimizer(force_init=False):
if optimizer == "sgd":
module.init_optimizer(kvstore='device',
optimizer=optimizer,
optimizer_params={'lr_scheduler': lr_scheduler,
'momentum': momentum,
'clip_gradient': clip_gradient,
'wd': weight_decay},
force_init=force_init)
elif optimizer == "adam":
module.init_optimizer(kvstore='device',
optimizer=optimizer,
optimizer_params={'lr_scheduler': lr_scheduler,
#'momentum': momentum,
'clip_gradient': clip_gradient,
'wd': weight_decay},
force_init=force_init)
else:
raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py')
if mode == "train":
reset_optimizer(force_init=True)
else:
reset_optimizer(force_init=False)
#tensorboard setting
tblog_dir = args.config.get('common', 'tensorboard_log_dir')
summary_writer = SummaryWriter(tblog_dir)
while True:
if n_epoch >= num_epoch:
break
loss_metric.reset()
log.info('---------train---------')
for nbatch, data_batch in enumerate(data_train):
module.forward_backward(data_batch)
module.update()
# tensorboard setting
if (nbatch + 1) % show_every == 0:
module.update_metric(loss_metric, data_batch.label)
#summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
if (nbatch+1) % save_checkpoint_every_n_batch == 0:
log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
# commented for Libri_sample data set to see only train cer
log.info('---------validation---------')
data_val.reset()
eval_metric.reset()
for nbatch, data_batch in enumerate(data_val):
# when is_train = False it leads to high cer when batch_norm
module.forward(data_batch, is_train=True)
module.update_metric(eval_metric, data_batch.label)
# tensorboard setting
val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
curr_acc = val_cer
summary_writer.add_scalar('CER validation', val_cer, n_epoch)
assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'
data_train.reset()
# tensorboard setting
train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
summary_writer.add_scalar('CER train', train_cer, n_epoch)
# save checkpoints
if n_epoch % save_checkpoint_every_n_epoch == 0:
log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)
n_epoch += 1
lr_scheduler.learning_rate=learning_rate/learning_rate_annealing
log.info('FINISH')