| MXNet R Tutorial on Callback Function |
| ====================================== |
| |
| This vignette gives users a guideline for using and writing callback functions, |
| which can very useful in model training. |
| |
| This tutorial is written in Rmarkdown. |
| |
| - You can directly view the hosted version of the tutorial from [MXNet R Document](http://mxnet.readthedocs.io/en/latest/packages/r/CallbackFunctionTutorial.html) |
| |
| - You can find the Rmarkdown source from [here](https://github.com/dmlc/mxnet/blob/master/R-package/vignettes/CallbackFunctionTutorial.Rmd) |
| |
| Model training example |
| ---------- |
| |
| Let's begin from a small example. We can build and train a model using the following code: |
| |
| ```{r} |
| library(mxnet) |
| data(BostonHousing, package="mlbench") |
| train.ind = seq(1, 506, 3) |
| train.x = data.matrix(BostonHousing[train.ind, -14]) |
| train.y = BostonHousing[train.ind, 14] |
| test.x = data.matrix(BostonHousing[-train.ind, -14]) |
| test.y = BostonHousing[-train.ind, 14] |
| data <- mx.symbol.Variable("data") |
| fc1 <- mx.symbol.FullyConnected(data, num_hidden=1) |
| lro <- mx.symbol.LinearRegressionOutput(fc1) |
| mx.set.seed(0) |
| model <- mx.model.FeedForward.create( |
| lro, X=train.x, y=train.y, |
| eval.data=list(data=test.x, label=test.y), |
| ctx=mx.cpu(), num.round=10, array.batch.size=20, |
| learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse) |
| ``` |
| |
| Besides, we provide two optional parameters, `batch.end.callback` and `epoch.end.callback`, which can provide great flexibility in model training. |
| |
| How to use callback functions |
| --------- |
| |
| Two callback functions are provided in this package: |
| |
| - `mx.callback.save.checkpoint` is used to save checkpoint to files each period iteration. |
| |
| ```{r} |
| model <- mx.model.FeedForward.create( |
| lro, X=train.x, y=train.y, |
| eval.data=list(data=test.x, label=test.y), |
| ctx=mx.cpu(), num.round=10, array.batch.size=20, |
| learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse, |
| epoch.end.callback = mx.callback.save.checkpoint("boston")) |
| ``` |
| |
| |
| - `mx.callback.log.train.metric` is used to log training metric each period. You can use it either as a `batch.end.callback` or a |
| `epoch.end.callback`. |
| |
| ```{r} |
| model <- mx.model.FeedForward.create( |
| lro, X=train.x, y=train.y, |
| eval.data=list(data=test.x, label=test.y), |
| ctx=mx.cpu(), num.round=10, array.batch.size=20, |
| learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse, |
| batch.end.callback = mx.callback.log.train.metric(5)) |
| ``` |
| |
| You can also save the training and evaluation errors for later usage by passing a reference class. |
| |
| ```{r} |
| logger <- mx.metric.logger$new() |
| model <- mx.model.FeedForward.create( |
| lro, X=train.x, y=train.y, |
| eval.data=list(data=test.x, label=test.y), |
| ctx=mx.cpu(), num.round=10, array.batch.size=20, |
| learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse, |
| epoch.end.callback = mx.callback.log.train.metric(5, logger)) |
| head(logger$train) |
| head(logger$eval) |
| ``` |
| |
| How to write your own callback functions |
| ---------- |
| |
| You can find the source code for two callback functions from [here](https://github.com/dmlc/mxnet/blob/master/R-package/R/callback.R) and they can be used as your template: |
| |
| Basically, all callback functions follow the structure below: |
| |
| ```{r, eval=FALSE} |
| mx.callback.fun <- function() { |
| function(iteration, nbatch, env, verbose) { |
| } |
| } |
| ``` |
| |
| The `mx.callback.save.checkpoint` function below is stateless. It just get the model from environment and save it. |
| |
| ```{r, eval=FALSE} |
| mx.callback.save.checkpoint <- function(prefix, period=1) { |
| function(iteration, nbatch, env, verbose) { |
| if (iteration %% period == 0) { |
| mx.model.save(env$model, prefix, iteration) |
| if(verbose) cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration)) |
| } |
| return(TRUE) |
| } |
| } |
| ``` |
| |
| The `mx.callback.log.train.metric` is a little more complex. It will hold a reference class and update it during the training |
| process. |
| |
| ```{r, eval=FALSE} |
| mx.callback.log.train.metric <- function(period, logger=NULL) { |
| function(iteration, nbatch, env, verbose) { |
| if (nbatch %% period == 0 && !is.null(env$metric)) { |
| result <- env$metric$get(env$train.metric) |
| if (nbatch != 0) |
| if(verbose) cat(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value, "\n")) |
| if (!is.null(logger)) { |
| if (class(logger) != "mx.metric.logger") { |
| stop("Invalid mx.metric.logger.") |
| } |
| logger$train <- c(logger$train, result$value) |
| if (!is.null(env$eval.metric)) { |
| result <- env$metric$get(env$eval.metric) |
| if (nbatch != 0) |
| cat(paste0("Batch [", nbatch, "] Validation-", result$name, "=", result$value, "\n")) |
| logger$eval <- c(logger$eval, result$value) |
| } |
| } |
| } |
| return(TRUE) |
| } |
| } |
| ``` |
| |
| Now you might be curious why both callback functions `return(TRUE)`. |
| Can we `return(FALSE)`? |
| |
| Yes! You can stop the training early by `return(FALSE)`. See the examples below. |
| |
| ```{r} |
| mx.callback.early.stop <- function(eval.metric) { |
| function(iteration, nbatch, env, verbose) { |
| if (!is.null(env$metric)) { |
| if (!is.null(eval.metric)) { |
| result <- env$metric$get(env$eval.metric) |
| if (result$value < eval.metric) { |
| return(FALSE) |
| } |
| } |
| } |
| return(TRUE) |
| } |
| } |
| model <- mx.model.FeedForward.create( |
| lro, X=train.x, y=train.y, |
| eval.data=list(data=test.x, label=test.y), |
| ctx=mx.cpu(), num.round=10, array.batch.size=20, |
| learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse, |
| epoch.end.callback = mx.callback.early.stop(10)) |
| ``` |
| |
| You can see once the validation metric goes below the threshold we set, the training process will stop early. |