blob: bc89721094039f427f5924305cfe23a0099dd47a [file] [log] [blame]
#' Learning rate scheduler. Reduction based on a factor value.
#'
#' @param step (integer)
#' Schedule learning rate after n updates
#' @param factor (double)
#' The factor for reducing the learning rate
#' @return scheduler function
#'
#' @export
mx.lr_scheduler.FactorScheduler <- function(step, factor_val, stop_factor_lr=1e-8, verbose=TRUE) {
if(step < 1) stop("Schedule step must be greater or equal than 1 round")
if(factor_val > 1) stop("Factor must be no more than 1 to make lr reduce")
function(optimizerEnv){
num_update <- optimizerEnv$num_update
count <- optimizerEnv$count
lr <- optimizerEnv$lr
if (num_update > count + step){
count <- count + step
lr <- lr * factor_val
if(lr < stop_factor_lr){
lr <- stop_factor_lr
if(verbose) message(paste0("Update[", num_update,
"]: now learning rate arrived at ", lr,
"will not change in the future"))
} else{
if(verbose) message(paste0("Update[", num_update,
"]: learning rate is changed to ", lr))
}
optimizerEnv$lr <- lr
optimizerEnv$count <- count
}
}
}
#' Multifactor learning rate scheduler. Reduction based on a factor value at different steps.
#'
#' @param step (array of integer)
#' Schedule learning rate after n updates
#' @param factor (double)
#' The factor for reducing the learning rate
#' @return scheduler function
#'
#' @export
mx.lr_scheduler.MultiFactorScheduler <- function(step, factor_val, stop_factor_lr=1e-8, verbose=TRUE) {
if(!all(step == cummax(step))) stop("Schedule step must be an increasing integer list")
if(any(step < 1)) stop("Schedule step must be greater or equal than 1 round")
if(factor_val > 1) stop("Factor must be no more than 1 to make lr reduce")
function(optimizerEnv){
if(is.null(optimizerEnv$cur_step_ind)){
cur_step_ind <- 1
} else{
cur_step_ind <- optimizerEnv$cur_step_ind
}
num_update <- optimizerEnv$num_update
lr <- optimizerEnv$lr
count <- optimizerEnv$count
if(cur_step_ind < length(step)){
if(num_update > step[cur_step_ind]){
count <- step[cur_step_ind]
cur_step_ind <- cur_step_ind + 1
lr <- lr * factor_val
if(lr < stop_factor_lr){
lr <- stop_factor_lr
if(verbose) message(paste0("Update[", num_update,
"]: now learning rate arrived at ", lr,
"will not change in the future"))
} else{
if(verbose) message(paste0("Update[", num_update,
"]: learning rate is changed to ", lr))
}
optimizerEnv$lr <- lr
optimizerEnv$count <- count
optimizerEnv$cur_step_ind <- cur_step_ind
}
}
}
}