| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file elemwise_binary_scalar_op.h |
| * \brief Function definition of elementwise binary scalar operators |
| */ |
| #ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_ |
| #define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_ |
| |
| #include <mxnet/operator_util.h> |
| #include <vector> |
| #include <utility> |
| #include "../mshadow_op.h" |
| #include "../elemwise_op_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| template<typename xpu, typename OP> |
| void BinaryScalarCompute(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| double alpha = nnvm::get<double>(attrs.parsed); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> lhs = inputs[0].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(out, req[0], F<OP>(lhs, scalar<DType>(DType(alpha)))); |
| }); |
| } |
| |
| template<typename xpu, typename OP> |
| void BinaryScalarBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| double alpha = nnvm::get<double>(attrs.parsed); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s); |
| Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(igrad, req[0], ograd*F<OP>(lhs, scalar<DType>(DType(alpha)))); |
| }); |
| } |
| |
| #define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \ |
| NNVM_REGISTER_OP(name) \ |
| .set_num_inputs(1) \ |
| .set_num_outputs(1) \ |
| .set_attr_parser([](NodeAttrs* attrs) { \ |
| attrs->parsed = std::stod(attrs->dict["scalar"]); \ |
| }) \ |
| .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \ |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \ |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", \ |
| [](const NodeAttrs& attrs){ \ |
| return std::vector<std::pair<int, int> >{{0, 0}}; \ |
| }) \ |
| .add_argument("data", "NDArray-or-Symbol", "source input") \ |
| .add_argument("scalar", "float", "scalar input") |
| |
| } // namespace op |
| } // namespace mxnet |
| #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_ |