blob: b65611215b7abd9dc2743db7e0868a8cfd43ad85 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
*/
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "mxnet-cpp/MxNetCpp.h"
// Allow IDE to parse the types
#include "../include/mxnet-cpp/op.h"
using namespace mxnet::cpp;
static const Symbol BN_BETA;
static const Symbol BN_GAMMA;
Symbol ConvFactoryBN(Symbol data, int num_filter,
Shape kernel, Shape stride, Shape pad,
const std::string & name,
const std::string & suffix = "") {
Symbol conv_w("conv_" + name + suffix + "_w"), conv_b("conv_" + name + suffix + "_b");
Symbol conv = Convolution("conv_" + name + suffix, data,
conv_w, conv_b, kernel,
num_filter, stride, Shape(1, 1), pad);
Symbol bn = BatchNorm("bn_" + name + suffix, conv, BN_GAMMA, BN_BETA);
return Activation("relu_" + name + suffix, bn, "relu");
}
Symbol InceptionFactoryA(Symbol data, int num_1x1, int num_3x3red,
int num_3x3, int num_d3x3red, int num_d3x3,
PoolingPoolType pool, int proj,
const std::string & name) {
Symbol c1x1 = ConvFactoryBN(data, num_1x1, Shape(1, 1), Shape(1, 1),
Shape(0, 0), name + "1x1");
Symbol c3x3r = ConvFactoryBN(data, num_3x3red, Shape(1, 1), Shape(1, 1),
Shape(0, 0), name + "_3x3r");
Symbol c3x3 = ConvFactoryBN(c3x3r, num_3x3, Shape(3, 3), Shape(1, 1),
Shape(1, 1), name + "_3x3");
Symbol cd3x3r = ConvFactoryBN(data, num_d3x3red, Shape(1, 1), Shape(1, 1),
Shape(0, 0), name + "_double_3x3", "_reduce");
Symbol cd3x3 = ConvFactoryBN(cd3x3r, num_d3x3, Shape(3, 3), Shape(1, 1),
Shape(1, 1), name + "_double_3x3_0");
cd3x3 = ConvFactoryBN(data = cd3x3, num_d3x3, Shape(3, 3), Shape(1, 1),
Shape(1, 1), name + "_double_3x3_1");
Symbol pooling = Pooling(name + "_pool", data,
Shape(3, 3), pool, false, false,
PoolingPoolingConvention::valid,
Shape(1, 1), Shape(1, 1));
Symbol cproj = ConvFactoryBN(pooling, proj, Shape(1, 1), Shape(1, 1),
Shape(0, 0), name + "_proj");
std::vector<Symbol> lst;
lst.push_back(c1x1);
lst.push_back(c3x3);
lst.push_back(cd3x3);
lst.push_back(cproj);
return Concat("ch_concat_" + name + "_chconcat", lst, lst.size());
}
Symbol InceptionFactoryB(Symbol data, int num_3x3red, int num_3x3,
int num_d3x3red, int num_d3x3, const std::string & name) {
Symbol c3x3r = ConvFactoryBN(data, num_3x3red, Shape(1, 1),
Shape(1, 1), Shape(0, 0),
name + "_3x3", "_reduce");
Symbol c3x3 = ConvFactoryBN(c3x3r, num_3x3, Shape(3, 3), Shape(2, 2),
Shape(1, 1), name + "_3x3");
Symbol cd3x3r = ConvFactoryBN(data, num_d3x3red, Shape(1, 1), Shape(1, 1),
Shape(0, 0), name + "_double_3x3", "_reduce");
Symbol cd3x3 = ConvFactoryBN(cd3x3r, num_d3x3, Shape(3, 3), Shape(1, 1),
Shape(1, 1), name + "_double_3x3_0");
cd3x3 = ConvFactoryBN(cd3x3, num_d3x3, Shape(3, 3), Shape(2, 2),
Shape(1, 1), name + "_double_3x3_1");
Symbol pooling = Pooling("max_pool_" + name + "_pool", data,
Shape(3, 3), PoolingPoolType::max,
false, false, PoolingPoolingConvention::valid, Shape(2, 2));
std::vector<Symbol> lst;
lst.push_back(c3x3);
lst.push_back(cd3x3);
lst.push_back(pooling);
return Concat("ch_concat_" + name + "_chconcat", lst, lst.size());
}
Symbol InceptionSymbol(int num_classes) {
// data and label
Symbol data = Symbol::Variable("data");
Symbol data_label = Symbol::Variable("data_label");
// stage 1
Symbol conv1 = ConvFactoryBN(data, 64, Shape(7, 7), Shape(2, 2), Shape(3, 3), "conv1");
Symbol pool1 = Pooling("pool1", conv1, Shape(3, 3), PoolingPoolType::max,
false, false, PoolingPoolingConvention::valid, Shape(2, 2));
// stage 2
Symbol conv2red = ConvFactoryBN(pool1, 64, Shape(1, 1), Shape(1, 1), Shape(0, 0), "conv2red");
Symbol conv2 = ConvFactoryBN(conv2red, 192, Shape(3, 3), Shape(1, 1), Shape(1, 1), "conv2");
Symbol pool2 = Pooling("pool2", conv2, Shape(3, 3), PoolingPoolType::max,
false, false, PoolingPoolingConvention::valid, Shape(2, 2));
// stage 3
Symbol in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, PoolingPoolType::avg, 32, "3a");
Symbol in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, PoolingPoolType::avg, 64, "3b");
Symbol in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, "3c");
// stage 4
Symbol in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, PoolingPoolType::avg, 128, "4a");
Symbol in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, PoolingPoolType::avg, 128, "4b");
Symbol in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, PoolingPoolType::avg, 128, "4c");
Symbol in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, PoolingPoolType::avg, 128, "4d");
Symbol in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, "4e");
// stage 5
Symbol in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, PoolingPoolType::avg, 128, "5a");
Symbol in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, PoolingPoolType::max, 128, "5b");
// average pooling
Symbol avg = Pooling("global_pool", in5b, Shape(7, 7), PoolingPoolType::avg);
// classifier
Symbol flatten = Flatten("flatten", avg);
Symbol conv1_w("conv1_w"), conv1_b("conv1_b");
Symbol fc1 = FullyConnected("fc1", flatten, conv1_w, conv1_b, num_classes);
return SoftmaxOutput("softmax", fc1, data_label);
}
int main(int argc, char const *argv[]) {
int batch_size = 40;
int max_epoch = 100;
float learning_rate = 1e-4;
float weight_decay = 1e-4;
auto inception_bn_net = InceptionSymbol(10);
std::map<std::string, NDArray> args_map;
std::map<std::string, NDArray> aux_map;
args_map["data"] = NDArray(Shape(batch_size, 3, 224, 224), Context::gpu());
args_map["data_label"] = NDArray(Shape(batch_size), Context::gpu());
inception_bn_net.InferArgsMap(Context::gpu(), &args_map, args_map);
auto train_iter = MXDataIter("ImageRecordIter")
.SetParam("path_imglist", "./train.lst")
.SetParam("path_imgrec", "./train.rec")
.SetParam("data_shape", Shape(3, 224, 224))
.SetParam("batch_size", batch_size)
.SetParam("shuffle", 1)
.CreateDataIter();
auto val_iter = MXDataIter("ImageRecordIter")
.SetParam("path_imglist", "./val.lst")
.SetParam("path_imgrec", "./val.rec")
.SetParam("data_shape", Shape(3, 224, 224))
.SetParam("batch_size", batch_size)
.CreateDataIter();
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
->SetParam("clip_gradient", 10);
auto *exec = inception_bn_net.SimpleBind(Context::gpu(), args_map);
for (int iter = 0; iter < max_epoch; ++iter) {
LG << "Epoch: " << iter;
train_iter.Reset();
while (train_iter.Next()) {
auto data_batch = train_iter.GetDataBatch();
data_batch.data.CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
exec->Forward(true);
exec->Backward();
exec->UpdateAll(opt, learning_rate, weight_decay);
NDArray::WaitAll();
}
Accuracy acu;
val_iter.Reset();
while (val_iter.Next()) {
auto data_batch = val_iter.GetDataBatch();
data_batch.data.CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
exec->Forward(false);
NDArray::WaitAll();
acu.Update(data_batch.label, exec->outputs[0]);
}
LG << "Accuracy: " << acu.Get();
}
delete exec;
MXNotifyShutdown();
return 0;
}