blob: 51e6ad53e1710ba2084fb8cd9b1f6e36746ae677 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
use strict;
use warnings;
use Test::More tests => 77;
use AI::MXNet 'mx';
use AI::MXNet::Gluon 'gluon';
use AI::MXNet::TestUtils qw/allclose almost_equal/;
use AI::MXNet::Base;
use Scalar::Util 'blessed';
sub test_rnn
{
my $cell = gluon->rnn->RNNCell(100, prefix=>'rnn_');
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
my ($outputs) = $cell->unroll(3, $inputs);
$outputs = mx->sym->Group($outputs);
is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
test_rnn();
sub test_lstm
{
my $cell = gluon->rnn->LSTMCell(100, prefix=>'rnn_');
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
my ($outputs) = $cell->unroll(3, $inputs);
$outputs = mx->sym->Group($outputs);
is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
test_lstm();
sub test_lstm_forget_bias
{
my $forget_bias = 2;
my $stack = gluon->rnn->SequentialRNNCell();
$stack->add(gluon->rnn->LSTMCell(100, i2h_bias_initializer=>mx->init->LSTMBias($forget_bias), prefix=>'l0_'));
$stack->add(gluon->rnn->LSTMCell(100, i2h_bias_initializer=>mx->init->LSTMBias($forget_bias), prefix=>'l1_'));
my $dshape = [32, 1, 200];
my $data = mx->sym->Variable('data');
my ($sym) = $stack->unroll(1, $data, merge_outputs=>1);
my $mod = mx->mod->Module($sym, context=>mx->cpu(0));
$mod->bind(data_shapes=>[['data', $dshape]]);
$mod->init_params();
my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments() };
my $expected_bias = pdl((0)x100, ($forget_bias)x100, (0)x200);
ok(allclose(($mod->get_params())[0]->{$bias_argument}->aspdl, $expected_bias));
}
test_lstm_forget_bias();
sub test_gru
{
my $cell = gluon->rnn->GRUCell(100, prefix=>'rnn_');
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
my ($outputs) = $cell->unroll(3, $inputs);
$outputs = mx->sym->Group($outputs);
is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
test_gru();
sub test_residual
{
my $cell = gluon->rnn->ResidualCell(gluon->rnn->GRUCell(50, prefix=>'rnn_'));
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
my ($outputs) = $cell->unroll(2, $inputs);
$outputs = mx->sym->Group($outputs);
is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50]);
is_deeply($outs, [[10, 50], [10, 50]]);
$outputs = $outputs->eval(args => { rnn_t0_data=>mx->nd->ones([10, 50]),
rnn_t1_data=>mx->nd->ones([10, 50]),
rnn_i2h_weight=>mx->nd->zeros([150, 50]),
rnn_i2h_bias=>mx->nd->zeros([150]),
rnn_h2h_weight=>mx->nd->zeros([150, 50]),
rnn_h2h_bias=>mx->nd->zeros([150]) });
my $expected_outputs = mx->nd->ones([10, 50]);
ok(($outputs->[0] == $expected_outputs)->aspdl->all);
ok(($outputs->[1] == $expected_outputs)->aspdl->all);
}
test_residual();
sub test_residual_bidirectional
{
my $cell = gluon->rnn->ResidualCell(
gluon->rnn->BidirectionalCell(
gluon->rnn->GRUCell(25, prefix=>'rnn_l_'),
gluon->rnn->GRUCell(25, prefix=>'rnn_r_')
)
);
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
my ($outputs) = $cell->unroll(2, $inputs, merge_outputs => 0);
$outputs = mx->sym->Group($outputs);
is_deeply([sort $cell->collect_params()->keys()],
['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight',
'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50]);
is_deeply($outs, [[10, 50], [10, 50]]);
$outputs = $outputs->eval(args => { rnn_t0_data=>mx->nd->ones([10, 50])+5,
rnn_t1_data=>mx->nd->ones([10, 50])+5,
rnn_l_i2h_weight=>mx->nd->zeros([75, 50]),
rnn_l_i2h_bias=>mx->nd->zeros([75]),
rnn_l_h2h_weight=>mx->nd->zeros([75, 25]),
rnn_l_h2h_bias=>mx->nd->zeros([75]),
rnn_r_i2h_weight=>mx->nd->zeros([75, 50]),
rnn_r_i2h_bias=>mx->nd->zeros([75]),
rnn_r_h2h_weight=>mx->nd->zeros([75, 25]),
rnn_r_h2h_bias=>mx->nd->zeros([75]),
});
my $expected_outputs = mx->nd->ones([10, 50])+5;
ok(($outputs->[0] == $expected_outputs)->aspdl->all);
ok(($outputs->[1] == $expected_outputs)->aspdl->all);
}
test_residual_bidirectional();
sub test_stack
{
my $cell = gluon->rnn->SequentialRNNCell();
for my $i (0..4)
{
if($i == 1)
{
$cell->add(gluon->rnn->ResidualCell(gluon->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_")));
}
else
{
$cell->add(gluon->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_"));
}
}
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
my ($outputs) = $cell->unroll(3, $inputs);
$outputs = mx->sym->Group($outputs);
my %keys = map { $_ => 1 } $cell->collect_params()->keys();
for my $i (0..4)
{
ok($keys{"rnn_stack${i}_h2h_weight"});
ok($keys{"rnn_stack${i}_h2h_bias"});
ok($keys{"rnn_stack${i}_i2h_weight"});
ok($keys{"rnn_stack${i}_i2h_bias"});
}
is_deeply($outputs->list_outputs(), ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
test_stack();
sub test_bidirectional
{
my $cell = gluon->rnn->BidirectionalCell(
gluon->rnn->LSTMCell(100, prefix=>'rnn_l0_'),
gluon->rnn->LSTMCell(100, prefix=>'rnn_r0_'),
output_prefix=>'rnn_bi_');
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
my ($outputs) = $cell->unroll(3, $inputs);
$outputs = mx->sym->Group($outputs);
is_deeply($outputs->list_outputs(), ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output']);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 200], [10, 200], [10, 200]]);
}
test_bidirectional();
sub test_zoneout
{
my $cell = gluon->rnn->ZoneoutCell(gluon->rnn->RNNCell(100, prefix=>'rnn_'), zoneout_outputs=>0.5,
zoneout_states=>0.5);
my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
my ($outputs) = $cell->unroll(3, $inputs);
$outputs = mx->sym->Group($outputs);
my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
test_zoneout();
sub check_rnn_forward
{
my ($layer, $inputs, $deterministic) = @_;
$deterministic //= 1;
$inputs->attach_grad();
$layer->collect_params()->initialize();
my $out;
mx->autograd->record(sub {
$out = ($layer->unroll(3, $inputs, merge_outputs=>0))[0];
mx->autograd->backward($out);
$out = ($layer->unroll(3, $inputs, merge_outputs=>1))[0];
$out->backward;
});
my $pdl_out = $out->aspdl;
my $pdl_dx = $inputs->grad->aspdl;
$layer->hybridize;
mx->autograd->record(sub {
$out = ($layer->unroll(3, $inputs, merge_outputs=>0))[0];
mx->autograd->backward($out);
$out = ($layer->unroll(3, $inputs, merge_outputs=>1))[0];
$out->backward;
});
if($deterministic)
{
ok(almost_equal($pdl_out, $out->aspdl, 1e-3));
ok(almost_equal($pdl_dx, $inputs->grad->aspdl, 1e-3));
}
}
sub test_rnn_cells
{
check_rnn_forward(gluon->rnn->LSTMCell(100, input_size=>200), mx->nd->ones([8, 3, 200]));
check_rnn_forward(gluon->rnn->RNNCell(100, input_size=>200), mx->nd->ones([8, 3, 200]));
check_rnn_forward(gluon->rnn->GRUCell(100, input_size=>200), mx->nd->ones([8, 3, 200]));
my $bilayer = gluon->rnn->BidirectionalCell(
gluon->rnn->LSTMCell(100, input_size=>200),
gluon->rnn->LSTMCell(100, input_size=>200)
);
check_rnn_forward($bilayer, mx->nd->ones([8, 3, 200]));
check_rnn_forward(gluon->rnn->DropoutCell(0.5), mx->nd->ones([8, 3, 200]), 0);
check_rnn_forward(
gluon->rnn->ZoneoutCell(
gluon->rnn->LSTMCell(100, input_size=>200),
0.5, 0.2
),
mx->nd->ones([8, 3, 200]),
0
);
my $net = gluon->rnn->SequentialRNNCell();
$net->add(gluon->rnn->LSTMCell(100, input_size=>200));
$net->add(gluon->rnn->RNNCell(100, input_size=>100));
$net->add(gluon->rnn->GRUCell(100, input_size=>100));
check_rnn_forward($net, mx->nd->ones([8, 3, 200]));
}
test_rnn_cells();
sub check_rnn_layer_forward
{
my ($layer, $inputs, $states) = @_;
$layer->collect_params()->initialize();
$inputs->attach_grad;
my $out;
mx->autograd->record(sub {
if(defined $states)
{
$out = $layer->($inputs, $states);
ok(@$out == 2);
$out = $out->[0];
}
else
{
$out = $layer->($inputs);
ok(blessed $out and $out->isa('AI::MXNet::NDArray'));
}
$out->backward();
});
my $pdl_out = $out->aspdl;
my $pdl_dx = $inputs->grad->aspdl;
$layer->hybridize;
mx->autograd->record(sub {
if(defined $states)
{
($out, $states) = $layer->($inputs, $states);
ok(blessed $out and $out->isa('AI::MXNet::NDArray'));
}
else
{
$out = $layer->($inputs, $states);
ok(blessed $out and $out->isa('AI::MXNet::NDArray'));
}
$out->backward();
});
ok(almost_equal($pdl_out, $out->aspdl, 1e-3));
ok(almost_equal($pdl_dx, $inputs->grad->aspdl, 1e-3));
}
sub test_rnn_layers
{
check_rnn_layer_forward(gluon->rnn->RNN(10, 2), mx->nd->ones([8, 3, 20]));
check_rnn_layer_forward(gluon->rnn->RNN(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), mx->nd->ones([4, 3, 10]));
check_rnn_layer_forward(gluon->rnn->LSTM(10, 2), mx->nd->ones([8, 3, 20]));
check_rnn_layer_forward(gluon->rnn->LSTM(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), [mx->nd->ones([4, 3, 10]), mx->nd->ones([4, 3, 10])]);
check_rnn_layer_forward(gluon->rnn->GRU(10, 2), mx->nd->ones([8, 3, 20]));
check_rnn_layer_forward(gluon->rnn->GRU(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), mx->nd->ones([4, 3, 10]));
}
test_rnn_layers();