blob: b853703c5f6b728474dcf56f2cc07ceed709df97 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file optimizer.h
* \brief definition of optimizer
* \author Chuntao Hong, Zhang Chen
*/
#ifndef MXNET_CPP_OPTIMIZER_H_
#define MXNET_CPP_OPTIMIZER_H_
#include <dmlc/strtonum.h>
#include <map>
#include <vector>
#include <string>
#include <memory>
#include <functional>
#include "mxnet-cpp/base.h"
#include "dmlc/logging.h"
#include "mxnet-cpp/ndarray.h"
#include "mxnet-cpp/op_map.h"
#include "mxnet-cpp/lr_scheduler.h"
namespace mxnet {
namespace cpp {
/*!
* \brief Optimizer interface
*/
class Optimizer {
public:
/*!
* \brief constructor
* \param beign_num_update The initial number of updates
*/
explicit Optimizer(unsigned begin_num_update);
/*!
* \brief get optimizer type
* \return string of optimizer type
*/
virtual std::string GetType() const = 0;
/*!
* \brief destructor
*/
virtual ~Optimizer();
/*!
* \brief set config parameters
* \param name name of the config parameter
* \param value value of the config parameter
* \return reference of self
*/
template <typename T>
Optimizer* SetParam(const std::string& name, const T& value) {
std::string value_str;
std::stringstream ss;
ss << value;
ss >> value_str;
params_[name] = value_str;
return this;
}
/*!
* \bried set the lr scheduler
* \param lrScheduler lr scheduler used for this optimizer
* \return reference if self
*/
Optimizer* SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
CHECK(lrScheduler);
lrScheduler_ = std::move(lrScheduler);
lrScheduler_->SetLR(dmlc::stof(params_["lr"]));
return this;
}
/*!
* \brief Update a weight with gradient.
* \param index the unique index for the weight.
* \param weight the weight to update.
* \param grad gradient for the weight.
*/
virtual void Update(int index, NDArray weight, NDArray grad) = 0;
// TODO(zhangcheng-qinyinghua)
// implement Update a list of arrays, maybe in the form of map
// void Update(int index, std::vector<NDArray> weights, std::vector<NDArray>
// grad, mx_float lr);
/*!
* \brief Serialize the optimizer parameters to a string.
* \return serialization
*/
std::string Serialize() const;
protected:
std::map<std::string, std::string> params_;
static OpMap*& op_map();
const std::vector<const char*> GetParamKeys_() const;
const std::vector<const char*> GetParamValues_() const;
std::map<int, unsigned> count_;
unsigned begin_num_update_, num_update_;
unsigned UpdateCount_(int index);
float GetLR_(int index);
float GetWD_(int index);
virtual void CreateState_(int index, NDArray weight);
std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
};
typedef std::function<Optimizer*()> OptimizerCreator;
class OptimizerRegistry {
public:
static Optimizer* Find(const std::string& name);
static int __REGISTER__(const std::string& name, OptimizerCreator creator);
private:
static std::map<std::string, OptimizerCreator>& cmap();
OptimizerRegistry() = delete;
~OptimizerRegistry() = delete;
};
#define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType) \
OptimizerRegistry::__REGISTER__(#Name, []() { return new OptimizerType(); })
class SGDOptimizer : public Optimizer {
public:
explicit SGDOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~SGDOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> states_;
AtomicSymbolCreator update_handle_;
AtomicSymbolCreator mom_update_handle_;
};
class SignumOptimizer : public Optimizer {
public:
explicit SignumOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~SignumOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> states_;
AtomicSymbolCreator update_handle_;
AtomicSymbolCreator mom_update_handle_;
};
class RMSPropOptimizer : public Optimizer {
public:
explicit RMSPropOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~RMSPropOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> n_, g_, delta_;
AtomicSymbolCreator update_handle_;
AtomicSymbolCreator alex_update_handle_;
};
class AdamOptimizer : public Optimizer {
public:
explicit AdamOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~AdamOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> mean_;
std::map<int, NDArray*> var_;
AtomicSymbolCreator update_handle_;
};
class AdaGradOptimizer : public Optimizer {
public:
explicit AdaGradOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~AdaGradOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> history_;
};
class AdaDeltaOptimizer : public Optimizer {
public:
explicit AdaDeltaOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~AdaDeltaOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> acc_g_, acc_delta_;
};
} // namespace cpp
} // namespace mxnet
#endif // MXNET_CPP_OPTIMIZER_H_