blob: ea1c0ad328161bdeb2892986dd229039270250f0 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file optimizer_op.cc
* \brief Optimizer operators
* \author Junyuan Xie
*/
#include "./optimizer_op-inl.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SGDParam);
NNVM_REGISTER_OP(sgd_update)
.describe("Updater function for sgd optimizer")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>);
NNVM_REGISTER_OP(sgd_mom_update)
.describe("Updater function for sgd optimizer")
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
.set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>);
DMLC_REGISTER_PARAMETER(AdamParam);
NNVM_REGISTER_OP(adam_update)
.describe("Updater function for adam optimizer")
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu>);
} // namespace op
} // namespace mxnet