blob: b2b0c6023ac15b09852831fb1f990aed5aeaa3a4 [file] [log] [blame]
"""
Contains the definition of the Inception Resnet V2 architecture.
As described in http://arxiv.org/abs/1602.07261.
Inception-v4, Inception-ResNet and the Impact of Residual Connections
on Learning
Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
"""
import mxnet as mx
def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True):
conv = mx.symbol.Convolution(
data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = mx.symbol.BatchNorm(data=conv)
if with_act:
act = mx.symbol.Activation(
data=bn, act_type=act_type, attr=mirror_attr)
return act
else:
return bn
def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
tower_conv = ConvFactory(net, 32, (1, 1))
tower_conv1_0 = ConvFactory(net, 32, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, 32, (3, 3), pad=(1, 1))
tower_conv2_0 = ConvFactory(net, 32, (1, 1))
tower_conv2_1 = ConvFactory(tower_conv2_0, 48, (3, 3), pad=(1, 1))
tower_conv2_2 = ConvFactory(tower_conv2_1, 64, (3, 3), pad=(1, 1))
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False)
net += scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
tower_conv = ConvFactory(net, 192, (1, 1))
tower_conv1_0 = ConvFactory(net, 129, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2))
tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1))
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False)
net += scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def block8(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
tower_conv = ConvFactory(net, 192, (1, 1))
tower_conv1_0 = ConvFactory(net, 192, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, 224, (1, 3), pad=(0, 1))
tower_conv1_2 = ConvFactory(tower_conv1_1, 256, (3, 1), pad=(1, 0))
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False)
net += scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def repeat(inputs, repetitions, layer, *args, **kwargs):
outputs = inputs
for i in range(repetitions):
outputs = layer(outputs, *args, **kwargs)
return outputs
def get_symbol(num_classes=1000, **kwargs):
data = mx.symbol.Variable(name='data')
conv1a_3_3 = ConvFactory(data=data, num_filter=32,
kernel=(3, 3), stride=(2, 2))
conv2a_3_3 = ConvFactory(conv1a_3_3, 32, (3, 3))
conv2b_3_3 = ConvFactory(conv2a_3_3, 64, (3, 3), pad=(1, 1))
maxpool3a_3_3 = mx.symbol.Pooling(
data=conv2b_3_3, kernel=(3, 3), stride=(2, 2), pool_type='max')
conv3b_1_1 = ConvFactory(maxpool3a_3_3, 80, (1, 1))
conv4a_3_3 = ConvFactory(conv3b_1_1, 192, (3, 3))
maxpool5a_3_3 = mx.symbol.Pooling(
data=conv4a_3_3, kernel=(3, 3), stride=(2, 2), pool_type='max')
tower_conv = ConvFactory(maxpool5a_3_3, 96, (1, 1))
tower_conv1_0 = ConvFactory(maxpool5a_3_3, 48, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, 64, (5, 5), pad=(2, 2))
tower_conv2_0 = ConvFactory(maxpool5a_3_3, 64, (1, 1))
tower_conv2_1 = ConvFactory(tower_conv2_0, 96, (3, 3), pad=(1, 1))
tower_conv2_2 = ConvFactory(tower_conv2_1, 96, (3, 3), pad=(1, 1))
tower_pool3_0 = mx.symbol.Pooling(data=maxpool5a_3_3, kernel=(
3, 3), stride=(1, 1), pad=(1, 1), pool_type='avg')
tower_conv3_1 = ConvFactory(tower_pool3_0, 64, (1, 1))
tower_5b_out = mx.symbol.Concat(
*[tower_conv, tower_conv1_1, tower_conv2_2, tower_conv3_1])
net = repeat(tower_5b_out, 10, block35, scale=0.17, input_num_channels=320)
tower_conv = ConvFactory(net, 384, (3, 3), stride=(2, 2))
tower_conv1_0 = ConvFactory(net, 256, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, 256, (3, 3), pad=(1, 1))
tower_conv1_2 = ConvFactory(tower_conv1_1, 384, (3, 3), stride=(2, 2))
tower_pool = mx.symbol.Pooling(net, kernel=(
3, 3), stride=(2, 2), pool_type='max')
net = mx.symbol.Concat(*[tower_conv, tower_conv1_2, tower_pool])
net = repeat(net, 20, block17, scale=0.1, input_num_channels=1088)
tower_conv = ConvFactory(net, 256, (1, 1))
tower_conv0_1 = ConvFactory(tower_conv, 384, (3, 3), stride=(2, 2))
tower_conv1 = ConvFactory(net, 256, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1, 288, (3, 3), stride=(2, 2))
tower_conv2 = ConvFactory(net, 256, (1, 1))
tower_conv2_1 = ConvFactory(tower_conv2, 288, (3, 3), pad=(1, 1))
tower_conv2_2 = ConvFactory(tower_conv2_1, 320, (3, 3), stride=(2, 2))
tower_pool = mx.symbol.Pooling(net, kernel=(
3, 3), stride=(2, 2), pool_type='max')
net = mx.symbol.Concat(
*[tower_conv0_1, tower_conv1_1, tower_conv2_2, tower_pool])
net = repeat(net, 9, block8, scale=0.2, input_num_channels=2080)
net = block8(net, with_act=False, input_num_channels=2080)
net = ConvFactory(net, 1536, (1, 1))
net = mx.symbol.Pooling(net, kernel=(
1, 1), global_pool=True, stride=(2, 2), pool_type='avg')
net = mx.symbol.Flatten(net)
net = mx.symbol.Dropout(data=net, p=0.2)
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes)
softmax = mx.symbol.SoftmaxOutput(data=net, name='softmax')
return softmax