blob: 12cd7c2391c3811c156b00fdc329e297475f0caf [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
use strict;
use warnings;
package AI::MXNet::Gluon::Trainer;
use AI::MXNet::NS;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;
use IO::File;
use Mouse;
=head1 NAME
AI::MXNet::Gluon::Trainer
=cut
=head1 DESCRIPTION
Applies an `Optimizer` on a set of Parameters. Trainer should
be used together with `autograd`.
Parameters
----------
params : AI::MXNet::Gluon::ParameterDict
The set of parameters to optimize.
optimizer : str or Optimizer
The optimizer to use. See
`help <http://mxnet.io/api/python/optimization/optimization.html#the-mxnet-optimizer-package>`_
on Optimizer for a list of available optimizers.
optimizer_params : hash ref
Key-word arguments to be passed to optimizer constructor. For example,
{learning_rate => 0.1}. All optimizers accept learning_rate, wd (weight decay),
clip_gradient, and lr_scheduler. See each optimizer's
constructor for a list of additional supported arguments.
kvstore : str or KVStore
kvstore type for multi-gpu and distributed training. See help on
mx->kvstore->create for more information.
compression_params : hash ref
Specifies type of gradient compression and additional arguments depending
on the type of compression being used. For example, 2bit compression requires a threshold.
Arguments would then be {type => '2bit', threshold => 0.5}
See AI::MXNet::KVStore->set_gradient_compression method for more details on gradient compression.
update_on_kvstore : Bool, default undef
Whether to perform parameter updates on kvstore. If undef, then trainer will choose the more
suitable option depending on the type of kvstore.
Properties
----------
learning_rate : float
The current learning rate of the optimizer. Given an Optimizer object
optimizer, its learning rate can be accessed as optimizer->learning_rate.
=cut
has 'params' => (is => 'rw', isa => 'HashRef|ArrayRef|AI::MXNet::Gluon::ParameterDict');
has 'optimizer' => (is => 'ro', isa => 'Optimizer');
has 'optimizer_params' => (is => 'ro', isa => 'Maybe[HashRef]');
has 'compression_params' => (is => 'ro', isa => 'Maybe[HashRef]');
has 'kvstore' => (is => 'rw', isa => 'Maybe[KVStore]', default => 'device');
has 'update_on_kvstore' => (is => 'rw', isa => 'Maybe[Bool]');
has [qw/_scale _contexts
_kv_initialized
_param2idx
_kvstore_params
_contains_sparse
_params_to_init
_updaters
_optimizer/] => (is => 'rw', init_arg => undef);
around BUILDARGS => \&AI::MXNet::Base::process_arguments;
method python_constructor_arguments()
{
[qw/params optimizer optimizer_params kvstore compression_params update_on_kvstore/]
}
sub BUILD
{
my $self = shift;
my @params;
if(blessed $self->params)
{
@params = $self->params->values;
}
elsif(ref $self->params eq 'HASH')
{
@params = values %{ $self->params };
}
else
{
@params = @{ $self->params };
}
$self->params([]);
$self->_contains_sparse(0);
$self->_param2idx({});
for(enumerate(\@params))
{
my ($i, $param) = @$_;
if(not(blessed $param and $param->isa('AI::MXNet::Gluon::Parameter')))
{
confess(
"First argument must be a array or hash of Parameters, ".
"got list of [$param]."
);
}
$self->_param2idx->{ $param->name } = $i;
push @{ $self->params }, $param;
$param->_set_trainer($self);
if($param->stype ne 'default')
{
$self->_contains_sparse(1);
}
}
my $optimizer_params = $self->optimizer_params//{};
$self->_scale(delete $optimizer_params->{rescale_grad}//1);
$self->_contexts($self->_check_contexts);
$self->_init_optimizer($self->optimizer, $optimizer_params);
$self->_kvstore_params({
kvstore => $self->kvstore,
update_on_kvstore => $self->update_on_kvstore
});
$self->_kv_initialized(0);
$self->kvstore(undef);
$self->update_on_kvstore(undef);
$self->_params_to_init([]);
$self->_reset_kvstore();
}
method _check_contexts()
{
my $contexts;
for my $param (@{ $self->params })
{
my $ctx = $param->list_ctx;
assert(
(not defined $contexts or join('', @{ $contexts }) eq join('', @{ $ctx })),
"All Parameters must be initialized on the same set of contexts, ".
"but Parameter ${\ $param->name } is initialized on @{ $ctx//[] } while previous Parameters ".
"are initialized on @{ $contexts//[] }."
);
$contexts = $ctx;
}
return $contexts;
}
method _init_optimizer($optimizer, $optimizer_params)
{
my %param_dict = map { $_ => $self->params->[$_] } 0 .. @{ $self->params } - 1;
if(blessed $optimizer and $optimizer->isa('AI::MXNet::Optimizer'))
{
assert(
(not %{ $optimizer_params }),
"optimizer_params must be empty if optimizer is an instance of ".
"Optimizer instead of str"
);
$self->_optimizer($optimizer);
$self->_optimizer->param_dict(\%param_dict);
}
else
{
$self->_optimizer(
AI::MXNet::Optimizer->create(
$optimizer, param_dict => \%param_dict,
%{ $optimizer_params }
)
);
}
$self->_updaters([
map { AI::MXNet::Optimizer->get_updater($self->_optimizer) } @{ $self->_contexts }
]);
}
method _init_params()
{
assert(
$self->_kv_initialized,
"Cannot initialize parameters in KVStore ".
"when KVStore is not initialized."
);
my @params_to_init;
if($self->kvstore)
{
for my $param (@{ $self->_params_to_init })
{
if(@{ $param->_deferred_init })
{
push @params_to_init, $param;
}
else
{
my $param_arrays = $param->_check_and_get($param->_data, []);
my $idx = $self->_param2idx->{ $param->name };
$self->kvstore->init($idx, $param_arrays->[0]);
if($param->stype eq 'default')
{
$self->kvstore->pull($idx, out => $param_arrays, priority=>-$idx);
}
}
}
}
$self->_params_to_init(\@params_to_init);
}
method _reset_kvstore()
{
if($self->kvstore and $self->kvstore->type =~ /dist/)
{
confess("Cannot reset distributed KVStore.");
}
$self->_kv_initialized(0);
$self->kvstore(undef);
$self->update_on_kvstore(undef);
$self->_params_to_init([@{ $self->params }]);
}
method _init_kvstore()
{
my $config = $self->_kvstore_params;
my ($kvstore, $update_on_kvstore);
if($self->_contains_sparse)
{
($kvstore, $update_on_kvstore) = AI::MXNet::Module::_create_sparse_kvstore($config->{kvstore});
# update_on_kvstore is set to False by the user
if(defined $config->{update_on_kvstore} and not $config->{update_on_kvstore})
{
confess(
"Cannot set update_on_kvstore to False when sparse ".
"gradients and/or sparse weights are present."
)
}
}
else
{
my %arg_arrays = map { $_->name => $_->data($self->_contexts->[0]) } @{ $self->params };
($kvstore, $update_on_kvstore) = AI::MXNet::Module::_create_kvstore(
$config->{kvstore}, scalar(@{$self->_contexts }), \%arg_arrays
);
if(defined $config->{update_on_kvstore})
{
$update_on_kvstore = $config->{update_on_kvstore};
}
}
if($kvstore)
{
if($self->compression_params)
{
$kvstore->set_gradient_compression($self->compression_params);
}
# kv->pull(row_sparse_grad) is not supported
if($kvstore->type =~ /dist/ and not $self->_contains_sparse)
{
$update_on_kvstore = 0;
}
if($update_on_kvstore)
{
# optimizer preferably needs to be set before init for multiprecision
$kvstore->set_optimizer($self->_optimizer);
}
$self->kvstore($kvstore);
$self->update_on_kvstore($update_on_kvstore);
}
else
{
$self->kvstore(undef);
$self->update_on_kvstore(undef);
}
$self->_kv_initialized(1);
}
# Internal method to invoke pull operations on KVStore. If $full_idx is set to 1,
# $kv->pull is preferred instead of $kv->row_sparse_pull.
method _row_sparse_pull($parameter, $out, $row_id, $full_idx=0)
{
# initialize kv and params if not already
$self->_init_kvstore() unless $self->_kv_initialized;
$self->_init_params() if scalar(@{ $self->_params_to_init });
my $idx = $self->_param2idx->{ $parameter->name };
if($full_idx and not $self->kvstore->type =~ /dist/)
{
assert($row_id->size == $out->shape->[0]);
$self->kvstore->pull($idx, out => $out, priority => -$idx, ignore_sparse => 0);
}
else
{
$self->kvstore->row_sparse_pull($idx, out => $out, row_ids => $row_id, priority => -$idx);
}
}
=head2 step
Makes one step of parameter update. Should be called after
`autograd->backward()` and outside of `record()` scope.
For normal parameter updates, `step()` should be used, which internally calls
`allreduce_grads()` and then `update()`. However, if you need to get the reduced
gradients to perform certain transformation, such as in gradient clipping, then
you may want to manually call `allreduce_grads()` and `update()` separately.
Parameters
----------
$batch_size : Int
Batch size of data processed. Gradient will be normalized by `1/batch_size`.
Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
$ignore_stale_grad : Bool, optional, default=False
If true, ignores Parameters with stale gradient (gradient that has not
been updated by `backward` after last step) and skip update.
=cut
method step(Int $batch_size, Bool $ignore_stale_grad=0)
{
$self->_init_kvstore() unless $self->_kv_initialized;
$self->_init_params() if scalar(@{ $self->_params_to_init });
$self->_optimizer->rescale_grad($self->_scale/$batch_size);
$self->_allreduce_grads();
$self->_update($ignore_stale_grad);
}
=head2 allreduce_grads
For each parameter, reduce the gradients from different contexts.
Should be called after `autograd.backward()`, outside of `record()` scope,
and before `trainer.update()`.
For normal parameter updates, `step()` should be used, which internally calls
`allreduce_grads()` and then `update()`. However, if you need to get the reduced
gradients to perform certain transformation, such as in gradient clipping, then
you may want to manually call `allreduce_grads()` and `update()` separately.
=cut
method allreduce_grads()
{
$self->_init_kvstore() unless $self->_kv_initialized;
$self->_init_params() if scalar(@{ $self->_params_to_init });
assert(
(not ($self->kvstore and $self->update_on_kvstore)),
'allreduce_grads() when parameters are updated on kvstore '.
'is not supported. Try setting `update_on_kvstore` '.
'to False when creating trainer.'
);
$self->_allreduce_grads();
}
method _allreduce_grads()
{
if($self->kvstore)
{
for(enumerate($self->params))
{
my ($i, $param) = @$_;
if($param->grad_req ne 'null')
{
$self->kvstore->push($i, $param->list_grad(), priority=>-$i);
if(not $self->update_on_kvstore)
{
$self->kvstore->pull($i, out => $param->list_grad(), priority=>-$i);
}
}
}
}
}
method learning_rate(Maybe [Num] $lr)
{
if(not blessed $self->_optimizer)
{
AI::MXNet::Logging->warning(
"Optimizer has to be defined before its learning ".
"rate can be accessed."
);
return;
}
else
{
if(defined $lr)
{
$self->_optimizer->lr($lr);
}
return $self->_optimizer->lr;
}
}
=head2 set_learning_rate
Sets a new learning rate of the optimizer.
Parameters
----------
lr : float
The new learning rate of the optimizer.
=cut
method set_learning_rate(Num $lr)
{
$self->learning_rate($lr);
}
=head2 update
Makes one step of parameter update.
Should be called after autograd->backward() and outside of record() scope,
and after trainer->update`.
For normal parameter updates, step() should be used, which internally calls
allreduce_grads() and then update(). However, if you need to get the reduced
gradients to perform certain transformation, such as in gradient clipping, then
you may want to manually call allreduce_grads() and update() separately.
Parameters
----------
$batch_size : Int
Batch size of data processed. Gradient will be normalized by `1/$batch_size`.
Set this to 1 if you normalized loss manually with $loss = mean($loss).
$ignore_stale_grad : Bool, optional, default=False
If true, ignores Parameters with stale gradient (gradient that has not
been updated by backward() after last step) and skip update.
=cut
method update(Int $batch_size, Bool $ignore_stale_grad=0)
{
$self->_init_kvstore() unless $self->_kv_initialized;
$self->_init_params() if scalar(@{ $self->_params_to_init });
assert(
(not ($self->kvstore and $self->update_on_kvstore)),
'update() when parameters are updated on kvstore '.
'is not supported. Try setting `update_on_kvstore` '.
'to False when creating trainer.'
);
$self->_optimizer->rescale_grad($self->_scale/$batch_size);
$self->_update($ignore_stale_grad);
}
method _update(Bool $ignore_stale_grad=0):
{
for(enumerate($self->params))
{
my ($i, $param) = @$_;
next if($param->grad_req eq 'null');
if(not $ignore_stale_grad)
{
for my $data (@{ $param->_check_and_get($param->_data, []) })
{
if(not $data->_fresh_grad)
{
AI::MXNet::Logging->warning(
"Gradient of Parameter '%s' on context %s has not been updated ".
"by backward since last `step`. This could mean a bug in your ".
"model that made it only use a subset of the Parameters (Blocks) ".
"for this iteration. If you are intentionally only using a subset, ".
"call step with ignore_stale_grad=True to suppress this ".
"warning and skip updating of Parameters with stale gradient",
$param->name, $data->context
);
}
}
}
if($self->kvstore and $self->update_on_kvstore)
{
if($param->stype eq 'default')
{
# 'row_sparse' parameters are not pulled immediately - they're pulled
# in `SparseBlock.sparse_forward`
$self->kvstore->pull($i, out => $param->list_data(), priority=>-$i);
}
next;
}
for(zip($self->_updaters, $param->list_data(), $param->list_grad()))
{
my ($upd, $arr, $grad) = @$_;
if(not $ignore_stale_grad or $arr->_fresh_grad)
{
$upd->($i, $grad, $arr);
$arr->_fresh_grad(0);
}
}
}
}
=head2 save_states
Saves trainer states (e.g. optimizer, momentum) to a file.
Parameters
----------
fname : str
Path to output states file.
=cut
method save_states(Str $fname)
{
assert(defined $self->_optimizer);
$self->_init_kvstore() unless $self->_kv_initialized;
$self->_init_params() if scalar(@{ $self->_params_to_init });
if($self->update_on_kvstore)
{
$self->kvstore->save_optimizer_states($fname, dump_optimizer=>1);
}
else
{
open(F, ">$fname") or Carp::confess("can not open $fname: $1");
print F $self->_updaters->[0]->get_states(dump_optimizer => 1);
close(F);
}
}
=head2 load_states
Loads trainer states (e.g. optimizer, momentum) from a file.
Parameters
----------
fname : str
Path to input states file.
=cut
method load_states(Str $fname)
{
$self->_init_kvstore() unless $self->_kv_initialized;
$self->_init_params() if scalar(@{ $self->_params_to_init });
if($self->update_on_kvstore)
{
$self->kvstore->load_optimizer_states($fname);
$self->_optimizer($self->kvstore->_updater->optimizer);
$self->_optimizer->param_dict({ map { $_->[0] => $_->[1] } enumerate($self->params) });
}
else
{
my $states = join('', IO::File->new($fname)->getlines);
for my $updater (@{ $self->_updaters })
{
$updater->set_states($states);
$updater->optimizer($self->_updaters->[0]->optimizer);
}
$self->_optimizer($self->_updaters->[0]->optimizer);
}
}
__PACKAGE__->AI::MXNet::NS::register('AI::MXNet::Gluon');
1;