| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file optimizer_op-inl.h |
| * \brief Optimizer operators |
| * \author Junyuan Xie |
| */ |
| #ifndef MXNET_OPERATOR_OPTIMIZER_OP_INL_H_ |
| #define MXNET_OPERATOR_OPTIMIZER_OP_INL_H_ |
| #include <dmlc/parameter.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/operator_util.h> |
| #include <mxnet/op_attr_types.h> |
| #include <mshadow/base.h> |
| #include <nnvm/op.h> |
| #include <nnvm/op_attr_types.h> |
| #include <vector> |
| #include "./operator_common.h" |
| #include "./mshadow_op.h" |
| #include "./elemwise_op_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| struct SGDParam : public dmlc::Parameter<SGDParam> { |
| float lr; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(SGDParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("learning_rate"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("weight decay"); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("rescale gradient as grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip gradient to " |
| "grad = max(min(grad, -clip_gradient), clip_gradient). " |
| "Otherwise turned off."); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void SGDUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mshadow_op; |
| const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| if (param.clip_gradient >= 0.0f) { |
| Assign(out, req[0], |
| scalar<DType>(1.f-param.lr*param.wd)*weight |
| - scalar<DType>(param.lr) |
| * F<clip>(scalar<DType>(param.rescale_grad)*grad, |
| DType(param.clip_gradient))); |
| } else { |
| Assign(out, req[0], |
| scalar<DType>(1.f-param.lr*param.wd)*weight |
| - scalar<DType>(param.lr*param.rescale_grad)*grad); |
| } |
| }); |
| } |
| |
| struct SGDMomParam : public dmlc::Parameter<SGDMomParam> { |
| float lr; |
| float momentum; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(SGDMomParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("learning_rate"); |
| DMLC_DECLARE_FIELD(momentum) |
| .set_default(0.0f) |
| .describe("momentum"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("weight decay"); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("rescale gradient as grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip gradient to " |
| "grad = max(min(grad, -clip_gradient), clip_gradient). " |
| "Otherwise turned off."); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mshadow_op; |
| SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| if (param.clip_gradient >= 0.0f) { |
| mom = scalar<DType>(param.momentum)*mom |
| - scalar<DType>(param.lr*param.wd)*weight |
| - scalar<DType>(param.lr) |
| * F<clip>(scalar<DType>(param.rescale_grad)*grad, |
| DType(param.clip_gradient)); |
| } else { |
| mom = scalar<DType>(param.momentum)*mom |
| - scalar<DType>(param.lr*param.wd)*weight |
| - scalar<DType>(param.lr*param.rescale_grad)*grad; |
| } |
| Assign(out, req[0], weight + mom); |
| }); |
| } |
| |
| struct AdamParam : public dmlc::Parameter<AdamParam> { |
| float lr; |
| float beta1; |
| float beta2; |
| float epsilon; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(AdamParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("learning_rate"); |
| DMLC_DECLARE_FIELD(beta1) |
| .set_default(0.9f) |
| .describe("beta1"); |
| DMLC_DECLARE_FIELD(beta2) |
| .set_default(0.999f) |
| .describe("beta2"); |
| DMLC_DECLARE_FIELD(epsilon) |
| .set_default(1e-8f) |
| .describe("epsilon"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("weight decay"); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("rescale gradient as grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip gradient to " |
| "grad = max(min(grad, -clip_gradient), clip_gradient). " |
| "Otherwise turned off."); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void AdamUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mshadow_op; |
| const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| grad = scalar<DType>(param.rescale_grad) * grad + |
| scalar<DType>(param.wd) * weight; |
| |
| if (param.clip_gradient >= 0.0f) { |
| mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * |
| F<clip>(grad, DType(param.clip_gradient)); |
| var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>( |
| F<clip>(grad, DType(param.clip_gradient))); |
| } else { |
| mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad; |
| var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad); |
| } |
| Assign(out, req[0], |
| weight - |
| scalar<DType>(param.lr) * mean / |
| (F<square_root>(var) + scalar<DType>(param.epsilon))); |
| }); |
| } |
| |
| // This RMSProp code follows the version in |
| // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) |
| // by Alex Graves, 2013. |
| struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> { |
| float lr; |
| float gamma1; |
| float gamma2; |
| float epsilon; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| float clip_weights; |
| DMLC_DECLARE_PARAMETER(RMSPropAlexParam) { |
| DMLC_DECLARE_FIELD(lr).describe("learning_rate"); |
| DMLC_DECLARE_FIELD(gamma1).set_default(0.95f).describe("gamma1"); |
| DMLC_DECLARE_FIELD(gamma2).set_default(0.9f).describe("gamma2"); |
| DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f).describe("epsilon"); |
| DMLC_DECLARE_FIELD(wd).set_default(0.0f).describe("weight decay"); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("rescale gradient as grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip gradient to " |
| "grad = max(min(grad, -clip_gradient), clip_gradient). " |
| "Otherwise turned off."); |
| DMLC_DECLARE_FIELD(clip_weights) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip weights to " |
| "weights = max(min(weights, -clip_weights), clip_weights). " |
| "Otherwise turned off."); |
| } |
| }; |
| |
| template <typename xpu> |
| inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mshadow_op; |
| const RMSPropAlexParam ¶m = nnvm::get<RMSPropAlexParam>(attrs.parsed); |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> state_n = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> state_g = inputs[3].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> delta = inputs[4].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| grad = scalar<DType>(param.rescale_grad) * grad + |
| scalar<DType>(param.wd) * weight; |
| |
| if (param.clip_gradient >= 0.0f) { |
| state_n = scalar<DType>(1.f - param.gamma1) * |
| F<clip>(grad, DType(param.clip_gradient)) * |
| F<clip>(grad, DType(param.clip_gradient)) + |
| scalar<DType>(param.gamma1) * state_n; |
| state_g = scalar<DType>(1.f - param.gamma1) * |
| F<clip>(grad, DType(param.clip_gradient)) + |
| scalar<DType>(param.gamma1) * state_g; |
| delta = scalar<DType>(param.gamma2) * delta - |
| scalar<DType>(param.lr) * |
| (F<clip>(grad, DType(param.clip_gradient)) / |
| (F<square_root>(state_n - state_g * state_g) + |
| scalar<DType>(param.epsilon))); |
| } else { |
| state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) + |
| scalar<DType>(param.gamma1) * state_n; |
| state_g = scalar<DType>(1.f - param.gamma1) * grad + |
| scalar<DType>(param.gamma1) * state_g; |
| delta = scalar<DType>(param.gamma2) * delta - |
| scalar<DType>(param.lr) * |
| (grad / (F<square_root>(state_n - state_g * state_g) + |
| scalar<DType>(param.epsilon))); |
| } |
| |
| if (param.clip_weights >= 0.0f) { |
| Assign(out, req[0], F<clip>(weight + delta, DType(param.clip_weights))); |
| } else { |
| Assign(out, req[0], weight + delta); |
| } |
| }); |
| } |
| |
| // This RMSProp code follows the version in |
| // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf |
| // by Tieleman & Hinton, 2012 |
| struct RMSPropParam : public dmlc::Parameter<RMSPropParam> { |
| float lr; |
| float gamma1; |
| float epsilon; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| float clip_weights; |
| DMLC_DECLARE_PARAMETER(RMSPropParam) { |
| DMLC_DECLARE_FIELD(lr).describe("learning_rate"); |
| DMLC_DECLARE_FIELD(gamma1).set_default(0.95f).describe("gamma1"); |
| DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f).describe("epsilon"); |
| DMLC_DECLARE_FIELD(wd).set_default(0.0f).describe("weight decay"); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("rescale gradient as grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip gradient to " |
| "grad = max(min(grad, -clip_gradient), clip_gradient). " |
| "Otherwise turned off."); |
| DMLC_DECLARE_FIELD(clip_weights) |
| .set_default(-1.0f) |
| .describe("If greater than 0, clip weights to " |
| "weights = max(min(weights, -clip_weights), clip_weights). " |
| "Otherwise turned off."); |
| } |
| }; |
| |
| template <typename xpu> |
| inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mshadow_op; |
| const RMSPropParam ¶m = nnvm::get<RMSPropParam>(attrs.parsed); |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> state_n = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| grad = scalar<DType>(param.rescale_grad) * grad + |
| scalar<DType>(param.wd) * weight; |
| |
| if (param.clip_gradient >= 0.0f) { |
| state_n = scalar<DType>(1.f - param.gamma1) * |
| F<clip>(grad, DType(param.clip_gradient)) * |
| F<clip>(grad, DType(param.clip_gradient)) + |
| scalar<DType>(param.gamma1) * state_n; |
| if (param.clip_weights >= 0.0f) { |
| Assign(out, req[0], |
| F<clip>(weight - |
| scalar<DType>(param.lr) * |
| (F<clip>(grad, DType(param.clip_gradient)) / |
| (F<square_root>(state_n) + |
| scalar<DType>(param.epsilon))), |
| DType(param.clip_weights))); |
| } else { |
| Assign(out, req[0], weight - |
| scalar<DType>(param.lr) * |
| (F<clip>(grad, DType(param.clip_gradient)) / |
| (F<square_root>(state_n) + |
| scalar<DType>(param.epsilon)))); |
| } |
| } else { |
| state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) + |
| scalar<DType>(param.gamma1) * state_n; |
| if (param.clip_weights >= 0.0f) { |
| Assign(out, req[0], |
| F<clip>(weight - |
| scalar<DType>(param.lr) * |
| (grad / (F<square_root>(state_n) + |
| scalar<DType>(param.epsilon))), |
| DType(param.clip_weights))); |
| } else { |
| Assign(out, req[0], weight - |
| scalar<DType>(param.lr) * |
| (grad / (F<square_root>(state_n) + |
| scalar<DType>(param.epsilon)))); |
| } |
| } |
| }); |
| } |
| |
| } // namespace op |
| } // namespace mxnet |
| |
| |
| #endif // MXNET_OPERATOR_OPTIMIZER_OP_INL_H_ |