blob: 5e577c6e3d4f256735ebf05a0dacf1a2f5f4f4d8 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file elemwise_binary_scalar_op.h
* \brief Function definition of elementwise binary scalar operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <utility>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
template<typename xpu, typename OP>
void BinaryScalarCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
double alpha = nnvm::get<double>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> lhs = inputs[0].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(out, req[0], F<OP>(lhs, scalar<DType>(DType(alpha))));
});
}
template<typename xpu, typename OP>
void BinaryScalarBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
double alpha = nnvm::get<double>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(igrad, req[0], ograd*F<OP>(lhs, scalar<DType>(DType(alpha))));
});
}
#define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_