blob: 4af7d62cf533125a146484f7b9209c1b5ea9b2d7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
require(mxnet)
source("get_data.R")
if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) ==
1) {
mx.ctx.default(new = mx.gpu())
message("Using GPU for testing.")
}
print_inferred_shape <- function(net) {
slist <- mx.symbol.infer.shape(symbol = net, data = c(168, 168, 1, 2))
print(slist$out.shapes)
}
convolution_module <- function(net, kernel_size, pad_size, filter_count, stride = c(1,
1), work_space = 2048, batch_norm = TRUE, down_pool = FALSE, up_pool = FALSE,
act_type = "relu", convolution = TRUE) {
if (up_pool) {
net <- mx.symbol.Deconvolution(net, kernel = c(2, 2), pad = c(0, 0), stride = c(2,
2), num_filter = filter_count, workspace = work_space)
net <- mx.symbol.BatchNorm(net)
if (act_type != "") {
net <- mx.symbol.Activation(net, act_type = act_type)
}
}
if (convolution) {
conv <- mx.symbol.Convolution(data = net, kernel = kernel_size, stride = stride,
pad = pad_size, num_filter = filter_count, workspace = work_space)
net <- conv
}
if (batch_norm) {
net <- mx.symbol.BatchNorm(net)
}
if (act_type != "") {
net <- mx.symbol.Activation(net, act_type = act_type)
}
if (down_pool) {
pool <- mx.symbol.Pooling(net, pool_type = "max", kernel = c(2, 2), stride = c(2,
2))
net <- pool
}
print_inferred_shape(net)
return(net)
}
get_unet <- function() {
data <- mx.symbol.Variable("data")
kernel_size <- c(3, 3)
pad_size <- c(1, 1)
filter_count <- 32
pool1 <- convolution_module(data, kernel_size, pad_size, filter_count = filter_count,
down_pool = TRUE)
net <- pool1
pool2 <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
2, down_pool = TRUE)
net <- pool2
pool3 <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4, down_pool = TRUE)
net <- pool3
pool4 <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4, down_pool = TRUE)
net <- pool4
net <- mx.symbol.Dropout(net)
pool5 <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
8, down_pool = TRUE)
net <- pool5
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4, up_pool = TRUE)
net <- convolution_module(net, kernel_size, pad_size = c(2, 2), filter_count = filter_count *
4, up_pool = TRUE)
net <- mx.symbol.Crop(net, pool3, num.args = 2)
net <- mx.symbol.concat(c(pool3, net), num.args = 2)
net <- mx.symbol.Dropout(net)
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4)
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4, up_pool = TRUE)
net <- mx.symbol.concat(c(pool2, net), num.args = 2)
net <- mx.symbol.Dropout(net)
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4)
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4, up_pool = TRUE)
convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
4)
net <- mx.symbol.concat(c(pool1, net), num.args = 2)
net <- mx.symbol.Dropout(net)
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
2)
net <- convolution_module(net, kernel_size, pad_size, filter_count = filter_count *
2, up_pool = TRUE)
net <- convolution_module(net, kernel_size, pad_size, filter_count = 1, batch_norm = FALSE,
act_type = "")
net <- mx.symbol.SoftmaxOutput(data = net, name = "sm")
return(net)
}
context("Image segmentation")
test_that("UNET", {
list.of.packages <- c("imager")
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,
"Package"])]
if (length(new.packages))
install.packages(new.packages, repos = "https://cloud.r-project.org/")
GetISBI_data()
library(imager)
IMG_SIZE <- 168
files <- list.files(path = "data/ISBI/train-volume/")
a <- "data/ISBI/train-volume/"
filess <- paste(a, files, sep = "")
list_of_images <- lapply(filess, function(x) {
x <- load.image(x)
y <- resize(x, size_x = IMG_SIZE, size_y = IMG_SIZE)
})
train.x <- do.call("cbind", lapply(list_of_images, as.vector))
train.array <- train.x
dim(train.array) <- c(IMG_SIZE, IMG_SIZE, 1, 30)
files <- list.files(path = "data/ISBI/train-labels")
b <- "data/ISBI/train-labels/"
filess <- paste(b, files, sep = "")
list_of_images <- lapply(filess, function(x) {
x <- load.image(x)
y <- resize(x, size_x = IMG_SIZE, size_y = IMG_SIZE)
})
train.y <- do.call("cbind", lapply(list_of_images, as.vector))
train.y[which(train.y < 0.5)] <- 0
train.y[which(train.y > 0.5)] <- 1
train.y.array <- train.y
dim(train.y.array) <- c(IMG_SIZE, IMG_SIZE, 1, 30)
devices <- mx.ctx.default()
mx.set.seed(0)
net <- get_unet()
model <- mx.model.FeedForward.create(net, X = train.array, y = train.y.array,
ctx = devices, num.round = 2, initializer = mx.init.normal(sqrt(2/576)),
learning.rate = 0.05, momentum = 0.99, array.batch.size = 2)
})