blob: 97b6ce3161a0c5a792252d7f2d22727390edf93a [file] [log] [blame]
MXNet R Tutorial on Callback Function
======================================
This vignette gives users a guideline for using and writing callback functions,
which can very useful in model training.
This tutorial is written in Rmarkdown.
- You can directly view the hosted version of the tutorial from [MXNet R Document](http://mxnet.readthedocs.io/en/latest/packages/r/CallbackFunctionTutorial.html)
- You can find the Rmarkdown source from [here](https://github.com/dmlc/mxnet/blob/master/R-package/vignettes/CallbackFunctionTutorial.Rmd)
Model training example
----------
Let's begin from a small example. We can build and train a model using the following code:
```{r}
library(mxnet)
data(BostonHousing, package="mlbench")
train.ind = seq(1, 506, 3)
train.x = data.matrix(BostonHousing[train.ind, -14])
train.y = BostonHousing[train.ind, 14]
test.x = data.matrix(BostonHousing[-train.ind, -14])
test.y = BostonHousing[-train.ind, 14]
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=1)
lro <- mx.symbol.LinearRegressionOutput(fc1)
mx.set.seed(0)
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse)
```
Besides, we provide two optional parameters, `batch.end.callback` and `epoch.end.callback`, which can provide great flexibility in model training.
How to use callback functions
---------
Two callback functions are provided in this package:
- `mx.callback.save.checkpoint` is used to save checkpoint to files each period iteration.
```{r}
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
epoch.end.callback = mx.callback.save.checkpoint("boston"))
```
- `mx.callback.log.train.metric` is used to log training metric each period. You can use it either as a `batch.end.callback` or a
`epoch.end.callback`.
```{r}
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
batch.end.callback = mx.callback.log.train.metric(5))
```
You can also save the training and evaluation errors for later usage by passing a reference class.
```{r}
logger <- mx.metric.logger$new()
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
epoch.end.callback = mx.callback.log.train.metric(5, logger))
head(logger$train)
head(logger$eval)
```
How to write your own callback functions
----------
You can find the source code for two callback functions from [here](https://github.com/dmlc/mxnet/blob/master/R-package/R/callback.R) and they can be used as your template:
Basically, all callback functions follow the structure below:
```{r, eval=FALSE}
mx.callback.fun <- function() {
function(iteration, nbatch, env, verbose) {
}
}
```
The `mx.callback.save.checkpoint` function below is stateless. It just get the model from environment and save it.
```{r, eval=FALSE}
mx.callback.save.checkpoint <- function(prefix, period=1) {
function(iteration, nbatch, env, verbose) {
if (iteration %% period == 0) {
mx.model.save(env$model, prefix, iteration)
if(verbose) cat(sprintf("Model checkpoint saved to %s-%04d.params\n", prefix, iteration))
}
return(TRUE)
}
}
```
The `mx.callback.log.train.metric` is a little more complex. It will hold a reference class and update it during the training
process.
```{r, eval=FALSE}
mx.callback.log.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env, verbose) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
result <- env$metric$get(env$train.metric)
if (nbatch != 0)
if(verbose) cat(paste0("Batch [", nbatch, "] Train-", result$name, "=", result$value, "\n"))
if (!is.null(logger)) {
if (class(logger) != "mx.metric.logger") {
stop("Invalid mx.metric.logger.")
}
logger$train <- c(logger$train, result$value)
if (!is.null(env$eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (nbatch != 0)
cat(paste0("Batch [", nbatch, "] Validation-", result$name, "=", result$value, "\n"))
logger$eval <- c(logger$eval, result$value)
}
}
}
return(TRUE)
}
}
```
Now you might be curious why both callback functions `return(TRUE)`.
Can we `return(FALSE)`?
Yes! You can stop the training early by `return(FALSE)`. See the examples below.
```{r}
mx.callback.early.stop <- function(eval.metric) {
function(iteration, nbatch, env, verbose) {
if (!is.null(env$metric)) {
if (!is.null(eval.metric)) {
result <- env$metric$get(env$eval.metric)
if (result$value < eval.metric) {
return(FALSE)
}
}
}
return(TRUE)
}
}
model <- mx.model.FeedForward.create(
lro, X=train.x, y=train.y,
eval.data=list(data=test.x, label=test.y),
ctx=mx.cpu(), num.round=10, array.batch.size=20,
learning.rate=2e-6, momentum=0.9, eval.metric=mx.metric.rmse,
epoch.end.callback = mx.callback.early.stop(10))
```
You can see once the validation metric goes below the threshold we set, the training process will stop early.