blob: 1826e1265de79e08f5c2bbc2d541e4a7dab9c1d2 [file] [log] [blame]
import re
import sys
sys.path.insert(0, "../../python")
import time
import logging
import os.path
import mxnet as mx
import numpy as np
from lstm_proj import lstm_unroll
from io_util import BucketSentenceIter, TruncatedSentenceIter, SimpleIter, DataReadStream
from config_util import parse_args, get_checkpoint_path, parse_contexts
from io_func.feat_readers.writer_kaldi import KaldiWriteOut
# some constants
METHOD_BUCKETING = 'bucketing'
METHOD_TBPTT = 'truncated-bptt'
METHOD_SIMPLE = 'simple'
def prepare_data(args):
batch_size = args.config.getint('train', 'batch_size')
num_hidden = args.config.getint('arch', 'num_hidden')
num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
if num_hidden_proj > 0:
init_h = [('l%d_init_h'%l, (batch_size, num_hidden_proj)) for l in range(num_lstm_layer)]
else:
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h
file_test = args.config.get('data', 'test')
file_label_mean = args.config.get('data', 'label_mean')
file_format = args.config.get('data', 'format')
feat_dim = args.config.getint('data', 'xdim')
label_dim = args.config.getint('data', 'ydim')
test_data_args = {
"gpu_chunk": 32768,
"lst_file": file_test,
"file_format": file_format,
"separate_lines":True,
"has_labels":False
}
label_mean_args = {
"gpu_chunk": 32768,
"lst_file": file_label_mean,
"file_format": file_format,
"separate_lines":True,
"has_labels":False
}
test_sets = DataReadStream(test_data_args, feat_dim)
label_mean_sets = DataReadStream(label_mean_args, label_dim)
return (init_states, test_sets, label_mean_sets)
if __name__ == '__main__':
args = parse_args()
args.config.write(sys.stderr)
decoding_method = args.config.get('train', 'method')
contexts = parse_contexts(args)
init_states, test_sets, label_mean_sets = prepare_data(args)
state_names = [x[0] for x in init_states]
batch_size = args.config.getint('train', 'batch_size')
num_hidden = args.config.getint('arch', 'num_hidden')
num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
feat_dim = args.config.getint('data', 'xdim')
label_dim = args.config.getint('data', 'ydim')
out_file = args.config.get('data', 'out_file')
num_epoch = args.config.getint('train', 'num_epoch')
model_name = get_checkpoint_path(args)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')
# load the model
sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epoch)
if decoding_method == METHOD_BUCKETING:
buckets = args.config.get('train', 'buckets')
buckets = list(map(int, re.split(r'\W+', buckets)))
data_test = BucketSentenceIter(test_sets, buckets, batch_size, init_states, feat_dim=feat_dim, has_label=False)
def sym_gen(seq_len):
sym = lstm_unroll(num_lstm_layer, seq_len, feat_dim, num_hidden=num_hidden,
num_label=label_dim, take_softmax=True, num_hidden_proj=num_hidden_proj)
data_names = ['data'] + state_names
label_names = ['softmax_label']
return (sym, data_names, label_names)
module = mx.mod.BucketingModule(sym_gen,
default_bucket_key=data_test.default_bucket_key,
context=contexts)
elif decoding_method == METHOD_SIMPLE:
data_test = SimpleIter(test_sets, batch_size, init_states, feat_dim=feat_dim, label_dim=label_dim,
label_mean_sets=label_mean_sets, has_label=False)
def sym_gen(seq_len):
sym = lstm_unroll(num_lstm_layer, seq_len, feat_dim, num_hidden=num_hidden,
num_label=label_dim, take_softmax=False, num_hidden_proj=num_hidden_proj)
data_names = ['data'] + state_names
label_names = []
return (sym, data_names, label_names)
module = mx.mod.BucketingModule(sym_gen,
default_bucket_key=data_test.default_bucket_key,
context=contexts)
else:
truncate_len=20
data_test = TruncatedSentenceIter(test_sets, batch_size, init_states,
truncate_len, feat_dim=feat_dim,
do_shuffling=False, pad_zeros=True, has_label=True)
sym = lstm_unroll(num_lstm_layer, truncate_len, feat_dim, num_hidden=num_hidden,
num_label=label_dim, output_states=True, num_hidden_proj=num_hidden_proj)
data_names = [x[0] for x in data_test.provide_data]
label_names = ['softmax_label']
module = mx.mod.Module(sym, context=contexts, data_names=data_names,
label_names=label_names)
# set the parameters
module.bind(data_shapes=data_test.provide_data, label_shapes=None, for_training=False)
module.set_params(arg_params=arg_params, aux_params=aux_params)
kaldiWriter = KaldiWriteOut(None, out_file)
kaldiWriter.open_or_fd()
for preds, i_batch, batch in module.iter_predict(data_test):
label = batch.label[0].asnumpy().astype('int32')
posteriors = preds[0].asnumpy().astype('float32')
# copy over states
if decoding_method == METHOD_BUCKETING:
for (ind, utt) in enumerate(batch.utt_id):
if utt != "GAP_UTT":
posteriors = np.log(posteriors[:label[0][0],1:] + 1e-20) - np.log(data_train.label_mean).T
kaldiWriter.write(utt, posteriors)
elif decoding_method == METHOD_SIMPLE:
for (ind, utt) in enumerate(batch.utt_id):
if utt != "GAP_UTT":
posteriors = posteriors[:batch.utt_len,1:] - np.log(data_test.label_mean[1:]).T
kaldiWriter.write(utt, posteriors)
else:
outputs = module.get_outputs()
# outputs[0] is softmax, 1:end are states
for i in range(1, len(outputs)):
outputs[i].copyto(data_test.init_state_arrays[i-1])
for (ind, utt) in enumerate(batch.utt_id):
if utt != "GAP_UTT":
posteriors = np.log(posteriors[:,1:])# - np.log(data_train.label_mean).T
kaldiWriter.write(utt, posteriors)
kaldiWriter.close()
args.config.write(sys.stderr)