blob: 675cc94c64f134585428a6f5e2546c9b3e71dcd4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
import numpy as np
import mxnet as mx
from mxnet.test_utils import *
# * GroupAdaGrad
class PyGroupAdaGrad(mx.optimizer.Optimizer):
"""The python reference of Group AdaGrad optimizer.
Parameters
----------
eps: float, optional
Small value to avoid division by 0.
"""
def __init__(self, eps=1e-5, **kwargs):
super(PyGroupAdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps
def create_state(self, index, weight):
assert len(weight.shape) == 2
history = mx.nd.zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
return history
def update(self, index, weight, grad, state):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0
history = state
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mx.nd.mean(mx.nd.square(grad), axis=1, keepdims=True)
div = lr * grad / mx.nd.sqrt(history + self.float_stable_eps)
weight[:] -= div
def test_group_adagrad():
mx.random.seed(0)
opt1 = PyGroupAdaGrad
opt2 = mx.optimizer.contrib.GroupAdaGrad
shape = (3, 4)
eps_options = [{}, {'eps': 1e-8}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
for dtype in [np.float32]:
for options in itertools.product(eps_options, cg_options, rg_options):
kwarg = dict(wd=0.0)
for option in options:
kwarg.update(option)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
compare_states=False)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
w_stype='row_sparse',
g_stype='row_sparse',
compare_states=False)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
g_stype='row_sparse',
compare_states=False)
def test_adamw():
shape = (3, 4)
weight = mx.nd.random.uniform(shape=shape)
weight_ref = weight.copy()
grad = mx.nd.random.uniform(shape=shape)
m = mx.nd.random.uniform(shape=shape)
v = mx.nd.random.uniform(shape=shape)
rescale_grad = mx.nd.array([10])
eta, lr, wd, epsilon = 1, 1, 0, 1e-8
beta1, beta2 = 0.9, 0.999
kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
'beta1': beta1, 'beta2': beta2}
# update is skipped for rescale = nan scalar
mx.nd.contrib.adamw_update(weight, grad, m, v,
np.nan, out=weight, **kwargs)
# weight remains unchanged
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
# update is skipped for rescale = 0
mx.nd.contrib.adamw_update(weight, grad, m, v,
rescale_grad * 0, out=weight, **kwargs)
# weight remains unchanged
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
# update is skipped for rescale = nan
mx.nd.contrib.adamw_update(weight, grad, m, v,
rescale_grad * np.nan, out=weight, **kwargs)
# weight remains unchanged
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
# update is skipped for rescale = inf
mx.nd.contrib.adamw_update(weight, grad, m, v,
rescale_grad * np.inf, out=weight, **kwargs)
# weight remains unchanged
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
# multi-precision update is skipped for rescale = nan
weight_fp16 = weight.astype('float16')
grad_fp16 = grad.astype('float16')
weight_fp16_ref = weight_fp16.copy()
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
rescale_grad * np.nan, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
# multi-precision update is skipped for rescale = nan scalar
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
np.nan, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
# multi-precision update is skipped for rescale = inf
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
rescale_grad * np.inf, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
# multi-precision update is skipped for rescale = 0
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
rescale_grad * 0, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
# reference normal update
grad_rescale = rescale_grad * grad
m_ref = beta1*m + (1-beta1)*grad_rescale
v_ref = beta2*v + (1-beta2)*(grad_rescale**2)
weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight * wd)
m_test = m.copy()
v_test = v.copy()
weight_test = weight.copy()
# op normal update
mx.nd.contrib.adamw_update(weight_test, grad, m_test, v_test,
rescale_grad, out=weight_test, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight_test.asnumpy())
mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_test.asnumpy())
mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_test.asnumpy())
# reference normal multi-precision update
m_fp32 = m.copy()
v_fp32 = v.copy()
weight_fp32 = weight.copy()
grad_rescale = rescale_grad * grad_fp16.astype('float32')
m_ref = beta1*m_fp32 + (1-beta1)*grad_rescale
v_ref = beta2*v_fp32 + (1-beta2)*(grad_rescale**2)
weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight * wd)
weight_fp16_ref = weight_ref.astype('float16')
# op normal multi-precision update
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m_fp32, v_fp32, weight_fp32,
rescale_grad, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_fp32.asnumpy())
mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_fp32.asnumpy())
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight_fp32.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())
if __name__ == '__main__':
import nose
nose.runmodule()