blob: 6af8094829f615e2a5993ff28ec12908cb3f3393 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@doc doc"""
SGD(; kwargs...)
Stochastic gradient descent optimizer.
Vanilla SGD:
```math
\theta \leftarrow \theta - \eta \nabla
```
SGD with momentum::
```math
\begin{align*}
\nu & \leftarrow \mu \nu_{t-1} - \eta \nabla \\
\theta & \leftarrow \theta + \nu_t
\end{align*}
```
### Arguments
* `η`: default `0.01`, learning rate.
* `μ`: default `0`, the momentum, usually set to `0.9` in this implementation.
* `λ`: default `0.0001`, weight decay is equivalent to
adding a global l2 regularizer to the parameters.
* `clip`: default `0`, gradient clipping.
If positive, will clip the gradient into the bounded range `[-clip, clip]`.
* `scale`: default `0`, gradient rescaling.
If != 0, multiply the gradient with `scale` before updating.
Often choose to be `1.0 / batch_size`.
If leave it default, high-level API like `fit!` will set it to
`1.0 / batch_size`, since `fit!` knows the `batch_size`.
* `μ_sched::AbstractMomentumScheduler`: default `Momentum.Null()`,
a dynamic momentum scheduler. If set, will overwrite the `momentum`
parameter.
* `η_sched::AbstractLearningRateScheduler`: default `LearningRate.Fixed(η)`, a
dynamic learning rate scheduler. If set, will overwrite the `η` parameter.
"""
SGD
@defstruct SGD <: AbstractOptimizer (
:: Real = 0.01, η > 0),
:: Real = 0.0, μ >= 0),
(clip :: Real = 0, clip >= 0),
scale :: Real = 0,
:: Real = 0.0001, λ >= 0),
η_sched :: Any = initlrsched(η),
μ_sched :: Any = initmomsched(μ)
)
create_state(sgd::SGD, ::Int, W::NDArray) =
isa(sgd_sched, Momentum.Null) ? nothing : zeros(size(W), context(W))
function update!(sgd::SGD, ::Int, W::NDArray, ∇::NDArray, ::Nothing)
η = get(sgd_sched)
normgrad!(sgd, W, ∇)
@inplace W += *
end
# update with momentum
function update!(sgd::SGD, ::Int, W::NDArray, ∇::NDArray, ν::NDArray)
η = get(sgd_sched)
μ = get(sgd_sched)
normgrad!(sgd, W, ∇)
@inplace ν .*= μ
@inplace ν .+= .*
@inplace W .+= ν
end