blob: c6b1fd90f846ae34906ef8a0b2a8ea90fa490d90 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*/
#include <map>
#include <string>
#include <vector>
#include <fstream>
#include <chrono>
#include <cstdlib>
#include "utils.h"
#include "mxnet-cpp/MxNetCpp.h"
using namespace mxnet::cpp;
Symbol LenetSymbol() {
/*
* LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner.
* "Gradient-based learning applied to document recognition."
* Proceedings of the IEEE (1998)
* */
/*define the symbolic net*/
Symbol data = Symbol::Variable("data");
Symbol data_label = Symbol::Variable("data_label");
Symbol conv1_w("conv1_w"), conv1_b("conv1_b");
Symbol conv2_w("conv2_w"), conv2_b("conv2_b");
Symbol conv3_w("conv3_w"), conv3_b("conv3_b");
Symbol fc1_w("fc1_w"), fc1_b("fc1_b");
Symbol fc2_w("fc2_w"), fc2_b("fc2_b");
Symbol conv1 = Convolution("conv1", data, conv1_w, conv1_b, Shape(5, 5), 20);
Symbol tanh1 = Activation("tanh1", conv1, ActivationActType::kTanh);
Symbol pool1 = Pooling("pool1", tanh1, Shape(2, 2), PoolingPoolType::kMax,
false, false, PoolingPoolingConvention::kValid, Shape(2, 2));
Symbol conv2 = Convolution("conv2", pool1, conv2_w, conv2_b, Shape(5, 5), 50);
Symbol tanh2 = Activation("tanh2", conv2, ActivationActType::kTanh);
Symbol pool2 = Pooling("pool2", tanh2, Shape(2, 2), PoolingPoolType::kMax,
false, false, PoolingPoolingConvention::kValid, Shape(2, 2));
Symbol flatten = Flatten("flatten", pool2);
Symbol fc1 = FullyConnected("fc1", flatten, fc1_w, fc1_b, 500);
Symbol tanh3 = Activation("tanh3", fc1, ActivationActType::kTanh);
Symbol fc2 = FullyConnected("fc2", tanh3, fc2_w, fc2_b, 10);
Symbol lenet = SoftmaxOutput("softmax", fc2, data_label);
return lenet;
}
NDArray ResizeInput(NDArray data, const Shape new_shape) {
NDArray pic = data.Reshape(Shape(0, 1, 28, 28));
NDArray output;
Operator("_contrib_BilinearResize2D")
.SetParam("height", new_shape[2])
.SetParam("width", new_shape[3])
(pic).Invoke(output);
return output;
}
int main(int argc, char const *argv[]) {
/*setup basic configs*/
int W = 28;
int H = 28;
int batch_size = 128;
int max_epoch = argc > 1 ? strtol(argv[1], nullptr, 10) : 100;
float learning_rate = 1e-4;
float weight_decay = 1e-4;
auto dev_ctx = Context::cpu();
int num_gpu;
MXGetGPUCount(&num_gpu);
#if MXNET_USE_CUDA
if (num_gpu > 0) {
dev_ctx = Context::gpu();
}
#endif
TRY
auto lenet = LenetSymbol();
std::map<std::string, NDArray> args_map;
const Shape data_shape = Shape(batch_size, 1, H, W),
label_shape = Shape(batch_size);
args_map["data"] = NDArray(data_shape, dev_ctx);
args_map["data_label"] = NDArray(label_shape, dev_ctx);
lenet.InferArgsMap(dev_ctx, &args_map, args_map);
args_map["fc1_w"] = NDArray(Shape(500, 4 * 4 * 50), dev_ctx);
NDArray::SampleGaussian(0, 1, &args_map["fc1_w"]);
args_map["fc2_b"] = NDArray(Shape(10), dev_ctx);
args_map["fc2_b"] = 0;
std::vector<std::string> data_files = { "./data/mnist_data/train-images-idx3-ubyte",
"./data/mnist_data/train-labels-idx1-ubyte",
"./data/mnist_data/t10k-images-idx3-ubyte",
"./data/mnist_data/t10k-labels-idx1-ubyte"
};
auto train_iter = MXDataIter("MNISTIter");
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}
auto val_iter = MXDataIter("MNISTIter");
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}
Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0)
->SetParam("clip_gradient", 10)
->SetParam("lr", learning_rate)
->SetParam("wd", weight_decay);
auto *exec = lenet.SimpleBind(dev_ctx, args_map);
auto arg_names = lenet.ListArguments();
// Create metrics
Accuracy train_acc, val_acc;
for (int iter = 0; iter < max_epoch; ++iter) {
int samples = 0;
train_iter.Reset();
train_acc.Reset();
auto tic = std::chrono::system_clock::now();
while (train_iter.Next()) {
samples += batch_size;
auto data_batch = train_iter.GetDataBatch();
ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
// Compute gradients
exec->Forward(true);
exec->Backward();
// Update parameters
for (size_t i = 0; i < arg_names.size(); ++i) {
if (arg_names[i] == "data" || arg_names[i] == "data_label") continue;
opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
}
// Update metric
train_acc.Update(data_batch.label, exec->outputs[0]);
}
// one epoch of training is finished
auto toc = std::chrono::system_clock::now();
float duration = std::chrono::duration_cast<std::chrono::milliseconds>
(toc - tic).count() / 1000.0;
LG << "Epoch[" << iter << "] " << samples / duration \
<< " samples/sec " << "Train-Accuracy=" << train_acc.Get();;
val_iter.Reset();
val_acc.Reset();
Accuracy acu;
val_iter.Reset();
while (val_iter.Next()) {
auto data_batch = val_iter.GetDataBatch();
ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]);
data_batch.label.CopyTo(&args_map["data_label"]);
NDArray::WaitAll();
// Only forward pass is enough as no gradient is needed when evaluating
exec->Forward(false);
NDArray::WaitAll();
acu.Update(data_batch.label, exec->outputs[0]);
val_acc.Update(data_batch.label, exec->outputs[0]);
}
LG << "Epoch[" << iter << "] Val-Accuracy=" << val_acc.Get();
}
delete exec;
delete opt;
MXNotifyShutdown();
CATCH
return 0;
}