blob: c86476f65417874439e3f311a89a805618392b32 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file optimizer.hpp
* \brief implementation of optimizer
* \author Chuntao Hong, Zhang Chen
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_HPP_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_HPP_
#include <algorithm>
#include <utility>
#include <numeric>
#include <map>
#include <cmath>
#include <string>
#include <vector>
#include "mxnet-cpp/optimizer.h"
#include "mxnet-cpp/op.h"
#include "mxnet-cpp/op_map.h"
namespace {
// TODO(lx75249): Add imperative operators to op.h under ndarray namespace
inline void _clip(mxnet::cpp::NDArray &data, float limit) {
data = mxnet::cpp::Operator("clip")
.SetParam("a_min", -limit)
.SetParam("a_max", limit)
.SetInput("data", data)
.Invoke()[0];
}
inline mxnet::cpp::NDArray _sqrt(mxnet::cpp::NDArray data) {
return mxnet::cpp::Operator("sqrt")
.SetInput("data", data)
.Invoke()[0];
}
} // namespace
namespace mxnet {
namespace cpp {
inline Optimizer::Optimizer(unsigned begin_num_update)
: begin_num_update_(begin_num_update),
num_update_(begin_num_update_) {
}
inline std::map<std::string, OptimizerCreator>& OptimizerRegistry::cmap() {
static std::map<std::string, OptimizerCreator> cmap_;
return cmap_;
}
inline OpMap*& Optimizer::op_map() {
static OpMap *op_map_ = new OpMap();
return op_map_;
}
inline Optimizer::~Optimizer() {}
inline void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr,
mx_float wd) {
params_["lr"] = std::to_string(lr);
params_["wd"] = std::to_string(wd);
UpdateCount_(index);
Update(index, weight, grad);
}
inline void Optimizer::CreateState_(int index, NDArray weight) {
}
inline std::string Optimizer::Serialize() const {
using ValueType = std::map<std::string, std::string>::value_type;
auto params = params_;
params.emplace("opt_type", GetType());
return std::accumulate(params.cbegin(), params.cend(), std::string(""),
[](const std::string& sum, const ValueType& i) {
return sum + '\n' + i.first + '=' + i.second;
}).substr(1);
}
inline const std::vector<const char*> Optimizer::GetParamKeys_() const {
std::vector<const char*> keys;
for (auto& iter : params_)
keys.push_back(iter.first.c_str());
return keys;
}
inline const std::vector<const char*> Optimizer::GetParamValues_() const {
std::vector<const char*> values;
for (auto& iter : params_)
values.push_back(iter.second.c_str());
return values;
}
inline unsigned Optimizer::UpdateCount_(int index) {
if (count_.count(index) == 0) {
count_.emplace(index, begin_num_update_);
}
unsigned new_count = ++count_[index];
num_update_ = std::max(num_update_, new_count);
return new_count;
}
inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility
MXNETCPP_REGISTER_OPTIMIZER(rmsprop, RMSPropOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer);
auto it = cmap().find(name);
if (it == cmap().end())
return nullptr;
return it->second();
}
inline int OptimizerRegistry::__REGISTER__(const std::string& name, OptimizerCreator creator) {
CHECK_EQ(cmap().count(name), 0) << name << " already registered";
cmap().emplace(name, std::move(creator));
return 0;
}
inline SGDOptimizer::SGDOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("sgd_update");
mom_update_handle_ = op_map()->GetSymbolCreator("sgd_mom_update");
}
inline std::string SGDOptimizer::GetType() const {
return "sgd";
}
inline SGDOptimizer::~SGDOptimizer() {
for (auto &it : states_) {
delete it.second;
}
}
inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (states_.count(index) == 0) {
CreateState_(index, weight);
}
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());
NDArrayHandle inputs[3];
inputs[0] = weight.GetHandle();
inputs[1] = grad.GetHandle();
int num_outputs = 1;
NDArrayHandle output = weight.GetHandle();
NDArrayHandle *outputs = &output;
if (states_[index] == nullptr) {
MXImperativeInvoke(update_handle_, 2, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
} else {
inputs[2] = states_[index]->GetHandle();
MXImperativeInvoke(mom_update_handle_, 3, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
}
}
inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
if (params_.count("momentum") == 0) {
states_[index] = nullptr;
} else {
states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*states_[index] = 0;
}
}
inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
alex_update_handle_ = op_map()->GetSymbolCreator("rmspropalex_update");
SetParam("gamma1", 0.9f);
SetParam("gamma2", 0.9f);
SetParam("epsilon", 1e-8);
}
inline std::string RMSPropOptimizer::GetType() const {
return "rmsprop";
}
inline RMSPropOptimizer::~RMSPropOptimizer() {
for (auto &it : n_) {
delete it.second;
}
for (auto &it : g_) {
delete it.second;
}
for (auto &it : delta_) {
delete it.second;
}
}
inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (n_.count(index) == 0) {
CreateState_(index, weight);
}
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());
NDArrayHandle inputs[5];
inputs[0] = weight.GetHandle();
inputs[1] = grad.GetHandle();
inputs[2] = n_[index]->GetHandle();
inputs[3] = g_[index]->GetHandle();
inputs[4] = delta_[index]->GetHandle();
int num_outputs = 1;
NDArrayHandle output = weight.GetHandle();
NDArrayHandle *outputs = &output;
MXImperativeInvoke(alex_update_handle_, 5, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
}
inline void RMSPropOptimizer::CreateState_(int index, NDArray weight) {
n_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*n_[index] = 0;
g_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*g_[index] = 0;
delta_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*delta_[index] = 0;
}
inline AdamOptimizer::AdamOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("adam_update");
SetParam("beta1", 0.9f);
SetParam("beta2", 0.999f);
SetParam("epsilon", 1e-8);
}
inline std::string AdamOptimizer::GetType() const {
return "adam";
}
inline AdamOptimizer::~AdamOptimizer() {
for (auto &it : mean_) {
delete it.second;
}
for (auto &it : var_) {
delete it.second;
}
}
inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (mean_.count(index) == 0) {
CreateState_(index, weight);
}
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());
float lr = std::stof(params_["lr"]);
float wd = std::stof(params_["wd"]);
float b1 = std::stof(params_["beta1"]);
float b2 = std::stof(params_["beta2"]);
float t = count_[index];
float coef1 = 1.0f - std::pow(b1, t);
float coef2 = 1.0f - std::pow(b2, t);
lr *= std::sqrt(coef2) / coef1;
NDArrayHandle inputs[4];
inputs[0] = weight.GetHandle();
inputs[1] = grad.GetHandle();
int num_outputs = 1;
NDArrayHandle output = weight.GetHandle();
NDArrayHandle *outputs = &output;
inputs[2] = mean_[index]->GetHandle();
inputs[3] = var_[index]->GetHandle();
MXImperativeInvoke(update_handle_, 4, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
}
inline void AdamOptimizer::CreateState_(int index, NDArray weight) {
mean_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*mean_[index] = 0;
var_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*var_[index] = 0;
}
inline AdaGradOptimizer::AdaGradOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
SetParam("eps", 1e-7);
}
inline std::string AdaGradOptimizer::GetType() const {
return "adagrad";
}
inline void AdaGradOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (history_.count(index) == 0) {
CreateState_(index, weight);
}
float lr = std::stof(params_["lr"]);
float wd = std::stof(params_["wd"]);
float eps = std::stof(params_["eps"]);
if (params_.count("rescale_grad") > 0) {
grad *= std::stof(params_["rescale_grad"]);
}
if (params_.count("clip_gradient") > 0) {
_clip(grad, std::stof(params_["clip_gradient"]));
}
auto& history = *history_[index];
history += grad * grad;
weight -= (grad / _sqrt(history + eps) + weight * wd) * lr;
}
inline AdaGradOptimizer::~AdaGradOptimizer() {
for (auto& it : history_) {
delete it.second;
}
}
inline void AdaGradOptimizer::CreateState_(int index, NDArray weight) {
history_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*history_[index] = 0;
}
inline AdaDeltaOptimizer::AdaDeltaOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
SetParam("rho", 0.90f);
SetParam("epsilon", 1e-5);
}
inline std::string AdaDeltaOptimizer::GetType() const {
return "adadelta";
}
inline void AdaDeltaOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (acc_g_.count(index) == 0) {
CreateState_(index, weight);
}
float wd = std::stof(params_["wd"]);
float rho = std::stof(params_["rho"]);
float epsilon = std::stof(params_["epsilon"]);
if (params_.count("rescale_grad") > 0) {
grad *= std::stof(params_["rescale_grad"]);
}
if (params_.count("clip_gradient") > 0) {
_clip(grad, std::stof(params_["clip_gradient"]));
}
auto& acc_g = *acc_g_[index];
auto& acc_delta = *acc_delta_[index];
acc_g *= rho;
acc_g += grad * grad * (1.0f - rho);
auto delta = _sqrt(acc_delta + epsilon) / _sqrt(acc_g + epsilon) * grad;
acc_delta *= rho;
acc_delta += delta * delta * (1.0f - rho);
weight *= 1.0f - wd;
weight -= delta;
}
inline AdaDeltaOptimizer::~AdaDeltaOptimizer() {
for (auto& it : acc_g_) {
delete it.second;
}
for (auto& it : acc_delta_) {
delete it.second;
}
}
inline void AdaDeltaOptimizer::CreateState_(int index, NDArray weight) {
acc_g_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*acc_g_[index] = 0;
acc_delta_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*acc_delta_[index] = 0;
}
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_HPP_