| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file optimizer.h |
| * \brief definition of optimizer |
| * \author Chuntao Hong, Zhang Chen |
| */ |
| |
| #ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_H_ |
| #define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_H_ |
| |
| #include <map> |
| #include <vector> |
| #include <string> |
| #include <memory> |
| #include <functional> |
| #include "mxnet-cpp/base.h" |
| #include "dmlc/logging.h" |
| #include "mxnet-cpp/ndarray.h" |
| #include "mxnet-cpp/op_map.h" |
| |
| namespace mxnet { |
| namespace cpp { |
| |
| /*! |
| * \brief Optimizer interface |
| */ |
| class Optimizer { |
| public: |
| /*! |
| * \brief constructor |
| * \param beign_num_update The initial number of updates |
| */ |
| explicit Optimizer(unsigned begin_num_update); |
| /*! |
| * \brief get optimizer type |
| * \return string of optimizer type |
| */ |
| virtual std::string GetType() const = 0; |
| /*! |
| * \brief destructor |
| */ |
| virtual ~Optimizer(); |
| /*! |
| * \brief set config parameters |
| * \param name name of the config parameter |
| * \param value value of the config parameter |
| * \return reference of self |
| */ |
| template <typename T> |
| Optimizer *SetParam(const std::string &name, const T &value) { |
| std::string value_str; |
| std::stringstream ss; |
| ss << value; |
| ss >> value_str; |
| |
| params_[name] = value_str; |
| return this; |
| } |
| /*! |
| * \brief Update a weight with gradient. |
| * \param index the unique index for the weight. |
| * \param weight the weight to update. |
| * \param grad gradient for the weight. |
| * \param lr learning rate. |
| * \param wd weight decay. |
| */ |
| void Update(int index, NDArray weight, NDArray grad, mx_float lr, |
| mx_float wd); |
| /*! |
| * \brief Update a weight with gradient. |
| * \param index the unique index for the weight. |
| * \param weight the weight to update. |
| * \param grad gradient for the weight. |
| */ |
| virtual void Update(int index, NDArray weight, NDArray grad) = 0; |
| // TODO(zhangcheng-qinyinghua) |
| // implement Update a list of arrays, maybe in the form of map |
| // void Update(int index, std::vector<NDArray> weights, std::vector<NDArray> |
| // grad, mx_float lr); |
| |
| /*! |
| * \brief Serialize the optimizer parameters to a string. |
| * \return serialization |
| */ |
| std::string Serialize() const; |
| |
| protected: |
| std::map<std::string, std::string> params_; |
| static OpMap*& op_map(); |
| const std::vector<const char*> GetParamKeys_() const; |
| const std::vector<const char*> GetParamValues_() const; |
| std::map<int, unsigned> count_; |
| unsigned begin_num_update_, num_update_; |
| unsigned UpdateCount_(int index); |
| virtual void CreateState_(int index, NDArray weight); |
| }; |
| |
| typedef std::function<Optimizer*()> OptimizerCreator; |
| |
| class OptimizerRegistry { |
| public: |
| static Optimizer* Find(const std::string& name); |
| static int __REGISTER__(const std::string& name, OptimizerCreator creator); |
| private: |
| static std::map<std::string, OptimizerCreator>& cmap(); |
| OptimizerRegistry() = delete; |
| ~OptimizerRegistry() = delete; |
| }; |
| |
| #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType) \ |
| static int __make_ ## OptimizerType ## _ ## Name ## __ = \ |
| OptimizerRegistry::__REGISTER__(#Name, [](){return new OptimizerType();}) |
| |
| class SGDOptimizer : public Optimizer { |
| public: |
| explicit SGDOptimizer(unsigned begin_num_update = 0); |
| std::string GetType() const override; |
| void Update(int index, NDArray weight, NDArray grad) override; |
| private: |
| virtual ~SGDOptimizer(); |
| void CreateState_(int index, NDArray weight) override; |
| std::map<int, NDArray*> states_; |
| AtomicSymbolCreator update_handle_; |
| AtomicSymbolCreator mom_update_handle_; |
| }; |
| |
| class RMSPropOptimizer : public Optimizer { |
| public: |
| explicit RMSPropOptimizer(unsigned begin_num_update = 0); |
| std::string GetType() const override; |
| void Update(int index, NDArray weight, NDArray grad) override; |
| private: |
| virtual ~RMSPropOptimizer(); |
| void CreateState_(int index, NDArray weight) override; |
| std::map<int, NDArray*> n_, g_, delta_; |
| AtomicSymbolCreator update_handle_; |
| AtomicSymbolCreator alex_update_handle_; |
| }; |
| |
| class AdamOptimizer : public Optimizer { |
| public: |
| explicit AdamOptimizer(unsigned begin_num_update = 0); |
| std::string GetType() const override; |
| void Update(int index, NDArray weight, NDArray grad) override; |
| private: |
| virtual ~AdamOptimizer(); |
| void CreateState_(int index, NDArray weight) override; |
| std::map<int, NDArray*> mean_; |
| std::map<int, NDArray*> var_; |
| AtomicSymbolCreator update_handle_; |
| }; |
| |
| class AdaGradOptimizer : public Optimizer { |
| public: |
| explicit AdaGradOptimizer(unsigned begin_num_update = 0); |
| std::string GetType() const override; |
| void Update(int index, NDArray weight, NDArray grad) override; |
| private: |
| virtual ~AdaGradOptimizer(); |
| void CreateState_(int index, NDArray weight) override; |
| std::map<int, NDArray*> history_; |
| }; |
| |
| class AdaDeltaOptimizer : public Optimizer { |
| public: |
| explicit AdaDeltaOptimizer(unsigned begin_num_update = 0); |
| std::string GetType() const override; |
| void Update(int index, NDArray weight, NDArray grad) override; |
| private: |
| virtual ~AdaDeltaOptimizer(); |
| void CreateState_(int index, NDArray weight) override; |
| std::map<int, NDArray*> acc_g_, acc_delta_; |
| }; |
| |
| |
| } // namespace cpp |
| } // namespace mxnet |
| |
| #endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPTIMIZER_H_ |