| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file elemwise_binary_op.h |
| * \brief Function defintion of elementwise binary operators |
| */ |
| #ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_ |
| #define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_ |
| |
| #include <mxnet/operator_util.h> |
| #include <vector> |
| #include <string> |
| #include <utility> |
| #include "../mshadow_op.h" |
| #include "../elemwise_op_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| template<typename xpu, typename OP> |
| void BinaryCompute(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> lhs = inputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> rhs = inputs[1].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(out, req[0], F<OP>(lhs, rhs)); |
| }); |
| } |
| |
| template<typename xpu, typename LOP, typename ROP> |
| void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| Tensor<xpu, 1, DType> lgrad = outputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> rgrad = outputs[1].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(lgrad, req[0], F<LOP>(ograd)); |
| ASSIGN_DISPATCH(rgrad, req[1], F<ROP>(ograd)); |
| }); |
| } |
| |
| template<typename xpu, typename LOP, typename ROP> |
| void BinaryBackwardUseOut(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| Tensor<xpu, 1, DType> lgrad = outputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> rgrad = outputs[1].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> out = inputs[1].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(lgrad, req[0], ograd*F<LOP>(out)); |
| ASSIGN_DISPATCH(rgrad, req[1], ograd*F<ROP>(out)); |
| }); |
| } |
| |
| template<typename xpu, typename LOP, typename ROP> |
| void BinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| Tensor<xpu, 1, DType> lgrad = outputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> rgrad = outputs[1].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> rhs = inputs[2].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(lgrad, req[0], ograd*F<LOP>(lhs, rhs)); |
| ASSIGN_DISPATCH(rgrad, req[1], ograd*F<ROP>(lhs, rhs)); |
| }); |
| } |
| |
| #define MXNET_OPERATOR_REGISTER_BINARY(name) \ |
| NNVM_REGISTER_OP(name) \ |
| .set_num_inputs(2) \ |
| .set_num_outputs(1) \ |
| .set_attr<nnvm::FListInputNames>("FListInputNames", \ |
| [](const NodeAttrs& attrs) { \ |
| return std::vector<std::string>{"lhs", "rhs"}; \ |
| }) \ |
| .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) \ |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \ |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", \ |
| [](const NodeAttrs& attrs){ \ |
| return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ |
| }) \ |
| .add_argument("lhs", "ndarray-or-symbol", "first input") \ |
| .add_argument("rhs", "ndarray-or-symbol", "second input") |
| |
| } // namespace op |
| } // namespace mxnet |
| #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_ |