blob: ac38b73ac4f8696017e09ff6064d585bad46dfda [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
from mxnet import gluon, np
import numpy as _np
import copy
from functools import partial
from numpy.testing import assert_allclose
from mxnet.gluon.utils import split_rnn_params
import pytest
from mxnet.test_utils import almost_equal, assert_almost_equal, default_device
from common import assert_raises_cudnn_not_satisfied, retry
def check_rnn_states(fused_states, stack_states, num_layers, bidirectional=False, is_lstm=True):
directions = 2 if bidirectional else 1
assert len(stack_states) / len(fused_states) == num_layers * directions
fused_states = [state.asnumpy() for state in fused_states]
stack_states = [_np.expand_dims(state.asnumpy(), axis=0) for state in stack_states]
if is_lstm:
stack_states_h = stack_states[0::2]
stack_states_c = stack_states[1::2]
stack_states = [_np.concatenate(stack_states_h, axis=0), _np.concatenate(stack_states_c, axis=0)]
else:
stack_states = [_np.concatenate(stack_states, axis=0)]
for f, s in zip(fused_states, stack_states):
assert f.shape == s.shape
assert_almost_equal(f, s, atol=1e-4, rtol=1e-4)
@mx.util.use_np
def test_rnn():
cell = gluon.rnn.RNNCell(100)
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight',
'i2h_bias', 'i2h_weight']
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
@mx.util.use_np
def test_lstm():
cell = gluon.rnn.LSTMCell(100)
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight']
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
@pytest.mark.serial
def test_lstmp():
hidden_size, projection_size = 512, 256
rtol, atol = 1e-4, 1e-4
batch_size, seq_len = 5, 3
input_size = 128
lstm_input = mx.np.random.uniform(size=(seq_len, batch_size, input_size))
# ==== Unidirectional Layer ====
for num_layers in [1, 3]:
fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
num_layers=num_layers, layout='TNC', bidirectional=False)
stack_layer = mx.gluon.rnn.HybridSequentialRNNCell()
for _ in range(num_layers):
stack_layer.add(gluon.rnn.LSTMPCell(hidden_size,
projection_size=projection_size))
fused_layer.initialize()
fused_begin_state = fused_layer.begin_state(batch_size)
stack_begin_state = stack_layer.begin_state(batch_size=batch_size)
fused_layer.infer_shape(lstm_input, fused_begin_state)
stack_layer.infer_shape(0, lstm_input, False)
stack_layer.initialize()
fused_layer_params = fused_layer.collect_params()
stack_layer_params = stack_layer.collect_params()
fused_weight_shape = fused_layer_params['rnn_param'].shape
print(fused_weight_shape)
w = mx.np.zeros(shape=fused_weight_shape)
fused_layer_params['rnn_param'].set_data(w)
fused_layer_params_split = split_rnn_params(w, 'lstm', num_layers, input_size,\
hidden_size, False, projection_size=projection_size)
for name, value in fused_layer_params_split.items():
stack_layer_params[name[1:].replace('_', '.', 1)].set_data(value)
fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state)
stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state,
layout='TNC',
merge_outputs=True)
assert_almost_equal(fused_output.asnumpy(), stack_output.asnumpy(), rtol=rtol, atol=atol)
check_rnn_states(fused_states, stack_states, num_layers, False)
# ==== Bidirectional Layer ====
for num_layers in [1, 3]:
fused_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
num_layers=num_layers, layout='TNC', bidirectional=True)
stack_layer = mx.gluon.rnn.HybridSequentialRNNCell()
for _ in range(num_layers):
stack_layer.add(
gluon.rnn.BidirectionalCell(gluon.rnn.LSTMPCell(hidden_size,
projection_size=projection_size),
gluon.rnn.LSTMPCell(hidden_size,
projection_size=projection_size)))
fused_layer.initialize()
stack_layer.initialize()
fused_begin_state = fused_layer.begin_state(batch_size)
stack_begin_state = stack_layer.begin_state(batch_size=batch_size)
fused_layer.infer_shape(lstm_input, fused_begin_state)
stack_layer.infer_shape(0, lstm_input, False)
stack_layer.initialize()
fused_layer_params = fused_layer.collect_params()
stack_layer_params = stack_layer.collect_params()
fused_weight_shape = fused_layer_params['rnn_param'].shape
w = mx.np.zeros(shape=fused_weight_shape)
fused_layer_params['rnn_param'].set_data(w)
fused_layer_params_split = split_rnn_params(w, 'lstm', num_layers, input_size,\
hidden_size, True, projection_size=projection_size)
for name, value in fused_layer_params_split.items():
cur = name.split("_")[0]
stack_layer_params["{}.{}_cell.{}".format(cur[1:], name[0], name[len(cur)+1:])].set_data(value)
fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state)
stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state,
layout='TNC',
merge_outputs=True)
assert_almost_equal(fused_output.asnumpy(), stack_output.asnumpy(), rtol=rtol, atol=atol)
check_rnn_states(fused_states, stack_states, num_layers, True)
@mx.util.use_np
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_lstm_cpu_inference():
# should behave the same as lstm cell
EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
[0.72045636, 0.72045636, 0.95215213, 0.95215213]],
[[0.95215213, 0.95215213, 0.72045636, 0.72045636],
[0.95215213, 0.95215213, 0.72045636, 0.72045636]]])
x = mx.np.ones(shape=(2, 2, 2))
model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
model.initialize(mx.init.One())
y = model(x).asnumpy()
mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
rtol=1e-3, atol=1e-5)
@mx.util.use_np
def test_gru():
cell = gluon.rnn.GRUCell(100, activation='relu', recurrent_activation='tanh')
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight']
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
@mx.util.use_np
@pytest.mark.serial
def test_residual():
cell = gluon.rnn.ResidualCell(gluon.rnn.GRUCell(50))
inputs = [mx.np.ones((10,50)) for i in range(2)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(2, inputs)
params = cell.collect_params()
assert sorted(params.keys()) == \
['base_cell.h2h_bias', 'base_cell.h2h_weight', 'base_cell.i2h_bias', 'base_cell.i2h_weight']
outs = [o.shape for o in outputs]
assert outs == [(10, 50), (10, 50)]
@mx.util.use_np
@pytest.mark.serial
def test_residual_bidirectional():
cell = gluon.rnn.ResidualCell(
gluon.rnn.BidirectionalCell(
gluon.rnn.GRUCell(25),
gluon.rnn.GRUCell(25)))
inputs = [mx.np.ones((10,50)) for i in range(2)]
cell.infer_shape(0, inputs[0], True)
cell.initialize()
outputs, _ = cell.unroll(2, inputs, merge_outputs=False)
params = cell.collect_params()
assert sorted(params.keys()) == \
['base_cell.l_cell.h2h_bias', 'base_cell.l_cell.h2h_weight',
'base_cell.l_cell.i2h_bias', 'base_cell.l_cell.i2h_weight',
'base_cell.r_cell.h2h_bias', 'base_cell.r_cell.h2h_weight',
'base_cell.r_cell.i2h_bias', 'base_cell.r_cell.i2h_weight']
outs = [o.shape for o in outputs]
assert outs == [(10, 50), (10, 50)]
@mx.util.use_np
def test_stack():
cell = gluon.rnn.SequentialRNNCell()
for i in range(5):
if i == 1:
cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100)))
else:
cell.add(gluon.rnn.LSTMCell(100))
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
keys = sorted(cell.collect_params().keys())
for i in range(5):
if i==1:
continue
assert '%d.h2h_weight'%i in keys
assert '%d.h2h_bias'%i in keys
assert '%d.i2h_weight'%i in keys
assert '%d.i2h_bias'%i in keys
assert '1.base_cell.h2h_weight' in keys
assert '1.base_cell.h2h_bias' in keys
assert '1.base_cell.i2h_weight' in keys
assert '1.base_cell.i2h_bias' in keys
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
@mx.util.use_np
@pytest.mark.serial
def test_hybridstack():
cell = gluon.rnn.HybridSequentialRNNCell()
for i in range(5):
if i == 1:
cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100)))
else:
cell.add(gluon.rnn.LSTMCell(100))
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
keys = sorted(cell.collect_params().keys())
for i in range(5):
if i==1:
continue
assert '%d.h2h_weight'%i in keys
assert '%d.h2h_bias'%i in keys
assert '%d.i2h_weight'%i in keys
assert '%d.i2h_bias'%i in keys
assert '1.base_cell.h2h_weight' in keys
assert '1.base_cell.h2h_bias' in keys
assert '1.base_cell.i2h_weight' in keys
assert '1.base_cell.i2h_bias' in keys
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
# Test HybridSequentialRNNCell nested in nn.HybridBlock, SequentialRNNCell will fail in this case
class BidirectionalOfSequential(gluon.HybridBlock):
def __init__(self):
super(BidirectionalOfSequential, self).__init__()
cell0 = gluon.rnn.HybridSequentialRNNCell()
cell0.add(gluon.rnn.LSTMCell(100))
cell0.add(gluon.rnn.LSTMCell(100))
cell1 = gluon.rnn.HybridSequentialRNNCell()
cell1.add(gluon.rnn.LSTMCell(100))
cell1.add(gluon.rnn.LSTMCell(100))
self.rnncell = gluon.rnn.BidirectionalCell(cell0, cell1)
def forward(self, x):
return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True)
def infer_shape(self, x, *args):
self.rnncell.infer_shape(0, x, True)
x = mx.np.random.uniform(size=(10, 3, 100))
net = BidirectionalOfSequential()
net.infer_shape(x)
net.initialize()
outs, _ = net(x)
assert outs.shape == (10, 3, 200)
@mx.util.use_np
def test_bidirectional():
cell = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(100),
gluon.rnn.LSTMCell(100))
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
outs = [o.shape for o in outputs]
assert outs == [(10, 200), (10, 200), (10, 200)]
@mx.util.use_np
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@pytest.mark.serial
def test_layer_bidirectional():
class RefBiLSTM(gluon.Block):
def __init__(self, size, **kwargs):
super(RefBiLSTM, self).__init__(**kwargs)
self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False)
self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False)
def forward(self, inpt):
fwd = self._lstm_fwd(inpt)
bwd_inpt = np.flip(inpt, 0)
bwd = self._lstm_bwd(bwd_inpt)
bwd = np.flip(bwd, 0)
return np.concatenate([fwd, bwd], axis=2)
size = 7
in_size = 5
weights = {}
for d in ['l', 'r']:
weights['{}0_i2h_weight'.format(d)] = mx.np.random.uniform(size=(size*4, in_size))
weights['{}0_h2h_weight'.format(d)] = mx.np.random.uniform(size=(size*4, size))
weights['{}0_i2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,))
weights['{}0_h2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,))
params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1)
for t in ['weight', 'bias']
for d in ['l', 'r']
for g in ['i2h', 'h2h'])
net_params_concat = mx.np.concatenate(params)
params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1)
for t in ['weight', 'bias']
for g in ['i2h', 'h2h'])
params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1)
for t in ['weight', 'bias']
for g in ['i2h', 'h2h'])
net_ref_left_params = mx.np.concatenate(params_left)
net_ref_right_params = mx.np.concatenate(params_right)
net = gluon.rnn.LSTM(size, bidirectional=True)
ref_net = RefBiLSTM(size)
net.initialize()
ref_net.initialize()
net_params = net.collect_params()
ref_net_params = ref_net.collect_params()
net_params['rnn_param'].set_data(net_params_concat)
ref_net_params['_lstm_fwd.rnn_param'].set_data(net_ref_left_params)
ref_net_params['_lstm_bwd.rnn_param'].set_data(net_ref_right_params)
data = mx.np.random.uniform(size=(11, 10, in_size))
assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy(), rtol=1e-04, atol=1e-02)
def test_zoneout():
cell = gluon.rnn.ZoneoutCell(gluon.rnn.RNNCell(100), zoneout_outputs=0.5,
zoneout_states=0.5)
inputs = [mx.np.ones((10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
@pytest.mark.serial
def test_unroll_layout():
cell = gluon.rnn.HybridSequentialRNNCell()
for i in range(5):
if i == 1:
cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100)))
else:
cell.add(gluon.rnn.LSTMCell(100))
inputs = [mx.np.random.uniform(size=(10,50)) for _ in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs, layout='TNC')
assert outputs[0].shape == (10, 100)
assert outputs[1].shape == (10, 100)
assert outputs[2].shape == (10, 100)
outputs, _ = cell.unroll(3, inputs, layout='NTC')
assert outputs[0].shape == (10, 100)
assert outputs[1].shape == (10, 100)
assert outputs[2].shape == (10, 100)
def check_rnn_forward_backward(layer, merged_inputs, hybridize, merge_outputs, deterministic):
input_size = 5
if merged_inputs:
inputs = mx.np.ones((8, 3, 5))
inputs.attach_grad()
else:
inputs = [mx.np.ones((8, 5)) for _ in range(3)]
for x in inputs:
x.attach_grad()
if hybridize:
layer.hybridize()
layer.initialize()
with mx.autograd.record():
out = layer.unroll(3, inputs, merge_outputs=merge_outputs)[0]
mx.autograd.backward(out)
if hasattr(layer, 'i2h_weight'):
assert layer.i2h_weight.shape[1] == input_size, (layer.i2h_weight.shape[1], input_size)
if merge_outputs:
np_out = out.asnumpy()
else:
np_out = _np.stack([x.asnumpy() for x in out], axis=1)
if merged_inputs:
np_dx = inputs.grad.asnumpy()
else:
np_dx = _np.stack([x.grad.asnumpy() for x in inputs], axis=1)
with mx.autograd.record():
out = layer.unroll(3, inputs, merge_outputs=not merge_outputs)[0]
mx.autograd.backward(out)
if merged_inputs:
input_grads = inputs.grad.asnumpy()
else:
input_grads = _np.stack([x.grad.asnumpy() for x in inputs], axis=1)
if deterministic:
if not merge_outputs:
ref_np_out = out.asnumpy()
else:
ref_np_out = _np.stack([x.asnumpy() for x in out], axis=1)
mx.test_utils.assert_almost_equal(np_out, ref_np_out, rtol=1e-3, atol=1e-5)
mx.test_utils.assert_almost_equal(np_dx, input_grads, rtol=1e-3, atol=1e-5)
@retry(3)
@pytest.mark.parametrize('layer,determinism', [
(gluon.rnn.LSTMCell(10, input_size=5), True),
(gluon.rnn.RNNCell(10, input_size=5), True),
(gluon.rnn.GRUCell(10, input_size=5), True),
(gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(10, input_size=5),
gluon.rnn.LSTMCell(10, input_size=5)
), True),
(gluon.rnn.DropoutCell(0.5), False),
])
@pytest.mark.parametrize('merged_inputs', [True, False])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('merge_outputs', [True, False, None])
@pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/18225')
def test_rnn_forward_backward(layer, merged_inputs, hybridize, merge_outputs, determinism):
check_rnn_forward_backward(layer, merged_inputs, hybridize, merge_outputs, determinism)
@pytest.mark.parametrize('seq_rnn_type', [
gluon.rnn.SequentialRNNCell,
gluon.rnn.HybridSequentialRNNCell
])
@pytest.mark.parametrize('determinism', [True, False])
@pytest.mark.parametrize('merged_inputs', [True, False])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('merge_outputs', [True, False, None])
@pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/18291')
def test_sequential_rnn_cells(seq_rnn_type, determinism, merged_inputs, hybridize, merge_outputs):
net = gluon.rnn.SequentialRNNCell()
net.add(gluon.rnn.LSTMCell(10, input_size=5))
net.add(gluon.rnn.RNNCell(10, input_size=10))
net.add(gluon.rnn.GRUCell(10, input_size=10))
if not determinism:
net.add(gluon.rnn.DropoutCell(0.5))
check_rnn_forward_backward(net, merged_inputs, hybridize, merge_outputs, determinism)
@mx.util.use_np
def test_rnn_cells_export_import():
class RNNLayer(gluon.HybridBlock):
def __init__(self):
super(RNNLayer, self).__init__()
self.cell = gluon.rnn.RNNCell(hidden_size=1)
def forward(self, seq):
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
return outputs
def infer_shape(self, input):
self.cell.infer_shape(0, input, False)
class LSTMLayer(gluon.HybridBlock):
def __init__(self):
super(LSTMLayer, self).__init__()
self.cell = gluon.rnn.LSTMCell(hidden_size=1)
def forward(self, seq):
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
return outputs
def infer_shape(self, input):
self.cell.infer_shape(0, input, False)
class GRULayer(gluon.HybridBlock):
def __init__(self):
super(GRULayer, self).__init__()
self.cell = gluon.rnn.GRUCell(hidden_size=1)
def forward(self, seq):
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
return outputs
def infer_shape(self, input):
self.cell.infer_shape(0, input, False)
for hybrid in [RNNLayer(), LSTMLayer(), GRULayer()]:
input = mx.np.ones(shape=(1, 2, 1), device=mx.device.current_device())
hybrid.infer_shape(input)
hybrid.initialize()
hybrid.hybridize()
output1 = hybrid(input)
hybrid.export(path="./model", epoch=0)
symbol = mx.gluon.SymbolBlock.imports(
symbol_file="./model-symbol.json",
input_names=["data"],
param_file="./model-0000.params",
device=mx.device.current_device()
)
output2 = symbol(input)
assert_almost_equal(output1.asnumpy(), output2.asnumpy())
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, device=mx.cpu()):
layer.initialize(device=device)
inputs = inputs.to_device(device)
inputs.attach_grad()
if states is not None:
if isinstance(states, (list, tuple)):
states = [s.to_device(device) for s in states]
else:
states = states.to_device(device)
with mx.autograd.record():
if states is None:
out = layer(inputs)
else:
out = layer(inputs, states)
if states is not None:
assert isinstance(out, (list, tuple)) and len(out) == 2
out = out[0]
else:
assert isinstance(out, mx.np.ndarray)
out.backward()
np_out = out.asnumpy()
np_dx = inputs.grad.asnumpy()
layer.hybridize()
with mx.autograd.record():
if states is not None:
out = layer(inputs, states)
assert isinstance(out, (list, tuple)) and len(out) == 2
out = out[0]
else:
out = layer(inputs)
assert isinstance(out, mx.np.ndarray)
out.backward()
if states is not None:
layer(inputs, states) # test is_training = false
else:
layer(inputs)
if not run_only:
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)
@mx.util.use_np
def run_rnn_layers(dtype, dtype2):
device = default_device()
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.np.ones((8, 3, 20), dtype=dtype), device=device)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.np.ones((8, 3, 20), dtype=dtype), mx.np.ones((4, 3, 10), dtype=dtype), device=device)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.np.ones((8, 3, 20), dtype=dtype), device=device)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.np.ones((8, 3, 20), dtype=dtype), [mx.np.ones((4, 3, 10), dtype=dtype), mx.np.ones((4, 3, 10), dtype=dtype)],device=device)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.np.ones((8, 3, 20), dtype=dtype),device=device)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.np.ones((8, 3, 20), dtype=dtype), mx.np.ones((4, 3, 10), dtype=dtype),device=device)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.np.ones((8, 3, 20), dtype=dtype),
run_only=True, device=device)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.np.ones((8, 3, 20), dtype=dtype), mx.np.ones((4, 3, 10), dtype=dtype), run_only=True, device=device)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.np.ones((8, 3, 20), dtype=dtype),
run_only=True, device=device)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.np.ones((8, 3, 20), dtype=dtype),
[mx.np.ones((4, 3, 10), dtype=dtype), mx.np.ones((4, 3, 10), dtype=dtype)], run_only=True, device=device)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.np.ones((8, 3, 20), dtype=dtype),
run_only=True, device=device)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.np.ones((8, 3, 20), dtype=dtype), mx.np.ones((4, 3, 10), dtype=dtype), run_only=True, device=device)
net = gluon.nn.Sequential()
net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
net.add(gluon.nn.BatchNorm(axis=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(3, activation='relu'))
net.initialize(device=device)
net.cast(dtype)
with mx.autograd.record():
out = net(mx.np.ones((2, 3, 10), dtype=dtype, device=device))
out.backward()
out = out.asnumpy()
net2 = gluon.nn.HybridSequential()
net2.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
net2.add(gluon.nn.BatchNorm(axis=2))
net2.add(gluon.nn.Flatten())
net2.add(gluon.nn.Dense(3, activation='relu'))
net2.hybridize()
net2.initialize(device=device)
net2.cast(dtype)
with mx.autograd.record():
out = net2(mx.np.ones((2, 3, 10), dtype=dtype, device=device))
out.backward()
out = out.asnumpy()
net3 = gluon.nn.HybridSequential()
net3.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype))
net3.add(gluon.nn.BatchNorm(axis=2))
net3.add(gluon.nn.Flatten())
net3.add(gluon.nn.Dense(3, activation='relu'))
net3.hybridize()
net3.initialize(device=device)
net3.cast(dtype2)
with mx.autograd.record():
out = net3(mx.np.ones((2, 3, 10), dtype=dtype2, device=device))
out.backward()
out = out.asnumpy()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@pytest.mark.serial
def test_rnn_layers_fp32():
run_rnn_layers('float32', 'float32')
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@pytest.mark.serial
def test_rnn_layers_fp16():
# Dynamic skip condition is best handled this way, rather than with pytest.mark.skipIf
if default_device().device_type == 'cpu':
pytest.skip('RNN FP16 only implemented for GPU for now')
run_rnn_layers('float16', 'float32')
def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4):
x = mx.np.random.normal(size=(1, 5, input_size))
fused_begin_state = fused_layer.begin_state(1)
stack_states = stack_layer.begin_state(batch_size=1)
fused_layer.infer_shape(x, fused_begin_state)
fused_layer_params = fused_layer.collect_params()
stack_layer.infer_shape(0, x, False)
stack_layer.initialize()
stack_layer_params = stack_layer.collect_params()
fused_weight_shape = fused_layer_params['rnn_param'].shape
w = mx.np.zeros(shape=fused_weight_shape)
fused_layer_params_split = split_rnn_params(w, mode, num_layers, input_size, hidden_size, bidirectional)
for name, value in fused_layer_params_split.items():
if 'bias' in name:
fused_layer_params_split[name] = mx.np.random.normal(size=value.shape)
_dir = 2 if bidirectional else 1
params = (fused_layer_params_split['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
for l in range(num_layers)
for d in ['l', 'r'][:_dir]
for g in ['i2h', 'h2h'])
fused_params = mx.np.concatenate(params)
fused_layer_params['rnn_param'].set_data(fused_params)
for name, value in fused_layer_params_split.items():
cur = name.split('_')[0]
num = cur[1:]
stack_name = ('{}.{}_cell.'.format(num, name[0]) if bidirectional else num + '.' ) + name[len(cur)+1:]
stack_layer_params[stack_name].set_data(value)
fx = x.copy()
sx = x.copy()
y = mx.np.random.uniform(size=(1, 5, hidden_size * 2 if bidirectional else hidden_size))
fx.attach_grad()
with mx.autograd.record():
fused_out, fused_states = fused_layer(fx, fused_begin_state)
l = loss(fused_out, y).mean()
l.backward()
mx.npx.waitall()
fused_grads = split_rnn_params(fused_layer.collect_params()['rnn_param'].data().grad,\
mode, num_layers, input_size, hidden_size, bidirectional)
fused_input_grad = fx.grad.asnumpy()
sx.attach_grad()
with mx.autograd.record():
stack_out, stack_states = stack_layer.unroll(5, sx, begin_state=stack_states, merge_outputs=True)
l = loss(stack_out, y).mean()
l.backward()
mx.npx.waitall()
stack_grads = dict([(name, p.grad()) for name, p in stack_layer.collect_params().items()])
stack_input_grad = sx.grad.asnumpy()
assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol)
assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
for name, value in fused_grads.items():
cur = name.split('_')[0]
num = cur[1:]
stack_name = ('{}.{}_cell.'.format(num, name[0]) if bidirectional else num + '.' ) + name[len(cur)+1:]
assert_allclose(value.asnumpy(), stack_grads[stack_name].asnumpy(), rtol=rtol, atol=atol)
num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1)
check_rnn_states(fused_states, stack_states, num_layers, bidirectional, len(fused_begin_state) == 2)
def create_op_by_mode(mode):
if mode == 'lstm':
fused_op = gluon.rnn.LSTM
stack_op = gluon.rnn.LSTMCell
recurrent_block_prefix = 'lstm0_'
elif mode == 'gru':
fused_op = gluon.rnn.GRU
stack_op = gluon.rnn.GRUCell
recurrent_block_prefix = 'gru0_'
elif mode == 'rnn_relu':
fused_op = partial(gluon.rnn.RNN, activation='relu')
stack_op = partial(gluon.rnn.RNNCell, activation='relu')
recurrent_block_prefix = 'rnn0_'
elif mode == 'rnn_tanh':
fused_op = partial(gluon.rnn.RNN, activation='tanh')
stack_op = partial(gluon.rnn.RNNCell, activation='tanh')
recurrent_block_prefix = 'rnn0_'
return fused_op, stack_op, recurrent_block_prefix
def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss):
fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False)
fused_layer.initialize()
stack_layer = mx.gluon.rnn.HybridSequentialRNNCell()
for _ in range(num_layers):
stack_layer.add(stack_op(hidden_size))
stack_layer.initialize()
check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size)
def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss):
fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True)
fused_layer.initialize()
stack_layer = mx.gluon.rnn.HybridSequentialRNNCell()
for _ in range(num_layers):
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size),
stack_op(hidden_size)))
stack_layer.initialize()
check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size, bidirectional=True)
@mx.util.use_np
@pytest.mark.parametrize('input_size', [8])
@pytest.mark.parametrize('hidden_size', [8, 16])
@pytest.mark.parametrize('num_layers', [1, 2, 3, 4])
@pytest.mark.parametrize('func', [check_rnn_unidir_layer_gradients,
check_rnn_bidir_layer_gradients])
@pytest.mark.parametrize('mode', ['lstm', 'gru', 'rnn_relu', 'rnn_tanh'])
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_layer(input_size, hidden_size, num_layers, func, mode):
loss = mx.gluon.loss.L2Loss()
func(mode, input_size, hidden_size, num_layers, loss)
@pytest.mark.serial
def test_rnn_unroll_variant_length():
# Test for imperative usage
cell_list = []
for base_cell_class in [gluon.rnn.RNNCell, gluon.rnn.LSTMCell, gluon.rnn.GRUCell]:
cell_list.append(base_cell_class(20))
cell_list.append(gluon.rnn.BidirectionalCell(
l_cell=base_cell_class(20),
r_cell=base_cell_class(20)))
cell_list.append(gluon.rnn.VariationalDropoutCell(base_cell=base_cell_class(20)))
stack_res_rnn_cell = gluon.rnn.SequentialRNNCell()
stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
cell_list.append(stack_res_rnn_cell)
batch_size = 4
max_length = 10
valid_length = [3, 10, 5, 6]
valid_length_nd = mx.np.array(valid_length)
for cell in cell_list:
# Test for NTC layout
print(type(cell))
data_nd = mx.np.random.normal(0, 1, size=(batch_size, max_length, 20))
if isinstance(cell, (gluon.rnn.HybridSequentialRNNCell, gluon.rnn.SequentialRNNCell)):
cell.infer_shape(0, data_nd, False)
elif isinstance(cell, gluon.rnn.BidirectionalCell):
cell.infer_shape(0, data_nd, True)
else:
cell.infer_shape(0, data_nd, False)
cell.initialize()
cell.hybridize()
outs, states = cell.unroll(length=max_length, inputs=data_nd,
valid_length=valid_length_nd,
merge_outputs=True,
layout='NTC')
for i, ele_length in enumerate(valid_length):
# Explicitly unroll each sequence and compare the final states and output
ele_out, ele_states = cell.unroll(length=ele_length,
inputs=data_nd[i:(i+1), :ele_length, :],
merge_outputs=True,
layout='NTC')
assert_almost_equal(ele_out, outs[i:(i+1), :ele_length, :])
if ele_length < max_length:
# Check the padded outputs are all zero
assert_almost_equal(outs[i:(i+1), ele_length:max_length, :], 0)
for valid_out_state, gt_state in zip(states, ele_states):
assert_almost_equal(valid_out_state[i:(i+1)], gt_state)
# Test for TNC layout
data_nd = mx.np.random.normal(0, 1, size=(max_length, batch_size, 20))
outs, states = cell.unroll(length=max_length, inputs=data_nd,
valid_length=valid_length_nd,
layout='TNC')
for i, ele_length in enumerate(valid_length):
# Explicitly unroll each sequence and compare the final states and output
ele_out, ele_states = cell.unroll(length=ele_length,
inputs=data_nd[:ele_length, i:(i+1), :],
merge_outputs=True,
layout='TNC')
assert_almost_equal(ele_out, outs[:ele_length, i:(i + 1), :])
if ele_length < max_length:
# Check the padded outputs are all zero
assert_almost_equal(outs[ele_length:max_length, i:(i+1), :], 0)
for valid_out_state, gt_state in zip(states, ele_states):
assert_almost_equal(valid_out_state[i:(i+1)], gt_state)
def test_cell_fill_shape():
cell = gluon.rnn.LSTMCell(10, input_size=7)
cell.hybridize()
assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
layer.hybridize()
check_rnn_layer_forward(layer, mx.np.ones((3, 2, 7)))
print(layer)
assert layer.rnn_param.shape[0] == 760
@pytest.mark.serial
def test_bidirectional_unroll_valid_length():
def _check_bidirectional_unroll_valid_length(length):
class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size),
gluon.rnn.LSTMCell(rnn_size))
def forward(self, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='NTC', merge_outputs=True)
return outputs, states
def infer_shape(self, x, *args):
self.bi_lstm.infer_shape(0, x, True)
rnn_size = 100
net = BiLSTM(rnn_size, length)
inputs_data = mx.np.random.uniform(size=(10, length, 50))
net.infer_shape(inputs_data)
net.initialize()
net.hybridize()
valid_len = mx.np.array([length]*10)
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (10, length, 200)
_check_bidirectional_unroll_valid_length(1)
_check_bidirectional_unroll_valid_length(3)
def check_rnn_forward(layer, inputs):
inputs.attach_grad()
layer.initialize()
with mx.autograd.record():
layer.unroll(3, inputs, merge_outputs=True)[0].backward()
mx.autograd.backward(layer.unroll(3, inputs, merge_outputs=False)[0])
mx.npx.waitall()
def test_rnn_cells():
check_rnn_forward(gluon.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)),
mx.np.ones((8, 3, 5, 7)))
check_rnn_forward(gluon.rnn.Conv1DRNNCell((5, 7), 10, (3,), (3,)),
mx.np.ones((8, 3, 5, 7)))
check_rnn_forward(gluon.rnn.Conv1DGRUCell((5, 7), 10, (3,), (3,)),
mx.np.ones((8, 3, 5, 7)))
net = mx.gluon.rnn.SequentialRNNCell()
net.add(gluon.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)))
net.add(gluon.rnn.Conv1DRNNCell((10, 5), 11, (3,), (3,)))
net.add(gluon.rnn.Conv1DGRUCell((11, 3), 12, (3,), (3,)))
check_rnn_forward(net, mx.np.ones((8, 3, 5, 7)))
@mx.util.use_np
def check_rnn_cell(cell, in_shape=(10, 50), out_shape=(10, 100), begin_state=None):
inputs = [mx.np.ones(shape=in_shape) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs, begin_state=begin_state)
assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight',
'i2h_bias', 'i2h_weight']
outs = [o.shape for o in outputs]
assert outs == [out_shape] * 3
@mx.util.use_np
def test_convrnn():
cell = gluon.rnn.Conv1DRNNCell((10, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 50), out_shape=(1, 100, 48))
cell = gluon.rnn.Conv2DRNNCell((10, 20, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
cell = gluon.rnn.Conv3DRNNCell((10, 20, 30, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
@mx.util.use_np
def test_convlstm():
cell = gluon.rnn.Conv1DLSTMCell((10, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 50), out_shape=(1, 100, 48))
cell = gluon.rnn.Conv2DLSTMCell((10, 20, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
cell = gluon.rnn.Conv3DLSTMCell((10, 20, 30, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
@mx.util.use_np
def test_convgru():
cell = gluon.rnn.Conv1DGRUCell((10, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 50), out_shape=(1, 100, 48))
cell = gluon.rnn.Conv2DGRUCell((10, 20, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
cell = gluon.rnn.Conv3DGRUCell((10, 20, 30, 50), 100, 3, 3)
check_rnn_cell(cell, in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
@mx.util.use_np
def test_conv_fill_shape():
input = mx.np.ones((8, 3, 5, 7))
cell = gluon.rnn.Conv1DLSTMCell((-1, 7), 10, (3,), (3,))
cell.infer_shape(0, input, False)
cell.hybridize()
check_rnn_forward(cell, input)
assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
@mx.util.use_np
def test_lstmp():
nhid = 100
nproj = 64
cell = gluon.rnn.LSTMPCell(nhid, nproj)
inputs = [mx.np.ones(shape=(10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs)
expected_params = ['h2h_bias', 'h2h_weight', 'h2r_weight', 'i2h_bias', 'i2h_weight']
assert sorted(cell.collect_params().keys()) == expected_params
outs = [o.shape for o in outputs]
assert outs == [(10, nproj)] * 3
@mx.util.use_np
def test_vardrop():
def check_vardrop(drop_inputs, drop_states, drop_outputs):
cell = gluon.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100),
drop_outputs=drop_outputs,
drop_states=drop_states,
drop_inputs=drop_inputs)
input_data = mx.np.random.uniform(size=(10, 3, 50), device=mx.device.current_device())
cell.infer_shape(0, input_data, False)
cell.initialize(init='xavier')
with mx.autograd.record():
outputs1, _ = cell.unroll(3, input_data, merge_outputs=True)
mx.npx.waitall()
outputs2, _ = cell.unroll(3, input_data, merge_outputs=True)
assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy())
inputs = [mx.np.ones(shape=(10,50)) for i in range(3)]
cell.infer_shape(0, inputs[0], False)
cell.initialize()
outputs, _ = cell.unroll(3, inputs, merge_outputs=False)
outs = [o.shape for o in outputs]
assert outs == [(10, 100), (10, 100), (10, 100)]
cell.reset()
cell.hybridize()
with mx.autograd.record():
outputs3, _ = cell.unroll(3, input_data, merge_outputs=True)
mx.npx.waitall()
outputs4, _ = cell.unroll(3, input_data, merge_outputs=True)
assert not almost_equal(outputs3.asnumpy(), outputs4.asnumpy())
assert not almost_equal(outputs1.asnumpy(), outputs3.asnumpy())
check_vardrop(0.5, 0.5, 0.5)
check_vardrop(0.5, 0, 0.5)
@mx.util.use_np
@pytest.mark.parametrize('cell_type,num_states', [
(gluon.rnn.RNNCell, 1),
(gluon.rnn.LSTMCell, 2),
(gluon.rnn.GRUCell, 1)
])
@pytest.mark.parametrize('layout', ['NTC', 'TNC'])
def test_unroll(cell_type, num_states, layout):
class RNNLayer(gluon.HybridBlock):
def __init__(self, cell_type, hidden_size, layout):
super(RNNLayer, self).__init__()
self.cell = cell_type(hidden_size)
self.layout = layout
def forward(self, inputs, states, valid_length):
if isinstance(valid_length, list) and len(valid_length) == 0:
valid_length = None
return gluon.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states,
valid_length=valid_length,
layout=self.layout)
def infer_shape(self, x, *args):
self.cell.infer_shape(0, x, False)
batch_size = 20
input_size = 50
hidden_size = 30
seq_len = 10
device = default_device()
if layout == 'TNC':
rnn_data = mx.np.random.normal(loc=0, scale=1, size=(seq_len, batch_size, input_size), device=device)
elif layout == 'NTC':
rnn_data = mx.np.random.normal(loc=0, scale=1, size=(batch_size, seq_len, input_size), device=device)
else:
print("Wrong layout")
return
valid_length = mx.np.round(mx.np.random.uniform(low=1, high=10, size=(batch_size), device=device))
state_shape = (batch_size, hidden_size)
states = [mx.np.random.normal(loc=0, scale=1, size=state_shape, device=device) for i in range(num_states)]
cell = cell_type(hidden_size)
if layout == 'TNC':
cell.infer_shape(0, rnn_data[0], False)
cell.initialize(device=default_device())
cell(rnn_data[0], states)
else:
cell.infer_shape(0, rnn_data[:,0,:], False)
cell.initialize(device=default_device())
cell(rnn_data[:,0,:], states)
params1 = cell.collect_params()
orig_params1 = copy.deepcopy(params1)
trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
layout=layout, merge_outputs=True)
res1.backward()
trainer.step(batch_size)
configs = [
lambda layer: None,
lambda layer: layer.hybridize(),
lambda layer: layer.hybridize({'inline_limit': 0}),
lambda layer: layer.hybridize({'static_alloc': True}),
lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ]
# We can't pass None to a hybrid block, but it accepts an empty list.
# so we use an empty list to represent valid_length if it's None.
if valid_length is None:
valid_length = []
for config in configs:
layer = RNNLayer(cell_type, hidden_size, layout)
layer.infer_shape(rnn_data)
layer.initialize(device=default_device())
config(layer)
res2, states2 = layer(rnn_data, states, valid_length)
params2 = layer.collect_params()
for key, val in orig_params1.items():
params2['cell.' + key].set_data(copy.deepcopy(val.data()))
trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res2, states2 = layer(rnn_data, states, valid_length)
assert_almost_equal(res1, res2, rtol=0.001, atol=0.0001)
assert len(states1) == len(states2)
for i in range(len(states1)):
assert_almost_equal(states1[i], states2[i], rtol=0.001, atol=0.0001)
res2.backward()
trainer.step(batch_size)
for key, val in params1.items():
weight1 = val.data()
weight2 = params2['cell.' + key].data()
# Subgraph created from npx.foreach in deferred compute is
# little bit different from the legacy foreach operator.
assert_almost_equal(weight1, weight2, rtol=0.1, atol=0.1)