| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| package AI::MXNet::RNN::Params; |
| use Mouse; |
| use AI::MXNet::Function::Parameters; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::Params - A container for holding variables. |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| A container for holding variables. |
| Used by RNN cells for parameter sharing between cells. |
| |
| Parameters |
| ---------- |
| prefix : str |
| All variables name created by this container will |
| be prepended with the prefix |
| =cut |
| has '_prefix' => (is => 'ro', init_arg => 'prefix', isa => 'Str', default => ''); |
| has '_params' => (is => 'rw', init_arg => undef); |
| around BUILDARGS => sub { |
| my $orig = shift; |
| my $class = shift; |
| return $class->$orig(prefix => $_[0]) if @_ == 1; |
| return $class->$orig(@_); |
| }; |
| |
| sub BUILD |
| { |
| my $self = shift; |
| $self->_params({}); |
| } |
| |
| |
| =head2 get |
| |
| Get a variable with the name or create a new one if does not exist. |
| |
| Parameters |
| ---------- |
| $name : str |
| name of the variable |
| @kwargs: |
| more arguments that are passed to mx->sym->Variable call |
| =cut |
| |
| method get(Str $name, @kwargs) |
| { |
| $name = $self->_prefix . $name; |
| if(not exists $self->_params->{$name}) |
| { |
| $self->_params->{$name} = AI::MXNet::Symbol->Variable($name, @kwargs); |
| } |
| return $self->_params->{$name}; |
| } |
| |
| package AI::MXNet::RNN::Cell::Base; |
| =head1 NAME |
| |
| AI::MXNet::RNNCell::Base |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Abstract base class for RNN cells |
| |
| Parameters |
| ---------- |
| prefix : str |
| prefix for name of layers |
| (and name of weight if params is undef) |
| params : AI::MXNet::RNN::Params or undef |
| container for weight sharing between cells. |
| created if undef. |
| =cut |
| |
| use AI::MXNet::Base; |
| use Mouse; |
| use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } }; |
| has '_prefix' => (is => 'rw', init_arg => 'prefix', isa => 'Str', default => ''); |
| has '_params' => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::RNN::Params]'); |
| has [qw/_own_params |
| _modified |
| _init_counter |
| _counter |
| /] => (is => 'rw', init_arg => undef); |
| |
| around BUILDARGS => sub { |
| my $orig = shift; |
| my $class = shift; |
| return $class->$orig(prefix => $_[0]) if @_ == 1; |
| return $class->$orig(@_); |
| }; |
| |
| sub BUILD |
| { |
| my $self = shift; |
| if(not defined $self->_params) |
| { |
| $self->_own_params(1); |
| $self->_params(AI::MXNet::RNN::Params->new($self->_prefix)); |
| } |
| else |
| { |
| $self->_own_params(0); |
| } |
| $self->_modified(0); |
| $self->reset; |
| } |
| |
| =head2 reset |
| |
| Reset before re-using the cell for another graph |
| =cut |
| |
| method reset() |
| { |
| $self->_init_counter(-1); |
| $self->_counter(-1); |
| } |
| |
| =head2 call |
| |
| Construct symbol for one step of RNN. |
| |
| Parameters |
| ---------- |
| $inputs : mx->sym->Variable |
| input symbol, 2D, batch * num_units |
| $states : mx->sym->Variable or ArrayRef[AI::MXNet::Symbol] |
| state from previous step or begin_state(). |
| |
| Returns |
| ------- |
| $output : AI::MXNet::Symbol |
| output symbol |
| $states : ArrayRef[AI::MXNet::Symbol] |
| state to next step of RNN. |
| Can be called via overloaded &{}: &{$cell}($inputs, $states); |
| =cut |
| |
| method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) |
| { |
| confess("Not Implemented"); |
| } |
| |
| method _gate_names() |
| { |
| ['']; |
| } |
| |
| =head2 params |
| |
| Parameters of this cell |
| =cut |
| |
| method params() |
| { |
| $self->_own_params(0); |
| return $self->_params; |
| } |
| |
| =head2 state_shape |
| |
| shape(s) of states |
| =cut |
| |
| method state_shape() |
| { |
| return [map { $_->{shape} } @{ $self->state_info }]; |
| } |
| |
| =head2 state_info |
| |
| shape and layout information of states |
| =cut |
| |
| method state_info() |
| { |
| confess("Not Implemented"); |
| } |
| |
| =head2 begin_state |
| |
| Initial state for this cell. |
| |
| Parameters |
| ---------- |
| :$func : sub ref, default is AI::MXNet::Symbol->can('zeros') |
| Function for creating initial state. |
| Can be AI::MXNet::Symbol->can('zeros'), |
| AI::MXNet::Symbol->can('uniform'), AI::MXNet::Symbol->can('Variable') etc. |
| Use AI::MXNet::Symbol->can('Variable') if you want to directly |
| feed the input as states. |
| @kwargs : |
| more keyword arguments passed to func. For example |
| mean, std, dtype, etc. |
| |
| Returns |
| ------- |
| $states : ArrayRef[AI::MXNet::Symbol] |
| starting states for first RNN step |
| =cut |
| |
| method begin_state(CodeRef :$func=AI::MXNet::Symbol->can('zeros'), @kwargs) |
| { |
| assert( |
| (not $self->_modified), |
| "After applying modifier cells (e.g. DropoutCell) the base " |
| ."cell cannot be called directly. Call the modifier cell instead." |
| ); |
| my @states; |
| my $func_needs_named_name = $func ne AI::MXNet::Symbol->can('Variable'); |
| for my $info (@{ $self->state_info }) |
| { |
| $self->_init_counter($self->_init_counter + 1); |
| my @name = (sprintf("%sbegin_state_%d", $self->_prefix, $self->_init_counter)); |
| my %info = %{ $info//{} }; |
| if($func_needs_named_name) |
| { |
| unshift(@name, 'name'); |
| } |
| else |
| { |
| if(exists $info{__layout__}) |
| { |
| $info{kwargs} = { __layout__ => delete $info{__layout__} }; |
| } |
| } |
| my %kwargs = (@kwargs, %info); |
| my $state = $func->( |
| 'AI::MXNet::Symbol', |
| @name, |
| %kwargs |
| ); |
| push @states, $state; |
| } |
| return \@states; |
| } |
| |
| =head2 unpack_weights |
| |
| Unpack fused weight matrices into separate |
| weight matrices |
| |
| Parameters |
| ---------- |
| $args : HashRef[AI::MXNet::NDArray] |
| hash ref containing packed weights. |
| usually from AI::MXNet::Module->get_output() |
| |
| Returns |
| ------- |
| $args : HashRef[AI::MXNet::NDArray] |
| hash ref with weights associated with |
| this cell, unpacked. |
| =cut |
| |
| method unpack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| my %args = %{ $args }; |
| my $h = $self->_num_hidden; |
| for my $group_name ('i2h', 'h2h') |
| { |
| my $weight = delete $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) }; |
| my $bias = delete $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) }; |
| enumerate(sub { |
| my ($j, $name) = @_; |
| my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name); |
| $args->{$wname} = $weight->slice([$j*$h,($j+1)*$h-1])->copy; |
| my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name); |
| $args->{$bname} = $bias->slice([$j*$h,($j+1)*$h-1])->copy; |
| }, $self->_gate_names); |
| } |
| return \%args; |
| } |
| |
| =head2 pack_weights |
| |
| Pack fused weight matrices into common |
| weight matrices |
| |
| Parameters |
| ---------- |
| args : HashRef[AI::MXNet::NDArray] |
| hash ref containing unpacked weights. |
| |
| Returns |
| ------- |
| $args : HashRef[AI::MXNet::NDArray] |
| hash ref with weights associated with |
| this cell, packed. |
| =cut |
| |
| method pack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| my %args = %{ $args }; |
| my $h = $self->_num_hidden; |
| for my $group_name ('i2h', 'h2h') |
| { |
| my @weight; |
| my @bias; |
| for my $name (@{ $self->_gate_names }) |
| { |
| my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name); |
| push @weight, delete $args{$wname}; |
| my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name); |
| push @bias, delete $args{$bname}; |
| } |
| $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate( |
| \@weight |
| ); |
| $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate( |
| \@bias |
| ); |
| } |
| return \%args; |
| } |
| |
| =head2 unroll |
| |
| Unroll an RNN cell across time steps. |
| |
| Parameters |
| ---------- |
| :$length : Int |
| number of steps to unroll |
| :$inputs : AI::MXNet::Symbol, array ref of Symbols, or undef |
| if inputs is a single Symbol (usually the output |
| of Embedding symbol), it should have shape |
| of [$batch_size, $length, ...] if layout == 'NTC' (batch, time series) |
| or ($length, $batch_size, ...) if layout == 'TNC' (time series, batch). |
| |
| If inputs is a array ref of symbols (usually output of |
| previous unroll), they should all have shape |
| ($batch_size, ...). |
| |
| If inputs is undef, a placeholder variables are |
| automatically created. |
| :$begin_state : array ref of Symbol |
| input states. Created by begin_state() |
| or output state of another cell. Created |
| from begin_state() if undef. |
| :$input_prefix : str |
| prefix for automatically created input |
| placehodlers. |
| :$layout : str |
| layout of input symbol. Only used if the input |
| is a single Symbol. |
| :$merge_outputs : Bool |
| If 0, returns outputs as an array ref of Symbols. |
| If 1, concatenates the output across the time steps |
| and returns a single symbol with the shape |
| [$batch_size, $length, ...) if the layout equal to 'NTC', |
| or [$length, $batch_size, ...) if the layout equal tp 'TNC'. |
| If undef, output whatever is faster |
| |
| Returns |
| ------- |
| $outputs : array ref of Symbol or Symbol |
| output symbols. |
| $states : Symbol or nested list of Symbol |
| has the same structure as begin_state() |
| =cut |
| |
| |
| method unroll( |
| Int $length, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, |
| Str :$input_prefix='', |
| Str :$layout='NTC', |
| Maybe[Bool] :$merge_outputs= |
| ) |
| { |
| $self->reset; |
| my $axis = index($layout, 'T'); |
| if(not defined $inputs) |
| { |
| $inputs = [ |
| map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1) |
| ]; |
| } |
| elsif(blessed($inputs)) |
| { |
| assert( |
| (@{ $inputs->list_outputs() } == 1), |
| "unroll doesn't allow grouped symbol as input. Please " |
| ."convert to list first or let unroll handle slicing" |
| ); |
| $inputs = AI::MXNet::Symbol->SliceChannel( |
| $inputs, |
| axis => $axis, |
| num_outputs => $length, |
| squeeze_axis => 1 |
| ); |
| } |
| else |
| { |
| assert(@$inputs == $length); |
| } |
| $begin_state //= $self->begin_state; |
| my $states = $begin_state; |
| my $outputs; |
| my @inputs = @{ $inputs }; |
| for my $i (0..$length-1) |
| { |
| my $output; |
| ($output, $states) = $self->( |
| $inputs[$i], |
| $states |
| ); |
| push @$outputs, $output; |
| } |
| if($merge_outputs) |
| { |
| @$outputs = map { AI::MXNet::Symbol->expand_dims($_, axis => $axis) } @$outputs; |
| $outputs = AI::MXNet::Symbol->Concat(@$outputs, dim => $axis); |
| } |
| return($outputs, $states); |
| } |
| |
| method _get_activation($inputs, $activation, @kwargs) |
| { |
| if(not ref $activation) |
| { |
| return AI::MXNet::Symbol->Activation($inputs, act_type => $activation, @kwargs); |
| } |
| else |
| { |
| return $activation->($inputs, @kwargs); |
| } |
| } |
| |
| method _cells_state_shape($cells) |
| { |
| return [map { @{ $_->state_shape } } @$cells]; |
| } |
| |
| method _cells_state_info($cells) |
| { |
| return [map { @{ $_->state_info } } @$cells]; |
| } |
| |
| method _cells_begin_state($cells, @kwargs) |
| { |
| return [map { @{ $_->begin_state(@kwargs) } } @$cells]; |
| } |
| |
| method _cells_unpack_weights($cells, $args) |
| { |
| $args = $_->unpack_weights($args) for @$cells; |
| return $args; |
| } |
| |
| method _cells_pack_weights($cells, $args) |
| { |
| $args = $_->pack_weights($args) for @$cells; |
| return $args; |
| } |
| |
| package AI::MXNet::RNN::Cell; |
| use Mouse; |
| extends 'AI::MXNet::RNN::Cell::Base'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::Cell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Simple recurrent neural network cell |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| activation : str or Symbol, default 'tanh' |
| type of activation function |
| prefix : str, default 'rnn_' |
| prefix for name of layers |
| (and name of weight if params is undef) |
| params : AI::MXNet::RNNParams or undef |
| container for weight sharing between cells. |
| created if undef. |
| =cut |
| |
| has '_num_hidden' => (is => 'ro', init_arg => 'num_hidden', isa => 'Int', required => 1); |
| has 'forget_bias' => (is => 'ro', isa => 'Num'); |
| has '_activation' => ( |
| is => 'ro', |
| init_arg => 'activation', |
| isa => 'Activation', |
| default => 'tanh' |
| ); |
| has '+_prefix' => (default => 'rnn_'); |
| has [qw/_iW _iB |
| _hW _hB/] => (is => 'rw', init_arg => undef); |
| |
| around BUILDARGS => sub { |
| my $orig = shift; |
| my $class = shift; |
| return $class->$orig(num_hidden => $_[0]) if @_ == 1; |
| return $class->$orig(@_); |
| }; |
| |
| sub BUILD |
| { |
| my $self = shift; |
| $self->_iW($self->params->get('i2h_weight')); |
| $self->_iB( |
| $self->params->get( |
| 'i2h_bias', |
| (defined($self->forget_bias) |
| ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias)) |
| : () |
| ) |
| ) |
| ); |
| $self->_hW($self->params->get('h2h_weight')); |
| $self->_hB($self->params->get('h2h_bias')); |
| } |
| |
| method state_info() |
| { |
| return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' }]; |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); |
| my $i2h = AI::MXNet::Symbol->FullyConnected( |
| data => $inputs, |
| weight => $self->_iW, |
| bias => $self->_iB, |
| num_hidden => $self->_num_hidden, |
| name => "${name}i2h" |
| ); |
| my $h2h = AI::MXNet::Symbol->FullyConnected( |
| data => @{$states}[0], |
| weight => $self->_hW, |
| bias => $self->_hB, |
| num_hidden => $self->_num_hidden, |
| name => "${name}h2h" |
| ); |
| my $output = $self->_get_activation( |
| $i2h + $h2h, |
| $self->_activation, |
| name => "${name}out" |
| ); |
| return ($output, [$output]); |
| } |
| |
| package AI::MXNet::RNN::LSTMCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::LSTMCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Long-Short Term Memory (LSTM) network cell. |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| prefix : str, default 'lstm_' |
| prefix for name of layers |
| (and name of weight if params is undef) |
| params : AI::MXNet::RNN::Params or None |
| container for weight sharing between cells. |
| created if undef. |
| forget_bias : bias added to forget gate, default 1.0. |
| Jozefowicz et al. 2015 recommends setting this to 1.0 |
| =cut |
| |
| has '+_prefix' => (default => 'lstm_'); |
| has '+_activation' => (init_arg => undef); |
| has '+forget_bias' => (is => 'ro', isa => 'Num', default => 1); |
| |
| method state_info() |
| { |
| return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' } , { shape => [0, $self->_num_hidden], __layout__ => 'NC' }]; |
| } |
| |
| method _gate_names() |
| { |
| [qw/_i _f _c _o/]; |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); |
| my @states = @{ $states }; |
| my $i2h = AI::MXNet::Symbol->FullyConnected( |
| data => $inputs, |
| weight => $self->_iW, |
| bias => $self->_iB, |
| num_hidden => $self->_num_hidden*4, |
| name => "${name}i2h" |
| ); |
| my $h2h = AI::MXNet::Symbol->FullyConnected( |
| data => $states[0], |
| weight => $self->_hW, |
| bias => $self->_hB, |
| num_hidden => $self->_num_hidden*4, |
| name => "${name}h2h" |
| ); |
| my $gates = $i2h + $h2h; |
| my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel( |
| $gates, num_outputs => 4, name => "${name}slice" |
| ) }; |
| my $in_gate = AI::MXNet::Symbol->Activation( |
| $slice_gates[0], act_type => "sigmoid", name => "${name}i" |
| ); |
| my $forget_gate = AI::MXNet::Symbol->Activation( |
| $slice_gates[1], act_type => "sigmoid", name => "${name}f" |
| ); |
| my $in_transform = AI::MXNet::Symbol->Activation( |
| $slice_gates[2], act_type => "tanh", name => "${name}c" |
| ); |
| my $out_gate = AI::MXNet::Symbol->Activation( |
| $slice_gates[3], act_type => "sigmoid", name => "${name}o" |
| ); |
| my $next_c = AI::MXNet::Symbol->_plus( |
| $forget_gate * $states[1], $in_gate * $in_transform, |
| name => "${name}state" |
| ); |
| my $next_h = AI::MXNet::Symbol->_mul( |
| $out_gate, |
| AI::MXNet::Symbol->Activation( |
| $next_c, act_type => "tanh" |
| ), |
| name => "${name}out" |
| ); |
| return ($next_h, [$next_h, $next_c]); |
| |
| } |
| |
| package AI::MXNet::RNN::GRUCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::GRUCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Gated Rectified Unit (GRU) network cell. |
| Note: this is an implementation of the cuDNN version of GRUs |
| (slight modification compared to Cho et al. 2014). |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| prefix : str, default 'gru_' |
| prefix for name of layers |
| (and name of weight if params is undef) |
| params : AI::MXNet::RNN::Params or undef |
| container for weight sharing between cells. |
| created if undef. |
| =cut |
| |
| has '+_prefix' => (default => 'gru_'); |
| |
| method _gate_names() |
| { |
| [qw/_r _z _o/]; |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); |
| my $prev_state_h = @{ $states }[0]; |
| my $i2h = AI::MXNet::Symbol->FullyConnected( |
| data => $inputs, |
| weight => $self->_iW, |
| bias => $self->_iB, |
| num_hidden => $self->_num_hidden*3, |
| name => "${name}i2h" |
| ); |
| my $h2h = AI::MXNet::Symbol->FullyConnected( |
| data => $prev_state_h, |
| weight => $self->_hW, |
| bias => $self->_hB, |
| num_hidden => $self->_num_hidden*3, |
| name => "${name}h2h" |
| ); |
| my ($i2h_r, $i2h_z); |
| ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel( |
| $i2h, num_outputs => 3, name => "${name}_i2h_slice" |
| ) }; |
| my ($h2h_r, $h2h_z); |
| ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel( |
| $h2h, num_outputs => 3, name => "${name}_h2h_slice" |
| ) }; |
| my $reset_gate = AI::MXNet::Symbol->Activation( |
| $i2h_r + $h2h_r, act_type => "sigmoid", name => "${name}_r_act" |
| ); |
| my $update_gate = AI::MXNet::Symbol->Activation( |
| $i2h_z + $h2h_z, act_type => "sigmoid", name => "${name}_z_act" |
| ); |
| my $next_h_tmp = AI::MXNet::Symbol->Activation( |
| $i2h + $reset_gate * $h2h, act_type => "tanh", name => "${name}_h_act" |
| ); |
| my $next_h = AI::MXNet::Symbol->_plus( |
| (1 - $update_gate) * $next_h_tmp, $update_gate * $prev_state_h, |
| name => "${name}out" |
| ); |
| return ($next_h, [$next_h]); |
| } |
| |
| package AI::MXNet::RNN::FusedCell; |
| use Mouse; |
| use AI::MXNet::Types; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell::Base'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::FusedCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Fusing RNN layers across time step into one kernel. |
| Improves speed but is less flexible. Currently only |
| supported if using cuDNN on GPU. |
| =cut |
| |
| has '_num_hidden' => (is => 'ro', isa => 'Int', init_arg => 'num_hidden', required => 1); |
| has '_num_layers' => (is => 'ro', isa => 'Int', init_arg => 'num_layers', default => 1); |
| has '_dropout' => (is => 'ro', isa => 'Num', init_arg => 'dropout', default => 0); |
| has '_get_next_state' => (is => 'ro', isa => 'Bool', init_arg => 'get_next_state', default => 0); |
| has '_bidirectional' => (is => 'ro', isa => 'Bool', init_arg => 'bidirectional', default => 0); |
| has 'forget_bias' => (is => 'ro', isa => 'Num', default => 1); |
| has 'initializer' => (is => 'rw', isa => 'Maybe[Initializer]'); |
| has '_mode' => ( |
| is => 'ro', |
| isa => enum([qw/rnn_relu rnn_tanh lstm gru/]), |
| init_arg => 'mode', |
| default => 'lstm' |
| ); |
| has [qw/_parameter |
| _directions/] => (is => 'rw', init_arg => undef); |
| |
| around BUILDARGS => sub { |
| my $orig = shift; |
| my $class = shift; |
| return $class->$orig(num_hidden => $_[0]) if @_ == 1; |
| return $class->$orig(@_); |
| }; |
| |
| sub BUILD |
| { |
| my $self = shift; |
| if(not $self->_prefix) |
| { |
| $self->_prefix($self->_mode.'_'); |
| } |
| if(not defined $self->initializer) |
| { |
| $self->initializer( |
| AI::MXNet::Xavier->new( |
| factor_type => 'in', |
| magnitude => 2.34 |
| ) |
| ); |
| } |
| if(not $self->initializer->isa('AI::MXNet::FusedRNN')) |
| { |
| $self->initializer( |
| AI::MXNet::FusedRNN->new( |
| init => $self->initializer, |
| num_hidden => $self->_num_hidden, |
| num_layers => $self->_num_layers, |
| mode => $self->_mode, |
| bidirectional => $self->_bidirectional, |
| forget_bias => $self->forget_bias |
| ) |
| ); |
| } |
| $self->_parameter($self->params->get('parameters', init => $self->initializer)); |
| $self->_directions($self->_bidirectional ? [qw/l r/] : ['l']); |
| } |
| |
| |
| method state_info() |
| { |
| my $b = @{ $self->_directions }; |
| my $n = $self->_mode eq 'lstm' ? 2 : 1; |
| return [map { +{ shape => [$b*$self->_num_layers, 0, $self->_num_hidden], __layout__ => 'LNC' } } 0..$n-1]; |
| } |
| |
| method _gate_names() |
| { |
| return { |
| rnn_relu => [''], |
| rnn_tanh => [''], |
| lstm => [qw/_i _f _c _o/], |
| gru => [qw/_r _z _o/] |
| }->{ $self->_mode }; |
| } |
| |
| method _num_gates() |
| { |
| return scalar(@{ $self->_gate_names }) |
| } |
| |
| method _slice_weights($arr, $li, $lh) |
| { |
| my %args; |
| my @gate_names = @{ $self->_gate_names }; |
| my @directions = @{ $self->_directions }; |
| |
| my $b = @directions; |
| my $p = 0; |
| for my $layer (0..$self->_num_layers-1) |
| { |
| for my $direction (@directions) |
| { |
| for my $gate (@gate_names) |
| { |
| my $name = sprintf('%s%s%d_i2h%s_weight', $self->_prefix, $direction, $layer, $gate); |
| my $size; |
| if($layer > 0) |
| { |
| $size = $b*$lh*$lh; |
| $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $b*$lh]); |
| } |
| else |
| { |
| $size = $li*$lh; |
| $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $li]); |
| } |
| $p += $size; |
| } |
| for my $gate (@gate_names) |
| { |
| my $name = sprintf('%s%s%d_h2h%s_weight', $self->_prefix, $direction, $layer, $gate); |
| my $size = $lh**2; |
| $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $lh]); |
| $p += $size; |
| } |
| } |
| } |
| for my $layer (0..$self->_num_layers-1) |
| { |
| for my $direction (@directions) |
| { |
| for my $gate (@gate_names) |
| { |
| my $name = sprintf('%s%s%d_i2h%s_bias', $self->_prefix, $direction, $layer, $gate); |
| $args{$name} = $arr->slice([$p,$p+$lh-1]); |
| $p += $lh; |
| } |
| for my $gate (@gate_names) |
| { |
| my $name = sprintf('%s%s%d_h2h%s_bias', $self->_prefix, $direction, $layer, $gate); |
| $args{$name} = $arr->slice([$p,$p+$lh-1]); |
| $p += $lh; |
| } |
| } |
| } |
| assert($p == $arr->size, "Invalid parameters size for FusedRNNCell"); |
| return %args; |
| } |
| |
| method unpack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| my %args = %{ $args }; |
| my $arr = delete $args{ $self->_parameter->name }; |
| my $b = @{ $self->_directions }; |
| my $m = $self->_num_gates; |
| my $h = $self->_num_hidden; |
| my $num_input = int(int(int($arr->size/$b)/$h)/$m) - ($self->_num_layers - 1)*($h+$b*$h+2) - $h - 2; |
| my %nargs = $self->_slice_weights($arr, $num_input, $self->_num_hidden); |
| %args = (%args, map { $_ => $nargs{$_}->copy } keys %nargs); |
| return \%args |
| } |
| |
| method pack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| my %args = %{ $args }; |
| my $b = @{ $self->_directions }; |
| my $m = $self->_num_gates; |
| my @c = @{ $self->_gate_names }; |
| my $h = $self->_num_hidden; |
| my $w0 = $args{ sprintf('%sl0_i2h%s_weight', $self->_prefix, $c[0]) }; |
| my $num_input = $w0->shape->[1]; |
| my $total = ($num_input+$h+2)*$h*$m*$b + ($self->_num_layers-1)*$m*$h*($h+$b*$h+2)*$b; |
| my $arr = AI::MXNet::NDArray->zeros([$total], ctx => $w0->context, dtype => $w0->dtype); |
| my %nargs = $self->_slice_weights($arr, $num_input, $h); |
| while(my ($name, $nd) = each %nargs) |
| { |
| $nd .= delete $args{ $name }; |
| } |
| $args{ $self->_parameter->name } = $arr; |
| return \%args; |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| confess("AI::MXNet::RNN::FusedCell cannot be stepped. Please use unroll"); |
| } |
| |
| method unroll( |
| Int $length, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, |
| Str :$input_prefix='', |
| Str :$layout='NTC', |
| Maybe[Bool] :$merge_outputs= |
| ) |
| { |
| $self->reset; |
| my $axis = index($layout, 'T'); |
| $inputs //= AI::MXNet::Symbol->Variable("${input_prefix}data"); |
| if(blessed($inputs)) |
| { |
| assert( |
| (@{ $inputs->list_outputs() } == 1), |
| "unroll doesn't allow grouped symbol as input. Please " |
| ."convert to list first or let unroll handle slicing" |
| ); |
| if($axis == 1) |
| { |
| AI::MXNet::Logging->warning( |
| "NTC layout detected. Consider using " |
| ."TNC for RNN::FusedCell for faster speed" |
| ); |
| $inputs = AI::MXNet::Symbol->SwapAxis($inputs, dim1 => 0, dim2 => 1); |
| } |
| else |
| { |
| assert($axis == 0, "Unsupported layout $layout"); |
| } |
| } |
| else |
| { |
| assert(@$inputs == $length); |
| $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis => 0) } @{ $inputs }]; |
| $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim => 0); |
| } |
| $begin_state //= $self->begin_state; |
| my $states = $begin_state; |
| my @states = @{ $states }; |
| my %states; |
| if($self->_mode eq 'lstm') |
| { |
| %states = (state => $states[0], state_cell => $states[1]); |
| } |
| else |
| { |
| %states = (state => $states[0]); |
| } |
| my $rnn = AI::MXNet::Symbol->RNN( |
| data => $inputs, |
| parameters => $self->_parameter, |
| state_size => $self->_num_hidden, |
| num_layers => $self->_num_layers, |
| bidirectional => $self->_bidirectional, |
| p => $self->_dropout, |
| state_outputs => $self->_get_next_state, |
| mode => $self->_mode, |
| name => $self->_prefix.'rnn', |
| %states |
| ); |
| my $outputs; |
| my %attr = (__layout__ => 'LNC'); |
| if(not $self->_get_next_state) |
| { |
| ($outputs, $states) = ($rnn, []); |
| } |
| elsif($self->_mode eq 'lstm') |
| { |
| my @rnn = @{ $rnn }; |
| $rnn[1]->_set_attr(%attr); |
| $rnn[2]->_set_attr(%attr); |
| ($outputs, $states) = ($rnn[0], [$rnn[1], $rnn[2]]); |
| } |
| else |
| { |
| my @rnn = @{ $rnn }; |
| $rnn[1]->_set_attr(%attr); |
| ($outputs, $states) = ($rnn[0], [$rnn[1]]); |
| } |
| if(defined $merge_outputs and not $merge_outputs) |
| { |
| AI::MXNet::Logging->warning( |
| "Call RNN::FusedCell->unroll with merge_outputs=1 " |
| ."for faster speed" |
| ); |
| $outputs = [@ { |
| AI::MXNet::Symbol->SliceChannel( |
| $outputs, |
| axis => 0, |
| num_outputs => $length, |
| squeeze_axis => 1 |
| ) |
| }]; |
| } |
| elsif($axis == 1) |
| { |
| $outputs = AI::MXNet::Symbol->SwapAxis($outputs, dim1 => 0, dim2 => 1); |
| } |
| return ($outputs, $states); |
| } |
| |
| =head2 unfuse |
| |
| Unfuse the fused RNN |
| |
| Returns |
| ------- |
| $cell : AI::MXNet::RNN::SequentialCell |
| unfused cell that can be used for stepping, and can run on CPU. |
| =cut |
| |
| method unfuse() |
| { |
| my $stack = AI::MXNet::RNN::SequentialCell->new; |
| my $get_cell = { |
| rnn_relu => sub { |
| AI::MXNet::RNN::Cell->new( |
| num_hidden => $self->_num_hidden, |
| activation => 'relu', |
| prefix => shift |
| ) |
| }, |
| rnn_tanh => sub { |
| AI::MXNet::RNN::Cell->new( |
| num_hidden => $self->_num_hidden, |
| activation => 'tanh', |
| prefix => shift |
| ) |
| }, |
| lstm => sub { |
| AI::MXNet::RNN::LSTMCell->new( |
| num_hidden => $self->_num_hidden, |
| prefix => shift |
| ) |
| }, |
| gru => sub { |
| AI::MXNet::RNN::GRUCell->new( |
| num_hidden => $self->_num_hidden, |
| prefix => shift |
| ) |
| }, |
| }->{ $self->_mode }; |
| for my $i (0..$self->_num_layers-1) |
| { |
| if($self->_bidirectional) |
| { |
| $stack->add( |
| AI::MXNet::RNN::BidirectionalCell->new( |
| $get_cell->(sprintf('%sl%d_', $self->_prefix, $i)), |
| $get_cell->(sprintf('%sr%d_', $self->_prefix, $i)), |
| output_prefix => sprintf('%sbi_%s_%d', $self->_prefix, $self->_mode, $i) |
| ) |
| ); |
| } |
| else |
| { |
| $stack->add($get_cell->(sprintf('%sl%d_', $self->_prefix, $i))); |
| } |
| } |
| return $stack; |
| } |
| |
| package AI::MXNet::RNN::SequentialCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell::Base'; |
| |
| =head1 NAME |
| |
| AI:MXNet::RNN::SequentialCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Sequentially stacking multiple RNN cells |
| |
| Parameters |
| ---------- |
| params : AI::MXNet::RNN::Params or undef |
| container for weight sharing between cells. |
| created if undef. |
| =cut |
| |
| has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef); |
| |
| sub BUILD |
| { |
| my ($self, $original_arguments) = @_; |
| $self->_override_cell_params(defined $original_arguments->{params}); |
| $self->_cells([]); |
| } |
| |
| =head2 add |
| |
| Append a cell to the stack. |
| |
| Parameters |
| ---------- |
| $cell : AI::MXNet::RNN::Cell::Base |
| =cut |
| |
| method add(AI::MXNet::RNN::Cell::Base $cell) |
| { |
| push @{ $self->_cells }, $cell; |
| if($self->_override_cell_params) |
| { |
| assert( |
| $cell->_own_params, |
| "Either specify params for SequentialRNNCell " |
| ."or child cells, not both." |
| ); |
| %{ $cell->params->_params } = (%{ $cell->params->_params }, %{ $self->params->_params }); |
| } |
| %{ $self->params->_params } = (%{ $self->params->_params }, %{ $cell->params->_params }); |
| } |
| |
| method state_info() |
| { |
| return $self->_cells_state_info($self->_cells); |
| } |
| |
| method begin_state(@kwargs) |
| { |
| assert( |
| (not $self->_modified), |
| "After applying modifier cells (e.g. DropoutCell) the base " |
| ."cell cannot be called directly. Call the modifier cell instead." |
| ); |
| return $self->_cells_begin_state($self->_cells, @kwargs); |
| } |
| |
| method unpack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| return $self->_cells_unpack_weights($self->_cells, $args) |
| } |
| |
| method pack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| return $self->_cells_pack_weights($self->_cells, $args); |
| } |
| |
| method call($inputs, $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my @next_states; |
| my $p = 0; |
| for my $cell (@{ $self->_cells }) |
| { |
| assert(not $cell->isa('AI::MXNet::BidirectionalCell')); |
| my $n = scalar(@{ $cell->state_info }); |
| my $state = [@{ $states }[$p..$p+$n-1]]; |
| $p += $n; |
| ($inputs, $state) = $cell->($inputs, $state); |
| push @next_states, $state; |
| } |
| return ($inputs, [map { @$_} @next_states]); |
| } |
| |
| method unroll( |
| Int $length, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, |
| Str :$input_prefix='', |
| Str :$layout='NTC', |
| Maybe[Bool] :$merge_outputs= |
| ) |
| { |
| my $num_cells = @{ $self->_cells }; |
| $begin_state //= $self->begin_state; |
| my $p = 0; |
| my $states; |
| my @next_states; |
| enumerate(sub { |
| my ($i, $cell) = @_; |
| my $n = @{ $cell->state_info }; |
| $states = [@{$begin_state}[$p..$p+$n-1]]; |
| $p += $n; |
| ($inputs, $states) = $cell->unroll( |
| $length, |
| inputs => $inputs, |
| input_prefix => $input_prefix, |
| begin_state => $states, |
| layout => $layout, |
| merge_outputs => ($i < $num_cells-1) ? undef : $merge_outputs |
| ); |
| push @next_states, $states; |
| }, $self->_cells); |
| return ($inputs, [map { @{ $_ } } @next_states]); |
| } |
| |
| package AI::MXNet::RNN::BidirectionalCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell::Base'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::BidirectionalCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Bidirectional RNN cell |
| |
| Parameters |
| ---------- |
| l_cell : AI::MXNet::RNN::Cell::Base |
| cell for forward unrolling |
| r_cell : AI::MXNet::RNN::Cell::Base |
| cell for backward unrolling |
| output_prefix : str, default 'bi_' |
| prefix for name of output |
| =cut |
| |
| has 'l_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); |
| has 'r_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); |
| has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_'); |
| has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef); |
| |
| around BUILDARGS => sub { |
| my $orig = shift; |
| my $class = shift; |
| if(@_ >= 2 and blessed $_[0] and blessed $_[1]) |
| { |
| my $l_cell = shift(@_); |
| my $r_cell = shift(@_); |
| return $class->$orig( |
| l_cell => $l_cell, |
| r_cell => $r_cell, |
| @_ |
| ); |
| } |
| return $class->$orig(@_); |
| }; |
| |
| sub BUILD |
| { |
| my ($self, $original_arguments) = @_; |
| $self->_override_cell_params(defined $original_arguments->{params}); |
| if($self->_override_cell_params) |
| { |
| assert( |
| ($self->l_cell->_own_params and $self->r_cell->_own_params), |
| "Either specify params for BidirectionalCell ". |
| "or child cells, not both." |
| ); |
| %{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params }); |
| %{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params }); |
| } |
| %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params }); |
| %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params }); |
| $self->_cells([$self->l_cell, $self->r_cell]); |
| } |
| |
| method unpack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| return $self->_cells_unpack_weights($self->_cells, $args) |
| } |
| |
| method pack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| return $self->_cells_pack_weights($self->_cells, $args); |
| } |
| |
| method call($inputs, $states) |
| { |
| confess("Bidirectional cannot be stepped. Please use unroll"); |
| } |
| |
| method state_info() |
| { |
| return $self->_cells_state_info($self->_cells); |
| } |
| |
| method begin_state(@kwargs) |
| { |
| assert((not $self->_modified), |
| "After applying modifier cells (e.g. DropoutCell) the base " |
| ."cell cannot be called directly. Call the modifier cell instead." |
| ); |
| return $self->_cells_begin_state($self->_cells, @kwargs); |
| } |
| |
| method unroll( |
| Int $length, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, |
| Str :$input_prefix='', |
| Str :$layout='NTC', |
| Maybe[Bool] :$merge_outputs= |
| ) |
| { |
| |
| my $axis = index($layout, 'T'); |
| if(not defined $inputs) |
| { |
| $inputs = [ |
| map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1) |
| ]; |
| } |
| elsif(blessed($inputs)) |
| { |
| assert( |
| (@{ $inputs->list_outputs() } == 1), |
| "unroll doesn't allow grouped symbol as input. Please " |
| ."convert to list first or let unroll handle slicing" |
| ); |
| $inputs = [ @{ AI::MXNet::Symbol->SliceChannel( |
| $inputs, |
| axis => $axis, |
| num_outputs => $length, |
| squeeze_axis => 1 |
| ) }]; |
| } |
| else |
| { |
| assert(@$inputs == $length); |
| } |
| $begin_state //= $self->begin_state; |
| my $states = $begin_state; |
| my ($l_cell, $r_cell) = @{ $self->_cells }; |
| my ($l_outputs, $l_states) = $l_cell->unroll( |
| $length, inputs => $inputs, |
| begin_state => [@{$states}[0..@{$l_cell->state_info}-1]], |
| layout => $layout, |
| merge_outputs => $merge_outputs |
| ); |
| my ($r_outputs, $r_states) = $r_cell->unroll( |
| $length, inputs => [reverse @{$inputs}], |
| begin_state => [@{$states}[@{$l_cell->state_info}..@{$states}-1]], |
| layout => $layout, |
| merge_outputs => $merge_outputs |
| ); |
| if(not defined $merge_outputs) |
| { |
| $merge_outputs = ( |
| blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol') |
| and |
| blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol') |
| ); |
| if(not $merge_outputs) |
| { |
| if(blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol')) |
| { |
| $l_outputs = [ |
| @{ AI::MXNet::Symbol->SliceChannel( |
| $l_outputs, axis => $axis, |
| num_outputs => $length, |
| squeeze_axis => 1 |
| ) } |
| ]; |
| } |
| if(blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol')) |
| { |
| $r_outputs = [ |
| @{ AI::MXNet::Symbol->SliceChannel( |
| $r_outputs, axis => $axis, |
| num_outputs => $length, |
| squeeze_axis => 1 |
| ) } |
| ]; |
| } |
| } |
| } |
| if($merge_outputs) |
| { |
| $l_outputs = [@{ $l_outputs }]; |
| $r_outputs = [@{ AI::MXNet::Symbol->reverse(blessed $r_outputs ? $r_outputs : @{ $r_outputs }, axis=>$axis) }]; |
| } |
| else |
| { |
| $r_outputs = [reverse(@{ $r_outputs })]; |
| } |
| my $outputs = []; |
| for(zip([0..@{ $l_outputs }-1], [@{ $l_outputs }], [@{ $r_outputs }])) { |
| my ($i, $l_o, $r_o) = @$_; |
| push @$outputs, AI::MXNet::Symbol->Concat( |
| $l_o, $r_o, dim=>(1+($merge_outputs?1:0)), |
| name => $merge_outputs |
| ? sprintf('%sout', $self->_output_prefix) |
| : sprintf('%st%d', $self->_output_prefix, $i) |
| ); |
| } |
| if($merge_outputs) |
| { |
| $outputs = @{ $outputs }[0]; |
| } |
| $states = [$l_states, $r_states]; |
| return($outputs, $states); |
| } |
| |
| package AI::MXNet::RNN::ConvCell::Base; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell::Base'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::Conv::Base |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Abstract base class for Convolutional RNN cells |
| |
| =cut |
| |
| has '_h2h_kernel' => (is => 'ro', isa => 'Shape', init_arg => 'h2h_kernel'); |
| has '_h2h_dilate' => (is => 'ro', isa => 'Shape', init_arg => 'h2h_dilate'); |
| has '_h2h_pad' => (is => 'rw', isa => 'Shape', init_arg => undef); |
| has '_i2h_kernel' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_kernel'); |
| has '_i2h_stride' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_stride'); |
| has '_i2h_dilate' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_dilate'); |
| has '_i2h_pad' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_pad'); |
| has '_num_hidden' => (is => 'ro', isa => 'DimSize', init_arg => 'num_hidden'); |
| has '_input_shape' => (is => 'ro', isa => 'Shape', init_arg => 'input_shape'); |
| has '_conv_layout' => (is => 'ro', isa => 'Str', init_arg => 'conv_layout', default => 'NCHW'); |
| has '_activation' => (is => 'ro', init_arg => 'activation'); |
| has '_state_shape' => (is => 'rw', init_arg => undef); |
| has [qw/i2h_weight_initializer h2h_weight_initializer |
| i2h_bias_initializer h2h_bias_initializer/] => (is => 'rw', isa => 'Maybe[Initializer]'); |
| |
| sub BUILD |
| { |
| my $self = shift; |
| assert ( |
| ($self->_h2h_kernel->[0] % 2 == 1 and $self->_h2h_kernel->[1] % 2 == 1), |
| "Only support odd numbers, got h2h_kernel= (@{[ $self->_h2h_kernel ]})" |
| ); |
| $self->_h2h_pad([ |
| int($self->_h2h_dilate->[0] * ($self->_h2h_kernel->[0] - 1) / 2), |
| int($self->_h2h_dilate->[1] * ($self->_h2h_kernel->[1] - 1) / 2) |
| ]); |
| # Infer state shape |
| my $data = AI::MXNet::Symbol->Variable('data'); |
| my $state_shape = AI::MXNet::Symbol->Convolution( |
| data => $data, |
| num_filter => $self->_num_hidden, |
| kernel => $self->_i2h_kernel, |
| stride => $self->_i2h_stride, |
| pad => $self->_i2h_pad, |
| dilate => $self->_i2h_dilate, |
| layout => $self->_conv_layout |
| ); |
| $state_shape = ($state_shape->infer_shape(data=>$self->_input_shape))[1]->[0]; |
| $state_shape->[0] = 0; |
| $self->_state_shape($state_shape); |
| } |
| |
| method state_info() |
| { |
| return [ |
| { shape => $self->_state_shape, __layout__ => $self->_conv_layout }, |
| { shape => $self->_state_shape, __layout__ => $self->_conv_layout } |
| ]; |
| } |
| |
| method call($inputs, $states) |
| { |
| confess("AI::MXNet::RNN::ConvCell::Base is abstract class for convolutional RNN"); |
| } |
| |
| package AI::MXNet::RNN::ConvCell; |
| use Mouse; |
| extends 'AI::MXNet::RNN::ConvCell::Base'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::ConvCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Convolutional RNN cells |
| |
| Parameters |
| ---------- |
| input_shape : array ref of int |
| Shape of input in single timestep. |
| num_hidden : int |
| Number of units in output symbol. |
| h2h_kernel : array ref of int, default (3, 3) |
| Kernel of Convolution operator in state-to-state transitions. |
| h2h_dilate : array ref of int, default (1, 1) |
| Dilation of Convolution operator in state-to-state transitions. |
| i2h_kernel : array ref of int, default (3, 3) |
| Kernel of Convolution operator in input-to-state transitions. |
| i2h_stride : array ref of int, default (1, 1) |
| Stride of Convolution operator in input-to-state transitions. |
| i2h_pad : array ref of int, default (1, 1) |
| Pad of Convolution operator in input-to-state transitions. |
| i2h_dilate : array ref of int, default (1, 1) |
| Dilation of Convolution operator in input-to-state transitions. |
| activation : str or Symbol, |
| default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2) |
| Type of activation function. |
| prefix : str, default 'ConvRNN_' |
| Prefix for name of layers (and name of weight if params is None). |
| params : RNNParams, default None |
| Container for weight sharing between cells. Created if None. |
| conv_layout : str, , default 'NCHW' |
| Layout of ConvolutionOp |
| =cut |
| |
| has '+_h2h_kernel' => (default => sub { [3, 3] }); |
| has '+_h2h_dilate' => (default => sub { [1, 1] }); |
| has '+_i2h_kernel' => (default => sub { [3, 3] }); |
| has '+_i2h_stride' => (default => sub { [1, 1] }); |
| has '+_i2h_dilate' => (default => sub { [1, 1] }); |
| has '+_i2h_pad' => (default => sub { [1, 1] }); |
| has '+_prefix' => (default => 'ConvRNN_'); |
| has '+_activation' => (default => sub { sub { AI::MXNet::Symbol->LeakyReLU(@_, act_type => 'leaky', slope => 0.2) } }); |
| has '+i2h_bias_initializer' => (default => 'zeros'); |
| has '+h2h_bias_initializer' => (default => 'zeros'); |
| has 'forget_bias' => (is => 'ro', isa => 'Num'); |
| has [qw/_iW _iB |
| _hW _hB/] => (is => 'rw', init_arg => undef); |
| |
| |
| sub BUILD |
| { |
| my $self = shift; |
| $self->_iW($self->_params->get('i2h_weight', init => $self->i2h_weight_initializer)); |
| $self->_hW($self->_params->get('h2h_weight', init => $self->h2h_weight_initializer)); |
| $self->_iB( |
| $self->params->get( |
| 'i2h_bias', |
| (defined($self->forget_bias and not defined $self->i2h_bias_initializer) |
| ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias)) |
| : (init => $self->i2h_bias_initializer) |
| ) |
| ) |
| ); |
| $self->_hB($self->_params->get('h2h_bias', init => $self->h2h_bias_initializer)); |
| } |
| |
| method _num_gates() |
| { |
| scalar(@{ $self->_gate_names() }); |
| } |
| |
| method _gate_names() |
| { |
| return [''] |
| } |
| |
| method _conv_forward($inputs, $states, $name) |
| { |
| my $i2h = AI::MXNet::Symbol->Convolution( |
| name => "${name}i2h", |
| data => $inputs, |
| num_filter => $self->_num_hidden*$self->_num_gates(), |
| kernel => $self->_i2h_kernel, |
| stride => $self->_i2h_stride, |
| pad => $self->_i2h_pad, |
| dilate => $self->_i2h_dilate, |
| weight => $self->_iW, |
| bias => $self->_iB |
| ); |
| my $h2h = AI::MXNet::Symbol->Convolution( |
| name => "${name}h2h", |
| data => @{ $states }[0], |
| num_filter => $self->_num_hidden*$self->_num_gates(), |
| kernel => $self->_h2h_kernel, |
| stride => [1, 1], |
| pad => $self->_h2h_pad, |
| dilate => $self->_h2h_dilate, |
| weight => $self->_hW, |
| bias => $self->_hB |
| ); |
| return ($i2h, $h2h); |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); |
| my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name); |
| my $output = $self->_get_activation($i2h + $h2h, $self->_activation, name => "${name}out"); |
| return ($output, [$output]); |
| } |
| |
| package AI::MXNet::RNN::ConvLSTMCell; |
| use Mouse; |
| extends 'AI::MXNet::RNN::ConvCell'; |
| has '+forget_bias' => (default => 1); |
| has '+_prefix' => (default => 'ConvLSTM_'); |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::ConvLSTMCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Convolutional LSTM network cell. |
| |
| Reference: |
| Xingjian et al. NIPS2015 |
| =cut |
| |
| method _gate_names() |
| { |
| return ['_i', '_f', '_c', '_o']; |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); |
| my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name); |
| my $gates = $i2h + $h2h; |
| my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel( |
| $gates, |
| num_outputs => 4, |
| axis => index($self->_conv_layout, 'C'), |
| name => "${name}slice" |
| ) }; |
| my $in_gate = AI::MXNet::Symbol->Activation( |
| $slice_gates[0], |
| act_type => "sigmoid", |
| name => "${name}i" |
| ); |
| my $forget_gate = AI::MXNet::Symbol->Activation( |
| $slice_gates[1], |
| act_type => "sigmoid", |
| name => "${name}f" |
| ); |
| my $in_transform = $self->_get_activation( |
| $slice_gates[2], |
| $self->_activation, |
| name => "${name}c" |
| ); |
| my $out_gate = AI::MXNet::Symbol->Activation( |
| $slice_gates[3], |
| act_type => "sigmoid", |
| name => "${name}o" |
| ); |
| my $next_c = AI::MXNet::Symbol->_plus( |
| $forget_gate * @{$states}[1], |
| $in_gate * $in_transform, |
| name => "${name}state" |
| ); |
| my $next_h = AI::MXNet::Symbol->_mul( |
| $out_gate, $self->_get_activation($next_c, $self->_activation), |
| name => "${name}out" |
| ); |
| return ($next_h, [$next_h, $next_c]); |
| } |
| |
| package AI::MXNet::RNN::ConvGRUCell; |
| use Mouse; |
| extends 'AI::MXNet::RNN::ConvCell'; |
| has '+_prefix' => (default => 'ConvGRU_'); |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::ConvGRUCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Convolutional GRU network cell. |
| =cut |
| |
| method _gate_names() |
| { |
| return ['_r', '_z', '_o']; |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) |
| { |
| $self->_counter($self->_counter + 1); |
| my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); |
| my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name); |
| my ($i2h_r, $i2h_z, $h2h_r, $h2h_z); |
| ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel($i2h, num_outputs => 3, name => "${name}_i2h_slice") }; |
| ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel($h2h, num_outputs => 3, name => "${name}_h2h_slice") }; |
| my $reset_gate = AI::MXNet::Symbol->Activation( |
| $i2h_r + $h2h_r, act_type => "sigmoid", |
| name => "${name}_r_act" |
| ); |
| my $update_gate = AI::MXNet::Symbol->Activation( |
| $i2h_z + $h2h_z, act_type => "sigmoid", |
| name => "${name}_z_act" |
| ); |
| my $next_h_tmp = $self->_get_activation($i2h + $reset_gate * $h2h, $self->_activation, name => "${name}_h_act"); |
| my $next_h = AI::MXNet::Symbol->_plus( |
| (1 - $update_gate) * $next_h_tmp, $update_gate * @{$states}[0], |
| name => "${name}out" |
| ); |
| return ($next_h, [$next_h]); |
| } |
| |
| package AI::MXNet::RNN::ModifierCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::Cell::Base'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::ModifierCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Base class for modifier cells. A modifier |
| cell takes a base cell, apply modifications |
| on it (e.g. Dropout), and returns a new cell. |
| |
| After applying modifiers the base cell should |
| no longer be called directly. The modifer cell |
| should be used instead. |
| =cut |
| |
| has 'base_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); |
| |
| around BUILDARGS => sub { |
| my $orig = shift; |
| my $class = shift; |
| if(@_%2) |
| { |
| my $base_cell = shift; |
| return $class->$orig(base_cell => $base_cell, @_); |
| } |
| return $class->$orig(@_); |
| }; |
| |
| sub BUILD |
| { |
| my $self = shift; |
| $self->base_cell->_modified(1); |
| } |
| |
| method params() |
| { |
| $self->_own_params(0); |
| return $self->base_cell->params; |
| } |
| |
| method state_info() |
| { |
| return $self->base_cell->state_info; |
| } |
| |
| method begin_state(CodeRef :$init_sym=AI::MXNet::Symbol->can('zeros'), @kwargs) |
| { |
| assert( |
| (not $self->_modified), |
| "After applying modifier cells (e.g. DropoutCell) the base " |
| ."cell cannot be called directly. Call the modifier cell instead." |
| ); |
| $self->base_cell->_modified(0); |
| my $begin_state = $self->base_cell->begin_state(func => $init_sym, @kwargs); |
| $self->base_cell->_modified(1); |
| return $begin_state; |
| } |
| |
| method unpack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| return $self->base_cell->unpack_weights($args) |
| } |
| |
| method pack_weights(HashRef[AI::MXNet::NDArray] $args) |
| { |
| return $self->base_cell->pack_weights($args) |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| confess("Not Implemented"); |
| } |
| |
| package AI::MXNet::RNN::DropoutCell; |
| use Mouse; |
| extends 'AI::MXNet::RNN::ModifierCell'; |
| has [qw/dropout_outputs dropout_states/] => (is => 'ro', isa => 'Num', default => 0); |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::DropoutCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Apply the dropout on base cell |
| =cut |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| my ($output, $states) = $self->base_cell->($inputs, $states); |
| if($self->dropout_outputs > 0) |
| { |
| $output = AI::MXNet::Symbol->Dropout(data => $output, p => $self->dropout_outputs); |
| } |
| if($self->dropout_states > 0) |
| { |
| $states = [map { AI::MXNet::Symbol->Dropout(data => $_, p => $self->dropout_states) } @{ $states }]; |
| } |
| return ($output, $states); |
| } |
| |
| package AI::MXNet::RNN::ZoneoutCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::ModifierCell'; |
| has [qw/zoneout_outputs zoneout_states/] => (is => 'ro', isa => 'Num', default => 0); |
| has 'prev_output' => (is => 'rw', init_arg => undef); |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::ZoneoutCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Apply Zoneout on base cell. |
| =cut |
| |
| sub BUILD |
| { |
| my $self = shift; |
| assert( |
| (not $self->base_cell->isa('AI::MXNet::RNN::FusedCell')), |
| "FusedRNNCell doesn't support zoneout. ". |
| "Please unfuse first." |
| ); |
| assert( |
| (not $self->base_cell->isa('AI::MXNet::RNN::BidirectionalCell')), |
| "BidirectionalCell doesn't support zoneout since it doesn't support step. ". |
| "Please add ZoneoutCell to the cells underneath instead." |
| ); |
| assert( |
| (not $self->base_cell->isa('AI::MXNet::RNN::SequentialCell') or not $self->_bidirectional), |
| "Bidirectional SequentialCell doesn't support zoneout. ". |
| "Please add ZoneoutCell to the cells underneath instead." |
| ); |
| } |
| |
| method reset() |
| { |
| $self->SUPER::reset; |
| $self->prev_output(undef); |
| } |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states); |
| my ($next_output, $next_states) = $cell->($inputs, $states); |
| my $mask = sub { |
| my ($p, $like) = @_; |
| AI::MXNet::Symbol->Dropout( |
| AI::MXNet::Symbol->ones_like( |
| $like |
| ), |
| p => $p |
| ); |
| }; |
| my $prev_output = $self->prev_output // AI::MXNet::Symbol->zeros(shape => [0, 0]); |
| my $output = $p_outputs != 0 |
| ? AI::MXNet::Symbol->where( |
| $mask->($p_outputs, $next_output), |
| $next_output, |
| $prev_output |
| ) |
| : $next_output; |
| my @states; |
| if($p_states != 0) |
| { |
| for(zip($next_states, $states)) { |
| my ($new_s, $old_s) = @$_; |
| push @states, AI::MXNet::Symbol->where( |
| $mask->($p_states, $new_s), |
| $new_s, |
| $old_s |
| ); |
| } |
| } |
| $self->prev_output($output); |
| return ($output, @states ? \@states : $next_states); |
| } |
| |
| package AI::MXNet::RNN::ResidualCell; |
| use Mouse; |
| use AI::MXNet::Base; |
| extends 'AI::MXNet::RNN::ModifierCell'; |
| |
| =head1 NAME |
| |
| AI::MXNet::RNN::ResidualCell |
| =cut |
| |
| =head1 DESCRIPTION |
| |
| Adds residual connection as described in Wu et al, 2016 |
| (https://arxiv.org/abs/1609.08144). |
| Output of the cell is output of the base cell plus input. |
| =cut |
| |
| method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) |
| { |
| my $output; |
| ($output, $states) = $self->base_cell->($inputs, $states); |
| $output = AI::MXNet::Symbol->elemwise_add($output, $inputs, name => $output->name.'_plus_residual'); |
| return ($output, $states) |
| } |
| |
| method unroll( |
| Int $length, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, |
| Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, |
| Str :$input_prefix='', |
| Str :$layout='NTC', |
| Maybe[Bool] :$merge_outputs= |
| ) |
| { |
| $self->reset; |
| $self->base_cell->_modified(0); |
| my ($outputs, $states) = $self->base_cell->unroll($length, inputs=>$inputs, begin_state=>$begin_state, |
| layout=>$layout, merge_outputs=>$merge_outputs); |
| $self->base_cell->_modified(1); |
| $merge_outputs //= (blessed($outputs) and $outputs->isa('AI::MXNet::Symbol')); |
| ($inputs) = _normalize_sequence($length, $inputs, $layout, $merge_outputs); |
| if($merge_outputs) |
| { |
| $outputs = AI::MXNet::Symbol->elemwise_add($outputs, $inputs, name => $outputs->name . "_plus_residual"); |
| } |
| else |
| { |
| my @temp; |
| for(zip([@{ $outputs }], [@{ $inputs }])) { |
| my ($output_sym, $input_sym) = @$_; |
| push @temp, AI::MXNet::Symbol->elemwise_add($output_sym, $input_sym, |
| name=>$output_sym->name."_plus_residual"); |
| } |
| $outputs = \@temp; |
| } |
| return ($outputs, $states); |
| } |
| |
| func _normalize_sequence($length, $inputs, $layout, $merge, $in_layout=) |
| { |
| assert((defined $inputs), |
| "unroll(inputs=>undef) has been deprecated. ". |
| "Please create input variables outside unroll." |
| ); |
| |
| my $axis = index($layout, 'T'); |
| my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis; |
| if(blessed($inputs)) |
| { |
| if(not $merge) |
| { |
| assert( |
| (@{ $inputs->list_outputs() } == 1), |
| "unroll doesn't allow grouped symbol as input. Please " |
| ."convert to list first or let unroll handle splitting" |
| ); |
| $inputs = [ @{ AI::MXNet::Symbol->split( |
| $inputs, |
| axis => $in_axis, |
| num_outputs => $length, |
| squeeze_axis => 1 |
| ) }]; |
| } |
| } |
| else |
| { |
| assert(not defined $length or @$inputs == $length); |
| if($merge) |
| { |
| $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis=>$axis) } @{ $inputs }]; |
| $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim=>$axis); |
| $in_axis = $axis; |
| } |
| } |
| |
| if(blessed($inputs) and $axis != $in_axis) |
| { |
| $inputs = AI::MXNet::Symbol->swapaxes($inputs, dim0=>$axis, dim1=>$in_axis); |
| } |
| return ($inputs, $axis); |
| } |
| |
| 1; |