blob: d2a7a27442532001c9f65a626dd42d88197ff543 [file] [log] [blame]
import re
import sys
sys.path.insert(0, "../../python")
import time
import logging
import os.path
import mxnet as mx
import numpy as np
from speechSGD import speechSGD
from lstm_proj import lstm_unroll
from io_util import BucketSentenceIter, TruncatedSentenceIter, DataReadStream
from config_util import parse_args, get_checkpoint_path, parse_contexts
# some constants
METHOD_BUCKETING = 'bucketing'
METHOD_TBPTT = 'truncated-bptt'
def prepare_data(args):
batch_size = args.config.getint('train', 'batch_size')
num_hidden = args.config.getint('arch', 'num_hidden')
num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
if num_hidden_proj > 0:
init_h = [('l%d_init_h'%l, (batch_size, num_hidden_proj)) for l in range(num_lstm_layer)]
else:
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h
file_train = args.config.get('data', 'train')
file_dev = args.config.get('data', 'dev')
file_format = args.config.get('data', 'format')
feat_dim = args.config.getint('data', 'xdim')
train_data_args = {
"gpu_chunk": 32768,
"lst_file": file_train,
"file_format": file_format,
"separate_lines": True
}
dev_data_args = {
"gpu_chunk": 32768,
"lst_file": file_dev,
"file_format": file_format,
"separate_lines": True
}
train_sets = DataReadStream(train_data_args, feat_dim)
dev_sets = DataReadStream(dev_data_args, feat_dim)
return (init_states, train_sets, dev_sets)
def CrossEntropy(labels, preds):
labels = labels.reshape((-1,))
preds = preds.reshape((-1, preds.shape[1]))
loss = 0.
num_inst = 0
for i in range(preds.shape[0]):
label = labels[i]
if label > 0:
loss += -np.log(max(1e-10, preds[i][int(label)]))
num_inst += 1
return loss , num_inst
def Acc_exclude_padding(labels, preds):
labels = labels.reshape((-1,))
preds = preds.reshape((-1, preds.shape[1]))
sum_metric = 0
num_inst = 0
for i in range(preds.shape[0]):
pred_label = np.argmax(preds[i], axis=0)
label = labels[i]
ind = np.nonzero(label.flat)
pred_label_real = pred_label.flat[ind]
label_real = label.flat[ind]
sum_metric += (pred_label_real == label_real).sum()
num_inst += len(pred_label_real)
return sum_metric, num_inst
class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
"""A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr`
dynamically based on performance on the validation set.
"""
def __init__(self, dynamic_lr, effective_sample_count=1, momentum=0.9, optimizer="sgd"):
super(SimpleLRScheduler, self).__init__()
self.dynamic_lr = dynamic_lr
self.effective_sample_count = effective_sample_count
self.momentum = momentum
self.optimizer = optimizer
def __call__(self, num_update):
if self.optimizer == "speechSGD":
return self.dynamic_lr / self.effective_sample_count, self.momentum
else:
return self.dynamic_lr / self.effective_sample_count
def score_with_state_forwarding(module, eval_data, eval_metric):
eval_data.reset()
eval_metric.reset()
for eval_batch in eval_data:
module.forward(eval_batch, is_train=False)
module.update_metric(eval_metric, eval_batch.label)
# copy over states
outputs = module.get_outputs()
# outputs[0] is softmax, 1:end are states
for i in range(1, len(outputs)):
outputs[i].copyto(eval_data.init_state_arrays[i-1])
def get_initializer(args):
init_type = getattr(mx.initializer, args.config.get('train', 'initializer'))
init_scale = args.config.getfloat('train', 'init_scale')
if init_type is mx.initializer.Xavier:
return mx.initializer.Xavier(magnitude=init_scale)
return init_type(init_scale)
def do_training(training_method, args, module, data_train, data_val):
from distutils.dir_util import mkpath
mkpath(os.path.dirname(get_checkpoint_path(args)))
batch_size = data_train.batch_size
batch_end_callbacks = [mx.callback.Speedometer(batch_size,
args.config.getint('train', 'show_every'))]
eval_allow_extra = True if training_method == METHOD_TBPTT else False
eval_metric = [mx.metric.np(CrossEntropy, allow_extra_outputs=eval_allow_extra),
mx.metric.np(Acc_exclude_padding, allow_extra_outputs=eval_allow_extra)]
eval_metric = mx.metric.create(eval_metric)
optimizer = args.config.get('train', 'optimizer')
momentum = args.config.getfloat('train', 'momentum')
learning_rate = args.config.getfloat('train', 'learning_rate')
lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)
if training_method == METHOD_TBPTT:
lr_scheduler.seq_len = data_train.truncate_len
n_epoch = 0
num_epoch = args.config.getint('train', 'num_epoch')
learning_rate = args.config.getfloat('train', 'learning_rate')
decay_factor = args.config.getfloat('train', 'decay_factor')
decay_bound = args.config.getfloat('train', 'decay_lower_bound')
clip_gradient = args.config.getfloat('train', 'clip_gradient')
weight_decay = args.config.getfloat('train', 'weight_decay')
if clip_gradient == 0:
clip_gradient = None
last_acc = -float("Inf")
last_params = None
module.bind(data_shapes=data_train.provide_data,
label_shapes=data_train.provide_label,
for_training=True)
module.init_params(initializer=get_initializer(args))
def reset_optimizer():
if optimizer == "sgd" or optimizer == "speechSGD":
module.init_optimizer(kvstore='device',
optimizer=args.config.get('train', 'optimizer'),
optimizer_params={'lr_scheduler': lr_scheduler,
'momentum': momentum,
'rescale_grad': 1.0,
'clip_gradient': clip_gradient,
'wd': weight_decay},
force_init=True)
else:
module.init_optimizer(kvstore='device',
optimizer=args.config.get('train', 'optimizer'),
optimizer_params={'lr_scheduler': lr_scheduler,
'rescale_grad': 1.0,
'clip_gradient': clip_gradient,
'wd': weight_decay},
force_init=True)
reset_optimizer()
while True:
tic = time.time()
eval_metric.reset()
for nbatch, data_batch in enumerate(data_train):
if training_method == METHOD_TBPTT:
lr_scheduler.effective_sample_count = data_train.batch_size * truncate_len
lr_scheduler.momentum = np.power(np.power(momentum, 1.0/(data_train.batch_size * truncate_len)), data_batch.effective_sample_count)
else:
if data_batch.effective_sample_count is not None:
lr_scheduler.effective_sample_count = 1#data_batch.effective_sample_count
module.forward_backward(data_batch)
module.update()
module.update_metric(eval_metric, data_batch.label)
batch_end_params = mx.model.BatchEndParam(epoch=n_epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=None)
for callback in batch_end_callbacks:
callback(batch_end_params)
if training_method == METHOD_TBPTT:
# copy over states
outputs = module.get_outputs()
# outputs[0] is softmax, 1:end are states
for i in range(1, len(outputs)):
outputs[i].copyto(data_train.init_state_arrays[i-1])
for name, val in eval_metric.get_name_value():
logging.info('Epoch[%d] Train-%s=%f', n_epoch, name, val)
toc = time.time()
logging.info('Epoch[%d] Time cost=%.3f', n_epoch, toc-tic)
data_train.reset()
# test on eval data
score_with_state_forwarding(module, data_val, eval_metric)
# test whether we should decay learning rate
curr_acc = None
for name, val in eval_metric.get_name_value():
logging.info("Epoch[%d] Dev-%s=%f", n_epoch, name, val)
if name == 'CrossEntropy':
curr_acc = val
assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'
if n_epoch > 0 and lr_scheduler.dynamic_lr > decay_bound and curr_acc > last_acc:
logging.info('Epoch[%d] !!! Dev set performance drops, reverting this epoch',
n_epoch)
logging.info('Epoch[%d] !!! LR decay: %g => %g', n_epoch,
lr_scheduler.dynamic_lr, lr_scheduler.dynamic_lr / float(decay_factor))
lr_scheduler.dynamic_lr /= decay_factor
# we reset the optimizer because the internal states (e.g. momentum)
# might already be exploded, so we want to start from fresh
reset_optimizer()
module.set_params(*last_params)
else:
last_params = module.get_params()
last_acc = curr_acc
n_epoch += 1
# save checkpoints
mx.model.save_checkpoint(get_checkpoint_path(args), n_epoch,
module.symbol, *last_params)
if n_epoch == num_epoch:
break
if __name__ == '__main__':
args = parse_args()
args.config.write(sys.stdout)
training_method = args.config.get('train', 'method')
contexts = parse_contexts(args)
init_states, train_sets, dev_sets = prepare_data(args)
state_names = [x[0] for x in init_states]
batch_size = args.config.getint('train', 'batch_size')
num_hidden = args.config.getint('arch', 'num_hidden')
num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
feat_dim = args.config.getint('data', 'xdim')
label_dim = args.config.getint('data', 'ydim')
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')
if training_method == METHOD_BUCKETING:
buckets = args.config.get('train', 'buckets')
buckets = list(map(int, re.split(r'\W+', buckets)))
data_train = BucketSentenceIter(train_sets, buckets, batch_size, init_states, feat_dim=feat_dim)
data_val = BucketSentenceIter(dev_sets, buckets, batch_size, init_states, feat_dim=feat_dim)
def sym_gen(seq_len):
sym = lstm_unroll(num_lstm_layer, seq_len, feat_dim, num_hidden=num_hidden,
num_label=label_dim, num_hidden_proj=num_hidden_proj)
data_names = ['data'] + state_names
label_names = ['softmax_label']
return (sym, data_names, label_names)
module = mx.mod.BucketingModule(sym_gen,
default_bucket_key=data_train.default_bucket_key,
context=contexts)
do_training(training_method, args, module, data_train, data_val)
elif training_method == METHOD_TBPTT:
truncate_len = args.config.getint('train', 'truncate_len')
data_train = TruncatedSentenceIter(train_sets, batch_size, init_states,
truncate_len=truncate_len, feat_dim=feat_dim)
data_val = TruncatedSentenceIter(dev_sets, batch_size, init_states,
truncate_len=truncate_len, feat_dim=feat_dim,
do_shuffling=False, pad_zeros=True)
sym = lstm_unroll(num_lstm_layer, truncate_len, feat_dim, num_hidden=num_hidden,
num_label=label_dim, output_states=True, num_hidden_proj=num_hidden_proj)
data_names = [x[0] for x in data_train.provide_data]
label_names = [x[0] for x in data_train.provide_label]
module = mx.mod.Module(sym, context=contexts, data_names=data_names,
label_names=label_names)
do_training(training_method, args, module, data_train, data_val)
else:
raise RuntimeError('Unknown training method: %s' % training_method)
print("="*80)
print("Finished Training")
print("="*80)
args.config.write(sys.stdout)