blob: 4e5d0b77277e84d986631363582ab8977af4a41f [file] [log] [blame]
require(mxnet)
train_model.fit <- function(args, network, data_loader) {
# log
if(!is.null(args$log_file)){
sink(file.path(args$log_dir, args$log_file), append = FALSE,
type=c("output", "message"))
cat(paste0("Starting computation of ", args$network, " at ", Sys.time(), "\n"))
}
cat("Arguments")
print(unlist(args))
# save model
if (is.null(args$model_prefix)) {
checkpoint <- NULL
} else {
checkpoint <- mx.callback.save.checkpoint(args$model_prefix)
}
# load pretrained model
if(!is.null(args$load_epoch)){
if(is.null(args$model_prefix)) stop("model_prefix should not be empty")
begin.round <- args$load_epoch
model <- mx.model.load(args$model_prefix, iteration=begin.round)
network <- model$symbol
arg.params <- model$arg.params
aux.params <- model$aux.params
} else{
arg.params <- NULL
aux.params <- NULL
begin.round <- 1
}
# data
data <- data_loader(args)
train <- data$train
val <- data$value
# devices
if (is.null(args$gpus)) {
devs <- mx.cpu()
} else {
devs <- lapply(unlist(strsplit(args$gpus, ",")), function(i) {
mx.gpu(as.integer(i))
})
}
# learning rate scheduler
if (args$lr_factor < 1){
epoch_size <- as.integer(max(args$num_examples/args$batch_size), 1)
if(!is.null(args$lr_multifactor)){
step <- as.integer(strsplit(args$lr_multifactor,",")[[1]])
step.updated <- step - begin.round + 1
step.updated <- step.updated[step.updated > 0]
step_batch <- epoch_size*step.updated
lr_scheduler <- mx.lr_scheduler.MultiFactorScheduler(step=step_batch, factor_val=args$lr_factor)
} else{
lr_scheduler <- mx.lr_scheduler.FactorScheduler(
step = as.integer(max(epoch_size * args$lr_factor_epoch, 1)),
factor_val = args$lr_factor)
}
} else{
lr_scheduler = NULL
}
# train
model <- mx.model.FeedForward.create(
X = train,
eval.data = val,
ctx = devs,
symbol = network,
begin.round = begin.round,
eval.metric = mx.metric.top_k_accuracy,
num.round = args$num_round,
learning.rate = args$lr,
momentum = args$mom,
wd = args$wd,
kvstore = args$kv_store,
array.batch.size = args$batch_size,
clip_gradient = args$clip_gradient,
lr_scheduler = lr_scheduler,
optimizer = "sgd",
initializer = mx.init.Xavier(factor_type="in", magnitude=2),
arg.params = arg.params,
aux.params = aux.params,
epoch.end.callback = checkpoint,
batch.end.callback = mx.callback.log.train.metric(50))
}