blob: 1ef26b13a1a9fd4a320db016a9e43b0cd4600c08 [file] [log] [blame]
``mx.symbol.mp_lamb_update_phase1``
======================================================================
Description
----------------------
Mixed Precision version of Phase I of lamb update
it performs the following operations and returns g:.
Link to paper: https://arxiv.org/pdf/1904.00962.pdf
.. math::
\begin{gather*}
grad32 = grad(float16) * rescale_grad
if (grad < -clip_gradient)
then
grad = -clip_gradient
if (grad > clip_gradient)
then
grad = clip_gradient
mean = beta1 * mean + (1 - beta1) * grad;
variance = beta2 * variance + (1. - beta2) * grad ^ 2;
if (bias_correction)
then
mean_hat = mean / (1. - beta1^t);
var_hat = var / (1 - beta2^t);
g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight32;
else
g = mean / (var_data^(1/2) + epsilon) + wd * weight32;
\end{gather*}
Usage
----------
.. code:: r
mx.symbol.mp_lamb_update_phase1(...)
Arguments
------------------
+----------------------------------------+------------------------------------------------------------+
| Argument | Description |
+========================================+============================================================+
| ``weight`` | NDArray-or-Symbol. |
| | |
| | Weight |
+----------------------------------------+------------------------------------------------------------+
| ``grad`` | NDArray-or-Symbol. |
| | |
| | Gradient |
+----------------------------------------+------------------------------------------------------------+
| ``mean`` | NDArray-or-Symbol. |
| | |
| | Moving mean |
+----------------------------------------+------------------------------------------------------------+
| ``var`` | NDArray-or-Symbol. |
| | |
| | Moving variance |
+----------------------------------------+------------------------------------------------------------+
| ``weight32`` | NDArray-or-Symbol. |
| | |
| | Weight32 |
+----------------------------------------+------------------------------------------------------------+
| ``beta1`` | float, optional, default=0.899999976. |
| | |
| | The decay rate for the 1st moment estimates. |
+----------------------------------------+------------------------------------------------------------+
| ``beta2`` | float, optional, default=0.999000013. |
| | |
| | The decay rate for the 2nd moment estimates. |
+----------------------------------------+------------------------------------------------------------+
| ``epsilon`` | float, optional, default=9.99999997e-07. |
| | |
| | A small constant for numerical stability. |
+----------------------------------------+------------------------------------------------------------+
| ``t`` | int, required. |
| | |
| | Index update count. |
+----------------------------------------+------------------------------------------------------------+
| ``bias.correction`` | boolean, optional, default=1. |
| | |
| | Whether to use bias correction. |
+----------------------------------------+------------------------------------------------------------+
| ``wd`` | float, required. |
| | |
| | Weight decay augments the objective function with a |
| | regularization term that penalizes large weights. The |
| | penalty scales with the square of the magnitude of each |
| | weight. |
+----------------------------------------+------------------------------------------------------------+
| ``rescale.grad`` | float, optional, default=1. |
| | |
| | Rescale gradient to grad = rescale_grad*grad. |
+----------------------------------------+------------------------------------------------------------+
| ``clip.gradient`` | float, optional, default=-1. |
| | |
| | Clip gradient to the range of [-clip_gradient, |
| | clip_gradient] If clip_gradient <= 0, gradient clipping is |
| | turned off. grad = max(min(grad, clip_gradient), |
| | -clip_gradient). |
+----------------------------------------+------------------------------------------------------------+
| ``name`` | string, optional. |
| | |
| | Name of the resulting symbol. |
+----------------------------------------+------------------------------------------------------------+
Value
----------
``out`` The result mx.symbol
Link to Source Code: http://github.com/apache/incubator-mxnet/blob/1.6.0/src/operator/optimizer_op.cc#L1033