| # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme |
| # pylint: disable=superfluous-parens, no-member, invalid-name |
| import sys |
| sys.path.insert(0, "../../python") |
| import numpy as np |
| import mxnet as mx |
| |
| from lstm import bi_lstm_unroll |
| from sort_io import BucketSentenceIter, default_build_vocab |
| |
| def Perplexity(label, pred): |
| label = label.T.reshape((-1,)) |
| loss = 0. |
| for i in range(pred.shape[0]): |
| loss += -np.log(max(1e-10, pred[i][int(label[i])])) |
| return np.exp(loss / label.size) |
| |
| if __name__ == '__main__': |
| batch_size = 100 |
| buckets = [] |
| num_hidden = 300 |
| num_embed = 512 |
| num_lstm_layer = 2 |
| |
| num_epoch = 1 |
| learning_rate = 0.1 |
| momentum = 0.9 |
| |
| contexts = [mx.context.gpu(i) for i in range(1)] |
| |
| vocab = default_build_vocab("./data/sort.train.txt") |
| |
| def sym_gen(seq_len): |
| return bi_lstm_unroll(seq_len, len(vocab), |
| num_hidden=num_hidden, num_embed=num_embed, |
| num_label=len(vocab)) |
| |
| init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] |
| init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] |
| init_states = init_c + init_h |
| |
| data_train = BucketSentenceIter("./data/sort.train.txt", vocab, |
| buckets, batch_size, init_states) |
| data_val = BucketSentenceIter("./data/sort.valid.txt", vocab, |
| buckets, batch_size, init_states) |
| |
| if len(buckets) == 1: |
| symbol = sym_gen(buckets[0]) |
| else: |
| symbol = sym_gen |
| |
| model = mx.model.FeedForward(ctx=contexts, |
| symbol=symbol, |
| num_epoch=num_epoch, |
| learning_rate=learning_rate, |
| momentum=momentum, |
| wd=0.00001, |
| initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) |
| |
| import logging |
| head = '%(asctime)-15s %(message)s' |
| logging.basicConfig(level=logging.DEBUG, format=head) |
| |
| model.fit(X=data_train, eval_data=data_val, |
| eval_metric = mx.metric.np(Perplexity), |
| batch_end_callback=mx.callback.Speedometer(batch_size, 50),) |
| |
| model.save("sort") |