| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| require(mxnet) |
| |
| source("get_data.R") |
| |
| context("models") |
| |
| if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) == |
| 1) { |
| mx.ctx.default(new = mx.gpu()) |
| message("Using GPU for testing.") |
| } |
| |
| test_that("MNIST", { |
| # # Network configuration |
| GetMNIST_ubyte() |
| batch.size <- 100 |
| data <- mx.symbol.Variable("data") |
| fc1 <- mx.symbol.FullyConnected(data, name = "fc1", num_hidden = 128) |
| act1 <- mx.symbol.Activation(fc1, name = "relu1", act_type = "relu") |
| fc2 <- mx.symbol.FullyConnected(act1, name = "fc2", num_hidden = 64) |
| act2 <- mx.symbol.Activation(fc2, name = "relu2", act_type = "relu") |
| fc3 <- mx.symbol.FullyConnected(act2, name = "fc3", num_hidden = 10) |
| softmax <- mx.symbol.Softmax(fc3, name = "sm") |
| |
| dtrain <- mx.io.MNISTIter(image = "data/train-images-idx3-ubyte", label = "data/train-labels-idx1-ubyte", |
| data.shape = c(784), batch.size = batch.size, shuffle = TRUE, flat = TRUE, |
| silent = 0, seed = 10) |
| |
| dtest <- mx.io.MNISTIter(image = "data/t10k-images-idx3-ubyte", label = "data/t10k-labels-idx1-ubyte", |
| data.shape = c(784), batch.size = batch.size, shuffle = FALSE, flat = TRUE, |
| silent = 0) |
| |
| mx.set.seed(0) |
| |
| # create the model |
| model <- mx.model.FeedForward.create(softmax, X = dtrain, eval.data = dtest, |
| ctx = mx.ctx.default(), num.round = 1, learning.rate = 0.1, momentum = 0.9, |
| initializer = mx.init.uniform(0.07), epoch.end.callback = mx.callback.save.checkpoint("chkpt"), |
| batch.end.callback = mx.callback.log.train.metric(100)) |
| |
| # do prediction |
| pred <- predict(model, dtest) |
| label <- mx.io.extract(dtest, "label") |
| dataX <- mx.io.extract(dtest, "data") |
| # Predict with R's array |
| pred2 <- predict(model, X = dataX) |
| |
| accuracy <- function(label, pred) { |
| ypred <- max.col(t(as.array(pred))) |
| return(sum((as.array(label) + 1) == ypred)/length(label)) |
| } |
| |
| expect_equal(accuracy(label, pred), accuracy(label, pred2), tolerance = 0.1) |
| |
| file.remove("chkpt-0001.params") |
| file.remove("chkpt-symbol.json") |
| }) |
| |
| test_that("Regression", { |
| data(BostonHousing, package = "mlbench") |
| train.ind <- seq(1, 506, 3) |
| train.x <- data.matrix(BostonHousing[train.ind, -14]) |
| train.y <- BostonHousing[train.ind, 14] |
| test.x <- data.matrix(BostonHousing[-train.ind, -14]) |
| test.y <- BostonHousing[-train.ind, 14] |
| data <- mx.symbol.Variable("data") |
| fc1 <- mx.symbol.FullyConnected(data, num_hidden = 1) |
| lro <- mx.symbol.LinearRegressionOutput(fc1) |
| |
| demo.metric.mae <- mx.metric.custom("mae", function(label, pred) { |
| pred <- mx.nd.reshape(pred, shape = 0) |
| res <- mx.nd.mean(mx.nd.abs(label - pred)) |
| return(as.array(res)) |
| }) |
| mx.set.seed(0) |
| model <- mx.model.FeedForward.create(lro, X = train.x, y = train.y, ctx = mx.ctx.default(), |
| num.round = 5, array.batch.size = 20, learning.rate = 2e-06, momentum = 0.9, |
| eval.metric = demo.metric.mae) |
| |
| train.x <- data.matrix(BostonHousing[train.ind, -(13:14)]) |
| train.y <- BostonHousing[train.ind, c(13:14)] |
| test.x <- data.matrix(BostonHousing[-train.ind, -(13:14)]) |
| test.y <- BostonHousing[-train.ind, c(13:14)] |
| |
| data <- mx.symbol.Variable("data") |
| fc2 <- mx.symbol.FullyConnected(data, num_hidden = 2) |
| lro2 <- mx.symbol.LinearRegressionOutput(fc2) |
| |
| mx.set.seed(0) |
| train_iter <- mx.io.arrayiter(data = t(train.x), label = t(train.y)) |
| |
| model <- mx.model.FeedForward.create(lro2, X = train_iter, ctx = mx.ctx.default(), |
| num.round = 50, array.batch.size = 20, learning.rate = 2e-06, momentum = 0.9) |
| }) |
| |
| |
| test_that("Classification", { |
| data(Sonar, package = "mlbench") |
| Sonar[, 61] <- as.numeric(Sonar[, 61]) - 1 |
| train.ind <- c(1:50, 100:150) |
| train.x <- data.matrix(Sonar[train.ind, 1:60]) |
| train.y <- Sonar[train.ind, 61] |
| test.x <- data.matrix(Sonar[-train.ind, 1:60]) |
| test.y <- Sonar[-train.ind, 61] |
| mx.set.seed(0) |
| model <- mx.mlp(train.x, train.y, hidden_node = 10, out_node = 2, out_activation = "softmax", |
| num.round = 5, array.batch.size = 15, learning.rate = 0.07, momentum = 0.9, |
| eval.metric = mx.metric.accuracy) |
| }) |
| |
| test_that("Fine-tune", { |
| GetInception() |
| GetCatDog() |
| train_iter <- mx.io.ImageRecordIter(path.imgrec = "./data/cats_dogs/cats_dogs_train.rec", |
| batch.size = 8, data.shape = c(224, 224, 3), rand.crop = TRUE, rand.mirror = TRUE) |
| val_iter <- mx.io.ImageRecordIter(path.imgrec = "./data/cats_dogs/cats_dogs_val.rec", |
| batch.size = 8, data.shape = c(224, 224, 3), rand.crop = FALSE, rand.mirror = FALSE) |
| inception_bn <- mx.model.load("./model/Inception-BN", iteration = 126) |
| symbol <- inception_bn$symbol |
| internals <- symbol$get.internals() |
| outputs <- internals$outputs |
| |
| flatten <- internals$get.output(which(outputs == "flatten_output")) |
| |
| new_fc <- mx.symbol.FullyConnected(data = flatten, num_hidden = 2, name = "fc1") |
| new_soft <- mx.symbol.SoftmaxOutput(data = new_fc, name = "softmax") |
| arg_params_new <- mx.model.init.params(symbol = new_soft, input.shape = list(data = c(224, |
| 224, 3, 8)), output.shape = NULL, initializer = mx.init.uniform(0.1), ctx = mx.cpu())$arg.params |
| fc1_weights_new <- arg_params_new[["fc1_weight"]] |
| fc1_bias_new <- arg_params_new[["fc1_bias"]] |
| |
| arg_params_new <- inception_bn$arg.params |
| |
| arg_params_new[["fc1_weight"]] <- fc1_weights_new |
| arg_params_new[["fc1_bias"]] <- fc1_bias_new |
| |
| # model <- mx.model.FeedForward.create(symbol = new_soft, X = train_iter, |
| # eval.data = val_iter, ctx = mx.ctx.default(), eval.metric = mx.metric.accuracy, |
| # num.round = 2, learning.rate = 0.05, momentum = 0.9, wd = 0.00001, kvstore = |
| # 'local', batch.end.callback = mx.callback.log.train.metric(50), initializer = |
| # mx.init.Xavier(factor_type = 'in', magnitude = 2.34), optimizer = 'sgd', |
| # arg.params = arg_params_new, aux.params = inception_bn$aux.params) |
| }) |
| |
| test_that("Matrix Factorization", { |
| |
| # Use fake random data instead of GetMovieLens() to remove external dependency |
| set.seed(123) |
| user <- sample(943, size = 1e+05, replace = T) |
| item <- sample(1682, size = 1e+05, replace = T) |
| score <- sample(5, size = 1e+05, replace = T) |
| DF <- data.frame(user, item, score) |
| |
| max_user <- max(DF$user) |
| max_item <- max(DF$item) |
| DF_mat_x <- data.matrix(t(DF[, 1:2])) |
| DF_y <- DF[, 3] |
| k <- 64 |
| user <- mx.symbol.Variable("user") |
| item <- mx.symbol.Variable("item") |
| score <- mx.symbol.Variable("score") |
| user1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(user), input_dim = max_user, |
| output_dim = k, name = "user1") |
| item1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(item), input_dim = max_item, |
| output_dim = k, name = "item1") |
| pred <- user1 * item1 |
| pred1 <- mx.symbol.sum_axis(pred, axis = 1, name = "pred1") |
| pred2 <- mx.symbol.Flatten(pred1, name = "pred2") |
| pred3 <- mx.symbol.LinearRegressionOutput(data = pred2, label = score, name = "pred3") |
| |
| mx.set.seed(123) |
| |
| CustomIter <- setRefClass("CustomIter", fields = c("iter1", "iter2"), contains = "Rcpp_MXArrayDataIter", |
| methods = list(initialize = function(iter1, iter2) { |
| .self$iter1 <- iter1 |
| .self$iter2 <- iter2 |
| .self |
| }, value = function() { |
| user <- .self$iter1$value()$data |
| item <- .self$iter2$value()$data |
| score <- .self$iter1$value()$label |
| list(user = user, item = item, score = score) |
| }, iter.next = function() { |
| .self$iter1$iter.next() |
| .self$iter2$iter.next() |
| }, reset = function() { |
| .self$iter1$reset() |
| .self$iter2$reset() |
| }, num.pad = function() { |
| .self$iter1$num.pad() |
| }, finalize = function() { |
| .self$iter1$finalize() |
| .self$iter2$finalize() |
| })) |
| |
| user_iter <- mx.io.arrayiter(data = DF[, 1], label = DF[, 3], batch.size = k) |
| |
| item_iter <- mx.io.arrayiter(data = DF[, 2], label = DF[, 3], batch.size = k) |
| |
| train_iter <- CustomIter$new(user_iter, item_iter) |
| |
| model <- mx.model.FeedForward.create(pred3, X = train_iter, ctx = mx.ctx.default(), |
| num.round = 5, initializer = mx.init.uniform(0.07), learning.rate = 0.07, |
| eval.metric = mx.metric.rmse, momentum = 0.9, epoch.end.callback = mx.callback.log.train.metric(1), |
| input.names = c("user", "item"), output.names = "score") |
| }) |
| |
| test_that("Captcha", { |
| GetCaptcha_data() |
| data.shape <- c(80, 30, 3) |
| batch_size <- 40 |
| train <- mx.io.ImageRecordIter(path.imgrec = "./data/captcha_example/captcha_train.rec", |
| path.imglist = "./data/captcha_example/captcha_train.lst", batch.size = batch_size, |
| label.width = 4, data.shape = data.shape, mean.img = "mean.bin") |
| |
| val <- mx.io.ImageRecordIter(path.imgrec = "./data/captcha_example/captcha_test.rec", |
| path.imglist = "./data/captcha_example/captcha_test.lst", batch.size = batch_size, |
| label.width = 4, data.shape = data.shape, mean.img = "mean.bin") |
| |
| data <- mx.symbol.Variable("data") |
| label <- mx.symbol.Variable("label") |
| conv1 <- mx.symbol.Convolution(data = data, kernel = c(5, 5), num_filter = 32) |
| pool1 <- mx.symbol.Pooling(data = conv1, pool_type = "max", kernel = c(2, 2), |
| stride = c(1, 1)) |
| relu1 <- mx.symbol.Activation(data = pool1, act_type = "relu") |
| |
| conv2 <- mx.symbol.Convolution(data = relu1, kernel = c(5, 5), num_filter = 32) |
| pool2 <- mx.symbol.Pooling(data = conv2, pool_type = "avg", kernel = c(2, 2), |
| stride = c(1, 1)) |
| relu2 <- mx.symbol.Activation(data = pool2, act_type = "relu") |
| |
| flatten <- mx.symbol.Flatten(data = relu2) |
| fc1 <- mx.symbol.FullyConnected(data = flatten, num_hidden = 120) |
| fc21 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10) |
| fc22 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10) |
| fc23 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10) |
| fc24 <- mx.symbol.FullyConnected(data = fc1, num_hidden = 10) |
| fc2 <- mx.symbol.concat(c(fc21, fc22, fc23, fc24), dim = 0, num.args = 4) |
| label <- mx.symbol.transpose(data = label) |
| label <- mx.symbol.Reshape(data = label, target_shape = c(0)) |
| captcha_net <- mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax") |
| |
| mx.metric.acc2 <- mx.metric.custom("accuracy", function(label, pred) { |
| label <- as.array(label) |
| pred <- as.array(pred) |
| ypred <- max.col(t(pred)) - 1 |
| ypred <- matrix(ypred, nrow = nrow(label), ncol = ncol(label), byrow = TRUE) |
| return(sum(colSums(label == ypred) == 4)/ncol(label)) |
| }) |
| |
| mx.set.seed(42) |
| |
| train$reset() |
| train$iter.next() |
| |
| input.names <- "data" |
| input.shape <- sapply(input.names, function(n) { |
| dim(train$value()[[n]]) |
| }, simplify = FALSE) |
| arg_names <- arguments(captcha_net) |
| output.names <- "label" |
| output.shape <- sapply(output.names, function(n) { |
| dim(train$value()[[n]]) |
| }, simplify = FALSE) |
| params <- mx.model.init.params(captcha_net, input.shape, output.shape, mx.init.Xavier(factor_type = "in", |
| magnitude = 2.34), mx.cpu()) |
| |
| # model <- mx.model.FeedForward.create( X = train, eval.data = val, ctx = |
| # mx.ctx.default(), symbol = captcha_net, eval.metric = mx.metric.acc2, num.round |
| # = 1, learning.rate = 1e-04, momentum = 0.9, wd = 1e-05, batch.end.callback = |
| # mx.callback.log.train.metric(50), initializer = mx.init.Xavier(factor_type = |
| # 'in', magnitude = 2.34), optimizer = 'sgd', clip_gradient = 10) |
| }) |