blob: 61687124266d4ef66a60977464aa6863c7af9aa0 [file] [log] [blame]
# pylint:skip-file
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
import time
import math
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states", "forward_state", "backward_state",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
"""LSTM Cell symbol"""
if dropout > 0.:
indata = mx.sym.Dropout(data=indata, p=dropout)
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
return LSTMState(c=next_c, h=next_h)
def bi_lstm_unroll(seq_len, input_size,
num_hidden, num_embed, num_label, dropout=0.):
embed_weight = mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
last_states = []
last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")))
last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h")))
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
i2h_bias=mx.sym.Variable("l0_i2h_bias"),
h2h_weight=mx.sym.Variable("l0_h2h_weight"),
h2h_bias=mx.sym.Variable("l0_h2h_bias"))
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
i2h_bias=mx.sym.Variable("l1_i2h_bias"),
h2h_weight=mx.sym.Variable("l1_h2h_weight"),
h2h_bias=mx.sym.Variable("l1_h2h_bias"))
# embeding layer
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
forward_hidden = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[0],
param=forward_param,
seqidx=seqidx, layeridx=0, dropout=dropout)
hidden = next_state.h
last_states[0] = next_state
forward_hidden.append(hidden)
backward_hidden = []
for seqidx in range(seq_len):
k = seq_len - seqidx - 1
hidden = wordvec[k]
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[1],
param=backward_param,
seqidx=k, layeridx=1,dropout=dropout)
hidden = next_state.h
last_states[1] = next_state
backward_hidden.insert(0, hidden)
hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
label = mx.sym.transpose(data=label)
label = mx.sym.Reshape(data=label, target_shape=(0,))
sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
return sm
def bi_lstm_inference_symbol(input_size, seq_len,
num_hidden, num_embed, num_label, dropout=0.):
seqidx = 0
embed_weight=mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")),
LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))]
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
i2h_bias=mx.sym.Variable("l0_i2h_bias"),
h2h_weight=mx.sym.Variable("l0_h2h_weight"),
h2h_bias=mx.sym.Variable("l0_h2h_bias"))
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
i2h_bias=mx.sym.Variable("l1_i2h_bias"),
h2h_weight=mx.sym.Variable("l1_h2h_weight"),
h2h_bias=mx.sym.Variable("l1_h2h_bias"))
data = mx.sym.Variable("data")
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
forward_hidden = []
for seqidx in range(seq_len):
next_state = lstm(num_hidden, indata=wordvec[seqidx],
prev_state=last_states[0],
param=forward_param,
seqidx=seqidx, layeridx=0, dropout=0.0)
hidden = next_state.h
last_states[0] = next_state
forward_hidden.append(hidden)
backward_hidden = []
for seqidx in range(seq_len):
k = seq_len - seqidx - 1
next_state = lstm(num_hidden, indata=wordvec[k],
prev_state=last_states[1],
param=backward_param,
seqidx=k, layeridx=1, dropout=0.0)
hidden = next_state.h
last_states[1] = next_state
backward_hidden.insert(0, hidden)
hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
weight=cls_weight, bias=cls_bias, name='pred')
sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
output = [sm]
for state in last_states:
output.append(state.c)
output.append(state.h)
return mx.sym.Group(output)