blob: 522e9194caa8c5bd63fc2fbc20210129afd02ea7 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@doc doc"""
Nadam(; kwargs...)
Nesterov Adam optimizer: Adam RMSprop with Nesterov momentum,
see [1] and notes for further description.
### Arguments
* `η`: default `0.001`, learning rate.
* `β1`: default `0.99`.
* `β2`: default `0.999`.
* `ϵ`: default `1e-8`, small value added for numerical stability.
* `clip`: default `0`, gradient clipping.
If positive, will clip the gradient into the range `[-clip, clip]`.
* `scale`: default `0`, gradient rescaling.
If != 0, multiply the gradient with `scale` before updating.
Often choose to be `1.0 / batch_size`.
If leave it default, high-level API like `fit!` will set it to
`1.0 / batch_size`, since `fit!` knows the `batch_size`.
* `λ`: default `0.00001`, weight decay is equivalent
to adding a global l2 regularizer for all the parameters.
* `η_sched::AbstractLearningRateScheduler`: default `nothing`, a
dynamic learning rate scheduler. If set, will overwrite the `η`
parameter.
* `μ_sched::NadamScheduler` default `NadamScheduler()` of the form.
```math
\mu_t = β_1 (1 - 0.5 \times 0.96^{t \times 0.004})
```
### Notes
Default parameters follow those provided in the paper.
It is recommended to leave the parameters of this optimizer
at their default values.
### References
1. [Incorporating Nesterov Momentum into Adam]
(http://cs229.stanford.edu/proj2015/054_report.pdf).
2. [On the importance of initialization and momentum in deep learning]
(http://www.cs.toronto.edu/~fritz/absps/momentum.pdf).
"""
Nadam
@defstruct Nadam <: AbstractOptimizer (
:: Real = 0.001, η > 0),
1 :: Real = 0.99, 0 <= β1 < 1),
2 :: Real = 0.999, 0 <= β2 < 1),
:: Real = 1e-8, ϵ > 0),
(clip :: Real = 0, clip >= 0),
scale :: Real = 0,
:: Real = 1e-5, λ >= 0),
η_sched :: Any = initlrsched(η),
μ_sched :: Momentum.NadamScheduler = Momentum.NadamScheduler = β1)
)
mutable struct NadamState
m :: NDArray
n :: NDArray
Πμ :: Float64
β2 :: Float64
t :: Int # use in NadamScheduler.
# we store `t` in state because state is created for each `index`
end
create_state(n::Nadam, ::Int, W::NDArray) =
NadamState(zeros(size(W), context(W)), zeros(size(W), context(W)),
1.0, n2, 1)
function update!(na::Nadam, ::Int, W::NDArray, ∇::NDArray, s::NadamState)
η = get(na_sched)
μₜ, μₜ₁= get(na_sched, s.t)
β1, β2 = na1, na2
ϵ = na
normgrad!(na, W, ∇)
s.t += 1
s.Πμ *= μₜ
Πμ′ = s.Πμ * μₜ₁
∇′ = / (1.0 - s.Πμ)
@inplace s.m .*= β1
@inplace s.m .+= (1.0 - β1) *
m̂ = s.m / (1.0 - Πμ′)
@inplace s.n .*= β2
@inplace s.n .+= (1.0 - β2) .* ∇.^2
n̂ = s.n / (1.0 - s2ᵗ)
s2 *= β2
m̄ = (1.0 - μₜ) * ∇′+ μₜ₁ * m̂
@inplace W .+= * m̄ ./ (sqrt(n̂) + ϵ)
end