| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| list.of.packages <- c("R.utils") |
| new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[, "Package"])] |
| if( length(new.packages)) install.packages(new.packages, repos = "https://cloud.r-project.org/") |
| |
| setwd(tempdir()) |
| |
| download.file("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", destfile="train-images-idx3-ubyte.gz") |
| |
| download.file("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", destfile="train-labels-idx1-ubyte.gz") |
| |
| download.file("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", destfile="t10k-images-idx3-ubyte.gz") |
| |
| download.file("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", destfile="t10k-labels-idx1-ubyte.gz") |
| |
| require(R.utils) |
| |
| gunzip("train-images-idx3-ubyte.gz") |
| |
| gunzip("train-labels-idx1-ubyte.gz") |
| |
| gunzip("t10k-images-idx3-ubyte.gz") |
| |
| gunzip("t10k-labels-idx1-ubyte.gz") |
| |
| require(mxnet) |
| |
| # Network configuration |
| batch.size <- 100 |
| data <- mx.symbol.Variable("data") |
| fc1 <- mx.symbol.FullyConnected(data, name = "fc1", num_hidden = 128) |
| act1 <- mx.symbol.Activation(fc1, name = "relu1", act_type = "relu") |
| fc2 <- mx.symbol.FullyConnected(act1, name = "fc2", num_hidden = 64) |
| act2 <- mx.symbol.Activation(fc2, name = "relu2", act_type = "relu") |
| fc3 <- mx.symbol.FullyConnected(act2, name = "fc3", num_hidden = 10) |
| softmax <- mx.symbol.Softmax(fc3, name = "sm") |
| |
| dtrain <- mx.io.MNISTIter( |
| image = "train-images-idx3-ubyte", |
| label = "train-labels-idx1-ubyte", |
| data.shape = c(784), |
| batch.size = batch.size, |
| shuffle = TRUE, |
| flat = TRUE, |
| silent = 0, |
| seed = 10) |
| |
| dtest = mx.io.MNISTIter( |
| image="t10k-images-idx3-ubyte", |
| label="t10k-labels-idx1-ubyte", |
| data.shape=c(784), |
| batch.size=batch.size, |
| shuffle=FALSE, |
| flat=TRUE, |
| silent=0) |
| |
| mx.set.seed(0) |
| devices = lapply(1:2, function(i) { |
| mx.cpu(i) |
| }) |
| |
| # create the model |
| model <- mx.model.FeedForward.create(softmax, X=dtrain, eval.data=dtest, |
| ctx=devices, num.round=1, |
| learning.rate=0.1, momentum=0.9, |
| initializer=mx.init.uniform(0.07), |
| epoch.end.callback=mx.callback.save.checkpoint("chkpt"), |
| batch.end.callback=mx.callback.log.train.metric(100)) |
| |
| # do prediction |
| pred <- predict(model, dtest) |
| label <- mx.io.extract(dtest, "label") |
| dataX <- mx.io.extract(dtest, "data") |
| # Predict with R's array |
| pred2 <- predict(model, X = dataX) |
| |
| accuracy <- function(label, pred) { |
| ypred = max.col(t(as.array(pred))) |
| return(sum((as.array(label) + 1) == ypred) / length(label)) |
| } |
| |
| print(paste0("Finish prediction... accuracy = ", accuracy(label, pred))) |
| print(paste0("Finish prediction... accuracy2 = ", accuracy(label, pred2))) |
| |
| |
| |
| # load the model |
| model <- mx.model.load("chkpt", 1) |
| |
| #continue training with some new arguments |
| model <- mx.model.FeedForward.create(model$symbol, X = dtrain, eval.data = dtest, |
| ctx = devices, num.round = 5, |
| learning.rate = 0.1, momentum = 0.9, |
| epoch.end.callback = mx.callback.save.checkpoint("reload_chkpt"), |
| batch.end.callback = mx.callback.log.train.metric(100), |
| arg.params = model$arg.params, aux.params = model$aux.params) |
| |
| # do prediction |
| pred <- predict(model, dtest) |
| label <- mx.io.extract(dtest, "label") |
| dataX <- mx.io.extract(dtest, "data") |
| # Predict with R's array |
| pred2 <- predict(model, X = dataX) |
| |
| accuracy <- function(label, pred) { |
| ypred <- max.col(t(as.array(pred))) |
| return(sum((as.array(label) + 1) == ypred) / length(label)) |
| } |
| |
| print(paste0("Finish prediction... accuracy=", accuracy(label, pred))) |
| print(paste0("Finish prediction... accuracy2=", accuracy(label, pred2))) |