blob: 5ee86d0d99e5a5c06708d38364db083e90a179b0 [file]
import mxnet as mx
import numpy as np
def test_rnn():
cell = mx.rnn.RNNCell(100, prefix='rnn_')
outputs, _ = cell.unroll(3, input_prefix='rnn_')
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
def test_lstm():
cell = mx.rnn.LSTMCell(100, prefix='rnn_')
outputs, _ = cell.unroll(3, input_prefix='rnn_')
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
def test_gru():
cell = mx.rnn.GRUCell(100, prefix='rnn_')
outputs, _ = cell.unroll(3, input_prefix='rnn_')
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
def test_stack():
cell = mx.rnn.SequentialRNNCell()
for i in range(5):
cell.add(mx.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
outputs, _ = cell.unroll(3, input_prefix='rnn_')
outputs = mx.sym.Group(outputs)
keys = sorted(cell.params._params.keys())
for i in range(5):
assert 'rnn_stack%d_h2h_weight'%i in keys
assert 'rnn_stack%d_h2h_bias'%i in keys
assert 'rnn_stack%d_i2h_weight'%i in keys
assert 'rnn_stack%d_i2h_bias'%i in keys
assert outputs.list_outputs() == ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
def test_bidirectional():
cell = mx.rnn.BidirectionalCell(
mx.rnn.LSTMCell(100, prefix='rnn_l0_'),
mx.rnn.LSTMCell(100, prefix='rnn_r0_'),
output_prefix='rnn_bi_')
outputs, _ = cell.unroll(3, input_prefix='rnn_')
outputs = mx.sym.Group(outputs)
assert outputs.list_outputs() == ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 200), (10, 200), (10, 200)]
def test_unfuse():
cell = mx.rnn.FusedRNNCell(100, num_layers=1, mode='lstm',
prefix='test_', bidirectional=True).unfuse()
outputs, _ = cell.unroll(3, input_prefix='rnn_')
outputs = mx.sym.Group(outputs)
assert outputs.list_outputs() == ['test_bi_lstm_0t0_output', 'test_bi_lstm_0t1_output', 'test_bi_lstm_0t2_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 200), (10, 200), (10, 200)]
if __name__ == '__main__':
test_rnn()
test_lstm()
test_gru()
test_stack()
test_bidirectional()
test_unfuse()