blob: a1cd8ea537d7c15579934f77894e3f1e19978af5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import mxnet as mx
from mxnet.gluon import contrib
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
from mxnet.test_utils import almost_equal
from common import setup_module, with_seed, teardown
import numpy as np
from numpy.testing import assert_allclose
def check_rnn_cell(cell, prefix, in_shape=(10, 50), out_shape=(10, 100), begin_state=None):
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs, begin_state=begin_state)
outputs = mx.sym.Group(outputs)
assert sorted(cell.collect_params().keys()) == [prefix+'h2h_bias', prefix+'h2h_weight',
prefix+'i2h_bias', prefix+'i2h_weight']
assert outputs.list_outputs() == [prefix+'t0_out_output', prefix+'t1_out_output', prefix+'t2_out_output']
args, outs, auxs = outputs.infer_shape(rnn_t0_data=in_shape,
rnn_t1_data=in_shape,
rnn_t2_data=in_shape)
assert outs == [out_shape]*3
def check_rnn_forward(layer, inputs):
inputs.attach_grad()
layer.collect_params().initialize()
with mx.autograd.record():
layer.unroll(3, inputs, merge_outputs=True)[0].backward()
mx.autograd.backward(layer.unroll(3, inputs, merge_outputs=False)[0])
mx.nd.waitall()
@with_seed()
def test_rnn_cells():
check_rnn_forward(contrib.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)),
mx.nd.ones((8, 3, 5, 7)))
check_rnn_forward(contrib.rnn.Conv1DRNNCell((5, 7), 10, (3,), (3,)),
mx.nd.ones((8, 3, 5, 7)))
check_rnn_forward(contrib.rnn.Conv1DGRUCell((5, 7), 10, (3,), (3,)),
mx.nd.ones((8, 3, 5, 7)))
net = mx.gluon.rnn.SequentialRNNCell()
net.add(contrib.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)))
net.add(contrib.rnn.Conv1DRNNCell((10, 5), 11, (3,), (3,)))
net.add(contrib.rnn.Conv1DGRUCell((11, 3), 12, (3,), (3,)))
check_rnn_forward(net, mx.nd.ones((8, 3, 5, 7)))
@with_seed()
def test_convrnn():
cell = contrib.rnn.Conv1DRNNCell((10, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48))
cell = contrib.rnn.Conv2DRNNCell((10, 20, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
cell = contrib.rnn.Conv3DRNNCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
@with_seed()
def test_convlstm():
cell = contrib.rnn.Conv1DLSTMCell((10, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48))
cell = contrib.rnn.Conv2DLSTMCell((10, 20, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
cell = contrib.rnn.Conv3DLSTMCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
@with_seed()
def test_convgru():
cell = contrib.rnn.Conv1DGRUCell((10, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48))
cell = contrib.rnn.Conv2DGRUCell((10, 20, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
cell = contrib.rnn.Conv3DGRUCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_')
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
@with_seed()
def test_conv_fill_shape():
cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
@with_seed()
def test_lstmp():
nhid = 100
nproj = 64
cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_')
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs)
outputs = mx.sym.Group(outputs)
expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
assert sorted(cell.collect_params().keys()) == expected_params
assert outputs.list_outputs() == expected_outputs, outputs.list_outputs()
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, nproj), (10, nproj), (10, nproj)]
@with_seed()
def test_vardrop():
def check_vardrop(drop_inputs, drop_states, drop_outputs):
cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
drop_outputs=drop_outputs,
drop_states=drop_states,
drop_inputs=drop_inputs)
cell.collect_params().initialize(init='xavier')
input_data = mx.nd.random_uniform(shape=(10, 3, 50), ctx=mx.context.current_context())
with mx.autograd.record():
outputs1, _ = cell.unroll(3, input_data, merge_outputs=True)
mx.nd.waitall()
outputs2, _ = cell.unroll(3, input_data, merge_outputs=True)
assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy())
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs, merge_outputs=False)
outputs = mx.sym.Group(outputs)
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
cell.reset()
cell.hybridize()
with mx.autograd.record():
outputs3, _ = cell.unroll(3, input_data, merge_outputs=True)
mx.nd.waitall()
outputs4, _ = cell.unroll(3, input_data, merge_outputs=True)
assert not almost_equal(outputs3.asnumpy(), outputs4.asnumpy())
assert not almost_equal(outputs1.asnumpy(), outputs3.asnumpy())
check_vardrop(0.5, 0.5, 0.5)
check_vardrop(0.5, 0, 0.5)
def test_concurrent():
model = HybridConcurrent(axis=1)
model.add(nn.Dense(128, activation='tanh', in_units=10))
model.add(nn.Dense(64, activation='tanh', in_units=10))
model.add(nn.Dense(32, in_units=10))
model2 = Concurrent(axis=1)
model2.add(nn.Dense(128, activation='tanh', in_units=10))
model2.add(nn.Dense(64, activation='tanh', in_units=10))
model2.add(nn.Dense(32, in_units=10))
# symbol
x = mx.sym.var('data')
y = model(x)
assert len(y.list_arguments()) == 7
# ndarray
model.initialize(mx.init.Xavier(magnitude=2.24))
model2.initialize(mx.init.Xavier(magnitude=2.24))
x = model(mx.nd.zeros((32, 10)))
x2 = model2(mx.nd.zeros((32, 10)))
assert x.shape == (32, 224)
assert x2.shape == (32, 224)
x.wait_to_read()
x2.wait_to_read()
@with_seed()
def test_identity():
model = Identity()
x = mx.nd.random.uniform(shape=(128, 33, 64))
mx.test_utils.assert_almost_equal(model(x).asnumpy(),
x.asnumpy())
@with_seed()
def test_sparse_embedding():
layer = SparseEmbedding(10, 100)
layer.initialize()
trainer = mx.gluon.Trainer(layer.collect_params(), 'sgd')
x = mx.nd.array([3,4,2,0,1])
with mx.autograd.record():
y = layer(x)
y.backward()
assert (layer.weight.grad().asnumpy()[:5] == 1).all()
assert (layer.weight.grad().asnumpy()[5:] == 0).all()
def test_datasets():
wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')
wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',
vocab=wikitext2_train.vocabulary)
wikitext2_test = contrib.data.text.WikiText2(root='data/wikitext-2', segment='test')
assert len(wikitext2_train) == 59305, len(wikitext2_train)
assert len(wikitext2_train.vocabulary) == 33278, len(wikitext2_train.vocabulary)
assert len(wikitext2_train.frequencies) == 33277, len(wikitext2_train.frequencies)
assert len(wikitext2_val) == 6181, len(wikitext2_val)
assert len(wikitext2_val.vocabulary) == 33278, len(wikitext2_val.vocabulary)
assert len(wikitext2_val.frequencies) == 13776, len(wikitext2_val.frequencies)
assert len(wikitext2_test) == 6974, len(wikitext2_test)
assert len(wikitext2_test.vocabulary) == 14143, len(wikitext2_test.vocabulary)
assert len(wikitext2_test.frequencies) == 14142, len(wikitext2_test.frequencies)
assert wikitext2_test.frequencies['English'] == 32
def test_sampler():
interval_sampler = contrib.data.IntervalSampler(10, 3)
assert sorted(list(interval_sampler)) == list(range(10))
interval_sampler = contrib.data.IntervalSampler(10, 3, rollover=False)
assert list(interval_sampler) == [0, 3, 6, 9]
if __name__ == '__main__':
import nose
nose.runmodule()