squad multi-lamb
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 0297a1d..f358635 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -605,3 +605,27 @@
                                                     etas=etas,
                                                     name=name,
                                                     **kwargs)
+
+def multi_lamb_update(weights, grads, mean, var, temp_g, step_count,
+                      out=None, name=None, num_tensors=0, **kwargs):
+    if not num_tensors:
+        num_tensors = len(weights)
+    temp_list = _flatten_list(zip(weights, grads, mean, var, temp_g))
+    return ndarray._internal._multi_lamb_update(*temp_list,
+                                                out=out,
+                                                num_tensors=num_tensors,
+                                                step_count=step_count,
+                                                name=name,
+                                                **kwargs)
+
+def multi_mp_lamb_update(weights, grads, mean, var, temp_g, weights32, step_count,
+                         out=None, name=None, num_tensors=0, **kwargs):
+    if not num_tensors:
+        num_tensors = len(weights)
+    temp_list = _flatten_list(zip(weights, grads, mean, var, temp_g, weights32))
+    return ndarray._internal._multi_mp_lamb_update(*temp_list,
+                                                   out=out,
+                                                   num_tensors=num_tensors,
+                                                   step_count=step_count,
+                                                   name=name,
+                                                   **kwargs)
diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py
index 42d418e..7f455f2 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -36,13 +36,14 @@
                        preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
                        preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2,
                        mp_lamb_update_phase1, mp_lamb_update_phase2)
+from ..ndarray.contrib import (multi_lamb_update, multi_mp_lamb_update)
 from ..ndarray import sparse
 from ..random import normal
 from ..util import is_np_array
 
 __all__ = [
     'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
-    'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
+    'MultiLAMB', 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
     'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
 ]
 
@@ -1044,6 +1045,95 @@
         self._update_impl(index, weight, grad, state,
                           multi_precision=use_multi_precision)
 
+@register
+class MultiLAMB(Optimizer):
+    """multiLAMB optimizer.
+    """
+    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
+                 lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
+        super(MultiLAMB, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        self.bias_correction = bias_correction
+        self.aggregate_num = max(1, min(50, int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "50"))))
+
+    def create_state(self, index, weight):
+        stype = weight.stype
+        dtype = weight.dtype
+        return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype), # mean
+                zeros(weight.shape, weight.context, dtype=dtype, stype=stype), # variance
+                zeros(weight.shape, weight.context, dtype=dtype, stype=stype)) # temp_g
+
+    def _update_impl(self, index, weights, grads, states, multi_precision=False):
+        step_count = []
+        if not isinstance(index, (tuple, list)):
+            weights = [weights]
+            grads = [grads]
+            states = [states]
+            self._update_count(index)
+            step_count.append(self._index_update_count[index])
+            lr = self._get_lr(index)
+            wd = self._get_wd(index)
+        else:
+            for i, (weight, grad) in enumerate(zip(weights, grads)):
+                assert(isinstance(weight, NDArray))
+                assert(isinstance(grad, NDArray))
+                self._update_count(i)
+                step_count.append(self._index_update_count[i])
+            lr = self._get_lr(index[0])
+            wd = self._get_wd(index[0])
+
+        kwargs = {'learning_rate': lr, 'beta1': self.beta1, 'beta2': self.beta2,
+                  'epsilon': self.epsilon, 'wd': wd,
+                  'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
+                  'bias_correction': self.bias_correction,
+                  'rescale_grad': self.rescale_grad}
+
+        if self.clip_gradient:
+            kwargs['clip_gradient'] = self.clip_gradient
+
+        updated_tensors = 0
+        while updated_tensors < len(weights):
+            sidx = updated_tensors
+            eidx = min(updated_tensors + self.aggregate_num, len(weights))
+            if not multi_precision:
+                mean, var, temp_g = list(zip(*states[sidx:eidx]))
+                multi_lamb_update(weights[sidx:eidx],
+                                  grads[sidx:eidx],
+                                  mean, var, temp_g,
+                                  out=weights[sidx:eidx],
+                                  step_count=step_count[sidx:eidx],
+                                  **kwargs)
+            else:
+                mean_var_g = list(zip(*states[sidx:eidx]))[1]
+                temp = list(zip(*mean_var_g))
+                mean = temp[0]
+                var = temp[1]
+                temp_g = temp[2]
+                multi_mp_lamb_update(weights[sidx:eidx],
+                                     grads[sidx:eidx],
+                                     mean, var, temp_g,
+                                     list(zip(*states[sidx:eidx]))[0],
+                                     out=weights[sidx:eidx],
+                                     step_count=step_count[sidx:eidx],
+                                     **kwargs)
+
+            updated_tensors += self.aggregate_num
+
+    def update(self, index, weight, grad, state):
+        self._update_impl(index, weight, grad, state, multi_precision=False)
+
+    def update_multi_precision(self, index, weight, grad, state):
+        if not isinstance(index, (tuple, list)):
+            use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
+        else:
+            use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
+        self._update_impl(index, weight, grad, state,
+                          multi_precision=use_multi_precision)
+
 #
 @register
 class LBSGD(Optimizer):
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index abdf570..6eaec47 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -2141,36 +2141,60 @@
 
 
 def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
-                      rtol=1e-4, atol=1e-5, compare_states=True):
+                      rtol=1e-4, atol=1e-5, compare_states=True, ntensors=1):
     """Compare opt1 and opt2."""
-    if w_stype == 'default':
-        w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
-        w1 = w2.copyto(default_context())
-    elif w_stype in ('row_sparse', 'csr'):
-        w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
-        w1 = w2.copyto(default_context()).tostype('default')
+    if ntensors == 1:
+        if w_stype == 'default':
+            w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
+            w1 = w2.copyto(default_context())
+        elif w_stype in ('row_sparse', 'csr'):
+            w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
+            w1 = w2.copyto(default_context()).tostype('default')
+        else:
+            raise Exception("type not supported yet")
+        if g_stype == 'default':
+            g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
+            g1 = g2.copyto(default_context())
+        elif g_stype in ('row_sparse', 'csr'):
+            g2 = rand_ndarray(shape, g_stype, dtype=dtype)
+            g1 = g2.copyto(default_context()).tostype('default')
+        else:
+            raise Exception("type not supported yet")
+
+        state1 = opt1.create_state_multi_precision(0, w1)
+        state2 = opt2.create_state_multi_precision(0, w2)
+        if compare_states:
+            compare_ndarray_tuple(state1, state2)
+
+        opt1.update_multi_precision(0, w1, g1, state1)
+        opt2.update_multi_precision(0, w2, g2, state2)
+        if compare_states:
+            compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
+        assert_almost_equal(w1, w2, rtol=rtol, atol=atol)
     else:
-        raise Exception("type not supported yet")
-    if g_stype == 'default':
-        g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
-        g1 = g2.copyto(default_context())
-    elif g_stype in ('row_sparse', 'csr'):
-        g2 = rand_ndarray(shape, g_stype, dtype=dtype)
-        g1 = g2.copyto(default_context()).tostype('default')
-    else:
-        raise Exception("type not supported yet")
+        # test multi-tensor: Opt1 single-tensor, Opt2 multi-tensor
+        from copy import deepcopy
+        if not isinstance(shape, list):
+            w1 = [mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) for i in range(ntensors)]
+            g1 = [mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) for i in range(ntensors)]
+        else:
+            w1, g1 = [], []
+            for s in shape:
+                w1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
+                g1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
+        w1 = tuple(w1)
+        w2 = deepcopy(w1)
+        g1 = tuple(g1)
+        g2 = deepcopy(g1)
+        state2 = [opt2.create_state_multi_precision(0, w2[i]) for i in range(ntensors)]
 
-    state1 = opt1.create_state_multi_precision(0, w1)
-    state2 = opt2.create_state_multi_precision(0, w2)
-    if compare_states:
-        compare_ndarray_tuple(state1, state2)
-
-    opt1.update_multi_precision(0, w1, g1, state1)
-    opt2.update_multi_precision(0, w2, g2, state2)
-    if compare_states:
-        compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
-    assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
-
+        opt2.update_multi_precision(list(range(ntensors)), w2, g2, state2)
+        for i in range(ntensors):
+            state1 = opt1.create_state_multi_precision(i, w1[i])
+            opt1.update_multi_precision(i, w1[i], g1[i], state1)
+            if compare_states:
+                compare_ndarray_tuple(state1, state2[i], rtol, atol)
+            compare_ndarray_tuple(w1[i], w2[i], rtol, atol)
 
 def same_symbol_structure(sym1, sym2):
     """Compare two symbols to check if they have the same computation graph structure.
diff --git a/src/operator/contrib/multi_lamb-inl.h b/src/operator/contrib/multi_lamb-inl.h
new file mode 100644
index 0000000..107a4c9
--- /dev/null
+++ b/src/operator/contrib/multi_lamb-inl.h
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file multi_lamb-inl.h
+ * \brief vectorized lars coefficient computed from sums of squared weights and grads
+ * \author Moises Hernandez
+ */
+#ifndef MXNET_OPERATOR_CONTRIB_MULTI_LAMB_INL_H_
+#define MXNET_OPERATOR_CONTRIB_MULTI_LAMB_INL_H_
+
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <mxnet/op_attr_types.h>
+#include <mshadow/base.h>
+#include <nnvm/op.h>
+#include <nnvm/op_attr_types.h>
+#include <vector>
+#include "../operator_common.h"
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../tensor/init_op.h"
+#include "../tensor/util/tensor_util-inl.h"
+#include "multi_sum_sq-inl.h"
+
+namespace mxnet {
+namespace op {
+
+namespace multilamb {
+enum MultiLambUpdateResource {kTempSpace};
+}  // namespace multilamb
+
+struct MultiLAMBParam : public dmlc::Parameter<MultiLAMBParam> {
+  float learning_rate;
+  float beta1;
+  float beta2;
+  float epsilon;
+  float wd;
+  float rescale_grad;
+  float lower_bound;
+  float upper_bound;
+  float clip_gradient;
+  bool bias_correction;
+  int num_tensors;
+  mxnet::Tuple<int> step_count;
+
+  DMLC_DECLARE_PARAMETER(MultiLAMBParam) {
+    DMLC_DECLARE_FIELD(learning_rate)
+    .set_default(0.001f)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(beta1)
+    .set_default(0.9f)
+    .describe("Exponential decay rate for the first moment estimates.");
+    DMLC_DECLARE_FIELD(beta2)
+    .set_default(0.999f)
+    .describe("Exponential decay rate for the second moment estimates.");
+    DMLC_DECLARE_FIELD(epsilon)
+    .set_default(1e-6f)
+    .describe("Small value to avoid division by 0.");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("Weight decay augments the objective function with a "
+              "regularization term that penalizes large weights. "
+               "The penalty scales with the square of the magnitude of each weight.");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Gradient rescaling factor");
+    DMLC_DECLARE_FIELD(lower_bound)
+    .set_default(1e-3f)
+    .describe("Lower limit of norm of weight.");
+    DMLC_DECLARE_FIELD(upper_bound)
+    .set_default(10.0f)
+    .describe("Upper limit of norm of weight.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    DMLC_DECLARE_FIELD(bias_correction)
+    .set_default(false)
+    .describe("Whether to use bias correction.");
+    DMLC_DECLARE_FIELD(step_count)
+    .describe("Step count for each tensor");
+    DMLC_DECLARE_FIELD(num_tensors)
+    .set_default(1)
+    .describe("Number of tensors");
+  }
+};
+
+template<typename ParamType, int input_stride>
+inline bool MultiLAMB_InferShape(const nnvm::NodeAttrs& attrs,
+                                 mxnet::ShapeVector *in_attrs,
+                                 mxnet::ShapeVector *out_attrs) {
+  const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), input_stride * param.num_tensors);
+  CHECK_EQ(out_attrs->size(), param.num_tensors);
+
+  bool all_inferred = true;
+  auto& input_shapes = *in_attrs;
+  auto& output_shapes = *out_attrs;
+
+  CHECK_EQ(param.step_count.ndim(), param.num_tensors)
+    << "Number of step counts is inconsistent with num_weights."
+    << "Expected number of step counts: "
+    << param.num_tensors << ", and got " << param.step_count.ndim();
+
+  // Weights, gradients, mean and variance
+  for (int i = 0; i < param.num_tensors; ++i) {
+    mxnet::ShapeVector input_vec;
+    mxnet::ShapeVector output_vec({output_shapes[i]});
+    for (int j = 0; j < input_stride; ++j) {
+      input_vec.push_back(input_shapes[i * input_stride + j]);
+    }
+    all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs, &input_vec, &output_vec);
+  }
+  return all_inferred;
+}
+
+template <typename ParamType, int input_stride>
+inline bool MP_MultiLAMB_InferType(const nnvm::NodeAttrs& attrs,
+                                    std::vector<int> *in_attrs,
+                                    std::vector<int> *out_attrs) {
+  const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), input_stride * param.num_tensors);
+  CHECK_EQ(out_attrs->size(), param.num_tensors);
+
+  bool all_inferred = true;
+  auto& input_types = *in_attrs;
+  auto& output_types = *out_attrs;
+
+  // weights, gradients
+  for (int i = 0; i < param.num_tensors; ++i) {
+    std::vector<int> input_vec;
+    std::vector<int> output_vec({output_types[i]});
+    for (int j = 0; j < 2; ++j) {
+      input_vec.push_back(input_types[i * input_stride + j]);
+    }
+    all_inferred = all_inferred &&
+            ElemwiseType<2, 1>(attrs, &input_vec, &output_vec);
+  }
+
+  // mean, var, temp_g, weights32 (master copies of weights)
+  for (int i = 0; i < param.num_tensors; ++i) {
+    TYPE_ASSIGN_CHECK(input_types, input_stride * i + 2, mshadow::kFloat32);
+    TYPE_ASSIGN_CHECK(input_types, input_stride * i + 3, mshadow::kFloat32);
+    TYPE_ASSIGN_CHECK(input_types, input_stride * i + 4, mshadow::kFloat32);
+    TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1, mshadow::kFloat32);
+  }
+  return all_inferred;
+}
+
+template<typename T>
+class LAMB_type_identity {
+ public:
+  using type = T;
+};
+
+template<typename T>
+class LAMB_single_precision {
+ public:
+  using type = float;
+};
+
+template<typename DType, typename MPDType>
+struct MultiLAMBKernelParam {
+  static const int N = 50;
+  size_t ntensors;
+  size_t max_size;
+  size_t total_size;
+  size_t sizes[N];
+  DType* weights[N];
+  DType* grads[N];
+  MPDType* mean[N];
+  MPDType* var[N];
+  MPDType* temp_g[N];
+  MPDType* weights32[N];
+  DType* out_data[N];
+  int step_count[N];
+
+  // gpu
+  int chunk_size = 65536;
+  int nchunks;
+};
+
+template<typename xpu,
+         typename DType,
+         typename MPDType,
+         typename ParamType = MultiLAMBParam,
+         int input_stride = 5>
+void FillMultiLAMBKernelParam(const nnvm::NodeAttrs& attrs,
+                              const OpContext &ctx,
+                              const std::vector<TBlob> &inputs,
+                              const std::vector<TBlob> &outputs,
+                              MultiLAMBKernelParam<DType, MPDType> *multi_param) {
+  const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
+  mxnet_op::Stream<xpu>* s = ctx.get_stream<xpu>();
+
+  multi_param->ntensors = p.num_tensors;
+  multi_param->total_size = 0;
+  multi_param->max_size = 0;
+  multi_param->nchunks = 0;
+
+  constexpr bool isSame = std::is_same<DType, MPDType>::value;
+  for (size_t i = 0; i < multi_param->ntensors; ++i) {
+    const auto idx = i * input_stride;
+    multi_param->sizes[i] = inputs[idx].shape_.Size();
+    multi_param->total_size += multi_param->sizes[i];
+    if (multi_param->max_size < multi_param->sizes[i])
+      multi_param->max_size = multi_param->sizes[i];
+
+    multi_param->weights[i] = inputs[idx].FlatTo2D<xpu, DType>(s).dptr_;
+    multi_param->grads[i] = inputs[idx + 1].FlatTo2D<xpu, DType>(s).dptr_;
+    multi_param->mean[i] = inputs[idx + 2].FlatTo2D<xpu, MPDType>(s).dptr_;
+    multi_param->var[i]  = inputs[idx + 3].FlatTo2D<xpu, MPDType>(s).dptr_;
+    multi_param->temp_g[i]  = inputs[idx + 4].FlatTo2D<xpu, MPDType>(s).dptr_;
+    // if mixed precision, then the last input in a set
+    // is 32-bit master copy of the weights
+    if (!isSame)
+      multi_param->weights32[i] = inputs[idx + input_stride - 1].FlatTo2D<xpu, MPDType>(s).dptr_;
+    multi_param->out_data[i] = outputs[i].FlatTo2D<xpu, DType>(s).dptr_;
+    multi_param->nchunks += (multi_param->sizes[i] + multi_param->chunk_size - 1)
+                            /multi_param->chunk_size;
+  }
+  memcpy(multi_param->step_count, p.step_count.begin(), multi_param->ntensors * sizeof(int));
+}
+
+using namespace mxnet_op;
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<cpu>* s);
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<gpu>* s);
+
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<cpu>* s);
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<gpu>* s);
+
+template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
+inline void multiLAMB(const nnvm::NodeAttrs& attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+  auto param = nnvm::get<MultiLAMBParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    using MPDType = typename MPTypeChooser<DType>::type;
+    MultiLAMBKernelParam<DType, MPDType> kernel_params;
+    FillMultiLAMBKernelParam<xpu, DType, MPDType, MultiLAMBParam, input_stride>
+            (attrs, ctx, inputs, outputs, &kernel_params);
+
+    // create vector of TBlob with all the weights contiguous
+    std::vector<TBlob> weights;
+    for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+        weights.emplace_back(inputs[index*input_stride]);
+    }
+    // create vector of TBlob with all the temp_g contiguous
+    std::vector<TBlob> temp_g;
+    for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+        temp_g.emplace_back(inputs[index*input_stride+4]);
+    }
+
+    // Calculate amount of temporary storage
+    size_t workspace_size = 2 * kernel_params.ntensors * sizeof(float) +
+        2 * kernel_params.nchunks * sizeof(int);
+
+      // Request temporary storage
+    Tensor<xpu, 1, char> workspace =
+    ctx.requested[multilamb::kTempSpace].get_space_typed<xpu, 1, char>(
+      Shape1(workspace_size), s);
+
+    // Create tensors
+    size_t pos_wspace = 0;
+    Tensor<xpu, 1, float> r1(reinterpret_cast<float*>(&workspace[pos_wspace]),
+      Shape1(kernel_params.ntensors), s);
+    pos_wspace += kernel_params.ntensors * sizeof(float);
+    Tensor<xpu, 1, float> r2(reinterpret_cast<float*>(&workspace[pos_wspace]),
+      Shape1(kernel_params.ntensors), s);
+    pos_wspace += kernel_params.ntensors * sizeof(float);
+    Tensor<xpu, 1, int> block_to_tensor(reinterpret_cast<int*>(&workspace[pos_wspace]),
+      Shape1(kernel_params.nchunks), s);
+    pos_wspace += kernel_params.nchunks * sizeof(int);
+    Tensor<xpu, 1, int> block_to_chunk(reinterpret_cast<int*>(&workspace[pos_wspace]),
+      Shape1(kernel_params.nchunks), s);
+
+    MultiSumSqRun<xpu>(weights, kernel_params.ntensors, r1.dptr_, s);
+    call_kernel1<MPDType, DType>(s, kernel_params, param, block_to_tensor.dptr_,
+                                  block_to_chunk.dptr_);
+    MultiSumSqRun<xpu>(temp_g, kernel_params.ntensors, r2.dptr_, s);
+    call_kernel2<MPDType, DType>(s, kernel_params, param, r1.dptr_, r2.dptr_,
+                                 block_to_tensor.dptr_, block_to_chunk.dptr_,
+                                 req[0]);
+  });
+}
+
+template<typename xpu, bool MP>
+inline void multiLAMBUpdate(const nnvm::NodeAttrs& attrs,
+                            const OpContext &ctx,
+                            const std::vector<TBlob> &inputs,
+                            const std::vector<OpReqType> &req,
+                            const std::vector<TBlob> &outputs) {
+  if (!MP) {
+    multiLAMB<xpu, LAMB_type_identity, 5>
+      (attrs, ctx, inputs, req, outputs);
+  } else {
+    multiLAMB<xpu, LAMB_single_precision, 6>
+      (attrs, ctx, inputs, req, outputs);
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_CONTRIB_MULTI_LAMB_INL_H_
diff --git a/src/operator/contrib/multi_lamb.cc b/src/operator/contrib/multi_lamb.cc
new file mode 100644
index 0000000..ebfa7a2
--- /dev/null
+++ b/src/operator/contrib/multi_lamb.cc
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file multi_lamb.cc
+ * \brief vectorized LAMB coefficient computed from sums of squared weights and grads
+ * \author Moises Hernandez
+ */
+
+#include "./multi_lamb-inl.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename MPDType, bool has_mixed_precision>
+struct MultiLAMB_step1_kernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+                                  const float learning_rate,
+                                  const float beta1, const float beta2,
+                                  const float epsilon,
+                                  const float wd,
+                                  const float clip_gradient,
+                                  const bool bias_correction,
+                                  const float rescale_grad) {
+    using namespace mshadow_op;
+    for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+      if ((size_t)i < kernel_params.sizes[index]) {
+        MPDType w = has_mixed_precision ? kernel_params.weights32[index][i]:
+                                          MPDType(kernel_params.weights[index][i]);
+        MPDType scaled_grad = static_cast<MPDType>(kernel_params.grads[index][i])*rescale_grad;
+        if (clip_gradient >= 0.0f)
+            scaled_grad = mshadow_op::clip::Map(scaled_grad, static_cast<MPDType>(clip_gradient));
+        MPDType mean = static_cast<MPDType>(beta1) * kernel_params.mean[index][i] +
+          (static_cast<MPDType>(1.0f) - static_cast<MPDType>(beta1)) * scaled_grad;
+        MPDType var = static_cast<MPDType>(beta2) * kernel_params.var[index][i] +
+          (static_cast<MPDType>(1.0f) - static_cast<MPDType>(beta2)) * scaled_grad * scaled_grad;
+        kernel_params.mean[index][i] = mean;
+        kernel_params.var[index][i] = var;
+
+        MPDType g;
+        if (bias_correction) {
+          MPDType mean_hat = mean / (static_cast<MPDType>(1.0f) -
+                                    power::Map(static_cast<MPDType>(beta1),
+                                    static_cast<MPDType>(kernel_params.step_count[index])));
+          MPDType var_hat = var / (static_cast<MPDType>(1.0f) -
+                                  power::Map(static_cast<MPDType>(beta2),
+                                  static_cast<MPDType>(kernel_params.step_count[index])));
+          g = mean_hat / (sqrt(var_hat) + epsilon) + wd * w;
+        } else {
+          g = mean / (sqrt(var) + epsilon) + wd * w;
+        }
+        kernel_params.temp_g[index][i] = g;
+      }
+    }
+  }
+};
+
+template<typename MPDType, bool has_mixed_precision>
+struct MultiLAMB_step2_kernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+                                  const float* sumSqWeigths,
+                                  const float* sumSqtemp_g,
+                                  const float learning_rate,
+                                  const float lower_bound,
+                                  const float upper_bound,
+                                  const OpReqType req) {
+    for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+      if ((size_t)i < kernel_params.sizes[index]) {
+        MPDType w = has_mixed_precision ? kernel_params.weights32[index][i]:
+                                            MPDType(kernel_params.weights[index][i]);
+        float r1 = sqrt(sumSqWeigths[index]);
+        float r2 = sqrt(sumSqtemp_g[index]);
+        r1 = std::min(std::max(r1, lower_bound), upper_bound);
+
+        // calculate lamb_trust_ratio
+        MPDType r;
+        if (r1 == 0.0f || r2 == 0.0f)
+          r = 1.0f;
+        else
+          r = r1/r2;
+
+        MPDType lr_adjusted = learning_rate * r;
+        w -= lr_adjusted * kernel_params.temp_g[index][i];
+
+        // update weights
+        if (has_mixed_precision)
+          kernel_params.weights32[index][i] = w;
+        KERNEL_ASSIGN(kernel_params.out_data[index][i], req, w);
+      }
+    }
+  }
+};
+
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<cpu>* s,
+                  const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+                  const MultiLAMBParam &param,
+                  int* block_to_tensor,
+                  int* block_to_chunk) {
+  Kernel<MultiLAMB_step1_kernel<MPDType, !std::is_same<DType, MPDType>::value>, cpu>::
+                                 Launch(s, kernel_params.max_size,
+                                 kernel_params,
+                                 param.learning_rate,
+                                 param.beta1, param.beta2,
+                                 param.epsilon,
+                                 param.wd,
+                                 param.clip_gradient,
+                                 param.bias_correction,
+                                 param.rescale_grad);
+}
+
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<cpu>* s,
+                  const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+                  const MultiLAMBParam &param,
+                  float* r1, float* r2,
+                  int* block_to_tensor,
+                  int* block_to_chunk,
+                  const OpReqType req) {
+  Kernel<MultiLAMB_step2_kernel<MPDType, !std::is_same<DType, MPDType>::value>, cpu>::
+                                 Launch(s, kernel_params.max_size,
+                                 kernel_params,
+                                 r1, r2,
+                                 param.learning_rate,
+                                 param.lower_bound, param.upper_bound,
+                                 req);
+}
+
+DMLC_REGISTER_PARAMETER(MultiLAMBParam);
+
+std::vector<std::string> LAMBParamToVector(uint32_t num_args, const char *pName[], size_t nParams) {
+  std::vector<std::string> ret;
+  for (uint32_t i = 0; i < num_args; ++i) {
+    const auto idx = std::to_string(i);
+    for (size_t j = 0; j < nParams; ++j)
+      ret.push_back(std::string(pName[i]) + idx);
+  }
+
+  return ret;
+}
+
+inline uint32_t num_tensors(const nnvm::NodeAttrs& attrs) {
+  return static_cast<uint32_t>(dmlc::get<MultiLAMBParam>(attrs.parsed).num_tensors);
+}
+
+NNVM_REGISTER_OP(_multi_lamb_update)
+.describe(R"code(Compute the LAMB coefficients of multiple weights and grads"
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    return num_tensors(attrs) * 5;
+  })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+    return num_tensors(attrs);
+  })
+.set_attr_parser(ParamParser<MultiLAMBParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MultiLAMB_InferShape<MultiLAMBParam, 5>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "temp_g"};
+    return LAMBParamToVector(num_tensors(attrs), paramName, sizeof(paramName)/sizeof(paramName[0]));
+  })
+// mutable: mean, var, temp_g,
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    std::vector<uint32_t> ret;
+    const auto iMax = num_tensors(attrs);
+    for (size_t i = 0; i < iMax; ++i) {
+      ret.push_back(i * 5 + 2);
+      ret.push_back(i * 5 + 3);
+      ret.push_back(i * 5 + 4);
+    }
+    return ret;
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", multiLAMBUpdate<cpu, false>)
+.add_argument("data", "NDArray-or-Symbol[]", "data")
+.add_arguments(MultiLAMBParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_multi_mp_lamb_update)
+.describe(R"code(Compute the LAMB coefficients of multiple weights and grads with Mix Precision"
+)code" ADD_FILELINE)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    return num_tensors(attrs) * 6;
+  })
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+    return num_tensors(attrs);
+  })
+.set_attr_parser(ParamParser<MultiLAMBParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", MultiLAMB_InferShape<MultiLAMBParam, 6>)
+.set_attr<nnvm::FInferType>("FInferType", MP_MultiLAMB_InferType<MultiLAMBParam, 6>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "temp_g", "weight32_"};
+    return LAMBParamToVector(num_tensors(attrs), paramName, sizeof(paramName)/sizeof(paramName[0]));
+  })
+// mutable: mean, var, temp_g, weights32
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    std::vector<uint32_t> ret;
+    const auto iMax = num_tensors(attrs);
+    for (size_t i = 0; i < iMax; ++i) {
+      ret.push_back(i * 6 + 2);
+      ret.push_back(i * 6 + 3);
+      ret.push_back(i * 6 + 4);
+      ret.push_back(i * 6 + 5);
+    }
+    return ret;
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", multiLAMBUpdate<cpu, true>)
+.add_argument("data", "NDArray-or-Symbol[]", "data")
+.add_arguments(MultiLAMBParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/multi_lamb.cu b/src/operator/contrib/multi_lamb.cu
new file mode 100644
index 0000000..ed357d8
--- /dev/null
+++ b/src/operator/contrib/multi_lamb.cu
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file multi_lamb.cu
+ * \brief vectorized lamb coefficient computed from sums of squared weights and grads
+ * \author Moises Hernandez
+ */
+
+#include "./multi_lamb-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define BLOCK_SIZE_LAMB 512
+#define ILP_LAMB 4
+
+template<bool has_mixed_precision, typename MPDType, typename DType>
+__global__ void kernel_step1(const MultiLAMBKernelParam<DType, MPDType> kernel_params,
+                             const float learning_rate,
+                             const float beta1, const float beta2,
+                             const MPDType beta3, const MPDType beta4,
+                             const float epsilon,
+                             const float wd,
+                             const float clip_gradient,
+                             const bool bias_correction,
+                             const float rescale_grad,
+                             int* block_to_tensor,
+                             int* block_to_chunk) {
+  const int tensorID = block_to_tensor[blockIdx.x];
+  const int chunckID = block_to_chunk[blockIdx.x];
+  const int startPos = chunckID * kernel_params.chunk_size + threadIdx.x;
+  const int stopPos = chunckID * kernel_params.chunk_size + kernel_params.chunk_size;
+
+  MPDType biascorrection1, biascorrection2;
+  if (bias_correction) {
+      biascorrection1 = 1.0f - std::pow(beta1, kernel_params.step_count[tensorID]);
+      biascorrection2 = 1.0f - std::pow(beta2, kernel_params.step_count[tensorID]);
+  } else {
+      biascorrection1 = 1.0f;
+      biascorrection2 = 1.0f;
+  }
+
+  MPDType r_weight[ILP_LAMB];
+  MPDType r_grad[ILP_LAMB];
+  MPDType r_mean[ILP_LAMB];
+  MPDType r_var[ILP_LAMB];
+  MPDType r_g[ILP_LAMB];
+
+  for (size_t i=startPos; i < stopPos && i < kernel_params.sizes[tensorID];
+                                                  i+= blockDim.x*ILP_LAMB) {
+#pragma unroll
+      for (int ii = 0; ii < ILP_LAMB; ii++) {
+          int load_pos = i + ii*blockDim.x;
+          if (load_pos < stopPos && load_pos < kernel_params.sizes[tensorID]) {
+              r_weight[ii] = has_mixed_precision ? kernel_params.weights32[tensorID][load_pos]:
+                                static_cast<MPDType>(kernel_params.weights[tensorID][load_pos]);
+              r_grad[ii] = static_cast<MPDType>(kernel_params.grads[tensorID][load_pos]);
+              r_mean[ii] = kernel_params.mean[tensorID][load_pos];
+              r_var[ii] = kernel_params.var[tensorID][load_pos];
+          } else {
+              r_weight[ii] = static_cast<MPDType>(0);
+              r_grad[ii] = static_cast<MPDType>(0);
+              r_mean[ii] = static_cast<MPDType>(0);
+              r_var[ii] = static_cast<MPDType>(0);
+          }
+      }
+#pragma unroll
+      for (int ii = 0; ii < ILP_LAMB; ii++) {
+          r_grad[ii] = r_grad[ii] * rescale_grad;
+          if (clip_gradient >= 0.0f)
+              r_grad[ii] = max(min(r_grad[ii], clip_gradient), -clip_gradient);
+          r_mean[ii] = static_cast<MPDType>(beta1) * r_mean[ii] + beta3 * r_grad[ii];
+          r_var[ii] = static_cast<MPDType>(beta2) * r_var[ii] + beta4 * r_grad[ii] * r_grad[ii];
+          r_g[ii] = (r_mean[ii] / biascorrection1) / (sqrtf(r_var[ii] / biascorrection2) + epsilon)
+                    + wd * r_weight[ii];
+       }
+#pragma unroll
+      for (int ii = 0; ii < ILP_LAMB; ii++) {
+          int store_pos = i + ii*blockDim.x;
+          if (store_pos < stopPos && store_pos < kernel_params.sizes[tensorID]) {
+              kernel_params.mean[tensorID][store_pos] = r_mean[ii];
+              kernel_params.var[tensorID][store_pos] = r_var[ii];
+              kernel_params.temp_g[tensorID][store_pos] = r_g[ii];
+          }
+      }
+  }
+}
+
+template<bool has_mixed_precision, typename MPDType, typename DType>
+__global__ void kernel_step2(const MultiLAMBKernelParam<DType, MPDType> kernel_params,
+                             const float* sumSqWeigths,
+                             const float* sumSqtemp_g,
+                             const float learning_rate,
+                             const float lower_bound,
+                             const float upper_bound,
+                             int* block_to_tensor,
+                             int* block_to_chunk,
+                             const OpReqType req) {
+  const int tensorID = block_to_tensor[blockIdx.x];
+  const int chunckID = block_to_chunk[blockIdx.x];
+  const int startPos = chunckID * kernel_params.chunk_size + threadIdx.x;
+  const int stopPos = chunckID * kernel_params.chunk_size + kernel_params.chunk_size;
+
+  MPDType r1 = sqrtf(sumSqWeigths[tensorID]);
+  MPDType r2 = sqrtf(sumSqtemp_g[tensorID]);
+  r1 = min(max(r1, lower_bound), upper_bound);
+
+  MPDType lr_adjusted;
+  if (r1 == 0.0f || r2 == 0.0f)
+      lr_adjusted = learning_rate;
+  else
+      lr_adjusted = learning_rate * r1/r2;
+
+  MPDType r_weight[ILP_LAMB];
+  MPDType r_g[ILP_LAMB];
+
+  for (size_t i=startPos; i < stopPos && i < kernel_params.sizes[tensorID];
+                                                  i+= blockDim.x*ILP_LAMB) {
+#pragma unroll
+      for (int ii = 0; ii < ILP_LAMB; ii++) {
+          int load_pos = i + ii*blockDim.x;
+          if (load_pos < stopPos&& load_pos < kernel_params.sizes[tensorID]) {
+              r_weight[ii] = has_mixed_precision ? kernel_params.weights32[tensorID][load_pos]:
+                                static_cast<MPDType>(kernel_params.weights[tensorID][load_pos]);
+              r_g[ii] = kernel_params.temp_g[tensorID][load_pos];
+          }
+      }
+#pragma unroll
+      for (int ii = 0; ii < ILP_LAMB; ii++) {
+          r_weight[ii] -= lr_adjusted * r_g[ii];
+      }
+#pragma unroll
+      for (int ii = 0; ii < ILP_LAMB; ii++) {
+          int store_pos = i + ii*blockDim.x;
+          if (store_pos < stopPos && store_pos < kernel_params.sizes[tensorID]) {
+              if (has_mixed_precision)
+                  kernel_params.weights32[tensorID][store_pos] = r_weight[ii];
+              KERNEL_ASSIGN(kernel_params.out_data[tensorID][store_pos], req, r_weight[ii]);
+          }
+      }
+  }
+}
+
+template<typename MPDType, typename DType>
+void call_kernel1(Stream<gpu>* s,
+                  const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+                  const MultiLAMBParam &param,
+                  int* block_to_tensor,
+                  int* block_to_chunk) {
+  int nblocks = kernel_params.nchunks;
+  int* host_block2tensor = reinterpret_cast<int*>(malloc(kernel_params.nchunks*sizeof(int)));
+  int* host_block2chunk = reinterpret_cast<int*>(malloc(kernel_params.nchunks*sizeof(int)));
+  int chunkID = 0;
+  for (size_t index = 0; index < kernel_params.ntensors; ++index) {
+    int current_chunk = 0;
+    for (size_t j = 0; j < kernel_params.sizes[index]; j+=kernel_params.chunk_size) {
+        host_block2tensor[chunkID] = index;
+        host_block2chunk[chunkID] = current_chunk;
+        current_chunk++;
+        chunkID++;
+    }
+  }
+  cudaMemcpyAsync(block_to_tensor, host_block2tensor, kernel_params.nchunks*sizeof(int),
+                  cudaMemcpyHostToDevice, Stream<gpu>::GetStream(s));
+  cudaMemcpyAsync(block_to_chunk, host_block2chunk, kernel_params.nchunks*sizeof(int),
+                  cudaMemcpyHostToDevice, Stream<gpu>::GetStream(s));
+
+  bool has_mixed_precision = !std::is_same<DType, MPDType>::value;
+  MPDType beta3 = 1.0 - param.beta1;
+  MPDType beta4 = 1.0 - param.beta2;
+
+  if (has_mixed_precision)
+    kernel_step1<true><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+                      kernel_params,
+                      param.learning_rate,
+                      param.beta1, param.beta2,
+                      beta3, beta4,
+                      param.epsilon, param.wd,
+                      param.clip_gradient,
+                      param.bias_correction,
+                      param.rescale_grad,
+                      block_to_tensor,
+                      block_to_chunk);
+  else
+    kernel_step1<false><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+                      kernel_params,
+                      param.learning_rate,
+                      param.beta1, param.beta2,
+                      beta3, beta4,
+                      param.epsilon, param.wd,
+                      param.clip_gradient,
+                      param.bias_correction,
+                      param.rescale_grad,
+                      block_to_tensor,
+                      block_to_chunk);
+  }
+
+template<typename MPDType, typename DType>
+void call_kernel2(Stream<gpu>* s,
+                  const MultiLAMBKernelParam<DType, MPDType>& kernel_params,
+                  const MultiLAMBParam &param,
+                  float* r1, float* r2,
+                  int* block_to_tensor,
+                  int* block_to_chunk,
+                  const OpReqType req) {
+  size_t nblocks = kernel_params.nchunks;
+  bool has_mixed_precision = !std::is_same<DType, MPDType>::value;
+  if (has_mixed_precision)
+    kernel_step2<true><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+                      kernel_params,
+                      r1, r2,
+                      param.learning_rate,
+                      param.lower_bound, param.upper_bound,
+                      block_to_tensor,
+                      block_to_chunk,
+                      req);
+  else
+    kernel_step2<false><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(
+                      kernel_params,
+                      r1, r2,
+                      param.learning_rate,
+                      param.lower_bound, param.upper_bound,
+                      block_to_tensor,
+                      block_to_chunk,
+                      req);
+}
+
+
+NNVM_REGISTER_OP(_multi_lamb_update)
+.set_attr<FCompute>("FCompute<gpu>",  multiLAMBUpdate<gpu, false>);
+
+NNVM_REGISTER_OP(_multi_mp_lamb_update)
+.set_attr<FCompute>("FCompute<gpu>",  multiLAMBUpdate<gpu, true>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 8f86bdf..44487d9 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -775,6 +775,104 @@
                                           dtype, w_stype='default', g_stype='row_sparse',
                                           rtol=1e-4, atol=2e-5)
 
+# MultiLAMB
+class PyMultiLAMB(mx.optimizer.Optimizer):
+    """python reference implemenation of lamb"""
+    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
+                 lower_bound=1e-3, upper_bound=10.0, bias_correction=False, 
+                 multi_precision=False, clip_gradient=-1, **kwargs):
+        super(PyMultiLAMB, self).__init__(learning_rate=learning_rate, **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        self.bias_correction = bias_correction
+        self.multi_precision = multi_precision
+        self.clip_gradient = clip_gradient
+         
+    def create_state(self, index, weight):
+        return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype),
+                mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype),
+                mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype))
+
+    def update(self, index, weight, grad, state):
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+        index_update_count = self._index_update_count[index]
+        
+        grad *= self.rescale_grad
+        if self.clip_gradient >= 0:
+            grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+
+        mean, var, temp_g = state
+        mean[:] = self.beta1 * mean + (1. - self.beta1) * grad.astype(mean.dtype)
+        var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad.astype(mean.dtype))
+
+        r1 = weight.norm()
+        if self.lower_bound:
+            r1 = mx.nd.maximum(r1, self.lower_bound)
+        if self.upper_bound:
+            r1 = mx.nd.minimum(r1, self.upper_bound)
+        
+        if self.bias_correction:
+            mean_hat = mean / (1. - mx.nd.power(self.beta1, index_update_count))
+            var_hat = var / (1. - mx.nd.power(self.beta2, index_update_count))
+        else:
+            mean_hat = mean
+            var_hat = var
+        
+        temp_g[:] = mean_hat / (mx.nd.sqrt(var_hat) + self.epsilon) + wd * weight
+        r2 = temp_g.norm()
+        # calculate lamb_trust_ratio
+        r = 1. if r1 == 0. or r2 == 0. else r1 / r2
+        lr *= r
+        weight[:] -= lr * temp_g
+
+@with_seed()
+def test_multilamb():
+    opt1 = PyMultiLAMB
+    opt2 = mx.optimizer.MultiLAMB
+    #set_default_context(mx.gpu(0))
+    
+    # shapes as Bert-large
+    dims_x = [1024, 4096, 1024, 1024]
+    dims_y = [1, 1, 1024, 4096]
+    dims_occurrences = [9, 1, 4, 2]
+    nlayers = 4 # 24
+    extra_dims_x=[30522, 512, 30522]
+    extra_dims_y=[1, 1024, 1024]
+    shapes=[]
+    for l in range(nlayers):
+        for i, (dx,dy) in enumerate(zip(dims_x, dims_y)):
+            for j in range(dims_occurrences[i]):
+                shapes.append((dx,dy))
+    for dx,dy in zip(extra_dims_x, extra_dims_y):
+        shapes.append((dx,dy))
+
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    bias_options = [{'bias_correction': False}, {'bias_correction': True}]
+    
+    for dtype in [np.float16, np.float32, np.float64]:
+        for cg_option in cg_options:
+            for rg_option in rg_options:
+                for wd_option in wd_options:
+                    for bias_option in bias_options:
+                        kwarg = {}
+                        kwarg.update(cg_option)
+                        kwarg.update(rg_option)
+                        kwarg.update(wd_option)
+                        kwarg.update(bias_option)
+                        if (dtype == np.float16):
+                            kwarg.update({'multi_precision': True})
+                        atol = 1e-3
+                        rtol = 1e-6
+                        compare_optimizer(opt1(**kwarg), opt2(**kwarg), shapes, dtype, 
+                                          rtol=rtol, atol=atol, ntensors=len(shapes))
+
 # AdaMax
 class PyAdamax(mx.optimizer.Optimizer):
     """The python reference of AdaMax optimizer.