| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file optimizer_op-inl.h |
| * \brief Optimizer operators |
| * \author Junyuan Xie |
| */ |
| #ifndef MXNET_OPERATOR_OPTIMIZER_OP_INL_H_ |
| #define MXNET_OPERATOR_OPTIMIZER_OP_INL_H_ |
| #include <dmlc/parameter.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/operator_util.h> |
| #include <mxnet/op_attr_types.h> |
| #include <mshadow/base.h> |
| #include <nnvm/op.h> |
| #include <nnvm/op_attr_types.h> |
| #include <vector> |
| #include "./operator_common.h" |
| #include "./mshadow_op.h" |
| #include "./elemwise_op_common.h" |
| #include "mxnet_op.h" |
| #include "./tensor/init_op.h" |
| #include "./tensor/util/tensor_util-inl.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| /* |
| * \brief log message for optimizers with lazy update. |
| */ |
| inline void LogLazyUpdate() { |
| common::LogOnce("Optimizer with lazy_update = True detected. " |
| "Be aware that lazy update with row_sparse gradient is different from " |
| "standard update, and may lead to different empirical results. See " |
| "https://mxnet.apache.org/api/python/optimization/optimization.html " |
| "for more details."); |
| } |
| |
| struct SGDParam : public dmlc::Parameter<SGDParam> { |
| float lr; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| bool lazy_update; |
| DMLC_DECLARE_PARAMETER(SGDParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(lazy_update) |
| .set_default(true) |
| .describe("If true, lazy updates are applied if gradient's stype is row_sparse."); |
| } |
| }; |
| |
| struct MultiSGDParam : public dmlc::Parameter<MultiSGDParam> { |
| mxnet::Tuple<float> lrs; |
| mxnet::Tuple<float> wds; |
| float rescale_grad; |
| float clip_gradient; |
| int num_weights; |
| DMLC_DECLARE_PARAMETER(MultiSGDParam) { |
| DMLC_DECLARE_FIELD(lrs) |
| .describe("Learning rates."); |
| DMLC_DECLARE_FIELD(wds) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(num_weights) |
| .set_default(1) |
| .describe("Number of updated weights."); |
| } |
| }; |
| |
| struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> { |
| mxnet::Tuple<float> lrs; |
| mxnet::Tuple<float> wds; |
| float momentum; |
| float rescale_grad; |
| float clip_gradient; |
| int num_weights; |
| DMLC_DECLARE_PARAMETER(MultiSGDMomParam) { |
| DMLC_DECLARE_FIELD(lrs) |
| .describe("Learning rates."); |
| DMLC_DECLARE_FIELD(wds) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(momentum) |
| .set_default(0.0f) |
| .describe("The decay rate of momentum estimates at each epoch."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(num_weights) |
| .set_default(1) |
| .describe("Number of updated weights."); |
| } |
| }; |
| |
| |
| template<typename ParamType, int input_stride> |
| inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector *in_attrs, |
| mxnet::ShapeVector *out_attrs) { |
| const ParamType& param = dmlc::get<ParamType>(attrs.parsed); |
| CHECK_EQ(in_attrs->size(), input_stride * param.num_weights); |
| CHECK_EQ(out_attrs->size(), param.num_weights); |
| |
| bool all_inferred = true; |
| auto& input_shapes = *in_attrs; |
| auto& output_shapes = *out_attrs; |
| // Learning rates |
| CHECK_EQ(param.lrs.ndim(), param.num_weights) |
| << "Number of learning rates is inconsistent with num_weights " |
| << "parameter passed. Expected number of learning rates: " |
| << param.num_weights << ", and got " << param.lrs.ndim(); |
| // Weight decays |
| CHECK_EQ(param.wds.ndim(), param.num_weights) |
| << "Number of weight decays is inconsistent with num_weights " |
| << "parameter passed. Expected number of weight decays: " |
| << param.num_weights << ", and got " << param.wds.ndim(); |
| // Weights and gradients |
| for (int i = 0; i < param.num_weights; ++i) { |
| mxnet::ShapeVector input_vec; |
| mxnet::ShapeVector output_vec({output_shapes[i]}); |
| for (int j = 0; j < input_stride; ++j) { |
| input_vec.push_back(input_shapes[i * input_stride + j]); |
| } |
| all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs, &input_vec, &output_vec); |
| } |
| return all_inferred; |
| } |
| |
| template <typename ParamType, int input_stride, int num_fp32_inputs> |
| inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs, |
| std::vector<int> *in_attrs, |
| std::vector<int> *out_attrs) { |
| const ParamType& param = dmlc::get<ParamType>(attrs.parsed); |
| CHECK_EQ(in_attrs->size(), input_stride * param.num_weights); |
| CHECK_EQ(out_attrs->size(), param.num_weights); |
| |
| bool all_inferred = true; |
| auto& input_types = *in_attrs; |
| auto& output_types = *out_attrs; |
| // Weights and gradients |
| for (int i = 0; i < param.num_weights; ++i) { |
| std::vector<int> input_vec; |
| std::vector<int> output_vec({output_types[i]}); |
| for (int j = 0; j < input_stride - num_fp32_inputs; ++j) { |
| input_vec.push_back(input_types[i * input_stride + j]); |
| } |
| all_inferred = all_inferred && |
| ElemwiseType<input_stride - num_fp32_inputs, 1>(attrs, &input_vec, &output_vec); |
| } |
| // master copies of weights |
| for (int i = 0; i < param.num_weights; ++i) { |
| for (int j = 0; j < num_fp32_inputs; ++j) { |
| TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32); |
| } |
| } |
| return all_inferred; |
| } |
| |
| template<typename DType, typename MPDType> |
| struct MultiSGDKernelParam { |
| static const int N = 60; |
| int count; |
| size_t max_size; |
| size_t sizes[N]; |
| DType * weights[N]; |
| DType * grads[N]; |
| MPDType * mom[N]; |
| MPDType * weights32[N]; |
| DType * out_data[N]; |
| MPDType lrs[N]; |
| MPDType wds[N]; |
| MPDType clip_gradient; |
| MPDType rescale_grad; |
| MPDType momentum; |
| }; |
| |
| template <typename MPDType, bool has_momentum, bool has_mixed_precision> |
| struct MultiSGDKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, const MultiSGDKernelParam<DType, MPDType>& param, |
| const OpReqType req) { |
| for (int index = 0; index < param.count; ++index) { |
| if ((size_t)i < param.sizes[index]) { |
| MPDType w = has_mixed_precision ? param.weights32[index][i] : |
| MPDType(param.weights[index][i]); |
| MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0); |
| if (param.clip_gradient >= 0.0f) { |
| mom = param.momentum*mom |
| - param.lrs[index]*param.wds[index]*w |
| - param.lrs[index] |
| *mshadow_op::clip::Map(param.rescale_grad * |
| static_cast<MPDType>(param.grads[index][i]), |
| param.clip_gradient); |
| } else { |
| mom = param.momentum*mom |
| - param.lrs[index]*param.wds[index]*w |
| - param.lrs[index]*param.rescale_grad*static_cast<MPDType>(param.grads[index][i]); |
| } |
| if (has_momentum) { |
| param.mom[index][i] = mom; |
| } |
| w = w + mom; |
| if (has_mixed_precision) { |
| param.weights32[index][i] = w; |
| } |
| KERNEL_ASSIGN(param.out_data[index][i], req, w); |
| } |
| } |
| } |
| }; |
| |
| template<typename xpu, |
| typename DType, |
| typename MPDType, |
| typename ParamType = MultiSGDParam, |
| int input_stride = 2> |
| MultiSGDKernelParam<DType, MPDType> FillMultiSGDKernelParam(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const ParamType& p = nnvm::get<ParamType>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MultiSGDKernelParam<DType, MPDType> param; |
| param.clip_gradient = p.clip_gradient; |
| param.rescale_grad = p.rescale_grad; |
| param.momentum = 0; |
| param.count = p.num_weights; |
| param.max_size = 0; |
| for (int i = 0; i < param.count; ++i) { |
| param.sizes[i] = inputs[i * input_stride].shape_.Size(); |
| if (param.max_size < param.sizes[i]) { |
| param.max_size = param.sizes[i]; |
| } |
| param.weights[i] = inputs[i * input_stride].FlatTo2D<xpu, DType>(s).dptr_; |
| param.grads[i] = inputs[i * input_stride + 1].FlatTo2D<xpu, DType>(s).dptr_; |
| // if mixed precision, then the last input in a set |
| // is 32-bit master copy of the weights |
| if (!std::is_same<DType, MPDType>::value) { |
| param.weights32[i] = inputs[i * input_stride + input_stride - 1] |
| .FlatTo2D<xpu, MPDType>(s).dptr_; |
| } |
| param.out_data[i] = outputs[i].FlatTo2D<xpu, DType>(s).dptr_; |
| param.lrs[i] = p.lrs[i]; |
| param.wds[i] = p.wds[i]; |
| } |
| |
| return param; |
| } |
| |
| |
| template<typename xpu, |
| typename DType, |
| typename MPDType, |
| int input_stride = 3> |
| MultiSGDKernelParam<DType, MPDType> FillMultiSGDMomKernelParam(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const MultiSGDMomParam& p = nnvm::get<MultiSGDMomParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MultiSGDKernelParam<DType, MPDType> param = |
| FillMultiSGDKernelParam<xpu, |
| DType, |
| MPDType, |
| MultiSGDMomParam, |
| input_stride>(attrs, ctx, inputs, outputs); |
| param.momentum = p.momentum; |
| for (int i = 0; i < param.count; ++i) { |
| param.mom[i] = inputs[i * input_stride + 2].FlatTo2D<xpu, MPDType>(s).dptr_; |
| } |
| |
| return param; |
| } |
| |
| template<typename T> |
| class type_identity { |
| public: |
| using type = T; |
| }; |
| |
| template<typename T> |
| class single_precision { |
| public: |
| using type = float; |
| }; |
| |
| template<typename xpu, template<typename> class MPTypeChooser, int input_stride> |
| inline void MultiSGDUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| using MPDType = typename MPTypeChooser<DType>::type; |
| MultiSGDKernelParam<DType, MPDType> param = |
| FillMultiSGDKernelParam<xpu, |
| DType, |
| MPDType, |
| MultiSGDParam, |
| input_stride>(attrs, ctx, inputs, outputs); |
| Kernel<MultiSGDKernel<MPDType, |
| false, |
| !std::is_same<DType, MPDType>::value>, |
| xpu>::Launch(s, param.max_size, param, req[0]); |
| }); |
| } |
| |
| template<typename xpu, template<typename> class MPTypeChooser, int input_stride> |
| inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| using MPDType = typename MPTypeChooser<DType>::type; |
| MultiSGDKernelParam<DType, MPDType> param = |
| FillMultiSGDMomKernelParam<xpu, |
| DType, |
| MPDType, |
| input_stride>(attrs, ctx, inputs, outputs); |
| Kernel<MultiSGDKernel<MPDType, |
| true, |
| !std::is_same<DType, MPDType>::value>, |
| xpu>::Launch(s, param.max_size, param, req[0]); |
| }); |
| } |
| |
| struct SGDKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, |
| const DType* grad_data, const DType param_clip_gradient, |
| const DType param_lr, const DType param_wd, const DType param_rescale_grad, |
| const OpReqType req) { |
| if (param_clip_gradient >= 0.0f) { |
| KERNEL_ASSIGN(out_data[i], req, |
| (1.f-param_lr*param_wd)*weight_data[i] |
| - (param_lr) |
| * mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient)); |
| } else { |
| KERNEL_ASSIGN(out_data[i], req, |
| (1.f-param_lr*param_wd)*weight_data[i] |
| - (param_lr*param_rescale_grad)*grad_data[i]); |
| } |
| } |
| }; |
| |
| template<typename xpu> |
| inline void SGDUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<SGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_, |
| grad.dptr_, static_cast<DType>(param.clip_gradient), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.rescale_grad), req[0]); |
| }); |
| } |
| |
| /*! \brief kernel for sparse sgd |
| */ |
| template<int req, typename xpu> |
| struct SGDDnsRspKernel; |
| |
| template<int req> |
| struct SGDDnsRspKernel<req, gpu> { |
| // DType is the output data type |
| // IType is row sparse idx type |
| // i is the ith element in row sparse gradient |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType* out, const DType* weight, |
| const IType* grad_idx, const DType *grad_val, |
| const DType clip_gradient, const DType lr, |
| const DType wd, const DType rescale_grad) { |
| using nnvm::dim_t; |
| using namespace mshadow_op; |
| const dim_t row_id = i / row_length; |
| const dim_t col_id = i % row_length; |
| const dim_t row_offset = grad_idx[row_id] * row_length; |
| const dim_t data_i = row_offset + col_id; |
| if (clip_gradient >= 0.0f) { |
| KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - |
| (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[i], clip_gradient)); |
| } else { |
| KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - |
| (lr * rescale_grad) * grad_val[i]); |
| } |
| } |
| }; |
| |
| /*! \brief kernel for sparse sgd |
| */ |
| template<int req> |
| struct SGDDnsRspKernel<req, cpu> { |
| // DType is the output data type |
| // IType is row sparse idx type |
| // i is the ith row in row sparse gradient |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType* out, const DType* weight, |
| const IType* grad_idx, const DType *grad_val, |
| const DType clip_gradient, const DType lr, |
| const DType wd, const DType rescale_grad) { |
| for (index_t j = 0; j < row_length; j++) { |
| index_t data_i = grad_idx[i] * row_length + j; |
| index_t grad_i = i * row_length + j; |
| if (clip_gradient >= 0.0f) { |
| KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - |
| (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient)); |
| } else { |
| KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - |
| (lr * rescale_grad) * grad_val[grad_i]); |
| } |
| } |
| } |
| }; |
| |
| /* |
| * \brief SGD update implementation for dense weight and row_sparse grad. |
| * Both standard update and lazy update are supported. |
| */ |
| template<typename xpu> |
| inline void SGDUpdateDnsRspImpl(const SGDParam& param, |
| const OpContext &ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const OpReqType& req, |
| TBlob *out) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mshadow_op; |
| using namespace mxnet_op; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| CHECK_EQ(grad.storage_type(), kRowSparseStorage); |
| // if gradients are zeros, no weights are updated |
| if (req == kNullOp) return; |
| CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; |
| CHECK_GT(weight.shape_.Size(), 0); |
| |
| MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, req_type, { |
| DType* weight_data = weight.dptr<DType>(); |
| float wd = param.wd; |
| // apply standard weight decay if not lazy update |
| if (!param.lazy_update) { |
| Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(), |
| weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd)); |
| wd = 0; |
| } |
| if (!grad.storage_initialized()) return; |
| const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>(); |
| const DType* grad_val = grad.data().dptr<DType>(); |
| const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0]; |
| const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); |
| size_t num_threads = num_rows; |
| if (std::is_same<xpu, gpu>::value) { |
| num_threads = num_rows * row_length; |
| } |
| Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length, |
| out->dptr<DType>(), weight_data, grad_idx, grad_val, |
| static_cast<DType>(param.clip_gradient), |
| static_cast<DType>(param.lr), static_cast<DType>(wd), |
| static_cast<DType>(param.rescale_grad)); |
| }); |
| }); |
| }); |
| } |
| |
| /* |
| * \brief SGD update implementation for row_sparse grad. |
| * Both standard update and lazy update are supported. |
| */ |
| template<typename xpu> |
| inline void SGDUpdateRspImpl(const SGDParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const OpReqType& req, |
| NDArray *out) { |
| CheckAllRowsPresent(weight, "SGDUpdate", "weights"); |
| // reuse dns rsp implementation when storage_shape == shape |
| TBlob out_blob = out->data(); |
| SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob); |
| } |
| |
| template<typename xpu> |
| inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<NDArray> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<NDArray> &outputs) { |
| const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed); |
| const auto w_stype = inputs[0].storage_type(); |
| const auto g_stype = inputs[1].storage_type(); |
| const auto o_stype = outputs[0].storage_type(); |
| if (o_stype == w_stype && g_stype == kRowSparseStorage && |
| (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) { |
| NDArray out = outputs[0]; |
| // std update and lazy update with rsp grad |
| SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| struct SGDMomParam : public dmlc::Parameter<SGDMomParam> { |
| float lr; |
| float momentum; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| bool lazy_update; |
| DMLC_DECLARE_PARAMETER(SGDMomParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(momentum) |
| .set_default(0.0f) |
| .describe("The decay rate of momentum estimates at each epoch."); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(lazy_update) |
| .set_default(true) |
| .describe("If true, lazy updates are applied if gradient's stype is row_sparse " |
| "and both weight and momentum have the same stype"); |
| } |
| }; |
| |
| |
| struct SGDMomKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, |
| const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, |
| const DType param_lr, const DType param_wd, const DType param_rescale_grad, |
| const OpReqType req) { |
| if (param_clip_gradient >= 0.0f) { |
| mom_data[i] = param_momentum*mom_data[i] |
| - param_lr*param_wd*weight_data[i] |
| - param_lr |
| *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient); |
| } else { |
| mom_data[i] = param_momentum*mom_data[i] |
| - param_lr*param_wd*weight_data[i] |
| - param_lr*param_rescale_grad*grad_data[i]; |
| } |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i] + mom_data[i]); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<SGDMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_, |
| grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.rescale_grad), req[0]); |
| }); |
| } |
| |
| template<int n_in, int n_out, int total_in> |
| inline bool MP_InferType(const nnvm::NodeAttrs& attrs, |
| std::vector<int> *in_attrs, |
| std::vector<int> *out_attrs) { |
| CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name; |
| CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name; |
| for (int i = n_in; i < total_in; ++i) { |
| TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32); |
| } |
| return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>( |
| attrs, in_attrs, out_attrs, -1); |
| } |
| |
| struct MP_SGDKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, |
| const DType* grad_data, float* weight32, const float param_clip_gradient, |
| const float param_lr, const float param_wd, const float param_rescale_grad, |
| const OpReqType req) { |
| if (param_clip_gradient >= 0.0f) { |
| float w = weight32[i]; |
| w = (1.f - param_lr*param_wd)*w - |
| (param_lr) * mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]), |
| param_clip_gradient); |
| weight32[i] = w; |
| KERNEL_ASSIGN(out_data[i], req, (DType)w); |
| } else { |
| float w = weight32[i]; |
| w = (1.f-param_lr*param_wd)*w |
| - (param_lr*param_rescale_grad)*static_cast<float>(grad_data[i]); |
| weight32[i] = w; |
| KERNEL_ASSIGN(out_data[i], req, (DType)w); |
| } |
| } |
| }; |
| |
| template<typename xpu> |
| inline void MP_SGDUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, float> weight32 = inputs[2].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<MP_SGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_, |
| grad.dptr_, weight32.dptr_, param.clip_gradient, |
| param.lr, param.wd, |
| param.rescale_grad, req[0]); |
| }); |
| } |
| |
| struct MP_SGDMomKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mom_data, |
| const DType* weight_data, const DType* grad_data, float* weight32, |
| const float param_clip_gradient, const float param_momentum, const float param_lr, |
| const float param_wd, const float param_rescale_grad, const OpReqType req) { |
| float w = weight32[i]; |
| float mom = mom_data[i]; |
| if (param_clip_gradient >= 0.0f) { |
| mom = param_momentum*mom |
| - param_lr*param_wd*w |
| - param_lr |
| *mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]), |
| param_clip_gradient); |
| } else { |
| mom = param_momentum*mom |
| - param_lr*param_wd*w |
| - param_lr*param_rescale_grad*static_cast<float>(grad_data[i]); |
| } |
| mom_data[i] = mom; |
| w = w + mom; |
| weight32[i] = w; |
| KERNEL_ASSIGN(out_data[i], req, w); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<MP_SGDMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, |
| weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.momentum, |
| param.lr, param.wd, param.rescale_grad, req[0]); |
| }); |
| } |
| |
| template<int req, typename xpu> |
| struct SGDMomDnsRspDnsKernel; |
| |
| template<int req> |
| struct SGDMomDnsRspDnsKernel<req, cpu> { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, |
| DType* mom_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType momentum, |
| const DType lr, const DType wd, const DType rescale_grad) { |
| const DType rate = lr * wd; |
| for (index_t j = 0; j < row_length; j++) { |
| index_t data_i = grad_idx[i] * row_length + j; |
| index_t grad_i = i * row_length + j; |
| if (clip_gradient >= 0.0f) { |
| mom_data[data_i] = momentum * mom_data[data_i] |
| - rate * weight_data[data_i] |
| - lr * |
| mshadow_op::clip::Map(rescale_grad * grad_data[grad_i], |
| clip_gradient); |
| } else { |
| mom_data[data_i] = momentum * mom_data[data_i] |
| - rate * weight_data[data_i] |
| - lr * rescale_grad * grad_data[grad_i]; |
| } |
| KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); |
| } |
| } |
| }; |
| |
| template<int req> |
| struct SGDMomDnsRspDnsKernel<req, gpu> { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, |
| DType* mom_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType momentum, |
| const DType lr, const DType wd, const DType rescale_grad) { |
| using nnvm::dim_t; |
| const DType rate = lr * wd; |
| const dim_t row_id = i / row_length; |
| const dim_t col_id = i % row_length; |
| const dim_t data_i = grad_idx[row_id] * row_length + col_id; |
| if (clip_gradient >= 0.0f) { |
| mom_data[data_i] = momentum * mom_data[data_i] |
| - rate * weight_data[data_i] |
| - lr * |
| mshadow_op::clip::Map(rescale_grad * grad_data[i], |
| clip_gradient); |
| } else { |
| mom_data[data_i] = momentum * mom_data[data_i] |
| - rate * weight_data[data_i] |
| - lr * rescale_grad * grad_data[i]; |
| } |
| KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); |
| } |
| }; |
| |
| /* |
| * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state. |
| */ |
| template<typename xpu> |
| inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param, |
| const OpContext& ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const TBlob& mom, |
| const OpReqType& req, |
| TBlob *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| if (!grad.storage_initialized() || req == kNullOp) return; |
| CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; |
| CHECK_GT(weight.shape_.Size(), 0); |
| CHECK_GT(mom.shape_.Size(), 0); |
| |
| MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, req_type, { |
| DType* weight_data = weight.dptr<DType>(); |
| IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); |
| DType* grad_val = grad.data().dptr<DType>(); |
| DType* mom_data = mom.dptr<DType>(); |
| DType* out_data = out->dptr<DType>(); |
| index_t num_rows = grad.aux_shape(kIdx)[0]; |
| auto row_length = weight.shape_.ProdShape(1, weight.ndim()); |
| size_t num_threads = num_rows; |
| if (std::is_same<xpu, gpu>::value) { |
| num_threads = num_rows * row_length; |
| } |
| Kernel<SGDMomDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length, |
| out_data, mom_data, weight_data, grad_idx, grad_val, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.rescale_grad)); |
| }); |
| }); |
| }); |
| } |
| |
| /* |
| * \brief sgd momentum lazy update for row_sparse grad. |
| */ |
| template<typename xpu> |
| inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const NDArray& mom, |
| const OpReqType& req, |
| NDArray *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| CheckAllRowsPresent(weight, "SGDMomUpdate", "weights"); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| // fill mom with zero values (if it's in rsp storage) |
| // in order to reuse the sgd mom dns impl |
| if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) { |
| NDArray mom_zeros = mom; |
| FillDnsZerosRspImpl(s, &mom_zeros); |
| } |
| TBlob out_blob = out->data(); |
| // reuse dns rsp implementation when storage_shape == shape |
| SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, |
| mom.data(), req, &out_blob); |
| } |
| |
| /*! |
| * \brief Storge type inference function for optimizers which support both |
| * lazy update and standard update, with states (e.g. 2nd order moment) |
| * \param num_states The number of states that could be row_sparse or dense |
| */ |
| template<size_t num_states, typename ParamType> |
| inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| using namespace common; |
| const ParamType& param = nnvm::get<ParamType>(attrs.parsed); |
| // weight, grad, state 0, state 1, ... -> weight |
| CHECK_EQ(in_attrs->size(), 2 + num_states); |
| CHECK_EQ(out_attrs->size(), 1U); |
| const int weight_stype = in_attrs->at(0); |
| const int grad_stype = in_attrs->at(1); |
| const int state_stype = in_attrs->at(2); |
| // the storage type of all states should be the same |
| for (size_t i = 3; i < 2 + num_states; i++) { |
| CHECK_EQ(state_stype, in_attrs->at(i)) |
| << "Inconsistent storage types detected in state " << i; |
| } |
| bool dispatched = false; |
| if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { |
| // dns, ... -> dns |
| dispatched = storage_type_assign(out_attrs, kDefaultStorage, |
| dispatch_mode, DispatchMode::kFCompute); |
| } |
| if (!dispatched && grad_stype == kRowSparseStorage && |
| (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) && |
| state_stype == weight_stype) { |
| // weight and state share stype, grad's stype = rsp |
| dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype), |
| dispatch_mode, DispatchMode::kFComputeEx); |
| // warn users if lazy_update is turned on |
| if (dispatched && param.lazy_update) LogLazyUpdate(); |
| } |
| if (!dispatched && grad_stype == kRowSparseStorage && |
| weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) { |
| // weight, grad, state, ... -> weight |
| // rsp, rsp, dns, ... -> rsp, standard update |
| dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype), |
| dispatch_mode, DispatchMode::kFComputeEx); |
| } |
| if (!dispatched) { |
| dispatched = dispatch_fallback(out_attrs, dispatch_mode); |
| } |
| return dispatched; |
| } |
| |
| /* |
| * \brief kernel for standard momentum update for dense weight, sparse grad and dense state. |
| */ |
| template<int req, typename xpu> |
| struct SGDMomStdDnsRspDnsKernel; |
| |
| |
| /* |
| * \brief standard momentum update for dense weight, row_sparse grad and dense states. |
| */ |
| template<typename xpu> |
| void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, |
| const OpContext& ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const TBlob& mom, |
| const OpReqType& req, |
| TBlob *out); |
| |
| /* |
| * \brief standard momentum update for row_sparse grad. |
| * both row_sparse and dense weight are supported. |
| */ |
| template<typename xpu> |
| inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const NDArray& mom, |
| const OpReqType& req, |
| NDArray *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| CheckAllRowsPresent(weight, "SGDMomUpdate", "weights"); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| // fill mom with zero values (if it's in rsp storage) |
| // in order to reuse the sgd mom dns impl |
| if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) { |
| NDArray mom_zeros = mom; |
| FillDnsZerosRspImpl(s, &mom_zeros); |
| } |
| TBlob out_blob = out->data(); |
| SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, |
| mom.data(), req, &out_blob); |
| } |
| |
| template<typename xpu> |
| inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<NDArray> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<NDArray> &outputs) { |
| using namespace mxnet_op; |
| const SGDMomParam& param = nnvm::get<SGDMomParam>(attrs.parsed); |
| auto &weight = inputs[0]; |
| auto &grad = inputs[1]; |
| auto &mom = inputs[2]; |
| const auto w_stype = weight.storage_type(); |
| const auto m_stype = mom.storage_type(); |
| const auto out_stype = outputs[0].storage_type(); |
| NDArray out = outputs[0]; |
| const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage; |
| const bool valid_grad = grad.storage_type() == kRowSparseStorage; |
| const bool lazy_update = param.lazy_update; |
| CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype"; |
| if (valid_weight && valid_grad && m_stype == w_stype) { |
| if (lazy_update) { |
| // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update |
| SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); |
| } else { |
| // rsp grad && m.stype = w.stype && lazy_update = false -> std update |
| SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); |
| } |
| } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == kDefaultStorage) { |
| // rsp weight, rsp grad, dns state -> std update |
| SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| |
| struct NAGParam : public dmlc::Parameter<NAGParam> { |
| float lr; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(NAGParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude " |
| "of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| struct NAGMomParam : public dmlc::Parameter<NAGMomParam> { |
| float lr; |
| float momentum; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(NAGMomParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(momentum) |
| .set_default(0.0f) |
| .describe("The decay rate of momentum estimates at each epoch."); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude " |
| "of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| struct NAGMomKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, |
| const DType* weight_data, const DType* grad_data, |
| const DType param_clip_gradient, const DType param_momentum, |
| const DType param_lr, const DType param_wd, |
| const DType param_rescale_grad, const OpReqType req) { |
| if (param_clip_gradient >= 0.0f) { |
| mom_data[i] = param_momentum*mom_data[i]; |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1) |
| *(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad |
| *grad_data[i], param_clip_gradient)+(param_wd*weight_data[i]))))); |
| mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i], |
| param_clip_gradient))+(param_wd*weight_data[i]))); |
| } else { |
| mom_data[i] = param_momentum*mom_data[i]; |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1) |
| *(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i])))); |
| mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i]) |
| +(param_wd*weight_data[i])); |
| } |
| } |
| }; |
| |
| template<typename xpu> |
| inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, |
| mom.dptr_, weight.dptr_, grad.dptr_, |
| static_cast<DType>(param.clip_gradient), |
| static_cast<DType>(param.momentum), static_cast<DType>(param.lr), |
| static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad), |
| req[0]); |
| }); |
| } |
| |
| struct MP_NAGMomKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| float* mom_data, const DType* weight_data, |
| const DType* grad_data, float* weight32, |
| const float param_clip_gradient, |
| const float param_momentum, const float param_lr, |
| const float param_wd, const float param_rescale_grad, |
| const OpReqType req) { |
| float w = weight32[i]; |
| if (param_clip_gradient >= 0.0f) { |
| mom_data[i] = param_momentum*mom_data[i]; |
| w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr |
| *(mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]), |
| param_clip_gradient)+(param_wd*w))); |
| mom_data[i] = mom_data[i] - param_lr |
| *((mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]), |
| param_clip_gradient))+(param_wd*w)); |
| weight32[i] = w; |
| KERNEL_ASSIGN(out_data[i], req, w); |
| } else { |
| mom_data[i] = param_momentum*mom_data[i]; |
| w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr |
| *(param_rescale_grad*static_cast<float>(grad_data[i])+(param_wd*w))); |
| mom_data[i] = mom_data[i] - param_lr |
| *((param_rescale_grad*static_cast<float>(grad_data[i]))+(param_wd*w)); |
| weight32[i] = w; |
| KERNEL_ASSIGN(out_data[i], req, w); |
| } |
| } |
| }; |
| |
| template<typename xpu> |
| inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<MP_NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, |
| mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, |
| param.clip_gradient, param.momentum, param.lr, param.wd, |
| param.rescale_grad, req[0]); |
| }); |
| } |
| |
| |
| struct FTMLParam : public dmlc::Parameter<FTMLParam> { |
| float lr; |
| float beta1; |
| float beta2; |
| double epsilon; |
| int t; |
| float wd; |
| float rescale_grad; |
| float clip_grad; |
| DMLC_DECLARE_PARAMETER(FTMLParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate."); |
| DMLC_DECLARE_FIELD(beta1) |
| .set_default(0.6f) |
| .set_range(0.0f, 1.0f) |
| .describe("Generally close to 0.5."); |
| DMLC_DECLARE_FIELD(beta2) |
| .set_default(0.999f) |
| .set_range(0.0f, 1.0f) |
| .describe("Generally close to 1."); |
| DMLC_DECLARE_FIELD(epsilon) |
| .set_default(1e-8f) |
| .describe("Epsilon to prevent div 0."); |
| DMLC_DECLARE_FIELD(t) |
| .describe("Number of update."); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_grad) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| |
| struct FTMLKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out, DType* weight, DType* grad, |
| DType* d, DType* v, DType* z, const DType lr, const DType beta1, |
| const DType beta2, const DType epsilon, const DType t, |
| const DType wd, const DType rescale_grad, const DType clip_grad, |
| const OpReqType req) { |
| using namespace mshadow_op; |
| const DType grad_i = clip_grad >= 0.0f |
| ? clip::Map(rescale_grad * grad[i] + wd * weight[i], clip_grad) |
| : (rescale_grad * grad[i] + wd * weight[i]); |
| v[i] = beta2 * v[i] + (1 - beta2) * square::Map(grad_i); |
| const DType d_t = (1 - power::Map(beta1, t)) / lr * |
| (square_root::Map(v[i] / (1 - power::Map(beta2, t))) + epsilon); |
| z[i] = beta1 * z[i] + (1 - beta1) * grad_i - (d_t - beta1 * d[i]) * weight[i]; |
| d[i] = d_t; |
| KERNEL_ASSIGN(out[i], req, - z[i] / d_t); |
| } |
| }; |
| |
| |
| template<typename xpu> |
| inline void FTMLUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| FTMLParam param = nnvm::get<FTMLParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| DType* weight_data = inputs[0].dptr<DType>(); |
| DType* grad_data = inputs[1].dptr<DType>(); |
| DType* d_data = inputs[2].dptr<DType>(); |
| DType* v_data = inputs[3].dptr<DType>(); |
| DType* z_data = inputs[4].dptr<DType>(); |
| DType* out_data = outputs[0].dptr<DType>(); |
| Kernel<FTMLKernel, xpu>::Launch(s, inputs[0].shape_.Size(), out_data, |
| weight_data, grad_data, d_data, v_data, z_data, static_cast<DType>(param.lr), |
| static_cast<DType>(param.beta1), static_cast<DType>(param.beta2), |
| static_cast<DType>(param.epsilon), static_cast<DType>(param.t), static_cast<DType>(param.wd), |
| static_cast<DType>(param.rescale_grad), static_cast<DType>(param.clip_grad), |
| req[0]); |
| }); |
| } |
| |
| struct AdamParam : public dmlc::Parameter<AdamParam> { |
| float lr; |
| float beta1; |
| float beta2; |
| float epsilon; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| bool lazy_update; |
| DMLC_DECLARE_PARAMETER(AdamParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(beta1) |
| .set_default(0.9f) |
| .describe("The decay rate for the 1st moment estimates."); |
| DMLC_DECLARE_FIELD(beta2) |
| .set_default(0.999f) |
| .describe("The decay rate for the 2nd moment estimates."); |
| DMLC_DECLARE_FIELD(epsilon) |
| .set_default(1e-8f) |
| .describe("A small constant for numerical stability."); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(lazy_update) |
| .set_default(true) |
| .describe("If true, lazy updates are applied if gradient's stype is row_sparse " |
| "and all of w, m and v have the same stype"); |
| } |
| }; |
| |
| struct AdamUpdateKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, |
| const DType clip_gradient, const DType rescale_grad, |
| const DType beta1, const DType beta2, |
| const DType lr, const DType wd, |
| const DType epsilon, const OpReqType req) { |
| using namespace mshadow_op; |
| |
| DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd; |
| if (clip_gradient >= 0.f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| |
| mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled; |
| var_data[i] = beta2 * var_data[i] + |
| (1.f - beta2) * grad_rescaled * grad_rescaled; |
| |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] / |
| (square_root::Map(var_data[i]) + epsilon)); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void AdamUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| Kernel<AdamUpdateKernel, xpu>::Launch(s, weight.shape_.Size(), |
| out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad), |
| static_cast<DType>(param.beta1), static_cast<DType>(param.beta2), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.epsilon), req[0]); |
| }); |
| } |
| |
| template<int req, typename xpu> |
| struct AdamDnsRspDnsKernel; |
| |
| /*! |
| * Note: this kernel performs sparse adam update. For each row-slice in row_sparse |
| * gradient, it finds the corresponding elements in weight, mean and var and performs |
| * the update. |
| * The kernel assumes dense weight/mean/var, and row_sparse gradient |
| */ |
| template<int req> |
| struct AdamDnsRspDnsKernel<req, cpu> { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, |
| DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2, |
| const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) { |
| using nnvm::dim_t; |
| using namespace mshadow_op; |
| const dim_t row_offset = grad_idx[i] * row_length; |
| for (dim_t j = 0; j < row_length; j++) { |
| // index in data/mean/var |
| const dim_t data_i = row_offset + j; |
| // index in grad |
| const dim_t grad_i = i * row_length + j; |
| const DType grad_rescaled = grad_data[grad_i] * rescale_grad + weight_data[data_i] * wd; |
| if (clip_gradient >= 0.0f) { |
| mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * |
| clip::Map(grad_rescaled, clip_gradient); |
| var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map( |
| clip::Map(grad_rescaled, clip_gradient)); |
| } else { |
| mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled; |
| var_data[data_i] = beta2 * var_data[data_i] + |
| (1.f - beta2) * grad_rescaled * grad_rescaled; |
| } |
| KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] / |
| (square_root::Map(var_data[data_i]) + epsilon)); |
| } |
| } |
| }; |
| |
| |
| template<int req> |
| struct AdamDnsRspDnsKernel<req, gpu> { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, |
| DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2, |
| const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) { |
| using nnvm::dim_t; |
| using namespace mshadow_op; |
| const dim_t row_id = i / row_length; |
| const dim_t col_id = i % row_length; |
| const dim_t row_offset = grad_idx[row_id] * row_length; |
| // index in data/mean/var |
| const dim_t data_i = row_offset + col_id; |
| // index in grad |
| DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[data_i] * wd; |
| if (clip_gradient >= 0.0f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled; |
| var_data[data_i] = beta2 * var_data[data_i] + |
| (1.f - beta2) * grad_rescaled * grad_rescaled; |
| KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] / |
| (square_root::Map(var_data[data_i]) + epsilon)); |
| } |
| }; |
| |
| /* |
| * \brief lazy adam update for dense weight, dense states and rsp grad. |
| */ |
| template<typename xpu> |
| inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param, |
| const OpContext& ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const TBlob& mean, |
| const TBlob& var, |
| const OpReqType& req, |
| TBlob *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| if (!grad.storage_initialized() || req == kNullOp) return; |
| CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update"; |
| CHECK_GT(weight.shape_.Size(), 0); |
| CHECK_GT(mean.shape_.Size(), 0); |
| CHECK_GT(var.shape_.Size(), 0); |
| |
| MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, req_type, { |
| const DType* weight_data = weight.dptr<DType>(); |
| const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); |
| const DType* grad_val = grad.data().dptr<DType>(); |
| DType* mean_data = mean.dptr<DType>(); |
| DType* var_data = var.dptr<DType>(); |
| DType* out_data = out->dptr<DType>(); |
| nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0]; |
| const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); |
| size_t num_threads = num_rows; |
| if (std::is_same<xpu, gpu>::value) { |
| num_threads = num_rows * row_length; |
| } |
| Kernel<AdamDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads, |
| row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1), |
| static_cast<DType>(param.beta2), static_cast<DType>(param.lr), |
| static_cast<DType>(param.wd), static_cast<DType>(param.epsilon), |
| static_cast<DType>(param.rescale_grad)); |
| }); |
| }); |
| }); |
| } |
| |
| /* |
| * \brief lazy adam update for both row_sparse and dense weight. |
| * grad is expected to be row_sparse. |
| */ |
| template<typename xpu> |
| inline void AdamLazyUpdateRspImpl(const AdamParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const NDArray& mean, |
| const NDArray& var, |
| const OpReqType& req, |
| NDArray *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| CheckAllRowsPresent(weight, "AdamUpdate", "weights"); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| // fill mean and variance with zero values in order to reuse the sgd mom dns impl |
| if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) { |
| NDArray mean_zeros = mean; |
| FillDnsZerosRspImpl(s, &mean_zeros); |
| } |
| if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) { |
| NDArray var_zeros = var; |
| FillDnsZerosRspImpl(s, &var_zeros); |
| } |
| TBlob out_blob = out->data(); |
| // reuse dns rsp implementation when storage_shape == shape |
| AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(), |
| var.data(), req, &out_blob); |
| } |
| |
| /* |
| * \brief kernel for standard adam update for dense weight, row_sparse grad and dense states. |
| */ |
| template<int req, typename xpu> |
| struct AdamStdDnsRspDnsKernel; |
| |
| /* |
| * \brief standard adam update for dense weight, row_sparse grad and dense states. |
| */ |
| template<typename xpu> |
| void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param, |
| const OpContext& ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const TBlob& mean, |
| const TBlob& var, |
| const OpReqType& req, |
| TBlob *out); |
| |
| /* |
| * \brief standard adam update for both row_sparse and dense weight. |
| * states are expected to be dense, while grad is expected to be row_sparse. |
| */ |
| template<typename xpu> |
| inline void AdamStdUpdateRspImpl(const AdamParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const NDArray& mean, |
| const NDArray& var, |
| const OpReqType& req, |
| NDArray *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| CheckAllRowsPresent(weight, "AdamStdUpdate", "weights"); |
| TBlob out_blob = out->data(); |
| // reuse dns rsp implementation when storage_shape == shape |
| AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(), |
| var.data(), req, &out_blob); |
| } |
| |
| template<typename xpu> |
| inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<NDArray> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<NDArray> &outputs) { |
| const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed); |
| const auto w_stype = inputs[0].storage_type(); |
| const auto g_stype = inputs[1].storage_type(); |
| const auto m_stype = inputs[2].storage_type(); |
| const auto v_stype = inputs[3].storage_type(); |
| const auto out_stype = outputs[0].storage_type(); |
| NDArray out = outputs[0]; |
| const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage; |
| CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype"; |
| CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype"; |
| if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) { |
| if (param.lazy_update) { |
| // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update |
| AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], |
| inputs[3], req[0], &out); |
| } else { |
| // rsp grad && m.stype = w.stype && lazy_update = false -> std update |
| AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], |
| inputs[3], req[0], &out); |
| } |
| } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage && |
| m_stype == kDefaultStorage) { |
| // rsp, rsp, dns, dns -> rsp, standard update |
| AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], |
| inputs[3], req[0], &out); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> { |
| float beta1; |
| float beta2; |
| float epsilon; |
| float t; |
| bool bias_correction; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) { |
| DMLC_DECLARE_FIELD(beta1) |
| .set_default(0.9f) |
| .describe("The decay rate for the 1st moment estimates."); |
| DMLC_DECLARE_FIELD(beta2) |
| .set_default(0.999f) |
| .describe("The decay rate for the 2nd moment estimates."); |
| DMLC_DECLARE_FIELD(epsilon) |
| .set_default(1e-6f) |
| .describe("A small constant for numerical stability."); |
| DMLC_DECLARE_FIELD(t) |
| .describe("Index update count."); |
| DMLC_DECLARE_FIELD(bias_correction) |
| .set_default(false) |
| .describe("Whether to use bias correction."); |
| DMLC_DECLARE_FIELD(wd) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> { |
| float lr; |
| float lower_bound; |
| float upper_bound; |
| DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(lower_bound) |
| .set_default(-1.0f) |
| .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set"); |
| DMLC_DECLARE_FIELD(upper_bound) |
| .set_default(-1.0f) |
| .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set"); |
| } |
| }; |
| |
| struct LambUpdatePhaseOneKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, |
| const DType clip_gradient, const DType rescale_grad, |
| const DType beta1, const DType beta2, const DType wd, |
| const DType epsilon, const DType t, |
| bool bias_correction, const OpReqType req) { |
| using namespace mshadow_op; |
| |
| DType grad_rescaled = grad_data[i] * rescale_grad; |
| if (clip_gradient >= 0.f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| |
| mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled; |
| var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled; |
| |
| DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; |
| |
| if (bias_correction) { |
| DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); |
| DType var_hat = var_data[i] / (1 - power::Map(beta2, t)); |
| g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; |
| } |
| KERNEL_ASSIGN(out_data[i], req, g); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(), |
| out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad), |
| static_cast<DType>(param.beta1), static_cast<DType>(param.beta2), |
| static_cast<DType>(param.wd), static_cast<DType>(param.epsilon), |
| static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]); |
| }); |
| } |
| |
| inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 4U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| |
| mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); |
| |
| mxnet::TShape& weight_shape = in_attrs->at(0); |
| mxnet::TShape& g_shape = in_attrs->at(1); |
| CHECK_EQ(weight_shape.ndim(), g_shape.ndim()) |
| << "total no. of dimensions for weights and g must match"; |
| for (int i=0; i < weight_shape.ndim(); ++i) { |
| CHECK_EQ(weight_shape[i], g_shape[i]) |
| << "weight and g dimension size mismatch at " << i << "-th index"; |
| } |
| mxnet::TShape& r1_shape = in_attrs->at(2); |
| mxnet::TShape& r2_shape = in_attrs->at(3); |
| CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect"; |
| CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect"; |
| for (int i=0; i < expected_out.ndim(); ++i) { |
| expected_out[i] = weight_shape[i]; |
| } |
| |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out); |
| return shape_is_known(expected_out); |
| } |
| |
| struct LambUpdatePhaseTwoKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| const DType* weight_data, const DType* g, |
| const DType* r1, const DType* r2, |
| DType lr, const DType lower_bound, |
| const DType upper_bound, const OpReqType req) { |
| using namespace mshadow_op; |
| |
| DType new_r1 = r1[0]; |
| if (lower_bound >= 0) { |
| new_r1 = maximum::Map(new_r1, lower_bound); |
| } |
| if (upper_bound >= 0) { |
| new_r1 = minimum::Map(new_r1, upper_bound); |
| } |
| if (new_r1 == 0.0f || r2[0] == 0.0f) { |
| lr = lr * 1.0f; |
| } else { |
| lr = lr * new_r1 / r2[0]; |
| } |
| |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(), |
| out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, |
| static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound), |
| static_cast<DType>(param.upper_bound), req[0]); |
| }); |
| } |
| |
| template<int n_in, int n_out, int total_in> |
| inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs, |
| std::vector<int> *in_attrs, |
| std::vector<int> *out_attrs) { |
| CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name; |
| CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name; |
| for (int i = 0; i < n_in; ++i) { |
| TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat16); |
| } |
| for (int i = n_in; i < total_in; ++i) { |
| TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32); |
| } |
| for (int i = 0; i < n_out; ++i) { |
| TYPE_ASSIGN_CHECK(*out_attrs, i, mshadow::kFloat32); |
| } |
| return true; |
| } |
| |
| struct MPLambUpdatePhaseOneKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, float* out_data, |
| float* mean_data, float* var_data, const DType* weight_data, |
| const DType* grad_data, const float* weight32_data, |
| const float clip_gradient, const float rescale_grad, |
| const float beta1, const float beta2, const float wd, |
| const float epsilon, const float t, |
| bool bias_correction, const OpReqType req) { |
| using namespace mshadow_op; |
| |
| float grad_rescaled = grad_data[i] * rescale_grad; |
| if (clip_gradient >= 0.f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| |
| mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled; |
| var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled; |
| |
| float g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; |
| |
| if (bias_correction) { |
| float mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); |
| float var_hat = var_data[i] / (1 - power::Map(beta2, t)); |
| g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; |
| } |
| KERNEL_ASSIGN(out_data[i], req, g); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void MPLambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s); |
| |
| Kernel<MPLambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(), |
| out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, |
| param.clip_gradient, param.rescale_grad, param.beta1, param.beta2, |
| param.wd, param.epsilon, param.t, param.bias_correction, req[0]); |
| }); |
| } |
| |
| inline bool MPLambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 5U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| |
| mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); |
| |
| mxnet::TShape& weight_shape = in_attrs->at(0); |
| mxnet::TShape& g_shape = in_attrs->at(1); |
| mxnet::TShape& weight32_shape = in_attrs->at(4); |
| CHECK_EQ(weight_shape.ndim(), g_shape.ndim()) |
| << "total no. of dimensions for weights and g must match"; |
| CHECK_EQ(weight_shape.ndim(), weight32_shape.ndim()) |
| << "total no. of dimensions for weights and g must match"; |
| for (int i=0; i < weight_shape.ndim(); ++i) { |
| CHECK_EQ(weight_shape[i], g_shape[i]) |
| << "weight and g dimension size mismatch at " << i << "-th index"; |
| CHECK_EQ(weight_shape[i], weight32_shape[i]) |
| << "weight and g dimension size mismatch at " << i << "-th index"; |
| } |
| mxnet::TShape& r1_shape = in_attrs->at(2); |
| mxnet::TShape& r2_shape = in_attrs->at(3); |
| CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect"; |
| CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect"; |
| for (int i=0; i < expected_out.ndim(); ++i) { |
| expected_out[i] = weight_shape[i]; |
| } |
| |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out); |
| return shape_is_known(expected_out); |
| } |
| |
| struct MPLambUpdatePhaseTwoKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| const DType* weight_data, const float* g, |
| const float* r1, const float* r2, const float* weight32, |
| float lr, const float lower_bound, |
| const float upper_bound, const OpReqType req) { |
| using namespace mshadow_op; |
| |
| float new_r1 = r1[0]; |
| if (lower_bound >= 0) { |
| new_r1 = maximum::Map(new_r1, lower_bound); |
| } |
| if (upper_bound >= 0) { |
| new_r1 = minimum::Map(new_r1, upper_bound); |
| } |
| if (new_r1 == 0.0f || r2[0] == 0.0f) { |
| lr = lr * 1.0f; |
| } else { |
| lr = lr * new_r1 / r2[0]; |
| } |
| |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void MPLambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, float> g = inputs[1].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> r1 = inputs[2].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> r2 = inputs[3].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| Kernel<MPLambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(), |
| out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, weight32.dptr_, |
| param.lr, param.lower_bound, |
| param.upper_bound, req[0]); |
| }); |
| } |
| |
| // This RMSProp code follows the version in |
| // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) |
| // by Alex Graves, 2013. |
| struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> { |
| float lr; |
| float gamma1; |
| float gamma2; |
| float epsilon; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| float clip_weights; |
| DMLC_DECLARE_PARAMETER(RMSPropAlexParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(gamma1).set_default(0.95f) |
| .describe("Decay rate."); |
| DMLC_DECLARE_FIELD(gamma2).set_default(0.9f) |
| .describe("Decay rate."); |
| DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f) |
| .describe("A small constant for numerical stability."); |
| DMLC_DECLARE_FIELD(wd).set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(clip_weights) |
| .set_default(-1.0f) |
| .describe("Clip weights to the range of [-clip_weights, clip_weights] " |
| "If clip_weights <= 0, weight clipping is turned off. " |
| "weights = max(min(weights, clip_weights), -clip_weights)."); |
| } |
| }; |
| |
| struct RMSPropAlexUpdateKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| DType* state_n_data, DType* state_g_data, DType* delta_data, |
| const DType* weight_data, const DType* grad_data, |
| const DType clip_gradient, const DType rescale_grad, |
| const DType gamma1, const DType gamma2, |
| const DType lr, const DType wd, |
| const DType clip_weights, const DType epsilon, |
| const OpReqType req) { |
| using namespace mshadow_op; |
| |
| DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i]; |
| if (clip_gradient >= 0.0f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| |
| state_n_data[i] = (1.f - gamma1) * grad_rescaled * grad_rescaled + |
| gamma1 * state_n_data[i]; |
| state_g_data[i] = (1.f - gamma1) * grad_rescaled + |
| gamma1 * state_g_data[i]; |
| delta_data[i] = gamma2 * delta_data[i] - |
| (lr * (grad_rescaled) / |
| (square_root::Map(state_n_data[i] - |
| state_g_data[i] * state_g_data[i] + epsilon))); |
| |
| if (clip_weights >= 0.0f) { |
| const DType clipped_weight = clip::Map(weight_data[i] + delta_data[i], clip_weights); |
| KERNEL_ASSIGN(out_data[i], req, clipped_weight); |
| } else { |
| KERNEL_ASSIGN(out_data[i], req, weight_data[i] + delta_data[i]); |
| } |
| } |
| }; |
| |
| template <typename xpu> |
| inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const RMSPropAlexParam ¶m = nnvm::get<RMSPropAlexParam>(attrs.parsed); |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| DType* weight_data = inputs[0].dptr<DType>(); |
| DType* grad_data = inputs[1].dptr<DType>(); |
| DType* state_n_data = inputs[2].dptr<DType>(); |
| DType* state_g_data = inputs[3].dptr<DType>(); |
| DType* delta_data = inputs[4].dptr<DType>(); |
| DType* out_data = outputs[0].dptr<DType>(); |
| |
| Kernel<RMSPropAlexUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(), |
| out_data, state_n_data, state_g_data, delta_data, weight_data, grad_data, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad), |
| static_cast<DType>(param.gamma1), static_cast<DType>(param.gamma2), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]); |
| }); |
| } |
| |
| // This RMSProp code follows the version in |
| // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf |
| // by Tieleman & Hinton, 2012 |
| struct RMSPropParam : public dmlc::Parameter<RMSPropParam> { |
| float lr; |
| float gamma1; |
| float epsilon; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| float clip_weights; |
| DMLC_DECLARE_PARAMETER(RMSPropParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(gamma1).set_default(0.95f) |
| .describe("The decay rate of momentum estimates."); |
| DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f) |
| .describe("A small constant for numerical stability."); |
| DMLC_DECLARE_FIELD(wd).set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(clip_weights) |
| .set_default(-1.0f) |
| .describe("Clip weights to the range of [-clip_weights, clip_weights] " |
| "If clip_weights <= 0, weight clipping is turned off. " |
| "weights = max(min(weights, clip_weights), -clip_weights)."); |
| } |
| }; |
| |
| struct RMSPropUpdateKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, |
| DType* out_data, DType* state_n_data, |
| const DType* weight_data, const DType* grad_data, |
| const DType clip_gradient, const DType rescale_grad, |
| const DType gamma1, const DType lr, const DType wd, |
| const DType clip_weights, const DType epsilon, |
| const OpReqType req) { |
| using namespace mshadow_op; |
| |
| DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i]; |
| if (clip_gradient >= 0.0f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| |
| state_n_data[i] = (1.f - gamma1) * (grad_rescaled * grad_rescaled) + gamma1 * state_n_data[i]; |
| |
| DType weight = weight_data[i] - |
| lr * (grad_rescaled / square_root::Map(state_n_data[i] + epsilon)); |
| if (clip_weights >= 0.0f) { |
| weight = clip::Map(weight, clip_weights); |
| } |
| KERNEL_ASSIGN(out_data[i], req, weight); |
| } |
| }; |
| |
| template <typename xpu> |
| inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const RMSPropParam ¶m = nnvm::get<RMSPropParam>(attrs.parsed); |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| DType* weight_data = inputs[0].dptr<DType>(); |
| DType* grad_data = inputs[1].dptr<DType>(); |
| DType* state_n_data = inputs[2].dptr<DType>(); |
| DType* out_data = outputs[0].dptr<DType>(); |
| |
| Kernel<RMSPropUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(), |
| out_data, state_n_data, weight_data, grad_data, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad), |
| static_cast<DType>(param.gamma1), static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]); |
| }); |
| } |
| |
| struct FtrlParam : public dmlc::Parameter<FtrlParam> { |
| float lr; |
| float lamda1; |
| float beta; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(FtrlParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(lamda1) |
| .set_default(0.01f) |
| .describe("The L1 regularization coefficient."); |
| DMLC_DECLARE_FIELD(beta) |
| .set_default(1.0f) |
| .describe("Per-Coordinate Learning Rate beta."); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| struct FtrlUpdateKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, |
| DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data, |
| const DType clip_gradient, const DType rescale_grad, |
| const DType beta, const DType lamda1, |
| const DType lr, const DType wd, |
| const OpReqType req) { |
| using namespace mshadow_op; |
| |
| DType grad_rescaled = grad_data[i] * rescale_grad; |
| if (clip_gradient >= 0.0f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| |
| z_data[i] += grad_rescaled - (square_root::Map(n_data[i] + |
| square::Map(grad_rescaled)) - square_root::Map(n_data[i])) * |
| weight_data[i] / lr; |
| n_data[i] += square::Map(grad_rescaled); |
| |
| KERNEL_ASSIGN(out_data[i], req, |
| (sign::Map(z_data[i]) * lamda1 - z_data[i]) / |
| ((beta + square_root::Map(n_data[i])) / lr + wd) * |
| gt::Map(abs::Map(z_data[i]), lamda1)); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void FtrlUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| |
| const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> z = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> n = inputs[3].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| |
| Kernel<FtrlUpdateKernel, xpu>::Launch(s, weight.shape_.Size(), |
| out.dptr_, n.dptr_, z.dptr_, weight.dptr_, grad.dptr_, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad), |
| static_cast<DType>(param.beta), static_cast<DType>(param.lamda1), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), req[0]); |
| }); |
| } |
| |
| template<int req> |
| struct FtrlDnsRspDnsKernel { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data, |
| DType* z_data, DType* n_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType lamda1, const DType beta, |
| const DType lr, const DType wd, const DType rescale_grad) { |
| using nnvm::dim_t; |
| using namespace mshadow_op; |
| const dim_t row_offset = grad_idx[i] * row_length; |
| for (dim_t j = 0; j < row_length; j++) { |
| // index in data/z/n |
| const dim_t data_i = row_offset + j; |
| // index in grad |
| const dim_t grad_i = i * row_length + j; |
| const DType grad_rescaled = grad_data[grad_i] * rescale_grad; |
| if (clip_gradient >= 0.0f) { |
| z_data[data_i] += clip::Map(grad_rescaled, clip_gradient) - |
| (square_root::Map(n_data[data_i] + |
| square::Map(clip::Map(grad_rescaled, clip_gradient))) - |
| square_root::Map(n_data[data_i])) * weight_data[data_i] / lr; |
| n_data[data_i] += square::Map(clip::Map(grad_rescaled, clip_gradient)); |
| } else { |
| z_data[data_i] += grad_rescaled - (square_root::Map(n_data[data_i] + |
| square::Map(grad_rescaled)) - square_root::Map(n_data[data_i])) * |
| weight_data[data_i] / lr; |
| n_data[data_i] += square::Map(grad_rescaled); |
| } |
| KERNEL_ASSIGN(out_data[data_i], req, |
| (sign::Map(z_data[data_i]) * lamda1 - z_data[data_i]) / |
| ((beta + square_root::Map(n_data[data_i])) / lr + wd) * |
| gt::Map(abs::Map(z_data[data_i]), lamda1)); |
| } |
| } |
| }; |
| |
| |
| template<typename xpu> |
| inline void FtrlUpdateDnsRspDnsImpl(const FtrlParam& param, |
| const OpContext& ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const TBlob& z, |
| const TBlob& n, |
| const OpReqType& req, |
| TBlob *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| if (!grad.storage_initialized() || req == kNullOp) return; |
| CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse ftrl_update"; |
| CHECK_GT(weight.shape_.Size(), 0); |
| CHECK_GT(z.shape_.Size(), 0); |
| CHECK_GT(n.shape_.Size(), 0); |
| |
| MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, req_type, { |
| const DType* weight_data = weight.dptr<DType>(); |
| const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); |
| const DType* grad_val = grad.data().dptr<DType>(); |
| DType* z_data = z.dptr<DType>(); |
| DType* n_data = n.dptr<DType>(); |
| DType* out_data = out->dptr<DType>(); |
| nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0]; |
| const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); |
| Kernel<FtrlDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, row_length, |
| out_data, z_data, n_data, weight_data, grad_idx, grad_val, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.lamda1), |
| static_cast<DType>(param.beta), static_cast<DType>(param.lr), |
| static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad)); |
| }); |
| }); |
| }); |
| } |
| |
| template<typename xpu> |
| inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const NDArray& z, |
| const NDArray& n, |
| const OpReqType& req, |
| NDArray *out) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| CheckAllRowsPresent(weight, "FtrlUpdate", "weights"); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| // fill z and n with zero values in order to reuse the sgd mom dns impl |
| if (!z.storage_initialized()) { |
| NDArray z_zeros = z; |
| FillDnsZerosRspImpl(s, &z_zeros); |
| } |
| if (!n.storage_initialized()) { |
| NDArray n_zeros = n; |
| FillDnsZerosRspImpl(s, &n_zeros); |
| } |
| TBlob out_blob = out->data(); |
| // reuse dns rsp implementation when storage_shape == shape |
| FtrlUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, z.data(), |
| n.data(), req, &out_blob); |
| } |
| |
| template<typename xpu> |
| inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<NDArray> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<NDArray> &outputs) { |
| const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed); |
| const auto weight_stype = inputs[0].storage_type(); |
| const auto z_stype = inputs[2].storage_type(); |
| const auto n_stype = inputs[3].storage_type(); |
| |
| const auto out_stype = outputs[0].storage_type(); |
| CHECK_EQ(z_stype, weight_stype) << "Inconsistent storage type detected between " |
| << " z.stype = " << z_stype << " and weight.stype = " << weight_stype; |
| CHECK_EQ(n_stype, weight_stype) << "Inconsistent storage type detected between " |
| << " n.stype = " << n_stype << " and weight.stype = " << weight_stype; |
| if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && out_stype == kRowSparseStorage) { |
| NDArray out = outputs[0]; |
| FtrlUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], |
| inputs[3], req[0], &out); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| |
| // Implementation for signSGD and Signum |
| |
| struct SignSGDParam : public dmlc::Parameter<SignSGDParam> { |
| float lr; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| DMLC_DECLARE_PARAMETER(SignSGDParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| |
| struct SignSGDKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, |
| const DType* grad_data, const DType param_clip_gradient, |
| const DType param_lr, const DType param_wd, const DType param_rescale_grad, |
| const OpReqType req) { |
| |
| // param_clip_gradient has no effect for SignSGD |
| KERNEL_ASSIGN(out_data[i], req, |
| (1.f-param_lr*param_wd)*weight_data[i] |
| - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0))); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_, |
| grad.dptr_, static_cast<DType>(param.clip_gradient), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.rescale_grad), req[0]); |
| }); |
| } |
| |
| |
| struct SignumParam : public dmlc::Parameter<SignumParam> { |
| float lr; |
| float momentum; |
| float wd; |
| float rescale_grad; |
| float clip_gradient; |
| float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter |
| DMLC_DECLARE_PARAMETER(SignumParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(momentum) |
| .set_default(0.0f) |
| .describe("The decay rate of momentum estimates at each epoch."); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("Weight decay augments the objective function with a " |
| "regularization term that penalizes large weights. " |
| "The penalty scales with the square of the magnitude of each weight."); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| DMLC_DECLARE_FIELD(wd_lh) |
| .set_default(0.0f) |
| .describe("The amount of weight decay that does not go into gradient/momentum calculations" |
| "otherwise do weight decay algorithmically only."); |
| } |
| }; |
| |
| struct SignumKernel { |
| template<typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, |
| const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, |
| const DType param_lr, const DType param_wd, const DType param_rescale_grad, |
| const DType param_wd_lh, const OpReqType req) { |
| if (param_clip_gradient >= 0.0f) { |
| mom_data[i] = param_momentum*mom_data[i] |
| - (1-param_momentum)*param_wd*weight_data[i] |
| - (1-param_momentum) |
| *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient); |
| } else { |
| mom_data[i] = param_momentum*mom_data[i] |
| - (1-param_momentum)*param_wd*weight_data[i] |
| - (1-param_momentum)*param_rescale_grad*grad_data[i]; |
| } |
| KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i] |
| + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0))); |
| } |
| }; |
| |
| template<typename xpu> |
| inline void SignumUpdate(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<TBlob> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &outputs) { |
| using namespace mxnet_op; |
| SignumParam param = nnvm::get<SignumParam>(attrs.parsed); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { |
| Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s); |
| Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); |
| Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_, |
| grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), |
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), |
| static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]); |
| }); |
| } |
| |
| struct AdagradParam : public dmlc::Parameter<AdagradParam> { |
| float lr; |
| float epsilon; |
| float rescale_grad; |
| float clip_gradient; |
| float wd; |
| DMLC_DECLARE_PARAMETER(AdagradParam) { |
| DMLC_DECLARE_FIELD(lr) |
| .describe("Learning rate"); |
| DMLC_DECLARE_FIELD(epsilon) |
| .set_default(1.0e-7) |
| .describe("epsilon"); |
| DMLC_DECLARE_FIELD(wd) |
| .set_default(0.0f) |
| .describe("weight decay"); |
| DMLC_DECLARE_FIELD(rescale_grad) |
| .set_default(1.0f) |
| .describe("Rescale gradient to grad = rescale_grad*grad."); |
| DMLC_DECLARE_FIELD(clip_gradient) |
| .set_default(-1.0f) |
| .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " |
| "If clip_gradient <= 0, gradient clipping is turned off. " |
| "grad = max(min(grad, clip_gradient), -clip_gradient)."); |
| } |
| }; |
| |
| inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed); |
| CHECK_EQ(in_attrs->size(), 3U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| const int weight_stype = in_attrs->at(0); |
| const int grad_stype = in_attrs->at(1); |
| const int state_stype = in_attrs->at(2); |
| bool dispatched = false; |
| if (!dispatched && grad_stype == kRowSparseStorage && |
| (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) && |
| state_stype == weight_stype && param.wd == 0.0f) { |
| // weight and state share stype, grad's stype = rsp |
| dispatched = storage_type_assign( |
| out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode, |
| DispatchMode::kFComputeEx); |
| } |
| return dispatched; |
| } |
| |
| template<typename xpu> |
| struct AdagradDnsRspDnsKernel; |
| |
| template<> |
| struct AdagradDnsRspDnsKernel<cpu> { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, |
| DType* state_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType epsilon, |
| const DType lr, const DType rescale_grad) { |
| using nnvm::dim_t; |
| using namespace mshadow_op; |
| const dim_t data_i = grad_idx[i] * row_length; |
| const dim_t grad_i = i * row_length; |
| for (dim_t j = 0; j < row_length; j++) { |
| const dim_t data_j = data_i + j; |
| const dim_t grad_j = grad_i + j; |
| DType grad_rescaled = grad_data[grad_j] * rescale_grad; |
| if (clip_gradient >= 0.0f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| const DType grad_squared = grad_rescaled * grad_rescaled; |
| state_data[data_j] += grad_squared; |
| const DType div = grad_rescaled / square_root::Map(state_data[data_j] + epsilon); |
| // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace |
| out_data[data_j] = weight_data[data_j] - div * lr; |
| } |
| } |
| }; |
| |
| template<> |
| struct AdagradDnsRspDnsKernel<gpu> { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, |
| DType* state_data, const DType* weight_data, const IType* grad_idx, |
| const DType* grad_data, const DType clip_gradient, const DType epsilon, |
| const DType lr, const DType rescale_grad) { |
| using nnvm::dim_t; |
| using namespace mshadow_op; |
| const dim_t row_id = i / row_length; |
| const dim_t col_id = i % row_length; |
| const dim_t data_i = grad_idx[row_id] * row_length + col_id; |
| DType grad_rescaled = grad_data[i] * rescale_grad; |
| if (clip_gradient >= 0.0f) { |
| grad_rescaled = clip::Map(grad_rescaled, clip_gradient); |
| } |
| const DType grad_squared = grad_rescaled * grad_rescaled; |
| state_data[data_i] += grad_squared; |
| const DType div = grad_rescaled / square_root::Map(state_data[data_i] + epsilon); |
| // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace |
| out_data[data_i] = weight_data[data_i] - div * lr; |
| } |
| }; |
| |
| template<typename xpu> |
| void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param, |
| const OpContext& ctx, |
| const TBlob& weight, |
| const NDArray& grad, |
| const TBlob& state, |
| const OpReqType& req, |
| TBlob *out) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| using namespace mshadow; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| CHECK_EQ(param.wd, 0.0f) |
| << "sparse adagrad_update does not support wd."; |
| if (req == kNullOp || !grad.storage_initialized()) return; |
| CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adagrad_update"; |
| CHECK_GT(weight.shape_.Size(), 0); |
| CHECK_GT(state.shape_.Size(), 0); |
| MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { |
| const DType* weight_data = weight.dptr<DType>(); |
| const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); |
| const DType* grad_val = grad.data().dptr<DType>(); |
| DType* state_data = state.dptr<DType>(); |
| DType* out_data = out->dptr<DType>(); |
| const nnvm::dim_t nnr = grad.storage_shape()[0]; |
| const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); |
| size_t num_threads = nnr; |
| if (std::is_same<xpu, gpu>::value) { |
| num_threads = nnr * row_length; |
| } |
| Kernel<AdagradDnsRspDnsKernel<xpu>, xpu>::Launch(s, num_threads, row_length, |
| out_data, state_data, weight_data, grad_idx, grad_val, |
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.epsilon), |
| static_cast<DType>(param.lr), static_cast<DType>(param.rescale_grad)); |
| }); |
| }); |
| } |
| |
| template<typename xpu> |
| inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param, |
| const OpContext& ctx, |
| const NDArray& weight, |
| const NDArray& grad, |
| const NDArray& state, |
| const OpReqType& req, |
| NDArray *out) { |
| using namespace mshadow; |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| CheckAllRowsPresent(weight, "AdagradUpdate", "weights"); |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| // fill history with zero values |
| if (!state.storage_initialized()) { |
| NDArray state_zeros = state; |
| FillDnsZerosRspImpl(s, &state_zeros); |
| } |
| TBlob out_blob = out->data(); |
| // reuse dns rsp implementation when storage_shape == shape |
| AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, |
| state.data(), req, &out_blob); |
| } |
| |
| template<typename xpu> |
| inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs, |
| const OpContext &ctx, |
| const std::vector<NDArray> &inputs, |
| const std::vector<OpReqType> &req, |
| const std::vector<NDArray> &outputs) { |
| using namespace mxnet_op; |
| const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed); |
| |
| const auto weight_stype = inputs[0].storage_type(); |
| const auto grad_stype = inputs[1].storage_type(); |
| const auto state_stype = inputs[2].storage_type(); |
| const auto output_stype = outputs[0].storage_type(); |
| |
| if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && |
| common::ContainsOnlyStorage(outputs, kRowSparseStorage)) { |
| NDArray out = outputs[0]; |
| AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], |
| req[0], &out); |
| } else if (state_stype == weight_stype && output_stype == weight_stype && |
| weight_stype == kDefaultStorage && |
| grad_stype == kRowSparseStorage) { |
| TBlob out_blob = outputs[0].data(); |
| AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1], |
| inputs[2].data(), req[0], |
| &out_blob); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_OPERATOR_OPTIMIZER_OP_INL_H_ |