blob: a00664ce3a50abc3599c954533df49c2d97d309f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
using MXNet
#--------------------------------------------------------------------------------
# Helper functions to construct larger networks
# basic Conv + BN + ReLU factory
function conv_factory(data, num_filter, kernel; stride=(1,1), pad=(0,0), act_type=:relu)
conv = mx.Convolution(data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = mx.BatchNorm(conv)
act = mx.Activation(bn, act_type=act_type)
return act
end
# simple downsampling factory
function downsample_factory(data, ch_3x3)
# conv 3x3
conv = conv_factory(data, ch_3x3, (3,3), stride=(2,2), pad=(1,1))
# pool
pool = mx.Pooling(data, kernel=(3,3), stride=(2,2), pool_type=:max)
# concat
concat = mx.Concat(conv, pool)
return concat
end
# a simple module
function simple_factory(data, ch_1x1, ch_3x3)
# 1x1
conv1x1 = conv_factory(data, ch_1x1, (1,1); pad=(0,0))
# 3x3
conv3x3 = conv_factory(data, ch_3x3, (3,3); pad=(1,1))
# concat
concat = mx.Concat(conv1x1, conv3x3)
return concat
end
#--------------------------------------------------------------------------------
# Actual architecture
data = mx.Variable(:data)
conv1 = conv_factory(data, 96, (3,3); pad=(1,1), act_type=:relu)
in3a = simple_factory(conv1, 32, 32)
in3b = simple_factory(in3a, 32, 48)
in3c = downsample_factory(in3b, 80)
in4a = simple_factory(in3c, 112, 48)
in4b = simple_factory(in4a, 96, 64)
in4c = simple_factory(in4b, 80, 80)
in4d = simple_factory(in4b, 48, 96)
in4e = downsample_factory(in4d, 96)
in5a = simple_factory(in4e, 176, 160)
in5b = simple_factory(in5a, 176, 160)
pool = mx.Pooling(in5b, pool_type=:avg, kernel=(7,7), name=:global_pool)
flatten = mx.Flatten(pool, name=:flatten1)
fc = mx.FullyConnected(flatten, num_hidden=10, name=:fc1)
softmax = mx.SoftmaxOutput(fc, name=:loss)
#--------------------------------------------------------------------------------
# Prepare data
filenames = mx.get_cifar10()
batch_size = 128
num_epoch = 10
num_gpus = 8
train_provider = mx.ImageRecordProvider(label_name=:loss_label,
path_imgrec=filenames[:train], mean_img=filenames[:mean],
rand_crop=true, rand_mirror=true, data_shape=(28,28,3),
batch_size=batch_size, preprocess_threads=1)
test_provider = mx.ImageRecordProvider(label_name=:loss_label,
path_imgrec=filenames[:test], mean_img=filenames[:mean],
rand_crop=false, rand_mirror=false, data_shape=(28,28,3),
batch_size=batch_size, preprocess_threads=1)
#--------------------------------------------------------------------------------
# Training model
gpus = [mx.Context(mx.GPU, i) for i = 0:num_gpus-1]
model = mx.FeedForward(softmax, context=gpus)
# optimizer
optimizer = mx.SGD(η=0.05, μ=0.9, λ=0.0001)
# fit parameters
mx.fit(model, optimizer, train_provider, n_epoch=num_epoch, eval_data=test_provider,
initializer=mx.UniformInitializer(0.07), callbacks=[mx.speedometer()])