blob: b158981ecf9742e4cc18cc9b2c99c194f5a89067 [file] [log] [blame]
# pylint: disable=invalid-name
"""Learning rate scheduler."""
import math
import logging
class LearningRateScheduler(object):
"""Base class of learning rate scheduler."""
def __init__(self):
self.base_lr = 0.01
def __call__(self, iteration):
"""
Call to schedule current learning rate.
Parameters
----------
iteration: int
Current iteration count.
"""
raise NotImplementedError("must override this")
class FactorScheduler(LearningRateScheduler):
"""Reduce learning rate in factor.
Parameters
----------
step: int
Schedule learning rate after every round.
factor: float
Reduce learning rate factor.
"""
def __init__(self, step, factor=0.1):
super(FactorScheduler, self).__init__()
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor >= 1.0:
raise ValueError("Factor must be less than 1 to make lr reduce")
self.step = step
self.factor = factor
self.old_lr = self.base_lr
self.init = False
def __call__(self, iteration):
"""
Call to schedule current learning rate.
Parameters
----------
iteration: int
Current iteration count.
"""
if not self.init:
self.init = True
self.old_lr = self.base_lr
lr = self.base_lr * math.pow(self.factor, int(iteration / self.step))
if lr != self.old_lr:
self.old_lr = lr
logging.info("At Iteration [%d]: Swith to new learning rate %.5f",
iteration, lr)
return lr