blob: ae2271c800b70cfba2a1b07638e19e7398a8eb24 [file] [log] [blame]
# pylint:skip-file
import mxnet as mx
import numpy as np
from collections import namedtuple
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias",
"ph2h_weight",
"c2i_bias", "c2f_bias", "c2o_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0):
"""LSTM Cell symbol"""
if dropout > 0.:
indata = mx.sym.Dropout(data=indata, p=dropout)
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
#bias=param.h2h_bias,
no_bias=True,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
Wcidc = mx.sym.broadcast_mul(param.c2i_bias, prev_state.c) + slice_gates[0]
in_gate = mx.sym.Activation(Wcidc, act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
Wcfdc = mx.sym.broadcast_mul(param.c2f_bias, prev_state.c) + slice_gates[2]
forget_gate = mx.sym.Activation(Wcfdc, act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
Wcoct = mx.sym.broadcast_mul(param.c2o_bias, next_c) + slice_gates[3]
out_gate = mx.sym.Activation(Wcoct, act_type="sigmoid")
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
if num_hidden_proj > 0:
proj_next_h = mx.sym.FullyConnected(data=next_h,
weight=param.ph2h_weight,
no_bias=True,
num_hidden=num_hidden_proj,
name="t%d_l%d_ph2h" % (seqidx, layeridx))
return LSTMState(c=next_c, h=proj_next_h)
else:
return LSTMState(c=next_c, h=next_h)
def lstm_unroll(num_lstm_layer, seq_len, input_size,
num_hidden, num_label, dropout=0., output_states=False, take_softmax=True, num_hidden_proj=0):
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
param_cells = []
last_states = []
for i in range(num_lstm_layer):
param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i),
ph2h_weight = mx.sym.Variable("l%d_ph2h_weight" % i),
c2i_bias = mx.sym.Variable("l%d_c2i_bias" % i, shape=(1,num_hidden)),
c2f_bias = mx.sym.Variable("l%d_c2f_bias" % i, shape=(1,num_hidden)),
c2o_bias = mx.sym.Variable("l%d_c2o_bias" % i, shape=(1, num_hidden))
))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
last_states.append(state)
assert(len(last_states) == num_lstm_layer)
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
dataSlice = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
hidden_all = []
for seqidx in range(seq_len):
hidden = dataSlice[seqidx]
# stack LSTM
for i in range(num_lstm_layer):
if i == 0:
dp = 0.
else:
dp = dropout
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i, dropout=dp, num_hidden_proj=num_hidden_proj)
hidden = next_state.h
last_states[i] = next_state
# decoder
if dropout > 0.:
hidden = mx.sym.Dropout(data=hidden, p=dropout)
hidden_all.append(hidden)
hidden_concat = mx.sym.Concat(*hidden_all, dim=1)
if num_hidden_proj > 0:
hidden_final = mx.sym.Reshape(hidden_concat, target_shape=(0, num_hidden_proj))
else:
hidden_final = mx.sym.Reshape(hidden_concat, target_shape=(0, num_hidden))
pred = mx.sym.FullyConnected(data=hidden_final, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
pred = mx.sym.Reshape(pred, shape=(-1, num_label))
label = mx.sym.Reshape(label, shape=(-1,))
if take_softmax:
sm = mx.sym.SoftmaxOutput(data=pred, label=label, ignore_label=0,
use_ignore=True, name='softmax')
else:
sm = pred
if output_states:
# block the gradients of output states
for i in range(num_lstm_layer):
state = last_states[i]
state = LSTMState(c=mx.sym.BlockGrad(state.c, name="l%d_last_c" % i),
h=mx.sym.BlockGrad(state.h, name="l%d_last_h" % i))
last_states[i] = state
# also output states, used in truncated-bptt to copy over states
unpack_c = [state.c for state in last_states]
unpack_h = [state.h for state in last_states]
sm = mx.sym.Group([sm] + unpack_c + unpack_h)
return sm