blob: 36b45c375b955d621e6b561dd845c45ecb3cd7f1 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file sgd-inl.h
* \brief Operator interface of mxnet.
* \author Junyuan Xie
*/
#ifndef MXNET_OPTIMIZER_SGD_INL_H_
#define MXNET_OPTIMIZER_SGD_INL_H_
#include <mshadow/tensor.h>
#include <mxnet/optimizer.h>
#include <dmlc/parameter.h>
#include <string>
#include <vector>
#include <map>
#include <utility>
namespace mxnet {
namespace opt {
struct SGDParam : public dmlc::Parameter<SGDParam> {
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(SGDParam) {
DMLC_DECLARE_FIELD(momentum)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
.describe("momentum");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("rescale gradient as grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("If greater than 0, clip gradient to "
"grad = max(min(grad, -clip_gradient), clip_gradient). "
"Otherwise turned off.");
}
};
struct sgd_clip {
MSHADOW_XINLINE static real_t Map(real_t x, real_t bound) {
if (x > bound) {
return bound;
} else if (x < -bound) {
return -bound;
} else {
return x;
}
}
};
template<typename xpu>
void sgd_mom_update(RunContext ctx, TBlob weight, const TBlob grad, TBlob mom,
float lr, float wd, const SGDParam& param) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu>* s = ctx.get_stream<xpu>();
Tensor<xpu, 2> weight2d = weight.FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> mom2d = mom.FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> grad2d = grad.FlatTo2D<xpu, real_t>(s);
if (param.clip_gradient > 0.0f) {
mom2d = param.momentum*mom2d -
lr*(param.rescale_grad*F<sgd_clip>(grad2d, param.clip_gradient) + wd*weight2d);
} else {
mom2d = param.momentum*mom2d - lr*(param.rescale_grad*grad2d + wd*weight2d);
}
weight2d += mom2d;
}
template<typename xpu>
void sgd_update(RunContext ctx, TBlob weight, const TBlob grad,
float lr, float wd, const SGDParam& param) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu>* s = ctx.get_stream<xpu>();
Tensor<xpu, 2> weight2d = weight.FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> grad2d = grad.FlatTo2D<xpu, real_t>(s);
if (param.clip_gradient >= 0.0f) {
weight2d -= lr*(param.rescale_grad*F<sgd_clip>(grad2d, param.clip_gradient) +
wd*weight2d);
} else {
weight2d -= lr*(param.rescale_grad*grad2d + wd*weight2d);
}
}
void call_sgd_mom_update_cpu(RunContext ctx, TBlob weight, const TBlob grad, TBlob mom,
float lr, float wd, const SGDParam& param);
void call_sgd_update_cpu(RunContext ctx, TBlob weight, const TBlob grad,
float lr, float wd, const SGDParam& param);
#if MXNET_USE_CUDA
void call_sgd_mom_update_gpu(RunContext ctx, TBlob weight, const TBlob grad, TBlob mom,
float lr, float wd, const SGDParam& param);
void call_sgd_update_gpu(RunContext ctx, TBlob weight, const TBlob grad,
float lr, float wd, const SGDParam& param);
#endif // MXNET_USE_CUDA
#if DMLC_USE_CXX11
class SGDOpt : public Optimizer {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}
void CreateState(const int index, const NDArray *weight) override {
if (param_.momentum > 0.0f && mom.find(index) == mom.end()) {
mom[index] = NDArray(weight->shape(), weight->ctx());
mom[index] = 0.0f;
}
}
void Update(const int index, NDArray *weight,
const NDArray *grad, const float lr, const float wd) override {
NDArray w = *weight, g = *grad;
CreateState(index, weight);
switch (w.ctx().dev_type) {
case Context::kCPU:
case Context::kCPUPinned:
if (param_.momentum > 0.0f) {
Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) {
call_sgd_mom_update_cpu(ctx, w.data(), g.data(), mom[index].data(), lr, wd, param_);
}, w.ctx(), {g.var()}, {w.var(), mom[index].var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate"));
} else {
Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) {
call_sgd_update_cpu(ctx, w.data(), g.data(), lr, wd, param_);
}, w.ctx(), {g.var()}, {w.var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate"));
}
break;
case Context::kGPU:
#if MXNET_USE_CUDA
if (param_.momentum > 0.0f) {
Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) {
call_sgd_mom_update_gpu(ctx, w.data(), g.data(), mom[index].data(), lr, wd, param_);
}, w.ctx(), {g.var()}, {w.var(), mom[index].var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate"));
} else {
Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) {
call_sgd_update_gpu(ctx, w.data(), g.data(), lr, wd, param_);
}, w.ctx(), {g.var()}, {w.var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate"));
}
break;
#else
LOG(FATAL) << "Please compile with CUDA enabled for cuda features";
#endif // MXNET_USE_CUDA
default:
LOG(FATAL) << "Unsupported device type for sgd optimizer: " << w.ctx().dev_type;
}
}
private:
SGDParam param_;
std::map<int, NDArray> mom;
};
#endif // DMLC_USE_CXX11
} // namespace opt
} // namespace mxnet
#endif // MXNET_OPTIMIZER_SGD_INL_H_