blob: 39d8f25520359dad9304a7925ad7801fe9e5a6cc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
AbstractCallback
Abstract type of callback functions used in training.
"""
abstract type AbstractCallback end
"""
AbstractBatchCallback
Abstract type of callbacks to be called every mini-batch.
"""
abstract type AbstractBatchCallback <: AbstractCallback end
"""
AbstractEpochCallback
Abstract type of callbacks to be called every epoch.
"""
abstract type AbstractEpochCallback <: AbstractCallback end
mutable struct BatchCallback <: AbstractBatchCallback
frequency :: Int
call_on_0 :: Bool
callback :: Function
end
"""
every_n_batch(callback :: Function, n :: Int; call_on_0 = false)
A convenient function to construct a callback that runs every `n` mini-batches.
# Arguments
* `call_on_0::Bool`: keyword argument, default false. Unless set, the callback
will *not* be run on batch 0.
For example, the [`speedometer`](@ref) callback is defined as
```julia
every_n_batch(frequency, call_on_0=true) do state :: OptimizationState
if state.curr_batch == 0
# reset timer
else
# compute and print speed
end
end
```
See also [`every_n_epoch`](@ref) and [`speedometer`](@ref).
"""
function every_n_batch(callback::Function, n::Int; call_on_0::Bool = false)
BatchCallback(n, call_on_0, callback)
end
function (cb :: BatchCallback)(state :: OptimizationState)
if state.curr_batch == 0
if cb.call_on_0
cb.callback(state)
end
elseif state.curr_batch % cb.frequency == 0
cb.callback(state)
end
end
"""
speedometer(;frequency=50)
Create an `AbstractBatchCallback` that measure the training speed
(number of samples processed per second) every k mini-batches.
# Arguments
* `frequency::Int`: keyword argument, default 50. The frequency (number of
min-batches) to measure and report the speed.
"""
function speedometer(;frequency::Int = 50)
cl_tic = 0
every_n_batch(frequency, call_on_0 = true) do state::OptimizationState
if state.curr_batch == 0
# reset timer
cl_tic = time()
else
speed = frequency * state.batch_size / (time() - cl_tic)
@info(format("Speed: {1:>6.2f} samples/sec", speed))
cl_tic = time()
end
end
end
mutable struct EpochCallback <: AbstractEpochCallback
frequency :: Int
call_on_0 :: Bool
callback :: Function
end
"""
every_n_epoch(callback :: Function, n :: Int; call_on_0 = false)
A convenient function to construct a callback that runs every `n` full data-passes.
* `call_on_0::Bool`: keyword argument, default false. Unless set, the callback
will *not* be run on epoch 0. Epoch 0 means no training has been performed
yet. This is useful if you want to inspect the randomly initialized model
that has not seen any data yet.
See also [`every_n_batch`](@ref).
"""
every_n_epoch(callback::Function, n::Int; call_on_0::Bool = false) =
EpochCallback(n, call_on_0, callback)
function (cb::EpochCallback)(model::Any, state::OptimizationState,
metric::Vector{Tuple{Symbol, T}}) where T<:Real
if state.curr_epoch == 0
if cb.call_on_0
cb.callback(model, state, metric)
end
elseif state.curr_epoch % cb.frequency == 0
cb.callback(model, state, metric)
end
end
"""
do_checkpoint(prefix; frequency=1, save_epoch_0=false)
Create an `AbstractEpochCallback` that save checkpoints of the model to disk.
The checkpoints can be loaded back later on.
# Arguments
* `prefix::AbstractString`: the prefix of the filenames to save the model.
The model architecture will be saved to prefix-symbol.json,
while the weights will be saved to prefix-0012.params,
for example, for the 12-th epoch.
* `frequency::Int`: keyword argument, default is 1.
The frequency (measured in epochs) to save checkpoints.
* `save_epoch_0::Bool`: keyword argument, default false. Whether we should save a
checkpoint for epoch 0 (model initialized but not seen any data yet).
"""
function do_checkpoint(prefix::AbstractString;
frequency::Int = 1, save_epoch_0::Bool = false)
mkpath(dirname(prefix))
every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state, metric
save_checkpoint(model, prefix, state)
end
end