blob: 795eb6ef3bc11b84ad378ca454339e2bcd50703d [file] [log] [blame]
# pylint:skip-file
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
import time
import math
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
"""LSTM Cell symbol"""
if dropout > 0.:
indata = mx.sym.Dropout(data=indata, p=dropout)
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
return LSTMState(c=next_c, h=next_h)
def lstm_unroll(num_lstm_layer, seq_len, input_size,
num_hidden, num_embed, num_label, dropout=0.,
concat_decode=True, use_loss=False):
"""unrolled lstm network"""
# initialize the parameter symbols
with mx.AttrScope(ctx_group='embed'):
embed_weight=mx.sym.Variable("embed_weight")
with mx.AttrScope(ctx_group='decode'):
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
param_cells = []
last_states = []
for i in range(num_lstm_layer):
with mx.AttrScope(ctx_group='layer%d' % i):
param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i)))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
last_states.append(state)
assert(len(last_states) == num_lstm_layer)
last_hidden = []
for seqidx in range(seq_len):
# embeding layer
with mx.AttrScope(ctx_group='embed'):
data = mx.sym.Variable("t%d_data" % seqidx)
hidden = mx.sym.Embedding(data=data, weight=embed_weight,
input_dim=input_size,
output_dim=num_embed,
name="t%d_embed" % seqidx)
# stack LSTM
for i in range(num_lstm_layer):
if i==0:
dp=0.
else:
dp = dropout
with mx.AttrScope(ctx_group='layer%d' % i):
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i, dropout=dp)
hidden = next_state.h
last_states[i] = next_state
# decoder
if dropout > 0.:
hidden = mx.sym.Dropout(data=hidden, p=dropout)
last_hidden.append(hidden)
out_prob = []
if not concat_decode:
for seqidx in range(seq_len):
with mx.AttrScope(ctx_group='decode'):
fc = mx.sym.FullyConnected(data=last_hidden[seqidx],
weight=cls_weight,
bias=cls_bias,
num_hidden=num_label,
name="t%d_cls" % seqidx)
label = mx.sym.Variable("t%d_label" % seqidx)
if use_loss:
sm = mx.sym.softmax_cross_entropy(fc, label, name="t%d_sm" % seqidx)
else:
sm = mx.sym.SoftmaxOutput(data=fc, label=label, name="t%d_sm" % seqidx)
out_prob.append(sm)
else:
with mx.AttrScope(ctx_group='decode'):
concat = mx.sym.Concat(*last_hidden, dim = 0)
fc = mx.sym.FullyConnected(data=concat,
weight=cls_weight,
bias=cls_bias,
num_hidden=num_label)
label = mx.sym.Variable("label")
if use_loss:
sm = mx.sym.softmax_cross_entropy(fc, label, name="sm")
else:
sm = mx.sym.SoftmaxOutput(data=fc, label=label, name="sm")
out_prob = [sm]
for i in range(num_lstm_layer):
state = last_states[i]
state = LSTMState(c=mx.sym.BlockGrad(state.c, name="l%d_last_c" % i),
h=mx.sym.BlockGrad(state.h, name="l%d_last_h" % i))
last_states[i] = state
unpack_c = [state.c for state in last_states]
unpack_h = [state.h for state in last_states]
list_all = out_prob + unpack_c + unpack_h
return mx.sym.Group(list_all)
def is_param_name(name):
return name.endswith("weight") or name.endswith("bias") or\
name.endswith("gamma") or name.endswith("beta")
def setup_rnn_model(default_ctx,
num_lstm_layer, seq_len,
num_hidden, num_embed, num_label,
batch_size, input_size,
initializer, dropout=0.,
group2ctx=None, concat_decode=True,
use_loss=False, buckets=None):
"""set up rnn model with lstm cells"""
max_len = max(buckets)
max_rnn_exec = None
models = {}
buckets.reverse()
for bucket_key in buckets:
# bind max_len first
rnn_sym = lstm_unroll(num_lstm_layer=num_lstm_layer,
num_hidden=num_hidden,
seq_len=seq_len,
input_size=input_size,
num_embed=num_embed,
num_label=num_label,
dropout=dropout,
concat_decode=concat_decode,
use_loss=use_loss)
arg_names = rnn_sym.list_arguments()
internals = rnn_sym.get_internals()
input_shapes = {}
for name in arg_names:
if name.endswith("init_c") or name.endswith("init_h"):
input_shapes[name] = (batch_size, num_hidden)
elif name.endswith("data"):
input_shapes[name] = (batch_size, )
elif name == "label":
input_shapes[name] = (batch_size * seq_len, )
elif name.endswith("label"):
input_shapes[name] = (batch_size, )
else:
pass
arg_shape, out_shape, aux_shape = rnn_sym.infer_shape(**input_shapes)
# bind arrays
arg_arrays = []
args_grad = {}
for shape, name in zip(arg_shape, arg_names):
group = internals[name].attr("__ctx_group__")
ctx = group2ctx[group] if group is not None else default_ctx
arg_arrays.append(mx.nd.zeros(shape, ctx))
if is_param_name(name):
args_grad[name] = mx.nd.zeros(shape, ctx)
if not name.startswith("t"):
print("%s group=%s, ctx=%s" % (name, group, str(ctx)))
#bind with shared executor
rnn_exec = None
if max_len == bucket_key:
rnn_exec = rnn_sym.bind(default_ctx, args=arg_arrays,
args_grad=args_grad,
grad_req="add", group2ctx=group2ctx)
max_rnn_exec = rnn_exec
else:
assert max_rnn_exec is not None
rnn_exec = rnn_sym.bind(default_ctx, args=arg_arrays,
args_grad=args_grad,
grad_req="add", group2ctx=group2ctx,
shared_exec = max_rnn_exec)
param_blocks = []
arg_dict = dict(zip(arg_names, rnn_exec.arg_arrays))
for i, name in enumerate(arg_names):
if is_param_name(name):
initializer(name, arg_dict[name])
param_blocks.append((i, arg_dict[name], args_grad[name], name))
else:
assert name not in args_grad
out_dict = dict(zip(rnn_sym.list_outputs(), rnn_exec.outputs))
init_states = [LSTMState(c=arg_dict["l%d_init_c" % i],
h=arg_dict["l%d_init_h" % i]) for i in range(num_lstm_layer)]
seq_data = [rnn_exec.arg_dict["t%d_data" % i] for i in range(seq_len)]
# we don't need to store the last state
last_states = None
if concat_decode:
seq_outputs = [out_dict["sm_output"]]
seq_labels = [rnn_exec.arg_dict["label"]]
else:
seq_outputs = [out_dict["t%d_sm_output" % i] for i in range(seq_len)]
seq_labels = [rnn_exec.arg_dict["t%d_label" % i] for i in range(seq_len)]
model = LSTMModel(rnn_exec=rnn_exec, symbol=rnn_sym,
init_states=init_states, last_states=last_states,
seq_data=seq_data, seq_labels=seq_labels, seq_outputs=seq_outputs,
param_blocks=param_blocks)
models[bucket_key] = model
buckets.reverse()
return models
def set_rnn_inputs(m, X, begin):
seq_len = len(m.seq_data)
batch_size = m.seq_data[0].shape[0]
for seqidx in range(seq_len):
idx = (begin + seqidx) % X.shape[0]
next_idx = (begin + seqidx + 1) % X.shape[0]
x = X[idx, :]
y = X[next_idx, :]
mx.nd.array(x).copyto(m.seq_data[seqidx])
if len(m.seq_labels) == 1:
m.seq_labels[0][seqidx*batch_size : seqidx*batch_size+batch_size] = y
else:
m.seq_labels[seqidx][:] = y
def set_rnn_inputs_from_batch(m, batch, batch_seq_length, batch_size):
X = batch.data
for seqidx in range(batch_seq_length):
idx = seqidx
next_idx = (seqidx + 1) % batch_seq_length
x = X[idx, :]
y = X[next_idx, :]
mx.nd.array(x).copyto(m.seq_data[seqidx])
if len(m.seq_labels) == 1:
m.seq_labels[0][seqidx*batch_size : seqidx*batch_size+batch_size] = y
else:
m.seq_labels[seqidx][:] = y
def calc_nll_concat(seq_label_probs, batch_size):
return -np.sum(np.log(seq_label_probs.asnumpy())) / batch_size
def calc_nll(seq_label_probs, batch_size, seq_len):
eps = 1e-10
nll = 0.
for seqidx in range(seq_len):
py = seq_label_probs[seqidx].asnumpy()
nll += -np.sum(np.log(np.maximum(py, eps))) / batch_size
return nll
def train_lstm(model, X_train_batch, X_val_batch,
num_round, update_period, concat_decode, batch_size, use_loss,
optimizer='sgd', half_life=2,max_grad_norm = 5.0, **kwargs):
opt = mx.optimizer.create(optimizer,
**kwargs)
updater = mx.optimizer.get_updater(opt)
epoch_counter = 0
#log_period = max(1000 / seq_len, 1)
log_period = 28
last_perp = 10000000.0
for iteration in range(num_round):
nbatch = 0
train_nll = 0
tic = time.time()
for data_batch in X_train_batch:
batch_seq_length = data_batch.bucket_key
m = model[batch_seq_length]
# reset init state
for state in m.init_states:
state.c[:] = 0.0
state.h[:] = 0.0
head_grad = []
if use_loss:
ctx = m.seq_outputs[0].context
head_grad = [mx.nd.ones((1,), ctx) for x in m.seq_outputs]
set_rnn_inputs_from_batch(m, data_batch, batch_seq_length, batch_size)
m.rnn_exec.forward(is_train=True)
# probability of each label class, used to evaluate nll
# Change back to individual ops to see if fine grained scheduling helps.
if not use_loss:
if concat_decode:
seq_label_probs = mx.nd.choose_element_0index(m.seq_outputs[0], m.seq_labels[0])
else:
seq_label_probs = [mx.nd.choose_element_0index(out, label).copyto(mx.cpu())
for out, label in zip(m.seq_outputs, m.seq_labels)]
m.rnn_exec.backward()
else:
seq_loss = [x.copyto(mx.cpu()) for x in m.seq_outputs]
m.rnn_exec.backward(head_grad)
# update epoch counter
epoch_counter += 1
if epoch_counter % update_period == 0:
# updare parameters
norm = 0.
for idx, weight, grad, name in m.param_blocks:
grad /= batch_size
l2_norm = mx.nd.norm(grad).asscalar()
norm += l2_norm*l2_norm
norm = math.sqrt(norm)
for idx, weight, grad, name in m.param_blocks:
if norm > max_grad_norm:
grad *= (max_grad_norm / norm)
updater(idx, grad, weight)
# reset gradient to zero
grad[:] = 0.0
if not use_loss:
if concat_decode:
train_nll += calc_nll_concat(seq_label_probs, batch_size)
else:
train_nll += calc_nll(seq_label_probs, batch_size, batch_seq_length)
else:
train_nll += sum([x.asscalar() for x in seq_loss]) / batch_size
nbatch += batch_size
toc = time.time()
if epoch_counter % log_period == 0:
print("Iter [%d] Train: Time: %.3f sec, NLL=%.3f, Perp=%.3f" % (
epoch_counter, toc - tic, train_nll / nbatch, np.exp(train_nll / nbatch)))
# end of training loop
toc = time.time()
print("Iter [%d] Train: Time: %.3f sec, NLL=%.3f, Perp=%.3f" % (
iteration, toc - tic, train_nll / nbatch, np.exp(train_nll / nbatch)))
val_nll = 0.0
nbatch = 0
for data_batch in X_val_batch:
batch_seq_length = data_batch.bucket_key
m = model[batch_seq_length]
# validation set, reset states
for state in m.init_states:
state.h[:] = 0.0
state.c[:] = 0.0
set_rnn_inputs_from_batch(m, data_batch, batch_seq_length, batch_size)
m.rnn_exec.forward(is_train=False)
# probability of each label class, used to evaluate nll
if not use_loss:
if concat_decode:
seq_label_probs = mx.nd.choose_element_0index(m.seq_outputs[0], m.seq_labels[0])
else:
seq_label_probs = [mx.nd.choose_element_0index(out, label).copyto(mx.cpu())
for out, label in zip(m.seq_outputs, m.seq_labels)]
else:
seq_loss = [x.copyto(mx.cpu()) for x in m.seq_outputs]
if not use_loss:
if concat_decode:
val_nll += calc_nll_concat(seq_label_probs, batch_size)
else:
val_nll += calc_nll(seq_label_probs, batch_size, batch_seq_length)
else:
val_nll += sum([x.asscalar() for x in seq_loss]) / batch_size
nbatch += batch_size
perp = np.exp(val_nll / nbatch)
print("Iter [%d] Val: NLL=%.3f, Perp=%.3f" % (
iteration, val_nll / nbatch, np.exp(val_nll / nbatch)))
if last_perp - 1.0 < perp:
opt.lr *= 0.5
print("Reset learning rate to %g" % opt.lr)
last_perp = perp
X_val_batch.reset()
X_train_batch.reset()
# is this function being used?
def setup_rnn_sample_model(ctx,
params,
num_lstm_layer,
num_hidden, num_embed, num_label,
batch_size, input_size):
seq_len = 1
rnn_sym = lstm_unroll(num_lstm_layer=num_lstm_layer,
input_size=input_size,
num_hidden=num_hidden,
seq_len=seq_len,
num_embed=num_embed,
num_label=num_label)
arg_names = rnn_sym.list_arguments()
input_shapes = {}
for name in arg_names:
if name.endswith("init_c") or name.endswith("init_h"):
input_shapes[name] = (batch_size, num_hidden)
elif name.endswith("data"):
input_shapes[name] = (batch_size, )
else:
pass
arg_shape, out_shape, aux_shape = rnn_sym.infer_shape(**input_shapes)
arg_arrays = [mx.nd.zeros(s, ctx) for s in arg_shape]
arg_dict = dict(zip(arg_names, arg_arrays))
for name, arr in params.items():
arg_dict[name][:] = arr
rnn_exec = rnn_sym.bind(ctx=ctx, args=arg_arrays, args_grad=None, grad_req="null")
out_dict = dict(zip(rnn_sym.list_outputs(), rnn_exec.outputs))
param_blocks = []
params_array = list(params.items())
for i in range(len(params)):
param_blocks.append((i, params_array[i][1], None, params_array[i][0]))
init_states = [LSTMState(c=arg_dict["l%d_init_c" % i],
h=arg_dict["l%d_init_h" % i]) for i in range(num_lstm_layer)]
if concat_decode:
seq_labels = [rnn_exec.arg_dict["label"]]
seq_outputs = [out_dict["sm_output"]]
else:
seq_labels = [rnn_exec.arg_dict["t%d_label" % i] for i in range(seq_len)]
seq_outputs = [out_dict["t%d_sm" % i] for i in range(seq_len)]
seq_data = [rnn_exec.arg_dict["t%d_data" % i] for i in range(seq_len)]
last_states = [LSTMState(c=out_dict["l%d_last_c_output" % i],
h=out_dict["l%d_last_h_output" % i]) for i in range(num_lstm_layer)]
return LSTMModel(rnn_exec=rnn_exec, symbol=rnn_sym,
init_states=init_states, last_states=last_states,
seq_data=seq_data, seq_labels=seq_labels, seq_outputs=seq_outputs,
param_blocks=param_blocks)
# Python3 np.random.choice is too strict in eval float probability so we use an alternative
import random
import bisect
import collections
def _cdf(weights):
total = sum(weights)
result = []
cumsum = 0
for w in weights:
cumsum += w
result.append(cumsum / total)
return result
def _choice(population, weights):
assert len(population) == len(weights)
cdf_vals = _cdf(weights)
x = random.random()
idx = bisect.bisect(cdf_vals, x)
return population[idx]
def sample_lstm(model, X_input_batch, seq_len, temperature=1., sample=True):
m = model
vocab = m.seq_outputs.shape[1]
batch_size = m.seq_data[0].shape[0]
outputs_ndarray = mx.nd.zeros(m.seq_outputs.shape)
outputs_batch = []
tmp = [i for i in range(vocab)]
for i in range(seq_len):
outputs_batch.append(np.zeros(X_input_batch.shape))
for i in range(seq_len):
set_rnn_inputs(m, X_input_batch, 0)
m.rnn_exec.forward(is_train=False)
outputs_ndarray[:] = m.seq_outputs
for init, last in zip(m.init_states, m.last_states):
last.c.copyto(init.c)
last.h.copyto(init.h)
prob = np.clip(outputs_ndarray.asnumpy(), 1e-6, 1 - 1e-6)
if sample:
rescale = np.exp(np.log(prob) / temperature)
for j in range(batch_size):
p = rescale[j, :]
p[:] /= p.sum()
outputs_batch[i][j] = _choice(tmp, p)
else:
outputs_batch[i][:] = np.argmax(prob, axis=1)
X_input_batch[:] = outputs_batch[i]
return outputs_batch