blob: dce9156ce666d6652d3d0a59658a6e295ae16bdb [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file elemwise_binary_op.h
* \brief Function defintion of elementwise binary operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include <utility>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
template<typename xpu, typename OP>
void BinaryCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> lhs = inputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> rhs = inputs[1].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(out, req[0], F<OP>(lhs, rhs));
});
}
template<typename xpu, typename LOP, typename ROP>
void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> lgrad = outputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> rgrad = outputs[1].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(lgrad, req[0], F<LOP>(ograd));
ASSIGN_DISPATCH(rgrad, req[1], F<ROP>(ograd));
});
}
template<typename xpu, typename LOP, typename ROP>
void BinaryBackwardUseOut(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> lgrad = outputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> rgrad = outputs[1].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> out = inputs[1].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(lgrad, req[0], ograd*F<LOP>(out));
ASSIGN_DISPATCH(rgrad, req[1], ograd*F<ROP>(out));
});
}
template<typename xpu, typename LOP, typename ROP>
void BinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> lgrad = outputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> rgrad = outputs[1].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> ograd = inputs[0].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> lhs = inputs[1].FlatTo1D<xpu, DType>(s);
Tensor<xpu, 1, DType> rhs = inputs[2].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(lgrad, req[0], ograd*F<LOP>(lhs, rhs));
ASSIGN_DISPATCH(rgrad, req[1], ograd*F<ROP>(lhs, rhs));
});
}
#define MXNET_OPERATOR_REGISTER_BINARY(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "ndarray-or-symbol", "first input") \
.add_argument("rhs", "ndarray-or-symbol", "second input")
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_