blob: 00bd1c592eed28d56c40fe8db863759bb35ae8ff [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*/
#include <iostream>
#include <map>
#include <string>
#include <fstream>
#include <cstdlib>
#include "utils.h"
#include "mxnet-cpp/MxNetCpp.h"
using namespace mxnet::cpp;
Symbol AlexnetSymbol(int num_classes) {
auto input_data = Symbol::Variable("data");
auto target_label = Symbol::Variable("label");
/*stage 1*/
auto conv1 = Operator("Convolution")
.SetParam("kernel", Shape(11, 11))
.SetParam("num_filter", 96)
.SetParam("stride", Shape(4, 4))
.SetParam("dilate", Shape(1, 1))
.SetParam("pad", Shape(0, 0))
.SetParam("num_group", 1)
.SetParam("workspace", 512)
.SetParam("no_bias", false)
.SetInput("data", input_data)
.CreateSymbol("conv1");
auto relu1 = Operator("Activation")
.SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */
.SetInput("data", conv1)
.CreateSymbol("relu1");
auto pool1 = Operator("Pooling")
.SetParam("kernel", Shape(3, 3))
.SetParam("pool_type", "max") /*avg,max,sum */
.SetParam("global_pool", false)
.SetParam("stride", Shape(2, 2))
.SetParam("pad", Shape(0, 0))
.SetInput("data", relu1)
.CreateSymbol("pool1");
auto lrn1 = Operator("LRN")
.SetParam("nsize", 5)
.SetParam("alpha", 0.0001)
.SetParam("beta", 0.75)
.SetParam("knorm", 1)
.SetInput("data", pool1)
.CreateSymbol("lrn1");
/*stage 2*/
auto conv2 = Operator("Convolution")
.SetParam("kernel", Shape(5, 5))
.SetParam("num_filter", 256)
.SetParam("stride", Shape(1, 1))
.SetParam("dilate", Shape(1, 1))
.SetParam("pad", Shape(2, 2))
.SetParam("num_group", 1)
.SetParam("workspace", 512)
.SetParam("no_bias", false)
.SetInput("data", lrn1)
.CreateSymbol("conv2");
auto relu2 = Operator("Activation")
.SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */
.SetInput("data", conv2)
.CreateSymbol("relu2");
auto pool2 = Operator("Pooling")
.SetParam("kernel", Shape(3, 3))
.SetParam("pool_type", "max") /*avg,max,sum */
.SetParam("global_pool", false)
.SetParam("stride", Shape(2, 2))
.SetParam("pad", Shape(0, 0))
.SetInput("data", relu2)
.CreateSymbol("pool2");
auto lrn2 = Operator("LRN")
.SetParam("nsize", 5)
.SetParam("alpha", 0.0001)
.SetParam("beta", 0.75)
.SetParam("knorm", 1)
.SetInput("data", pool2)
.CreateSymbol("lrn2");
/*stage 3*/
auto conv3 = Operator("Convolution")
.SetParam("kernel", Shape(3, 3))
.SetParam("num_filter", 384)
.SetParam("stride", Shape(1, 1))
.SetParam("dilate", Shape(1, 1))
.SetParam("pad", Shape(1, 1))
.SetParam("num_group", 1)
.SetParam("workspace", 512)
.SetParam("no_bias", false)
.SetInput("data", lrn2)
.CreateSymbol("conv3");
auto relu3 = Operator("Activation")
.SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */
.SetInput("data", conv3)
.CreateSymbol("relu3");
auto conv4 = Operator("Convolution")
.SetParam("kernel", Shape(3, 3))
.SetParam("num_filter", 384)
.SetParam("stride", Shape(1, 1))
.SetParam("dilate", Shape(1, 1))
.SetParam("pad", Shape(1, 1))
.SetParam("num_group", 1)
.SetParam("workspace", 512)
.SetParam("no_bias", false)
.SetInput("data", relu3)
.CreateSymbol("conv4");
auto relu4 = Operator("Activation")
.SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */
.SetInput("data", conv4)
.CreateSymbol("relu4");
auto conv5 = Operator("Convolution")
.SetParam("kernel", Shape(3, 3))
.SetParam("num_filter", 256)
.SetParam("stride", Shape(1, 1))
.SetParam("dilate", Shape(1, 1))
.SetParam("pad", Shape(1, 1))
.SetParam("num_group", 1)
.SetParam("workspace", 512)
.SetParam("no_bias", false)
.SetInput("data", relu4)
.CreateSymbol("conv5");
auto relu5 = Operator("Activation")
.SetParam("act_type", "relu")
.SetInput("data", conv5)
.CreateSymbol("relu5");
auto pool3 = Operator("Pooling")
.SetParam("kernel", Shape(3, 3))
.SetParam("pool_type", "max")
.SetParam("global_pool", false)
.SetParam("stride", Shape(2, 2))
.SetParam("pad", Shape(0, 0))
.SetInput("data", relu5)
.CreateSymbol("pool3");
/*stage4*/
auto flatten =
Operator("Flatten").SetInput("data", pool3).CreateSymbol("flatten");
auto fc1 = Operator("FullyConnected")
.SetParam("num_hidden", 4096)
.SetParam("no_bias", false)
.SetInput("data", flatten)
.CreateSymbol("fc1");
auto relu6 = Operator("Activation")
.SetParam("act_type", "relu")
.SetInput("data", fc1)
.CreateSymbol("relu6");
auto dropout1 = Operator("Dropout")
.SetParam("p", 0.5)
.SetInput("data", relu6)
.CreateSymbol("dropout1");
/*stage5*/
auto fc2 = Operator("FullyConnected")
.SetParam("num_hidden", 4096)
.SetParam("no_bias", false)
.SetInput("data", dropout1)
.CreateSymbol("fc2");
auto relu7 = Operator("Activation")
.SetParam("act_type", "relu")
.SetInput("data", fc2)
.CreateSymbol("relu7");
auto dropout2 = Operator("Dropout")
.SetParam("p", 0.5)
.SetInput("data", relu7)
.CreateSymbol("dropout2");
/*stage6*/
auto fc3 = Operator("FullyConnected")
.SetParam("num_hidden", num_classes)
.SetParam("no_bias", false)
.SetInput("data", dropout2)
.CreateSymbol("fc3");
auto softmax = Operator("SoftmaxOutput")
.SetParam("grad_scale", 1)
.SetParam("ignore_label", -1)
.SetParam("multi_output", false)
.SetParam("use_ignore", false)
.SetParam("normalization", "null") /*batch,null,valid */
.SetInput("data", fc3)
.SetInput("label", target_label)
.CreateSymbol("softmax");
return softmax;
}
NDArray ResizeInput(NDArray data, const Shape new_shape) {
NDArray pic = data.Reshape(Shape(0, 1, 28, 28));
NDArray pic_1channel;
Operator("_contrib_BilinearResize2D")
.SetParam("height", new_shape[2])
.SetParam("width", new_shape[3])
(pic).Invoke(pic_1channel);
NDArray output;
Operator("tile")
.SetParam("reps", Shape(1, 3, 1, 1))
(pic_1channel).Invoke(output);
return output;
}
int main(int argc, char const *argv[]) {
/*basic config*/
int max_epo = argc > 1 ? strtol(argv[1], nullptr, 10) : 100;
float learning_rate = 1e-4;
float weight_decay = 1e-4;
/*context*/
auto ctx = Context::cpu();
int num_gpu;
MXGetGPUCount(&num_gpu);
int batch_size = 32;
#if MXNET_USE_CUDA
if (num_gpu > 0) {
ctx = Context::gpu();
batch_size = 256;
}
#endif
TRY
/*net symbol*/
auto Net = AlexnetSymbol(10);
/*args_map and aux_map is used for parameters' saving*/
std::map<std::string, NDArray> args_map;
std::map<std::string, NDArray> aux_map;
/*we should tell mxnet the shape of data and label*/
const Shape data_shape = Shape(batch_size, 3, 256, 256),
label_shape = Shape(batch_size);
args_map["data"] = NDArray(data_shape, ctx);
args_map["label"] = NDArray(label_shape, ctx);
/*with data and label, executor can be generated automatically*/
auto *exec = Net.SimpleBind(ctx, args_map);
auto arg_names = Net.ListArguments();
aux_map = exec->aux_dict();
args_map = exec->arg_dict();
/*if fine tune from some pre-trained model, we should load the parameters*/
// NDArray::Load("./model/alex_params_3", nullptr, &args_map);
/*else, we should use initializer Xavier to init the params*/
auto initializer = Uniform(0.07);
for (auto &arg : args_map) {
/*be careful here, the arg's name must has some specific ends or starts for
* initializer to call*/
initializer(arg.first, &arg.second);
}
/*these binary files should be generated using im2rc tools, which can be found
* in mxnet/bin*/
std::vector<std::string> data_files = { "./data/mnist_data/train-images-idx3-ubyte",
"./data/mnist_data/train-labels-idx1-ubyte",
"./data/mnist_data/t10k-images-idx3-ubyte",
"./data/mnist_data/t10k-labels-idx1-ubyte"
};
auto train_iter = MXDataIter("MNISTIter");
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}
auto val_iter = MXDataIter("MNISTIter");
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}
Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0 / batch_size)
->SetParam("clip_gradient", 10)
->SetParam("lr", learning_rate)
->SetParam("wd", weight_decay);
Accuracy acu_train, acu_val;
LogLoss logloss_train, logloss_val;
for (int epoch = 0; epoch < max_epo; ++epoch) {
LG << "Train Epoch: " << epoch;
/*reset the metric every epoch*/
acu_train.Reset();
/*reset the data iter every epoch*/
train_iter.Reset();
int iter = 0;
while (train_iter.Next()) {
auto batch = train_iter.GetDataBatch();
/*use copyto to feed new data and label to the executor*/
ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]);
batch.label.CopyTo(&args_map["label"]);
exec->Forward(true);
exec->Backward();
for (size_t i = 0; i < arg_names.size(); ++i) {
if (arg_names[i] == "data" || arg_names[i] == "label") continue;
opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
}
NDArray::WaitAll();
acu_train.Update(batch.label, exec->outputs[0]);
logloss_train.Reset();
logloss_train.Update(batch.label, exec->outputs[0]);
++iter;
LG << "EPOCH: " << epoch << " ITER: " << iter
<< " Train Accuracy: " << acu_train.Get()
<< " Train Loss: " << logloss_train.Get();
}
LG << "EPOCH: " << epoch << " Train Accuracy: " << acu_train.Get();
LG << "Val Epoch: " << epoch;
acu_val.Reset();
val_iter.Reset();
logloss_val.Reset();
iter = 0;
while (val_iter.Next()) {
auto batch = val_iter.GetDataBatch();
ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]);
batch.label.CopyTo(&args_map["label"]);
exec->Forward(false);
NDArray::WaitAll();
acu_val.Update(batch.label, exec->outputs[0]);
logloss_val.Update(batch.label, exec->outputs[0]);
LG << "EPOCH: " << epoch << " ITER: " << iter << " Val Accuracy: " << acu_val.Get();
++iter;
}
LG << "EPOCH: " << epoch << " Val Accuracy: " << acu_val.Get();
LG << "EPOCH: " << epoch << " Val LogLoss: " << logloss_val.Get();
/*save the parameters*/
std::stringstream ss;
ss << epoch;
std::string epoch_str;
ss >> epoch_str;
std::string save_path_param = "alex_param_" + epoch_str;
auto save_args = args_map;
/*we do not want to save the data and label*/
save_args.erase(save_args.find("data"));
save_args.erase(save_args.find("label"));
/*the alexnet does not get any aux array, so we do not need to save
* aux_map*/
LG << "EPOCH: " << epoch << " Saving to..." << save_path_param;
NDArray::Save(save_path_param, save_args);
}
/*don't foget to release the executor*/
delete exec;
delete opt;
MXNotifyShutdown();
CATCH
return 0;
}