squad multi-lamb
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 0297a1d..f358635 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -605,3 +605,27 @@
etas=etas,
name=name,
**kwargs)
+
+def multi_lamb_update(weights, grads, mean, var, temp_g, step_count,
+ out=None, name=None, num_tensors=0, **kwargs):
+ if not num_tensors:
+ num_tensors = len(weights)
+ temp_list = _flatten_list(zip(weights, grads, mean, var, temp_g))
+ return ndarray._internal._multi_lamb_update(*temp_list,
+ out=out,
+ num_tensors=num_tensors,
+ step_count=step_count,
+ name=name,
+ **kwargs)
+
+def multi_mp_lamb_update(weights, grads, mean, var, temp_g, weights32, step_count,
+ out=None, name=None, num_tensors=0, **kwargs):
+ if not num_tensors:
+ num_tensors = len(weights)
+ temp_list = _flatten_list(zip(weights, grads, mean, var, temp_g, weights32))
+ return ndarray._internal._multi_mp_lamb_update(*temp_list,
+ out=out,
+ num_tensors=num_tensors,
+ step_count=step_count,
+ name=name,
+ **kwargs)
diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py
index 42d418e..7f455f2 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -36,13 +36,14 @@
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2,
mp_lamb_update_phase1, mp_lamb_update_phase2)
+from ..ndarray.contrib import (multi_lamb_update, multi_mp_lamb_update)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array
__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
- 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
+ 'MultiLAMB', 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]
@@ -1044,6 +1045,95 @@
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)
+@register
+class MultiLAMB(Optimizer):
+ """multiLAMB optimizer.
+ """
+ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
+ lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
+ super(MultiLAMB, self).__init__(learning_rate=learning_rate, **kwargs)
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.epsilon = epsilon
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+ self.bias_correction = bias_correction
+ self.aggregate_num = max(1, min(50, int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "50"))))
+
+ def create_state(self, index, weight):
+ stype = weight.stype
+ dtype = weight.dtype
+ return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype), # mean
+ zeros(weight.shape, weight.context, dtype=dtype, stype=stype), # variance
+ zeros(weight.shape, weight.context, dtype=dtype, stype=stype)) # temp_g
+
+ def _update_impl(self, index, weights, grads, states, multi_precision=False):
+ step_count = []
+ if not isinstance(index, (tuple, list)):
+ weights = [weights]
+ grads = [grads]
+ states = [states]
+ self._update_count(index)
+ step_count.append(self._index_update_count[index])
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+ else:
+ for i, (weight, grad) in enumerate(zip(weights, grads)):
+ assert(isinstance(weight, NDArray))
+ assert(isinstance(grad, NDArray))
+ self._update_count(i)
+ step_count.append(self._index_update_count[i])
+ lr = self._get_lr(index[0])
+ wd = self._get_wd(index[0])
+
+ kwargs = {'learning_rate': lr, 'beta1': self.beta1, 'beta2': self.beta2,
+ 'epsilon': self.epsilon, 'wd': wd,
+ 'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
+ 'bias_correction': self.bias_correction,
+ 'rescale_grad': self.rescale_grad}
+
+ if self.clip_gradient:
+ kwargs['clip_gradient'] = self.clip_gradient
+
+ updated_tensors = 0
+ while updated_tensors < len(weights):
+ sidx = updated_tensors
+ eidx = min(updated_tensors + self.aggregate_num, len(weights))
+ if not multi_precision:
+ mean, var, temp_g = list(zip(*states[sidx:eidx]))
+ multi_lamb_update(weights[sidx:eidx],
+ grads[sidx:eidx],
+ mean, var, temp_g,
+ out=weights[sidx:eidx],
+ step_count=step_count[sidx:eidx],
+ **kwargs)
+ else:
+ mean_var_g = list(zip(*states[sidx:eidx]))[1]
+ temp = list(zip(*mean_var_g))
+ mean = temp[0]
+ var = temp[1]
+ temp_g = temp[2]
+ multi_mp_lamb_update(weights[sidx:eidx],
+ grads[sidx:eidx],
+ mean, var, temp_g,
+ list(zip(*states[sidx:eidx]))[0],
+ out=weights[sidx:eidx],
+ step_count=step_count[sidx:eidx],
+ **kwargs)
+
+ updated_tensors += self.aggregate_num
+
+ def update(self, index, weight, grad, state):
+ self._update_impl(index, weight, grad, state, multi_precision=False)
+
+ def update_multi_precision(self, index, weight, grad, state):
+ if not isinstance(index, (tuple, list)):
+ use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
+ else:
+ use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
+ self._update_impl(index, weight, grad, state,
+ multi_precision=use_multi_precision)
+
#
@register
class LBSGD(Optimizer):
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index abdf570..6eaec47 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -2141,36 +2141,60 @@
def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
- rtol=1e-4, atol=1e-5, compare_states=True):
+ rtol=1e-4, atol=1e-5, compare_states=True, ntensors=1):
"""Compare opt1 and opt2."""
- if w_stype == 'default':
- w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
- w1 = w2.copyto(default_context())
- elif w_stype in ('row_sparse', 'csr'):
- w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
- w1 = w2.copyto(default_context()).tostype('default')
+ if ntensors == 1:
+ if w_stype == 'default':
+ w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
+ w1 = w2.copyto(default_context())
+ elif w_stype in ('row_sparse', 'csr'):
+ w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
+ w1 = w2.copyto(default_context()).tostype('default')
+ else:
+ raise Exception("type not supported yet")
+ if g_stype == 'default':
+ g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
+ g1 = g2.copyto(default_context())
+ elif g_stype in ('row_sparse', 'csr'):
+ g2 = rand_ndarray(shape, g_stype, dtype=dtype)
+ g1 = g2.copyto(default_context()).tostype('default')
+ else:
+ raise Exception("type not supported yet")
+
+ state1 = opt1.create_state_multi_precision(0, w1)
+ state2 = opt2.create_state_multi_precision(0, w2)
+ if compare_states:
+ compare_ndarray_tuple(state1, state2)
+
+ opt1.update_multi_precision(0, w1, g1, state1)
+ opt2.update_multi_precision(0, w2, g2, state2)
+ if compare_states:
+ compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
+ assert_almost_equal(w1, w2, rtol=rtol, atol=atol)
else:
- raise Exception("type not supported yet")
- if g_stype == 'default':
- g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
- g1 = g2.copyto(default_context())
- elif g_stype in ('row_sparse', 'csr'):
- g2 = rand_ndarray(shape, g_stype, dtype=dtype)
- g1 = g2.copyto(default_context()).tostype('default')
- else:
- raise Exception("type not supported yet")
+ # test multi-tensor: Opt1 single-tensor, Opt2 multi-tensor
+ from copy import deepcopy
+ if not isinstance(shape, list):
+ w1 = [mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) for i in range(ntensors)]
+ g1 = [mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) for i in range(ntensors)]
+ else:
+ w1, g1 = [], []
+ for s in shape:
+ w1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
+ g1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
+ w1 = tuple(w1)
+ w2 = deepcopy(w1)
+ g1 = tuple(g1)
+ g2 = deepcopy(g1)
+ state2 = [opt2.create_state_multi_precision(0, w2[i]) for i in range(ntensors)]
- state1 = opt1.create_state_multi_precision(0, w1)
- state2 = opt2.create_state_multi_precision(0, w2)
- if compare_states:
- compare_ndarray_tuple(state1, state2)
-
- opt1.update_multi_precision(0, w1, g1, state1)
- opt2.update_multi_precision(0, w2, g2, state2)
- if compare_states:
- compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
- assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
-
+ opt2.update_multi_precision(list(range(ntensors)), w2, g2, state2)
+ for i in range(ntensors):
+ state1 = opt1.create_state_multi_precision(i, w1[i])
+ opt1.update_multi_precision(i, w1[i], g1[i], state1)
+ if compare_states:
+ compare_ndarray_tuple(state1, state2[i], rtol, atol)
+ compare_ndarray_tuple(w1[i], w2[i], rtol, atol)
def same_symbol_structure(sym1, sym2):
"""Compare two symbols to check if they have the same computation graph structure.
diff --git a/src/operator/contrib/multi_lamb-inl.h b/src/operator/contrib/multi_lamb-inl.h
new file mode 100644
index 0000000..107a4c9
--- /dev/null
+++ b/src/operator/contrib/multi_lamb-inl.h
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file multi_lamb-inl.h
+ * \brief vectorized lars coefficient computed from sums of squared weights and grads
+ * \author Moises Hernandez
+ */
+#ifndef MXNET_OPERATOR_CONTRIB_MULTI_LAMB_INL_H_
+#define MXNET_OPERATOR_CONTRIB_MULTI_LAMB_INL_H_
+
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <mxnet/op_attr_types.h>
+#include <mshadow/base.h>
+#include <nnvm/op.h>
+#include <nnvm/op_attr_types.h>
+#include <vector>
+#include "../operator_common.h"
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../tensor/init_op.h"
+#include "../tensor/util/tensor_util-inl.h"
+#include "multi_sum_sq-inl.h"
+
+namespace mxnet {
+namespace op {
+
+namespace multilamb {
+enum MultiLambUpdateResource {kTempSpace};
+} // namespace multilamb
+
+struct MultiLAMBParam : public dmlc::Parameter<MultiLAMBParam> {
+ float learning_rate;
+ float beta1;
+ float beta2;
+ float epsilon;
+ float wd;
+ float rescale_grad;
+ float lower_bound;
+ float upper_bound;
+ float clip_gradient;
+ bool bias_correction;
+ int num_tensors;
+ mxnet::Tuple<int> step_count;
+
+ DMLC_DECLARE_PARAMETER(MultiLAMBParam) {
+ DMLC_DECLARE_FIELD(learning_rate)
+ .set_default(0.001f)
+ .describe("Learning rate");
+ DMLC_DECLARE_FIELD(beta1)
+ .set_default(0.9f)
+ .describe("Exponential decay rate for the first moment estimates.");
+ DMLC_DECLARE_FIELD(beta2)
+ .set_default(0.999f)
+ .describe("Exponential decay rate for the second moment estimates.");
+ DMLC_DECLARE_FIELD(epsilon)
+ .set_default(1e-6f)
+ .describe("Small value to avoid division by 0.");
+ DMLC_DECLARE_FIELD(wd)
+ .set_default(0.0f)
+ .describe("Weight decay augments the objective function with a "
+ "regularization term that penalizes large weights. "
+ "The penalty scales with the square of the magnitude of each weight.");
+ DMLC_DECLARE_FIELD(rescale_grad)
+ .set_default(1.0f)
+ .describe("Gradient rescaling factor");
+ DMLC_DECLARE_FIELD(lower_bound)
+ .set_default(1e-3f)
+ .describe("Lower limit of norm of weight.");
+ DMLC_DECLARE_FIELD(upper_bound)
+ .set_default(10.0f)
+ .describe("Upper limit of norm of weight.");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(bias_correction)
+ .set_default(false)
+ .describe("Whether to use bias correction.");
+ DMLC_DECLARE_FIELD(step_count)
+ .describe("Step count for each tensor");
+ DMLC_DECLARE_FIELD(num_tensors)
+ .set_default(1)
+ .describe("Number of tensors");
+ }
+};
+
+template<typename ParamType, int input_stride>
+inline bool MultiLAMB_InferShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), input_stride * param.num_tensors);
+ CHECK_EQ(out_attrs->size(), param.num_tensors);
+
+ bool all_inferred = true;
+ auto& input_shapes = *in_attrs;
+ auto& output_shapes = *out_attrs;
+
+ CHECK_EQ(param.step_count.ndim(), param.num_tensors)
+ << "Number of step counts is inconsistent with num_weights."
+ << "Expected number of step counts: "
+ << param.num_tensors << ", and got " << param.step_count.ndim();
+
+ // Weights, gradients, mean and variance
+ for (int i = 0; i < param.num_tensors; ++i) {
+ mxnet::ShapeVector input_vec;
+ mxnet::ShapeVector output_vec({output_shapes[i]});
+ for (int j = 0; j < input_stride; ++j) {
+ input_vec.push_back(input_shapes[i * input_stride + j]);
+ }
+ all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs, &input_vec, &output_vec);
+ }
+ return all_inferred;
+}
+
+template <typename ParamType, int input_stride>
+inline bool MP_MultiLAMB_InferType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), input_stride * param.num_tensors);
+ CHECK_EQ(out_attrs->size(), param.num_tensors);
+
+ bool all_inferred = true;
+ auto& input_types = *in_attrs;
+ auto& output_types = *out_attrs;
+
+ // weights, gradients
+ for (int i = 0; i < param.num_tensors; ++i) {
+ std::vector<int> input_vec;
+ std::vector<int> output_vec({output_types[i]});
+ for (int j = 0; j < 2; ++j) {
+ input_vec.push_back(input_types[i * input_stride + j]);
+ }
+ all_inferred = all_inferred &&
+ ElemwiseType<2, 1>(attrs, &input_vec, &output_vec);
+ }
+
+ // mean, var, temp_g, weights32 (master copies of weights)
+ for (int i = 0; i < param.num_tensors; ++i) {
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i + 2, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i + 3, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i + 4, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1, mshadow::kFloat32);
+ }
+ return all_inferred;
+}
+
+template<typename T>
+class LAMB_type_identity {
+ public:
+ using type = T;
+};
+
+template<typename T>
+class LAMB_single_precision {
+ public:
+ using type = float;
+};
+
+template<typename DType, typename MPDType>
+struct MultiLAMBKernelParam {
+ static const int N = 50;
+ size_t ntensors;
+ size_t max_size;
+ size_t total_size;
+ size_t sizes[N];
+ DType* weights[N];
+ DType* grads[N];
+ MPDType* mean[N];
+ MPDType* var[N];
+ MPDType* temp_g[N];
+ MPDType* weights32[N];
+ DType* out_data[N];
+ int step_count[N];
+
+ // gpu
+ int chunk_size = 65536;
+ int nchunks;
+};
+
+template<typename xpu,
+ typename DType,
+ typename MPDType,
+ typename ParamType = MultiLAMBParam,
+ int input_stride = 5>
+void FillMultiLAMBKernelParam(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<TBlob> &outputs,
+ MultiLAMBKernelParam<DType, MPDType> *multi_param) {
+ const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
+ mxnet_op::Stream<xpu>* s = ctx.get_stream<xpu>();
+
+ multi_param->ntensors = p.num_tensors;
+ multi_param->total_size = 0;
+ multi_param->max_size = 0;
+ multi_param->nchunks = 0;
+
+ constexpr bool isSame = std::is_same<DType, MPDType>::value;
+ for (size_t i = 0; i < multi_param->ntensors; ++i) {
+ const auto idx = i * input_stride;
+ multi_param->sizes[i] = inputs[idx].shape_.Size();
+ multi_param->total_size += multi_param->sizes[i];
+ if (multi_param->max_size < multi_param->sizes[i])
+ multi_param->max_size = multi_param->sizes[i];
+
+ multi_param->weights[i] = inputs[idx].FlatTo2D<xpu, DType>(s).dptr_;
+ multi_param->grads[i] = inputs[idx + 1].FlatTo2D<xpu, DType>(s).dptr_;
+ multi_param->mean[i] = inputs[idx + 2].FlatTo2D<xpu, MPDType>(s).dptr_;
+ multi_param->var[i] = inputs[idx + 3].FlatTo2D<xpu, MPDType>(s).dptr_;
+ multi_param->temp_g[i] = inputs[idx + 4].FlatTo2D<xpu, MPDType>(s).dptr_;
+ // if mixed precision, then the last input in a set
+ // is 32-bit master copy of the weights
+ if (!isSame)
+ multi_param->weights32[i] = inputs[idx + input_stride - 1].FlatTo2D<xpu, MPDType>(s).dptr_;
+ multi_param->out_data[i] = outputs[i].FlatTo2D<xpu, DType>(s).dptr_;
+ multi_param->nchunks += (multi_param->sizes[i] + multi_param->chunk_size - 1)
+ /multi_param->chunk_size;
+ }
+ memcpy(multi_param->step_count, p.step_count.begin(), multi_param->ntensors * sizeof(int));
+}
+
+using namespace mxnet_op;
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<cpu>* s);
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<gpu>* s);
+
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<cpu>* s);
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<gpu>* s);
+
+template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
+inline void multiLAMB(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ auto param = nnvm::get<MultiLAMBParam>(attrs.parsed);
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+
+ MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ using MPDType = typename MPTypeChooser<DType>::type;
+ MultiLAMBKernelParam<DType, MPDType> kernel_params;
+ FillMultiLAMBKernelParam<xpu, DType, MPDType, MultiLAMBParam, input_stride>
+ (attrs, ctx, inputs, outputs, &kernel_params);
+
+ // create vector of TBlob with all the weights contiguous
+ std::vector<TBlob> weights;
+ for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+ weights.emplace_back(inputs[index*input_stride]);
+ }
+ // create vector of TBlob with all the temp_g contiguous
+ std::vector<TBlob> temp_g;
+ for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+ temp_g.emplace_back(inputs[index*input_stride+4]);
+ }
+
+ // Calculate amount of temporary storage
+ size_t workspace_size = 2 * kernel_params.ntensors * sizeof(float) +
+ 2 * kernel_params.nchunks * sizeof(int);
+
+ // Request temporary storage
+ Tensor<xpu, 1, char> workspace =
+ ctx.requested[multilamb::kTempSpace].get_space_typed<xpu, 1, char>(
+ Shape1(workspace_size), s);
+
+ // Create tensors
+ size_t pos_wspace = 0;
+ Tensor<xpu, 1, float> r1(reinterpret_cast<float*>(&workspace[pos_wspace]),
+ Shape1(kernel_params.ntensors), s);
+ pos_wspace += kernel_params.ntensors * sizeof(float);
+ Tensor<xpu, 1, float> r2(reinterpret_cast<float*>(&workspace[pos_wspace]),
+ Shape1(kernel_params.ntensors), s);
+ pos_wspace += kernel_params.ntensors * sizeof(float);
+ Tensor<xpu, 1, int> block_to_tensor(reinterpret_cast<int*>(&workspace[pos_wspace]),
+ Shape1(kernel_params.nchunks), s);
+ pos_wspace += kernel_params.nchunks * sizeof(int);
+ Tensor<xpu, 1, int> block_to_chunk(reinterpret_cast<int*>(&workspace[pos_wspace]),
+ Shape1(kernel_params.nchunks), s);
+
+ MultiSumSqRun<xpu>(weights, kernel_params.ntensors, r1.dptr_, s);
+ call_kernel1<MPDType, DType>(s, kernel_params, param, block_to_tensor.dptr_,
+ block_to_chunk.dptr_);
+ MultiSumSqRun<xpu>(temp_g, kernel_params.ntensors, r2.dptr_, s);
+ call_kernel2<MPDType, DType>(s, kernel_params, param, r1.dptr_, r2.dptr_,
+ block_to_tensor.dptr_, block_to_chunk.dptr_,
+ req[0]);
+ });
+}
+
+template<typename xpu, bool MP>
+inline void multiLAMBUpdate(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ if (!MP) {
+ multiLAMB<xpu, LAMB_type_identity, 5>
+ (attrs, ctx, inputs, req, outputs);
+ } else {
+ multiLAMB<xpu, LAMB_single_precision, 6>
+ (attrs, ctx, inputs, req, outputs);
+ }
+}
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_OPERATOR_CONTRIB_MULTI_LAMB_INL_H_
diff --git a/src/operator/contrib/multi_lamb.cc b/src/operator/contrib/multi_lamb.cc
new file mode 100644
index 0000000..ebfa7a2
--- /dev/null
+++ b/src/operator/contrib/multi_lamb.cc
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file multi_lamb.cc
+ * \brief vectorized LAMB coefficient computed from sums of squared weights and grads
+ * \author Moises Hernandez
+ */
+
+#include "./multi_lamb-inl.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename MPDType, bool has_mixed_precision>
+struct MultiLAMB_step1_kernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i,
+ const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+ const float learning_rate,
+ const float beta1, const float beta2,
+ const float epsilon,
+ const float wd,
+ const float clip_gradient,
+ const bool bias_correction,
+ const float rescale_grad) {
+ using namespace mshadow_op;
+ for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+ if ((size_t)i < kernel_params.sizes[index]) {
+ MPDType w = has_mixed_precision ? kernel_params.weights32[index][i]:
+ MPDType(kernel_params.weights[index][i]);
+ MPDType scaled_grad = static_cast<MPDType>(kernel_params.grads[index][i])*rescale_grad;
+ if (clip_gradient >= 0.0f)
+ scaled_grad = mshadow_op::clip::Map(scaled_grad, static_cast<MPDType>(clip_gradient));
+ MPDType mean = static_cast<MPDType>(beta1) * kernel_params.mean[index][i] +
+ (static_cast<MPDType>(1.0f) - static_cast<MPDType>(beta1)) * scaled_grad;
+ MPDType var = static_cast<MPDType>(beta2) * kernel_params.var[index][i] +
+ (static_cast<MPDType>(1.0f) - static_cast<MPDType>(beta2)) * scaled_grad * scaled_grad;
+ kernel_params.mean[index][i] = mean;
+ kernel_params.var[index][i] = var;
+
+ MPDType g;
+ if (bias_correction) {
+ MPDType mean_hat = mean / (static_cast<MPDType>(1.0f) -
+ power::Map(static_cast<MPDType>(beta1),
+ static_cast<MPDType>(kernel_params.step_count[index])));
+ MPDType var_hat = var / (static_cast<MPDType>(1.0f) -
+ power::Map(static_cast<MPDType>(beta2),
+ static_cast<MPDType>(kernel_params.step_count[index])));
+ g = mean_hat / (sqrt(var_hat) + epsilon) + wd * w;
+ } else {
+ g = mean / (sqrt(var) + epsilon) + wd * w;
+ }
+ kernel_params.temp_g[index][i] = g;
+ }
+ }
+ }
+};
+
+template<typename MPDType, bool has_mixed_precision>
+struct MultiLAMB_step2_kernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i,
+ const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+ const float* sumSqWeigths,
+ const float* sumSqtemp_g,
+ const float learning_rate,
+ const float lower_bound,
+ const float upper_bound,
+ const OpReqType req) {
+ for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+ if ((size_t)i < kernel_params.sizes[index]) {
+ MPDType w = has_mixed_precision ? kernel_params.weights32[index][i]:
+ MPDType(kernel_params.weights[index][i]);
+ float r1 = sqrt(sumSqWeigths[index]);
+ float r2 = sqrt(sumSqtemp_g[index]);
+ r1 = std::min(std::max(r1, lower_bound), upper_bound);
+
+ // calculate lamb_trust_ratio
+ MPDType r;
+ if (r1 == 0.0f || r2 == 0.0f)
+ r = 1.0f;
+ else
+ r = r1/r2;
+
+ MPDType lr_adjusted = learning_rate * r;
+ w -= lr_adjusted * kernel_params.temp_g[index][i];
+
+ // update weights
+ if (has_mixed_precision)
+ kernel_params.weights32[index][i] = w;
+ KERNEL_ASSIGN(kernel_params.out_data[index][i], req, w);
+ }
+ }
+ }
+};
+
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<cpu>* s,
+ const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+ const MultiLAMBParam ¶m,
+ int* block_to_tensor,
+ int* block_to_chunk) {
+ Kernel<MultiLAMB_step1_kernel<MPDType, !std::is_same<DType, MPDType>::value>, cpu>::
+ Launch(s, kernel_params.max_size,
+ kernel_params,
+ param.learning_rate,
+ param.beta1, param.beta2,
+ param.epsilon,
+ param.wd,
+ param.clip_gradient,
+ param.bias_correction,
+ param.rescale_grad);
+}
+
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<cpu>* s,
+ const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+ const MultiLAMBParam ¶m,
+ float* r1, float* r2,
+ int* block_to_tensor,
+ int* block_to_chunk,
+ const OpReqType req) {
+ Kernel<MultiLAMB_step2_kernel<MPDType, !std::is_same<DType, MPDType>::value>, cpu>::
+ Launch(s, kernel_params.max_size,
+ kernel_params,
+ r1, r2,
+ param.learning_rate,
+ param.lower_bound, param.upper_bound,
+ req);
+}
+
+DMLC_REGISTER_PARAMETER(MultiLAMBParam);
+
+std::vector<std::string> LAMBParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
+ std::vector<std::string> ret;
+ for (uint32_t i = 0; i < num_args; ++i) {
+ const auto idx = std::to_string(i);
+ for (size_t j = 0; j < nParams; ++j)
+ ret.push_back(std::string(pName[i]) + idx);
+ }
+
+ return ret;
+}
+
+inline uint32_t num_tensors(const nnvm::NodeAttrs& attrs) {
+ return static_cast<uint32_t>(dmlc::get<MultiLAMBParam>(attrs.parsed).num_tensors);
+}
+
+NNVM_REGISTER_OP(_multi_lamb_update)
+.describe(R"code(Compute the LAMB coefficients of multiple weights and grads"
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ return num_tensors(attrs) * 5;
+ })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ return num_tensors(attrs);
+ })
+.set_attr_parser(ParamParser<MultiLAMBParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MultiLAMB_InferShape<MultiLAMBParam, 5>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "temp_g"};
+ return LAMBParamToVector(num_tensors(attrs), paramName, sizeof(paramName)/sizeof(paramName[0]));
+ })
+// mutable: mean, var, temp_g,
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ std::vector<uint32_t> ret;
+ const auto iMax = num_tensors(attrs);
+ for (size_t i = 0; i < iMax; ++i) {
+ ret.push_back(i * 5 + 2);
+ ret.push_back(i * 5 + 3);
+ ret.push_back(i * 5 + 4);
+ }
+ return ret;
+ })
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<FCompute>("FCompute<cpu>", multiLAMBUpdate<cpu, false>)
+.add_argument("data", "NDArray-or-Symbol[]", "data")
+.add_arguments(MultiLAMBParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_multi_mp_lamb_update)
+.describe(R"code(Compute the LAMB coefficients of multiple weights and grads with Mix Precision"
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ return num_tensors(attrs) * 6;
+ })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ return num_tensors(attrs);
+ })
+.set_attr_parser(ParamParser<MultiLAMBParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MultiLAMB_InferShape<MultiLAMBParam, 6>)
+.set_attr<nnvm::FInferType>("FInferType", MP_MultiLAMB_InferType<MultiLAMBParam, 6>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "temp_g", "weight32_"};
+ return LAMBParamToVector(num_tensors(attrs), paramName, sizeof(paramName)/sizeof(paramName[0]));
+ })
+// mutable: mean, var, temp_g, weights32
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ std::vector<uint32_t> ret;
+ const auto iMax = num_tensors(attrs);
+ for (size_t i = 0; i < iMax; ++i) {
+ ret.push_back(i * 6 + 2);
+ ret.push_back(i * 6 + 3);
+ ret.push_back(i * 6 + 4);
+ ret.push_back(i * 6 + 5);
+ }
+ return ret;
+ })
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<FCompute>("FCompute<cpu>", multiLAMBUpdate<cpu, true>)
+.add_argument("data", "NDArray-or-Symbol[]", "data")
+.add_arguments(MultiLAMBParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/contrib/multi_lamb.cu b/src/operator/contrib/multi_lamb.cu
new file mode 100644
index 0000000..ed357d8
--- /dev/null
+++ b/src/operator/contrib/multi_lamb.cu
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file multi_lamb.cu
+ * \brief vectorized lamb coefficient computed from sums of squared weights and grads
+ * \author Moises Hernandez
+ */
+
+#include "./multi_lamb-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define BLOCK_SIZE_LAMB 512
+#define ILP_LAMB 4
+
+template<bool has_mixed_precision, typename MPDType, typename DType>
+__global__ void kernel_step1(const MultiLAMBKernelParam<DType, MPDType> kernel_params,
+ const float learning_rate,
+ const float beta1, const float beta2,
+ const MPDType beta3, const MPDType beta4,
+ const float epsilon,
+ const float wd,
+ const float clip_gradient,
+ const bool bias_correction,
+ const float rescale_grad,
+ int* block_to_tensor,
+ int* block_to_chunk) {
+ const int tensorID = block_to_tensor[blockIdx.x];
+ const int chunckID = block_to_chunk[blockIdx.x];
+ const int startPos = chunckID * kernel_params.chunk_size + threadIdx.x;
+ const int stopPos = chunckID * kernel_params.chunk_size + kernel_params.chunk_size;
+
+ MPDType biascorrection1, biascorrection2;
+ if (bias_correction) {
+ biascorrection1 = 1.0f - std::pow(beta1, kernel_params.step_count[tensorID]);
+ biascorrection2 = 1.0f - std::pow(beta2, kernel_params.step_count[tensorID]);
+ } else {
+ biascorrection1 = 1.0f;
+ biascorrection2 = 1.0f;
+ }
+
+ MPDType r_weight[ILP_LAMB];
+ MPDType r_grad[ILP_LAMB];
+ MPDType r_mean[ILP_LAMB];
+ MPDType r_var[ILP_LAMB];
+ MPDType r_g[ILP_LAMB];
+
+ for (size_t i=startPos; i < stopPos && i < kernel_params.sizes[tensorID];
+ i+= blockDim.x*ILP_LAMB) {
+#pragma unroll
+ for (int ii = 0; ii < ILP_LAMB; ii++) {
+ int load_pos = i + ii*blockDim.x;
+ if (load_pos < stopPos && load_pos < kernel_params.sizes[tensorID]) {
+ r_weight[ii] = has_mixed_precision ? kernel_params.weights32[tensorID][load_pos]:
+ static_cast<MPDType>(kernel_params.weights[tensorID][load_pos]);
+ r_grad[ii] = static_cast<MPDType>(kernel_params.grads[tensorID][load_pos]);
+ r_mean[ii] = kernel_params.mean[tensorID][load_pos];
+ r_var[ii] = kernel_params.var[tensorID][load_pos];
+ } else {
+ r_weight[ii] = static_cast<MPDType>(0);
+ r_grad[ii] = static_cast<MPDType>(0);
+ r_mean[ii] = static_cast<MPDType>(0);
+ r_var[ii] = static_cast<MPDType>(0);
+ }
+ }
+#pragma unroll
+ for (int ii = 0; ii < ILP_LAMB; ii++) {
+ r_grad[ii] = r_grad[ii] * rescale_grad;
+ if (clip_gradient >= 0.0f)
+ r_grad[ii] = max(min(r_grad[ii], clip_gradient), -clip_gradient);
+ r_mean[ii] = static_cast<MPDType>(beta1) * r_mean[ii] + beta3 * r_grad[ii];
+ r_var[ii] = static_cast<MPDType>(beta2) * r_var[ii] + beta4 * r_grad[ii] * r_grad[ii];
+ r_g[ii] = (r_mean[ii] / biascorrection1) / (sqrtf(r_var[ii] / biascorrection2) + epsilon)
+ + wd * r_weight[ii];
+ }
+#pragma unroll
+ for (int ii = 0; ii < ILP_LAMB; ii++) {
+ int store_pos = i + ii*blockDim.x;
+ if (store_pos < stopPos && store_pos < kernel_params.sizes[tensorID]) {
+ kernel_params.mean[tensorID][store_pos] = r_mean[ii];
+ kernel_params.var[tensorID][store_pos] = r_var[ii];
+ kernel_params.temp_g[tensorID][store_pos] = r_g[ii];
+ }
+ }
+ }
+}
+
+template<bool has_mixed_precision, typename MPDType, typename DType>
+__global__ void kernel_step2(const MultiLAMBKernelParam<DType, MPDType> kernel_params,
+ const float* sumSqWeigths,
+ const float* sumSqtemp_g,
+ const float learning_rate,
+ const float lower_bound,
+ const float upper_bound,
+ int* block_to_tensor,
+ int* block_to_chunk,
+ const OpReqType req) {
+ const int tensorID = block_to_tensor[blockIdx.x];
+ const int chunckID = block_to_chunk[blockIdx.x];
+ const int startPos = chunckID * kernel_params.chunk_size + threadIdx.x;
+ const int stopPos = chunckID * kernel_params.chunk_size + kernel_params.chunk_size;
+
+ MPDType r1 = sqrtf(sumSqWeigths[tensorID]);
+ MPDType r2 = sqrtf(sumSqtemp_g[tensorID]);
+ r1 = min(max(r1, lower_bound), upper_bound);
+
+ MPDType lr_adjusted;
+ if (r1 == 0.0f || r2 == 0.0f)
+ lr_adjusted = learning_rate;
+ else
+ lr_adjusted = learning_rate * r1/r2;
+
+ MPDType r_weight[ILP_LAMB];
+ MPDType r_g[ILP_LAMB];
+
+ for (size_t i=startPos; i < stopPos && i < kernel_params.sizes[tensorID];
+ i+= blockDim.x*ILP_LAMB) {
+#pragma unroll
+ for (int ii = 0; ii < ILP_LAMB; ii++) {
+ int load_pos = i + ii*blockDim.x;
+ if (load_pos < stopPos&& load_pos < kernel_params.sizes[tensorID]) {
+ r_weight[ii] = has_mixed_precision ? kernel_params.weights32[tensorID][load_pos]:
+ static_cast<MPDType>(kernel_params.weights[tensorID][load_pos]);
+ r_g[ii] = kernel_params.temp_g[tensorID][load_pos];
+ }
+ }
+#pragma unroll
+ for (int ii = 0; ii < ILP_LAMB; ii++) {
+ r_weight[ii] -= lr_adjusted * r_g[ii];
+ }
+#pragma unroll
+ for (int ii = 0; ii < ILP_LAMB; ii++) {
+ int store_pos = i + ii*blockDim.x;
+ if (store_pos < stopPos && store_pos < kernel_params.sizes[tensorID]) {
+ if (has_mixed_precision)
+ kernel_params.weights32[tensorID][store_pos] = r_weight[ii];
+ KERNEL_ASSIGN(kernel_params.out_data[tensorID][store_pos], req, r_weight[ii]);
+ }
+ }
+ }
+}
+
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<gpu>* s,
+ const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+ const MultiLAMBParam ¶m,
+ int* block_to_tensor,
+ int* block_to_chunk) {
+ int nblocks = kernel_params.nchunks;
+ int* host_block2tensor = reinterpret_cast<int*>(malloc(kernel_params.nchunks*sizeof(int)));
+ int* host_block2chunk = reinterpret_cast<int*>(malloc(kernel_params.nchunks*sizeof(int)));
+ int chunkID = 0;
+ for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+ int current_chunk = 0;
+ for (size_t j = 0; j < kernel_params.sizes[index]; j+=kernel_params.chunk_size) {
+ host_block2tensor[chunkID] = index;
+ host_block2chunk[chunkID] = current_chunk;
+ current_chunk++;
+ chunkID++;
+ }
+ }
+ cudaMemcpyAsync(block_to_tensor, host_block2tensor, kernel_params.nchunks*sizeof(int),
+ cudaMemcpyHostToDevice, Stream<gpu>::GetStream(s));
+ cudaMemcpyAsync(block_to_chunk, host_block2chunk, kernel_params.nchunks*sizeof(int),
+ cudaMemcpyHostToDevice, Stream<gpu>::GetStream(s));
+
+ bool has_mixed_precision = !std::is_same<DType, MPDType>::value;
+ MPDType beta3 = 1.0 - param.beta1;
+ MPDType beta4 = 1.0 - param.beta2;
+
+ if (has_mixed_precision)
+ kernel_step1<true><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+ kernel_params,
+ param.learning_rate,
+ param.beta1, param.beta2,
+ beta3, beta4,
+ param.epsilon, param.wd,
+ param.clip_gradient,
+ param.bias_correction,
+ param.rescale_grad,
+ block_to_tensor,
+ block_to_chunk);
+ else
+ kernel_step1<false><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+ kernel_params,
+ param.learning_rate,
+ param.beta1, param.beta2,
+ beta3, beta4,
+ param.epsilon, param.wd,
+ param.clip_gradient,
+ param.bias_correction,
+ param.rescale_grad,
+ block_to_tensor,
+ block_to_chunk);
+ }
+
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<gpu>* s,
+ const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+ const MultiLAMBParam ¶m,
+ float* r1, float* r2,
+ int* block_to_tensor,
+ int* block_to_chunk,
+ const OpReqType req) {
+ size_t nblocks = kernel_params.nchunks;
+ bool has_mixed_precision = !std::is_same<DType, MPDType>::value;
+ if (has_mixed_precision)
+ kernel_step2<true><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+ kernel_params,
+ r1, r2,
+ param.learning_rate,
+ param.lower_bound, param.upper_bound,
+ block_to_tensor,
+ block_to_chunk,
+ req);
+ else
+ kernel_step2<false><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+ kernel_params,
+ r1, r2,
+ param.learning_rate,
+ param.lower_bound, param.upper_bound,
+ block_to_tensor,
+ block_to_chunk,
+ req);
+}
+
+
+NNVM_REGISTER_OP(_multi_lamb_update)
+.set_attr<FCompute>("FCompute<gpu>", multiLAMBUpdate<gpu, false>);
+
+NNVM_REGISTER_OP(_multi_mp_lamb_update)
+.set_attr<FCompute>("FCompute<gpu>", multiLAMBUpdate<gpu, true>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 8f86bdf..44487d9 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -775,6 +775,104 @@
dtype, w_stype='default', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)
+# MultiLAMB
+class PyMultiLAMB(mx.optimizer.Optimizer):
+ """python reference implemenation of lamb"""
+ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
+ lower_bound=1e-3, upper_bound=10.0, bias_correction=False,
+ multi_precision=False, clip_gradient=-1, **kwargs):
+ super(PyMultiLAMB, self).__init__(learning_rate=learning_rate, **kwargs)
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.epsilon = epsilon
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+ self.bias_correction = bias_correction
+ self.multi_precision = multi_precision
+ self.clip_gradient = clip_gradient
+
+ def create_state(self, index, weight):
+ return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype),
+ mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype),
+ mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype))
+
+ def update(self, index, weight, grad, state):
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+ index_update_count = self._index_update_count[index]
+
+ grad *= self.rescale_grad
+ if self.clip_gradient >= 0:
+ grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+
+ mean, var, temp_g = state
+ mean[:] = self.beta1 * mean + (1. - self.beta1) * grad.astype(mean.dtype)
+ var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad.astype(mean.dtype))
+
+ r1 = weight.norm()
+ if self.lower_bound:
+ r1 = mx.nd.maximum(r1, self.lower_bound)
+ if self.upper_bound:
+ r1 = mx.nd.minimum(r1, self.upper_bound)
+
+ if self.bias_correction:
+ mean_hat = mean / (1. - mx.nd.power(self.beta1, index_update_count))
+ var_hat = var / (1. - mx.nd.power(self.beta2, index_update_count))
+ else:
+ mean_hat = mean
+ var_hat = var
+
+ temp_g[:] = mean_hat / (mx.nd.sqrt(var_hat) + self.epsilon) + wd * weight
+ r2 = temp_g.norm()
+ # calculate lamb_trust_ratio
+ r = 1. if r1 == 0. or r2 == 0. else r1 / r2
+ lr *= r
+ weight[:] -= lr * temp_g
+
+@with_seed()
+def test_multilamb():
+ opt1 = PyMultiLAMB
+ opt2 = mx.optimizer.MultiLAMB
+ #set_default_context(mx.gpu(0))
+
+ # shapes as Bert-large
+ dims_x = [1024, 4096, 1024, 1024]
+ dims_y = [1, 1, 1024, 4096]
+ dims_occurrences = [9, 1, 4, 2]
+ nlayers = 4 # 24
+ extra_dims_x=[30522, 512, 30522]
+ extra_dims_y=[1, 1024, 1024]
+ shapes=[]
+ for l in range(nlayers):
+ for i, (dx,dy) in enumerate(zip(dims_x, dims_y)):
+ for j in range(dims_occurrences[i]):
+ shapes.append((dx,dy))
+ for dx,dy in zip(extra_dims_x, extra_dims_y):
+ shapes.append((dx,dy))
+
+ cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+ rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+ wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+ bias_options = [{'bias_correction': False}, {'bias_correction': True}]
+
+ for dtype in [np.float16, np.float32, np.float64]:
+ for cg_option in cg_options:
+ for rg_option in rg_options:
+ for wd_option in wd_options:
+ for bias_option in bias_options:
+ kwarg = {}
+ kwarg.update(cg_option)
+ kwarg.update(rg_option)
+ kwarg.update(wd_option)
+ kwarg.update(bias_option)
+ if (dtype == np.float16):
+ kwarg.update({'multi_precision': True})
+ atol = 1e-3
+ rtol = 1e-6
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg), shapes, dtype,
+ rtol=rtol, atol=atol, ntensors=len(shapes))
+
# AdaMax
class PyAdamax(mx.optimizer.Optimizer):
"""The python reference of AdaMax optimizer.