blob: a0562386ebbc1a797b826d19a485cea3e22b7957 [file] [log] [blame]
#' @export mx.metric.logger
mx.metric.logger <- setRefClass("mx.metric.logger", fields = list(train = "numeric", eval="numeric"))
#' Log training metric each period
#'
#' @param period The number of batch to log the training evaluation metric
#' @param logger The logger class
#'
#' @export
mx.callback.log.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env, verbose=TRUE) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
result <- env$metric$get(env$train.metric)
if (nbatch != 0 & verbose)
message(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value))
if (!is.null(logger)) {
if (class(logger) != "mx.metric.logger") {
stop("Invalid mx.metric.logger.")
}
logger$train <- c(logger$train, result$value)
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (nbatch != 0 & verbose)
message(paste0("Batch [", nbatch, "] Validation-", result$name, "=", result$value))
logger$eval <- c(logger$eval, result$value)
}
}
}
return(TRUE)
}
}
#' Calculate the training speed
#'
#' @param batch_size The batch size
#' @param frequency The frequency of the training speed update
#'
#' @export
mx.callback.log.speedometer <- function(batch.size, frequency=50){
function(iteration, nbatch, env, verbose=TRUE) {
count <- nbatch
if(is.null(env$count)) env$count <- 0
if(is.null(env$init)) env$init <- FALSE
if (env$count > count) env$init <- FALSE
env$count = count
if(env$init){
if (count %% frequency == 0 && !is.null(env$metric)){
time <- as.double(difftime(Sys.time(), env$tic, units = "secs"))
speed <- frequency*batch.size/time
result <- env$metric$get(env$train.metric)
if (nbatch != 0 & verbose)
message(paste0("Batch [", nbatch, "] Speed: ", speed, " samples/sec Train-",
result$name, "=", result$value))
env$tic = Sys.time()
}
} else {
env$init <- TRUE
env$tic <- Sys.time()
}
}
}
#' Save checkpoint to files each period iteration.
#'
#' @param prefix The prefix of the model checkpoint.
#'
#' @export
mx.callback.save.checkpoint <- function(prefix, period=1) {
function(iteration, nbatch, env, verbose=TRUE) {
if (iteration %% period == 0) {
mx.model.save(env$model, prefix, iteration)
if(verbose) message(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
}
return(TRUE)
}
}
#' Early stop with different conditions
#'
#' Early stopping applying different conditions: hard thresholds or epochs number from the best score. Tested with "epoch.end.callback" function.
#'
#' @param train.metric Numeric. Hard threshold for the metric of the training data set (optional)
#' @param eval.metric Numeric. Hard threshold for the metric of the evaluating data set (if set, optional)
#' @param bad.steps Integer. How much epochs should gone from the best score? Use this option with evaluation data set
#' @param maximize Logical. Do your model use maximizing or minimizing optimization?
#' @param verbose Logical
#'
#' @export
#'
mx.callback.early.stop <- function(train.metric = NULL, eval.metric = NULL, bad.steps = NULL, maximize = FALSE, verbose = FALSE) {
function(iteration, nbatch, env, verbose = verbose) {
# hard threshold for train metric
if (!is.null(env$metric)) {
if (!is.null(train.metric)) {
result <- env$metric$get(env$train.metric)
if ((maximize == F & result$value < train.metric) | (maximize == TRUE & result$value > train.metric)) {
return(FALSE)
}
}
# hard threshold for test metric
if (!is.null(eval.metric)) {
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if ((maximize == F & result$value < eval.metric) | (maximize == TRUE & result$value > eval.metric)) {
return(FALSE)
}
}
}
}
# not worse than previous X steps
if (!is.null(bad.steps)) {
# set / reset iteration variables
# it may be not the best practice to use global variables,
# but let's not touch "model.r" file
if (iteration == 1){
# reset iterator
mx.best.iter <<- 1
# reset best score
if (maximize) {
mx.best.score <<- 0
}
else {
mx.best.score <<- Inf
}
}
# test early stop round
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if ((maximize == F & result$value > mx.best.score) | (maximize == TRUE & result$value < mx.best.score)) {
if (mx.best.iter == bad.steps) {
if (verbose) {
message(paste0("Best score=", mx.best.score, ", iteration [", iteration - bad.steps, "]"))
}
return(FALSE)
} else {
mx.best.iter <<- mx.best.iter + 1
}
} else {
mx.best.score <<- result$value
mx.best.iter <<- 1
}
}
}
return(TRUE)
}
}