blob: ff63c0e00dcf330ad3ad2313a0b723d7a815319f [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file regression_output.cc
* \brief regression output operator
*/
#include "./regression_output-inl.h"
#include "./mshadow_op.h"
namespace mxnet {
namespace op {
template<>
Operator *CreateRegressionOutputOp<cpu>(reg_enum::RegressionOutputType type,
RegressionOutputParam param) {
switch (type) {
case reg_enum::kLinear:
return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow::op::minus>(param);
case reg_enum::kLogistic:
return new RegressionOutputOp<cpu, mshadow_op::sigmoid, mshadow::op::minus>(param);
case reg_enum::kMAE:
return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow_op::minus_sign>(param);
default:
LOG(FATAL) << "unknown activation type " << type;
}
return nullptr;
}
// DO_BIND_DISPATCH comes from operator_common.h
template<reg_enum::RegressionOutputType type>
Operator *RegressionOutputProp<type>::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateRegressionOutputOp, type, param_);
}
DMLC_REGISTER_PARAMETER(RegressionOutputParam);
MXNET_REGISTER_OP_PROPERTY(LinearRegressionOutput, RegressionOutputProp<reg_enum::kLinear>)
.describe("Use linear regression for final output, this is used on final output of a net.")
.add_argument("data", "Symbol", "Input data to function.")
.add_argument("label", "Symbol", "Input label to function.")
.add_arguments(RegressionOutputParam::__FIELDS__());
MXNET_REGISTER_OP_PROPERTY(MAERegressionOutput, RegressionOutputProp<reg_enum::kMAE>)
.describe("Use mean absolute error regression for final output, "
"this is used on final output of a net.")
.add_argument("data", "Symbol", "Input data to function.")
.add_argument("label", "Symbol", "Input label to function.")
.add_arguments(RegressionOutputParam::__FIELDS__());
MXNET_REGISTER_OP_PROPERTY(LogisticRegressionOutput, RegressionOutputProp<reg_enum::kLogistic>)
.describe("Use Logistic regression for final output, this is used on final output of a net.\n"
"Logistic regression is suitable for binary classification "
"or probability prediction tasks.")
.add_argument("data", "Symbol", "Input data to function.")
.add_argument("label", "Symbol", "Input label to function.")
.add_arguments(RegressionOutputParam::__FIELDS__());
} // namespace op
} // namespace mxnet