blob: 5d812ec8575bff2ea5c25a0de9e0b87564b80986 [file] [log] [blame]
library(mxnet)
# Basic Conv + BN + ReLU factory
ConvFactory <- function(data, num_filter, kernel, stride = c(1,1),
pad = c(0, 0), act_type = "relu") {
conv = mx.symbol.Convolution(
data = data, num_filter = num_filter, kernel = kernel, stride = stride, pad =
pad
)
bn = mx.symbol.BatchNorm(data = conv)
act = mx.symbol.Activation(data = bn, act_type = act_type)
return(act)
}
# A Simple Downsampling Factory
DownsampleFactory <- function(data, ch_3x3) {
# conv 3x3
conv = ConvFactory(
data = data, kernel = c(3, 3), stride = c(2, 2), num_filter = ch_3x3, pad =
c(1, 1)
)
# pool
pool = mx.symbol.Pooling(
data = data, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), pool_type =
'max'
)
# concat
concat = mx.symbol.Concat(c(conv, pool), num.args = 2)
return(concat)
}
# A Simple module
SimpleFactory <- function(data, ch_1x1, ch_3x3) {
# 1x1
conv1x1 = ConvFactory(
data = data, kernel = c(1, 1), pad = c(0, 0), num_filter = ch_1x1
)
# 3x3
conv3x3 = ConvFactory(
data = data, kernel = c(3, 3), pad = c(1, 1), num_filter = ch_3x3
)
#concat
concat = mx.symbol.Concat(c(conv1x1, conv3x3), num.args = 2)
return(concat)
}
get_symbol <- function(num_classes = 10) {
data = mx.symbol.Variable(name = "data")
conv1 = ConvFactory(
data = data, kernel = c(3,3), pad = c(1,1), num_filter = 96,
act_type = "relu"
)
in3a = SimpleFactory(conv1, 32, 32)
in3b = SimpleFactory(in3a, 32, 48)
in3c = DownsampleFactory(in3b, 80)
in4a = SimpleFactory(in3c, 112, 48)
in4b = SimpleFactory(in4a, 96, 64)
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
in4e = DownsampleFactory(in4d, 96)
in5a = SimpleFactory(in4e, 176, 160)
in5b = SimpleFactory(in5a, 176, 160)
pool = mx.symbol.Pooling(
data = in5b, pool_type = "avg", kernel = c(7,7), name = "global_pool"
)
flatten = mx.symbol.Flatten(data = pool, name = "flatten1")
fc = mx.symbol.FullyConnected(data = flatten, num_hidden = num_classes, name =
"fc1")
softmax = mx.symbol.SoftmaxOutput(data = fc, name = "softmax")
return(softmax)
}