blob: dc00ef55aa7d9bd286e1f0b191d3d6d3c93b6297 [file] [log] [blame]
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "python")))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "rnn")))
import numpy as np
import mxnet as mx
from lstm import lstm_unroll
from bucket_io import BucketSentenceIter, default_build_vocab
import os.path
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'rnn', 'data'))
def Perplexity(label, pred):
label = label.T.reshape((-1,))
loss = 0.
for i in range(pred.shape[0]):
loss += -np.log(max(1e-10, pred[i][int(label[i])]))
return np.exp(loss / label.size)
if __name__ == '__main__':
batch_size = 32
buckets = [10, 20, 30, 40, 50, 60]
#buckets = [32]
num_hidden = 200
num_embed = 200
num_lstm_layer = 2
#num_epoch = 25
num_epoch = 2
learning_rate = 0.01
momentum = 0.0
# dummy data is used to test speed without IO
dummy_data = False
contexts = [mx.context.gpu(i) for i in range(1)]
vocab = default_build_vocab(os.path.join(data_dir, "ptb.train.txt"))
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h
data_train = BucketSentenceIter(os.path.join(data_dir, "ptb.train.txt"), vocab,
buckets, batch_size, init_states)
data_val = BucketSentenceIter(os.path.join(data_dir, "ptb.valid.txt"), vocab,
buckets, batch_size, init_states)
if dummy_data:
data_train = DummyIter(data_train)
data_val = DummyIter(data_val)
state_names = [x[0] for x in init_states]
def sym_gen(seq_len):
sym = lstm_unroll(num_lstm_layer, seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))
data_names = ['data'] + state_names
label_names = ['softmax_label']
return (sym, data_names, label_names)
if len(buckets) == 1:
mod = mx.mod.Module(*sym_gen(buckets[0]), context=contexts)
else:
mod = mx.mod.BucketingModule(sym_gen, default_bucket_key=data_train.default_bucket_key, context=contexts)
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
mod.fit(data_train, eval_data=data_val, num_epoch=num_epoch,
eval_metric=mx.metric.np(Perplexity),
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
optimizer='sgd',
optimizer_params={'learning_rate':0.01, 'momentum': 0.9, 'wd': 0.00001})
# Now it is very easy to use the bucketing to do scoring or collect prediction outputs
metric = mx.metric.np(Perplexity)
mod.score(data_val, metric)
for name, val in metric.get_name_value():
logging.info('Validation-%s=%f', name, val)