blob: a233e46e09925aab504cf2176365ed65354e0c09 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import numpy as np
import mxnet as mx
import random
from random import randint
from mxnet.contrib.amp import amp
def prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence):
train_sent = []
val_sent = []
for _ in range(num_sentence):
len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty
train_sentence = []
val_sentence = []
for _ in range(len_sentence):
train_sentence.append(randint(1, len_vocab))
val_sentence.append(randint(1, len_vocab))
train_sent.append(train_sentence)
val_sent.append(val_sentence)
data_train = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)
data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)
return (data_train, data_val)
def train_model(context=mx.cpu()):
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logging.getLogger('').addHandler(console)
batch_size = 128
num_epochs = 5
num_hidden = 25
num_embed = 25
num_layers = 2
len_vocab = 50
buckets = [5, 10, 20, 30, 40]
invalid_label = -1
num_sentence = 1000
data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence)
stack = mx.rnn.SequentialRNNCell()
for i in range(num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i))
def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len_vocab,
output_dim=num_embed, name='embed')
stack.reset()
outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)
pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len_vocab, name='pred')
label = mx.sym.Reshape(label, shape=(-1,))
loss = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
return loss, ('data',), ('softmax_label',)
contexts = context
model = mx.mod.BucketingModule(
sym_gen=sym_gen,
default_bucket_key=data_train.default_bucket_key,
context=contexts)
logging.info('Begin fit...')
model.fit(
train_data=data_train,
eval_data=data_val,
eval_metric=mx.metric.Perplexity(invalid_label), # Use Perplexity for multiclass classification.
kvstore='device',
optimizer='sgd',
optimizer_params={'learning_rate': 0.01,
'momentum': 0,
'wd': 0.00001},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch=num_epochs,
batch_end_callback=mx.callback.Speedometer(batch_size, 50))
logging.info('Finished fit...')
return model
def test_bucket_module():
# This test forecasts random sequence of words to check bucketing.
# We cannot guarantee the accuracy of such an impossible task, and comments out the following line.
# assert model.score(data_val, mx.metric.MSE())[0][1] < 350, "High mean square error."
model = train_model()
if __name__ == "__main__":
test_bucket_module()