blob: a253e862fcce6a130bba7b06e90cc3ab75d27a05 [file] [log] [blame]
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
from lstm import LSTMState, LSTMParam, lstm, bi_lstm_inference_symbol
class BiLSTMInferenceModel(object):
def __init__(self,
seq_len,
input_size,
num_hidden,
num_embed,
num_label,
arg_params,
ctx=mx.cpu(),
dropout=0.):
self.sym = bi_lstm_inference_symbol(input_size, seq_len,
num_hidden,
num_embed,
num_label,
dropout)
batch_size = 1
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(2)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(2)]
data_shape = [("data", (batch_size, seq_len, ))]
input_shapes = dict(init_c + init_h + data_shape)
self.executor = self.sym.simple_bind(ctx=mx.cpu(), **input_shapes)
for key in self.executor.arg_dict.keys():
if key in arg_params:
arg_params[key].copyto(self.executor.arg_dict[key])
state_name = []
for i in range(2):
state_name.append("l%d_init_c" % i)
state_name.append("l%d_init_h" % i)
self.states_dict = dict(zip(state_name, self.executor.outputs[1:]))
self.input_arr = mx.nd.zeros(data_shape[0][1])
def forward(self, input_data, new_seq=False):
if new_seq == True:
for key in self.states_dict.keys():
self.executor.arg_dict[key][:] = 0.
input_data.copyto(self.executor.arg_dict["data"])
self.executor.forward()
for key in self.states_dict.keys():
self.states_dict[key].copyto(self.executor.arg_dict[key])
prob = self.executor.outputs[0].asnumpy()
return prob