blob: 1f69cf6138f22c48e90050bfe13d566edd12a8b8 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file loss_binary_op.cc
* \brief loss function that takes a data and label
*/
#include "./loss_binary_op-inl.h"
namespace mxnet {
namespace op {
NNVM_REGISTER_OP(softmax_cross_entropy)
.MXNET_DESCRIBE("Calculate cross_entropy(data, one_hot(label))")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxCrossEntropyShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCrossEntropyForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_softmax_cross_entropy"})
.add_argument("data", "NDArray", "Input data")
.add_argument("label", "NDArray", "Input label");
NNVM_REGISTER_OP(_backward_softmax_cross_entropy)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCrossEntropyBackward<cpu>);
} // namespace op
} // namespace mxnet