blob: 2a0168b5c79db6486258185a7f154c9d8469d73e [file] [log] [blame]
import numpy as np
import mxnet as mx
import math
from mxnet.test_utils import *
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
decay_factor=(1 - 1e-8), **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay_factor = decay_factor
def create_state(self, index, weight):
"""Create additional optimizer state: mean, variance
Parameters
----------
weight : NDArray
The weight data
"""
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance
def update(self, index, weight, grad, state):
"""Update the parameters.
Parameters
----------
index : int
An unique integer key used to index the parameters
weight : NDArray
weight ndarray
grad : NDArray
grad ndarray
state : NDArray or other objects returned by init_state
The auxiliary state used in optimization.
"""
lr = self._get_lr(index)
self._update_count(index)
t = self._index_update_count[index]
mean, variance = state
grad *= self.rescale_grad
if self.clip_gradient is not None:
mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad)
mean *= self.beta1
mean += grad * (1. - self.beta1)
variance *= self.beta2
variance += (1 - self.beta2) * mx.nd.square(grad, out=grad)
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
lr *= math.sqrt(coef2)/coef1
weight -= lr*mean/(mx.nd.sqrt(variance) + self.epsilon)
wd = self._get_wd(index)
if wd > 0.:
weight[:] -= (lr * wd) * weight
def test_lr_wd_mult():
data = mx.sym.Variable('data')
bias = mx.sym.Variable('fc1_bias', lr_mult=1.0)
fc1 = mx.sym.FullyConnected(data=data, bias=bias, name='fc1', num_hidden=10, lr_mult=0)
fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=10, wd_mult=0.5)
mod = mx.mod.Module(symbol=fc2, label_names=None)
mod.bind(data_shapes=[('data', (5,10))])
mod.init_params(initializer=mx.init.Uniform(1.0))
mod.init_optimizer(optimizer_params={'learning_rate': 1.0})
args1, _ = mod.get_params()
args1 = {k: v.asnumpy() for k, v in args1.items()}
mod.forward(mx.io.DataBatch(data=[mx.random.uniform(low=-1.0, high=1.0, shape=(5,10))], label=None), is_train=True)
mod.backward(mod.get_outputs())
mod.update()
args2, _ = mod.get_params()
args2 = {k: v.asnumpy() for k, v in args2.items()}
assert mod._optimizer.lr_mult == {'fc1_bias': 1.0, 'fc1_weight': 0.0}
assert mod._optimizer.wd_mult == {'fc2_bias': 0.5, 'fc2_weight': 0.5, 'fc1_bias': 0.0}
assert mx.test_utils.almost_equal(args1['fc1_weight'], args2['fc1_weight'], 1e-10)
assert not mx.test_utils.almost_equal(args1['fc1_bias'], args2['fc1_bias'], 1e-1)
assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1)
def compare_optimizer(opt1, opt2, shape):
w1 = mx.random.uniform(shape=shape)
g1 = mx.random.uniform(shape=shape)
w2 = w1.copyto(mx.cpu())
g2 = g1.copyto(mx.cpu())
state1 = opt1.create_state(0, w1)
state2 = opt2.create_state(0, w2)
for s1, s2, in zip(state1, state2):
assert(same(s1.asnumpy(), s2.asnumpy()))
opt1.update(0, w1, g1, state1)
opt2.update(0, w2, g2, state2)
for s1, s2, in zip(state1, state2):
assert(reldiff(s1.asnumpy(), s2.asnumpy()) < 1e-5)
assert(reldiff(w1.asnumpy(), w2.asnumpy()) < 1e-5)
def test_adam():
mx.random.seed(0)
opt1 = PyAdam
opt2 = mx.optimizer.Adam
shape = (3, 4, 5)
kwargs = [{},
{'clip_gradient': 0.5},
{'clip_gradient': 0.1},
{'rescale_grad': 0.1}]
for kwarg in kwargs:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape)
if __name__ == '__main__':
test_adam()