blob: ec6e700a854a8ec01bc306fc01b618fdf384733f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn, rnn
class RNNModel(gluon.HybridBlock):
"""A model with an encoder, recurrent layer, and a decoder."""
def __init__(self, mode, vocab_size, num_embed, num_hidden,
num_layers, dropout=0.5, tie_weights=False, **kwargs):
super(RNNModel, self).__init__(**kwargs)
with self.name_scope():
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(vocab_size, num_embed,
weight_initializer=mx.init.Uniform(0.1))
if mode == 'rnn_relu':
self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout,
input_size=num_embed)
elif mode == 'rnn_tanh':
self.rnn = rnn.RNN(num_hidden, num_layers, 'tanh', dropout=dropout,
input_size=num_embed)
elif mode == 'lstm':
self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout,
input_size=num_embed)
elif mode == 'gru':
self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout,
input_size=num_embed)
else:
raise ValueError("Invalid mode %s. Options are rnn_relu, "
"rnn_tanh, lstm, and gru"%mode)
if tie_weights:
self.decoder = nn.Dense(vocab_size, in_units=num_hidden,
params=self.encoder.params)
else:
self.decoder = nn.Dense(vocab_size, in_units=num_hidden)
self.num_hidden = num_hidden
def hybrid_forward(self, F, inputs, hidden):
emb = self.drop(self.encoder(inputs))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.reshape((-1, self.num_hidden)))
return decoded, hidden
def begin_state(self, *args, **kwargs):
return self.rnn.begin_state(*args, **kwargs)