| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| use strict; |
| use warnings; |
| use Test::More tests => 77; |
| use AI::MXNet 'mx'; |
| use AI::MXNet::Gluon 'gluon'; |
| use AI::MXNet::TestUtils qw/allclose almost_equal/; |
| use AI::MXNet::Base; |
| use Scalar::Util 'blessed'; |
| |
| sub test_rnn |
| { |
| my $cell = gluon->rnn->RNNCell(100, prefix=>'rnn_'); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; |
| my ($outputs) = $cell->unroll(3, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); |
| is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); |
| |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); |
| is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); |
| } |
| |
| test_rnn(); |
| |
| sub test_lstm |
| { |
| my $cell = gluon->rnn->LSTMCell(100, prefix=>'rnn_'); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; |
| my ($outputs) = $cell->unroll(3, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); |
| is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); |
| |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); |
| is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); |
| } |
| |
| test_lstm(); |
| |
| sub test_lstm_forget_bias |
| { |
| my $forget_bias = 2; |
| my $stack = gluon->rnn->SequentialRNNCell(); |
| $stack->add(gluon->rnn->LSTMCell(100, i2h_bias_initializer=>mx->init->LSTMBias($forget_bias), prefix=>'l0_')); |
| $stack->add(gluon->rnn->LSTMCell(100, i2h_bias_initializer=>mx->init->LSTMBias($forget_bias), prefix=>'l1_')); |
| |
| my $dshape = [32, 1, 200]; |
| my $data = mx->sym->Variable('data'); |
| |
| my ($sym) = $stack->unroll(1, $data, merge_outputs=>1); |
| my $mod = mx->mod->Module($sym, context=>mx->cpu(0)); |
| $mod->bind(data_shapes=>[['data', $dshape]]); |
| |
| $mod->init_params(); |
| |
| my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments() }; |
| my $expected_bias = pdl((0)x100, ($forget_bias)x100, (0)x200); |
| ok(allclose(($mod->get_params())[0]->{$bias_argument}->aspdl, $expected_bias)); |
| } |
| |
| test_lstm_forget_bias(); |
| |
| sub test_gru |
| { |
| my $cell = gluon->rnn->GRUCell(100, prefix=>'rnn_'); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; |
| my ($outputs) = $cell->unroll(3, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); |
| is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); |
| |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); |
| is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); |
| } |
| |
| test_gru(); |
| |
| sub test_residual |
| { |
| my $cell = gluon->rnn->ResidualCell(gluon->rnn->GRUCell(50, prefix=>'rnn_')); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1]; |
| my ($outputs) = $cell->unroll(2, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50]); |
| is_deeply($outs, [[10, 50], [10, 50]]); |
| $outputs = $outputs->eval(args => { rnn_t0_data=>mx->nd->ones([10, 50]), |
| rnn_t1_data=>mx->nd->ones([10, 50]), |
| rnn_i2h_weight=>mx->nd->zeros([150, 50]), |
| rnn_i2h_bias=>mx->nd->zeros([150]), |
| rnn_h2h_weight=>mx->nd->zeros([150, 50]), |
| rnn_h2h_bias=>mx->nd->zeros([150]) }); |
| my $expected_outputs = mx->nd->ones([10, 50]); |
| ok(($outputs->[0] == $expected_outputs)->aspdl->all); |
| ok(($outputs->[1] == $expected_outputs)->aspdl->all); |
| } |
| |
| test_residual(); |
| |
| sub test_residual_bidirectional |
| { |
| my $cell = gluon->rnn->ResidualCell( |
| gluon->rnn->BidirectionalCell( |
| gluon->rnn->GRUCell(25, prefix=>'rnn_l_'), |
| gluon->rnn->GRUCell(25, prefix=>'rnn_r_') |
| ) |
| ); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1]; |
| my ($outputs) = $cell->unroll(2, $inputs, merge_outputs => 0); |
| $outputs = mx->sym->Group($outputs); |
| is_deeply([sort $cell->collect_params()->keys()], |
| ['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight', |
| 'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight']); |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50]); |
| is_deeply($outs, [[10, 50], [10, 50]]); |
| $outputs = $outputs->eval(args => { rnn_t0_data=>mx->nd->ones([10, 50])+5, |
| rnn_t1_data=>mx->nd->ones([10, 50])+5, |
| rnn_l_i2h_weight=>mx->nd->zeros([75, 50]), |
| rnn_l_i2h_bias=>mx->nd->zeros([75]), |
| rnn_l_h2h_weight=>mx->nd->zeros([75, 25]), |
| rnn_l_h2h_bias=>mx->nd->zeros([75]), |
| rnn_r_i2h_weight=>mx->nd->zeros([75, 50]), |
| rnn_r_i2h_bias=>mx->nd->zeros([75]), |
| rnn_r_h2h_weight=>mx->nd->zeros([75, 25]), |
| rnn_r_h2h_bias=>mx->nd->zeros([75]), |
| }); |
| my $expected_outputs = mx->nd->ones([10, 50])+5; |
| ok(($outputs->[0] == $expected_outputs)->aspdl->all); |
| ok(($outputs->[1] == $expected_outputs)->aspdl->all); |
| } |
| |
| test_residual_bidirectional(); |
| |
| sub test_stack |
| { |
| my $cell = gluon->rnn->SequentialRNNCell(); |
| for my $i (0..4) |
| { |
| if($i == 1) |
| { |
| $cell->add(gluon->rnn->ResidualCell(gluon->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_"))); |
| } |
| else |
| { |
| $cell->add(gluon->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_")); |
| } |
| } |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; |
| my ($outputs) = $cell->unroll(3, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| my %keys = map { $_ => 1 } $cell->collect_params()->keys(); |
| for my $i (0..4) |
| { |
| ok($keys{"rnn_stack${i}_h2h_weight"}); |
| ok($keys{"rnn_stack${i}_h2h_bias"}); |
| ok($keys{"rnn_stack${i}_i2h_weight"}); |
| ok($keys{"rnn_stack${i}_i2h_bias"}); |
| } |
| is_deeply($outputs->list_outputs(), ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']); |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); |
| is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); |
| } |
| |
| test_stack(); |
| |
| sub test_bidirectional |
| { |
| my $cell = gluon->rnn->BidirectionalCell( |
| gluon->rnn->LSTMCell(100, prefix=>'rnn_l0_'), |
| gluon->rnn->LSTMCell(100, prefix=>'rnn_r0_'), |
| output_prefix=>'rnn_bi_'); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; |
| my ($outputs) = $cell->unroll(3, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| is_deeply($outputs->list_outputs(), ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output']); |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); |
| is_deeply($outs, [[10, 200], [10, 200], [10, 200]]); |
| } |
| |
| test_bidirectional(); |
| |
| sub test_zoneout |
| { |
| my $cell = gluon->rnn->ZoneoutCell(gluon->rnn->RNNCell(100, prefix=>'rnn_'), zoneout_outputs=>0.5, |
| zoneout_states=>0.5); |
| my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; |
| my ($outputs) = $cell->unroll(3, $inputs); |
| $outputs = mx->sym->Group($outputs); |
| my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); |
| is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); |
| } |
| |
| test_zoneout(); |
| |
| sub check_rnn_forward |
| { |
| my ($layer, $inputs, $deterministic) = @_; |
| $deterministic //= 1; |
| $inputs->attach_grad(); |
| $layer->collect_params()->initialize(); |
| my $out; |
| mx->autograd->record(sub { |
| $out = ($layer->unroll(3, $inputs, merge_outputs=>0))[0]; |
| mx->autograd->backward($out); |
| $out = ($layer->unroll(3, $inputs, merge_outputs=>1))[0]; |
| $out->backward; |
| }); |
| |
| my $pdl_out = $out->aspdl; |
| my $pdl_dx = $inputs->grad->aspdl; |
| |
| $layer->hybridize; |
| |
| mx->autograd->record(sub { |
| $out = ($layer->unroll(3, $inputs, merge_outputs=>0))[0]; |
| mx->autograd->backward($out); |
| $out = ($layer->unroll(3, $inputs, merge_outputs=>1))[0]; |
| $out->backward; |
| }); |
| |
| if($deterministic) |
| { |
| ok(almost_equal($pdl_out, $out->aspdl, 1e-3)); |
| ok(almost_equal($pdl_dx, $inputs->grad->aspdl, 1e-3)); |
| } |
| } |
| |
| sub test_rnn_cells |
| { |
| check_rnn_forward(gluon->rnn->LSTMCell(100, input_size=>200), mx->nd->ones([8, 3, 200])); |
| check_rnn_forward(gluon->rnn->RNNCell(100, input_size=>200), mx->nd->ones([8, 3, 200])); |
| check_rnn_forward(gluon->rnn->GRUCell(100, input_size=>200), mx->nd->ones([8, 3, 200])); |
| my $bilayer = gluon->rnn->BidirectionalCell( |
| gluon->rnn->LSTMCell(100, input_size=>200), |
| gluon->rnn->LSTMCell(100, input_size=>200) |
| ); |
| check_rnn_forward($bilayer, mx->nd->ones([8, 3, 200])); |
| check_rnn_forward(gluon->rnn->DropoutCell(0.5), mx->nd->ones([8, 3, 200]), 0); |
| check_rnn_forward( |
| gluon->rnn->ZoneoutCell( |
| gluon->rnn->LSTMCell(100, input_size=>200), |
| 0.5, 0.2 |
| ), |
| mx->nd->ones([8, 3, 200]), |
| 0 |
| ); |
| my $net = gluon->rnn->SequentialRNNCell(); |
| $net->add(gluon->rnn->LSTMCell(100, input_size=>200)); |
| $net->add(gluon->rnn->RNNCell(100, input_size=>100)); |
| $net->add(gluon->rnn->GRUCell(100, input_size=>100)); |
| check_rnn_forward($net, mx->nd->ones([8, 3, 200])); |
| } |
| |
| test_rnn_cells(); |
| |
| sub check_rnn_layer_forward |
| { |
| my ($layer, $inputs, $states) = @_; |
| $layer->collect_params()->initialize(); |
| $inputs->attach_grad; |
| my $out; |
| mx->autograd->record(sub { |
| if(defined $states) |
| { |
| $out = $layer->($inputs, $states); |
| ok(@$out == 2); |
| $out = $out->[0]; |
| } |
| else |
| { |
| $out = $layer->($inputs); |
| ok(blessed $out and $out->isa('AI::MXNet::NDArray')); |
| } |
| $out->backward(); |
| }); |
| |
| my $pdl_out = $out->aspdl; |
| my $pdl_dx = $inputs->grad->aspdl; |
| $layer->hybridize; |
| mx->autograd->record(sub { |
| if(defined $states) |
| { |
| ($out, $states) = $layer->($inputs, $states); |
| ok(blessed $out and $out->isa('AI::MXNet::NDArray')); |
| } |
| else |
| { |
| $out = $layer->($inputs, $states); |
| ok(blessed $out and $out->isa('AI::MXNet::NDArray')); |
| } |
| $out->backward(); |
| }); |
| ok(almost_equal($pdl_out, $out->aspdl, 1e-3)); |
| ok(almost_equal($pdl_dx, $inputs->grad->aspdl, 1e-3)); |
| } |
| |
| sub test_rnn_layers |
| { |
| check_rnn_layer_forward(gluon->rnn->RNN(10, 2), mx->nd->ones([8, 3, 20])); |
| check_rnn_layer_forward(gluon->rnn->RNN(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), mx->nd->ones([4, 3, 10])); |
| check_rnn_layer_forward(gluon->rnn->LSTM(10, 2), mx->nd->ones([8, 3, 20])); |
| check_rnn_layer_forward(gluon->rnn->LSTM(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), [mx->nd->ones([4, 3, 10]), mx->nd->ones([4, 3, 10])]); |
| check_rnn_layer_forward(gluon->rnn->GRU(10, 2), mx->nd->ones([8, 3, 20])); |
| check_rnn_layer_forward(gluon->rnn->GRU(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), mx->nd->ones([4, 3, 10])); |
| } |
| |
| test_rnn_layers(); |
| |