|  | # Licensed to the Apache Software Foundation (ASF) under one | 
|  | # or more contributor license agreements.  See the NOTICE file | 
|  | # distributed with this work for additional information | 
|  | # regarding copyright ownership.  The ASF licenses this file | 
|  | # to you under the Apache License, Version 2.0 (the | 
|  | # "License"); you may not use this file except in compliance | 
|  | # with the License.  You may obtain a copy of the License at | 
|  | # | 
|  | #   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | # | 
|  | # Unless required by applicable law or agreed to in writing, | 
|  | # software distributed under the License is distributed on an | 
|  | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | # KIND, either express or implied.  See the License for the | 
|  | # specific language governing permissions and limitations | 
|  | # under the License. | 
|  |  | 
|  | # pylint: disable=missing-docstring | 
|  | from __future__ import print_function | 
|  |  | 
|  | from collections import namedtuple | 
|  |  | 
|  | import mxnet as mx | 
|  | from nce import nce_loss | 
|  |  | 
|  | LSTMState = namedtuple("LSTMState", ["c", "h"]) | 
|  | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", | 
|  | "h2h_weight", "h2h_bias"]) | 
|  |  | 
|  |  | 
|  | def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): | 
|  | """LSTM Cell symbol""" | 
|  | if dropout > 0.: | 
|  | indata = mx.sym.Dropout(data=indata, p=dropout) | 
|  | i2h = mx.sym.FullyConnected(data=indata, | 
|  | weight=param.i2h_weight, | 
|  | bias=param.i2h_bias, | 
|  | num_hidden=num_hidden * 4, | 
|  | name="t%d_l%d_i2h" % (seqidx, layeridx)) | 
|  | h2h = mx.sym.FullyConnected(data=prev_state.h, | 
|  | weight=param.h2h_weight, | 
|  | bias=param.h2h_bias, | 
|  | num_hidden=num_hidden * 4, | 
|  | name="t%d_l%d_h2h" % (seqidx, layeridx)) | 
|  | gates = i2h + h2h | 
|  | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, | 
|  | name="t%d_l%d_slice" % (seqidx, layeridx)) | 
|  | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") | 
|  | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") | 
|  | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") | 
|  | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") | 
|  | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) | 
|  | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") | 
|  | return LSTMState(c=next_c, h=next_h) | 
|  |  | 
|  |  | 
|  | def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden): | 
|  | param_cells = [] | 
|  | last_states = [] | 
|  | for i in range(num_lstm_layer): | 
|  | param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), | 
|  | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), | 
|  | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), | 
|  | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) | 
|  | state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), | 
|  | h=mx.sym.Variable("l%d_init_h" % i)) | 
|  | last_states.append(state) | 
|  |  | 
|  | data = mx.sym.Variable('data') | 
|  | label = mx.sym.Variable('label') | 
|  | label_weight = mx.sym.Variable('label_weight') | 
|  | embed_weight = mx.sym.Variable('embed_weight') | 
|  | label_embed_weight = mx.sym.Variable('label_embed_weight') | 
|  | data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size, | 
|  | weight=embed_weight, | 
|  | output_dim=100, name='data_embed') | 
|  | datavec = mx.sym.SliceChannel(data=data_embed, | 
|  | num_outputs=seq_len, | 
|  | squeeze_axis=True, name='data_slice') | 
|  | labelvec = mx.sym.SliceChannel(data=label, | 
|  | num_outputs=seq_len, | 
|  | squeeze_axis=True, name='label_slice') | 
|  | labelweightvec = mx.sym.SliceChannel(data=label_weight, | 
|  | num_outputs=seq_len, | 
|  | squeeze_axis=True, name='label_weight_slice') | 
|  | probs = [] | 
|  | for seqidx in range(seq_len): | 
|  | hidden = datavec[seqidx] | 
|  |  | 
|  | for i in range(num_lstm_layer): | 
|  | next_state = _lstm(num_hidden, indata=hidden, | 
|  | prev_state=last_states[i], | 
|  | param=param_cells[i], | 
|  | seqidx=seqidx, layeridx=i) | 
|  | hidden = next_state.h | 
|  | last_states[i] = next_state | 
|  |  | 
|  | probs.append(nce_loss(data=hidden, | 
|  | label=labelvec[seqidx], | 
|  | label_weight=labelweightvec[seqidx], | 
|  | embed_weight=label_embed_weight, | 
|  | vocab_size=vocab_size, | 
|  | num_hidden=100)) | 
|  | return mx.sym.Group(probs) |