blob: df5540407219ff4dce6d2d3e70c2c0f8331647ca [file]
"""
learning rate scheduler, which adaptive changes the learning rate based on the
progress
"""
import logging
class LRScheduler(object):
"""Base class of a learning rate scheduler"""
def __init__(self):
"""
base_lr : float
the initial learning rate
"""
self.base_lr = 0.01
def __call__(self, num_update):
"""
Call to schedule current learning rate
The training progress is presented by `num_update`, which can be roughly
viewed as the number of minibatches executed so far. Its value is
non-decreasing, and increases at most by one.
The exact value is the upper bound of the number of updates applied to
a weight/index
See more details in https://github.com/dmlc/mxnet/issues/625
Parameters
----------
num_update: int
the maximal number of updates applied to a weight.
"""
raise NotImplementedError("must override this")
class FactorScheduler(LRScheduler):
"""Reduce learning rate in factor
Assume the weight has been updated by n times, then the learning rate will
be
base_lr * factor^(floor(n/step))
Parameters
----------
step: int
schedule learning rate after n updates
factor: float
the factor for reducing the learning rate
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8):
super(FactorScheduler, self).__init__()
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
raise ValueError("Factor must be no more than 1 to make lr reduce")
self.step = step
self.factor = factor
self.stop_factor_lr = stop_factor_lr
self.count = 0
def __call__(self, num_update):
"""
Call to schedule current learning rate
Parameters
----------
num_update: int
the maximal number of updates applied to a weight.
"""
# NOTE: use while rather than if (for continuing training via load_epoch)
while num_update > self.count + self.step:
self.count += self.step
self.base_lr *= self.factor
if self.base_lr < self.stop_factor_lr:
self.base_lr = self.stop_factor_lr
logging.info("Update[%d]: now learning rate arrived at %0.5e, will not "
"change in the future", num_update, self.base_lr)
else:
logging.info("Update[%d]: Change learning rate to %0.5e",
num_update, self.base_lr)
return self.base_lr
class MultiFactorScheduler(LRScheduler):
"""Reduce learning rate in factor at steps specified in a list
Assume the weight has been updated by n times, then the learning rate will
be
base_lr * factor^(sum((step/n)<=1)) # step is an array
Parameters
----------
step: list of int
schedule learning rate after n updates
factor: float
the factor for reducing the learning rate
"""
def __init__(self, step, factor=1):
super(MultiFactorScheduler, self).__init__()
assert isinstance(step, list) and len(step) >= 1
for i, _step in enumerate(step):
if i != 0 and step[i] <= step[i-1]:
raise ValueError("Schedule step must be an increasing integer list")
if _step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
raise ValueError("Factor must be no more than 1 to make lr reduce")
self.step = step
self.cur_step_ind = 0
self.factor = factor
self.count = 0
def __call__(self, num_update):
"""
Call to schedule current learning rate
Parameters
----------
num_update: int
the maximal number of updates applied to a weight.
"""
# NOTE: use while rather than if (for continuing training via load_epoch)
while self.cur_step_ind <= len(self.step)-1:
if num_update > self.step[self.cur_step_ind]:
self.count = self.step[self.cur_step_ind]
self.cur_step_ind += 1
self.base_lr *= self.factor
logging.info("Update[%d]: Change learning rate to %0.5e",
num_update, self.base_lr)
else:
return self.base_lr
return self.base_lr