| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| */ |
| #include <iostream> |
| #include <map> |
| #include <string> |
| #include <fstream> |
| #include <cstdlib> |
| #include "utils.h" |
| #include "mxnet-cpp/MxNetCpp.h" |
| |
| using namespace mxnet::cpp; |
| |
| Symbol AlexnetSymbol(int num_classes) { |
| auto input_data = Symbol::Variable("data"); |
| auto target_label = Symbol::Variable("label"); |
| /*stage 1*/ |
| auto conv1 = Operator("Convolution") |
| .SetParam("kernel", Shape(11, 11)) |
| .SetParam("num_filter", 96) |
| .SetParam("stride", Shape(4, 4)) |
| .SetParam("dilate", Shape(1, 1)) |
| .SetParam("pad", Shape(0, 0)) |
| .SetParam("num_group", 1) |
| .SetParam("workspace", 512) |
| .SetParam("no_bias", false) |
| .SetInput("data", input_data) |
| .CreateSymbol("conv1"); |
| auto relu1 = Operator("Activation") |
| .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ |
| .SetInput("data", conv1) |
| .CreateSymbol("relu1"); |
| auto pool1 = Operator("Pooling") |
| .SetParam("kernel", Shape(3, 3)) |
| .SetParam("pool_type", "max") /*avg,max,sum */ |
| .SetParam("global_pool", false) |
| .SetParam("stride", Shape(2, 2)) |
| .SetParam("pad", Shape(0, 0)) |
| .SetInput("data", relu1) |
| .CreateSymbol("pool1"); |
| auto lrn1 = Operator("LRN") |
| .SetParam("nsize", 5) |
| .SetParam("alpha", 0.0001) |
| .SetParam("beta", 0.75) |
| .SetParam("knorm", 1) |
| .SetInput("data", pool1) |
| .CreateSymbol("lrn1"); |
| /*stage 2*/ |
| auto conv2 = Operator("Convolution") |
| .SetParam("kernel", Shape(5, 5)) |
| .SetParam("num_filter", 256) |
| .SetParam("stride", Shape(1, 1)) |
| .SetParam("dilate", Shape(1, 1)) |
| .SetParam("pad", Shape(2, 2)) |
| .SetParam("num_group", 1) |
| .SetParam("workspace", 512) |
| .SetParam("no_bias", false) |
| .SetInput("data", lrn1) |
| .CreateSymbol("conv2"); |
| auto relu2 = Operator("Activation") |
| .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ |
| .SetInput("data", conv2) |
| .CreateSymbol("relu2"); |
| auto pool2 = Operator("Pooling") |
| .SetParam("kernel", Shape(3, 3)) |
| .SetParam("pool_type", "max") /*avg,max,sum */ |
| .SetParam("global_pool", false) |
| .SetParam("stride", Shape(2, 2)) |
| .SetParam("pad", Shape(0, 0)) |
| .SetInput("data", relu2) |
| .CreateSymbol("pool2"); |
| auto lrn2 = Operator("LRN") |
| .SetParam("nsize", 5) |
| .SetParam("alpha", 0.0001) |
| .SetParam("beta", 0.75) |
| .SetParam("knorm", 1) |
| .SetInput("data", pool2) |
| .CreateSymbol("lrn2"); |
| /*stage 3*/ |
| auto conv3 = Operator("Convolution") |
| .SetParam("kernel", Shape(3, 3)) |
| .SetParam("num_filter", 384) |
| .SetParam("stride", Shape(1, 1)) |
| .SetParam("dilate", Shape(1, 1)) |
| .SetParam("pad", Shape(1, 1)) |
| .SetParam("num_group", 1) |
| .SetParam("workspace", 512) |
| .SetParam("no_bias", false) |
| .SetInput("data", lrn2) |
| .CreateSymbol("conv3"); |
| auto relu3 = Operator("Activation") |
| .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ |
| .SetInput("data", conv3) |
| .CreateSymbol("relu3"); |
| auto conv4 = Operator("Convolution") |
| .SetParam("kernel", Shape(3, 3)) |
| .SetParam("num_filter", 384) |
| .SetParam("stride", Shape(1, 1)) |
| .SetParam("dilate", Shape(1, 1)) |
| .SetParam("pad", Shape(1, 1)) |
| .SetParam("num_group", 1) |
| .SetParam("workspace", 512) |
| .SetParam("no_bias", false) |
| .SetInput("data", relu3) |
| .CreateSymbol("conv4"); |
| auto relu4 = Operator("Activation") |
| .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ |
| .SetInput("data", conv4) |
| .CreateSymbol("relu4"); |
| auto conv5 = Operator("Convolution") |
| .SetParam("kernel", Shape(3, 3)) |
| .SetParam("num_filter", 256) |
| .SetParam("stride", Shape(1, 1)) |
| .SetParam("dilate", Shape(1, 1)) |
| .SetParam("pad", Shape(1, 1)) |
| .SetParam("num_group", 1) |
| .SetParam("workspace", 512) |
| .SetParam("no_bias", false) |
| .SetInput("data", relu4) |
| .CreateSymbol("conv5"); |
| auto relu5 = Operator("Activation") |
| .SetParam("act_type", "relu") |
| .SetInput("data", conv5) |
| .CreateSymbol("relu5"); |
| auto pool3 = Operator("Pooling") |
| .SetParam("kernel", Shape(3, 3)) |
| .SetParam("pool_type", "max") |
| .SetParam("global_pool", false) |
| .SetParam("stride", Shape(2, 2)) |
| .SetParam("pad", Shape(0, 0)) |
| .SetInput("data", relu5) |
| .CreateSymbol("pool3"); |
| /*stage4*/ |
| auto flatten = |
| Operator("Flatten").SetInput("data", pool3).CreateSymbol("flatten"); |
| auto fc1 = Operator("FullyConnected") |
| .SetParam("num_hidden", 4096) |
| .SetParam("no_bias", false) |
| .SetInput("data", flatten) |
| .CreateSymbol("fc1"); |
| auto relu6 = Operator("Activation") |
| .SetParam("act_type", "relu") |
| .SetInput("data", fc1) |
| .CreateSymbol("relu6"); |
| auto dropout1 = Operator("Dropout") |
| .SetParam("p", 0.5) |
| .SetInput("data", relu6) |
| .CreateSymbol("dropout1"); |
| /*stage5*/ |
| auto fc2 = Operator("FullyConnected") |
| .SetParam("num_hidden", 4096) |
| .SetParam("no_bias", false) |
| .SetInput("data", dropout1) |
| .CreateSymbol("fc2"); |
| auto relu7 = Operator("Activation") |
| .SetParam("act_type", "relu") |
| .SetInput("data", fc2) |
| .CreateSymbol("relu7"); |
| auto dropout2 = Operator("Dropout") |
| .SetParam("p", 0.5) |
| .SetInput("data", relu7) |
| .CreateSymbol("dropout2"); |
| /*stage6*/ |
| auto fc3 = Operator("FullyConnected") |
| .SetParam("num_hidden", num_classes) |
| .SetParam("no_bias", false) |
| .SetInput("data", dropout2) |
| .CreateSymbol("fc3"); |
| auto softmax = Operator("SoftmaxOutput") |
| .SetParam("grad_scale", 1) |
| .SetParam("ignore_label", -1) |
| .SetParam("multi_output", false) |
| .SetParam("use_ignore", false) |
| .SetParam("normalization", "null") /*batch,null,valid */ |
| .SetInput("data", fc3) |
| .SetInput("label", target_label) |
| .CreateSymbol("softmax"); |
| return softmax; |
| } |
| |
| NDArray ResizeInput(NDArray data, const Shape new_shape) { |
| NDArray pic = data.Reshape(Shape(0, 1, 28, 28)); |
| NDArray pic_1channel; |
| Operator("_contrib_BilinearResize2D") |
| .SetParam("height", new_shape[2]) |
| .SetParam("width", new_shape[3]) |
| (pic).Invoke(pic_1channel); |
| NDArray output; |
| Operator("tile") |
| .SetParam("reps", Shape(1, 3, 1, 1)) |
| (pic_1channel).Invoke(output); |
| return output; |
| } |
| |
| int main(int argc, char const *argv[]) { |
| /*basic config*/ |
| int max_epo = argc > 1 ? strtol(argv[1], nullptr, 10) : 100; |
| float learning_rate = 1e-4; |
| float weight_decay = 1e-4; |
| |
| /*context*/ |
| auto ctx = Context::cpu(); |
| int num_gpu; |
| MXGetGPUCount(&num_gpu); |
| int batch_size = 32; |
| #if MXNET_USE_CUDA |
| if (num_gpu > 0) { |
| ctx = Context::gpu(); |
| batch_size = 256; |
| } |
| #endif |
| |
| TRY |
| /*net symbol*/ |
| auto Net = AlexnetSymbol(10); |
| |
| /*args_map and aux_map is used for parameters' saving*/ |
| std::map<std::string, NDArray> args_map; |
| std::map<std::string, NDArray> aux_map; |
| |
| /*we should tell mxnet the shape of data and label*/ |
| const Shape data_shape = Shape(batch_size, 3, 256, 256), |
| label_shape = Shape(batch_size); |
| args_map["data"] = NDArray(data_shape, ctx); |
| args_map["label"] = NDArray(label_shape, ctx); |
| |
| /*with data and label, executor can be generated automatically*/ |
| auto *exec = Net.SimpleBind(ctx, args_map); |
| auto arg_names = Net.ListArguments(); |
| aux_map = exec->aux_dict(); |
| args_map = exec->arg_dict(); |
| |
| /*if fine tune from some pre-trained model, we should load the parameters*/ |
| // NDArray::Load("./model/alex_params_3", nullptr, &args_map); |
| /*else, we should use initializer Xavier to init the params*/ |
| auto initializer = Uniform(0.07); |
| for (auto &arg : args_map) { |
| /*be careful here, the arg's name must has some specific ends or starts for |
| * initializer to call*/ |
| initializer(arg.first, &arg.second); |
| } |
| |
| /*these binary files should be generated using im2rc tools, which can be found |
| * in mxnet/bin*/ |
| std::vector<std::string> data_files = { "./data/mnist_data/train-images-idx3-ubyte", |
| "./data/mnist_data/train-labels-idx1-ubyte", |
| "./data/mnist_data/t10k-images-idx3-ubyte", |
| "./data/mnist_data/t10k-labels-idx1-ubyte" |
| }; |
| |
| auto train_iter = MXDataIter("MNISTIter"); |
| if (!setDataIter(&train_iter, "Train", data_files, batch_size)) { |
| return 1; |
| } |
| |
| auto val_iter = MXDataIter("MNISTIter"); |
| if (!setDataIter(&val_iter, "Label", data_files, batch_size)) { |
| return 1; |
| } |
| |
| Optimizer* opt = OptimizerRegistry::Find("sgd"); |
| opt->SetParam("momentum", 0.9) |
| ->SetParam("rescale_grad", 1.0 / batch_size) |
| ->SetParam("clip_gradient", 10) |
| ->SetParam("lr", learning_rate) |
| ->SetParam("wd", weight_decay); |
| |
| Accuracy acu_train, acu_val; |
| LogLoss logloss_train, logloss_val; |
| for (int epoch = 0; epoch < max_epo; ++epoch) { |
| LG << "Train Epoch: " << epoch; |
| /*reset the metric every epoch*/ |
| acu_train.Reset(); |
| /*reset the data iter every epoch*/ |
| train_iter.Reset(); |
| int iter = 0; |
| while (train_iter.Next()) { |
| auto batch = train_iter.GetDataBatch(); |
| /*use copyto to feed new data and label to the executor*/ |
| ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]); |
| batch.label.CopyTo(&args_map["label"]); |
| exec->Forward(true); |
| exec->Backward(); |
| for (size_t i = 0; i < arg_names.size(); ++i) { |
| if (arg_names[i] == "data" || arg_names[i] == "label") continue; |
| opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]); |
| } |
| |
| NDArray::WaitAll(); |
| acu_train.Update(batch.label, exec->outputs[0]); |
| logloss_train.Reset(); |
| logloss_train.Update(batch.label, exec->outputs[0]); |
| ++iter; |
| LG << "EPOCH: " << epoch << " ITER: " << iter |
| << " Train Accuracy: " << acu_train.Get() |
| << " Train Loss: " << logloss_train.Get(); |
| } |
| LG << "EPOCH: " << epoch << " Train Accuracy: " << acu_train.Get(); |
| |
| LG << "Val Epoch: " << epoch; |
| acu_val.Reset(); |
| val_iter.Reset(); |
| logloss_val.Reset(); |
| iter = 0; |
| while (val_iter.Next()) { |
| auto batch = val_iter.GetDataBatch(); |
| ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]); |
| batch.label.CopyTo(&args_map["label"]); |
| exec->Forward(false); |
| NDArray::WaitAll(); |
| acu_val.Update(batch.label, exec->outputs[0]); |
| logloss_val.Update(batch.label, exec->outputs[0]); |
| LG << "EPOCH: " << epoch << " ITER: " << iter << " Val Accuracy: " << acu_val.Get(); |
| ++iter; |
| } |
| LG << "EPOCH: " << epoch << " Val Accuracy: " << acu_val.Get(); |
| LG << "EPOCH: " << epoch << " Val LogLoss: " << logloss_val.Get(); |
| |
| /*save the parameters*/ |
| std::stringstream ss; |
| ss << epoch; |
| std::string epoch_str; |
| ss >> epoch_str; |
| std::string save_path_param = "alex_param_" + epoch_str; |
| auto save_args = args_map; |
| /*we do not want to save the data and label*/ |
| save_args.erase(save_args.find("data")); |
| save_args.erase(save_args.find("label")); |
| /*the alexnet does not get any aux array, so we do not need to save |
| * aux_map*/ |
| LG << "EPOCH: " << epoch << " Saving to..." << save_path_param; |
| NDArray::Save(save_path_param, save_args); |
| } |
| /*don't foget to release the executor*/ |
| delete exec; |
| delete opt; |
| MXNotifyShutdown(); |
| CATCH |
| return 0; |
| } |