blob: 1b956cf8618f0d552b9d66456b26ab15c0fea743 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file elemwise_binary_op.h
* \brief Function definition of elementwise binary operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <vector>
#include <string>
#include <utility>
#include <typeinfo>
#include <algorithm>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "elemwise_unary_op.h"
#include "../../common/utils.h"
namespace mxnet {
namespace op {
/*! Gather binary operator functions into ElemwiseBinaryOp class */
class ElemwiseBinaryOp : public OpBase {
public:
template<typename OP, int Req>
struct BackwardUseNoneOp {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *igrad, const DType *ograd) {
KERNEL_ASSIGN(igrad[i], Req, OP::Map(ograd[i]));
}
};
template<typename OP, int Req>
struct BackwardUseInOp {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *igrad,
const DType *ograd, const DType *lhs, const DType *rhs) {
KERNEL_ASSIGN(igrad[i], Req, ograd[i] * OP::Map(lhs[i], rhs[i]));
}
};
/*! \brief For sparse, assume missing rvalue is 0 */
template<typename OP, int Req>
struct MissingRValueOp {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs) {
KERNEL_ASSIGN(out[i], Req, OP::Map(lhs[i], DType(0)));
}
};
/*! \brief For sparse, assume missing lvalue is 0 */
template<typename OP, int Req>
struct MissingLValueOp {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *rhs) {
KERNEL_ASSIGN(out[i], Req, OP::Map(DType(0), rhs[i]));
}
};
private:
/*!
* \brief CSR operation requires temp space
*/
enum ResourceRequestType {
kTempSpace
};
/*! \brief Fill contiguous dense output rows with value computed from 0 lhs and 0 rhs input */
template<typename xpu, typename DType, typename OP>
static inline size_t FillDense(mshadow::Stream<xpu> *s,
const size_t idx_l,
const size_t idx_r,
const OpReqType req,
mshadow::Tensor<xpu, 2, DType> *out,
const size_t iter_out) {
using namespace mshadow::expr;
const int index_out_min = std::min(idx_l, idx_r);
if (static_cast<size_t>(index_out_min) > iter_out) {
const size_t size = (*out)[iter_out].shape_.Size();
const DType zero_input_val = OP::Map(DType(0), DType(0));
#pragma omp parallel for
for (int i = iter_out; i < index_out_min; ++i) {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<SetToScalar<Req>, xpu>::Launch(s, size, (*out)[i].dptr_,
zero_input_val);
});
}
}
return index_out_min;
}
static inline bool IsSameArray(const NDArray& a1, const NDArray& a2) {
return a1.var() == a2.var();
}
/*! \brief Minimum of three */
static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) {
return a < b ? (a < c ? a : c) : (b < c ? b : c);
}
template<typename xpu, typename LOP, typename ROP, typename DType>
static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const int size = static_cast<int>((outputs[0].Size() + DataType<DType>::kLanes - 1)
/ DataType<DType>::kLanes);
const DType *ograd_dptr = inputs[0].dptr<DType>();
if (std::is_same<LOP, mshadow_op::identity>::value && req[0] == kWriteInplace) {
CHECK_EQ(ograd_dptr, outputs[0].dptr<DType>());
} else if (req[0] != kNullOp) {
DType *lgrad_dptr = outputs[0].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
Kernel<BackwardUseNoneOp<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
});
}
if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == kWriteInplace) {
CHECK_EQ(ograd_dptr, outputs[1].dptr<DType>());
} else if (req[1] != kNullOp) {
DType *rgrad_dptr = outputs[1].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
Kernel<BackwardUseNoneOp<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
});
}
}
template<typename xpu, typename LOP, typename ROP, typename DType>
static void BackwardUseIn_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
DCHECK_EQ(outputs.size(), 2U);
DCHECK_EQ(inputs.size(), 3U);
mxnet_op::Stream<xpu> *s = ctx.get_stream<xpu>();
const DType *ograd_dptr = inputs[0].dptr<DType>();
const DType *lhs_dptr = inputs[1].dptr<DType>();
const DType *rhs_dptr = inputs[2].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
const int size = static_cast<int>(
(outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
/ mxnet_op::DataType<DType>::kLanes);
DType * lgrad_dptr = outputs[0].dptr<DType>();
mxnet_op::Kernel<BackwardUseInOp<LOP, Req>, xpu>::Launch(
s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
const int size = static_cast<int>(
(outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
/ mxnet_op::DataType<DType>::kLanes);
DType * rgrad_dptr = outputs[1].dptr<DType>();
mxnet_op::Kernel<BackwardUseInOp<ROP, Req>, xpu>::Launch(
s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
}
template<
typename xpu,
typename LOP,
typename ROP,
typename DType,
bool in0_ok_dense = false,
bool in1_ok_dense = false,
bool in2_ok_dense = false,
typename BackupCompute>
static inline void BackwardUseInEx_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs,
BackupCompute backup_compute) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// lhs grad
if (req[0] != kNullOp) {
// RspRspOp can handle dense outputs so long as OP(0, 0) == 0
MSHADOW_IDX_TYPE_SWITCH(inputs[1].aux_type(rowsparse::kIdx), IType, {
RspRspOp<DType, IType, LOP>(
s, attrs, ctx, inputs[1], inputs[2], req[0], outputs[0],
false, false, false, false);
});
// lhs in-place
MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
RspRspOp<DType, IType, mshadow::op::mul>(
s, attrs, ctx, outputs[0], inputs[0], req[0], outputs[0],
false, false, true, false);
});
}
// rhs grad
if (req[1] != kNullOp) {
MSHADOW_IDX_TYPE_SWITCH(inputs[1].aux_type(rowsparse::kIdx), IType, {
RspRspOp<DType, IType, ROP>(
s, attrs, ctx, inputs[1], inputs[2], req[1], outputs[1],
false, false, false, false);
});
// rhs in-place
MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
RspRspOp<DType, IType, mshadow::op::mul>(
s, attrs, ctx, inputs[0], outputs[1], req[1], outputs[1],
false, false, true, false);
});
}
}
protected:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
template<typename DType, typename IType, typename OP>
static void RspRspOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
bool lhs_may_be_dense,
bool rhs_may_be_dense,
bool allow_inplace,
bool scatter);
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename DType, typename IType, typename CType, typename OP>
static inline void CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
public:
/*!
* \brief Rsp-op-Rsp operation which produces a dense result
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
static bool SparseSparseWithDenseResult(const nnvm::NodeAttrs& attrs,
int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
/*!
* \brief Allow one of the inputs to be dense and still produce a sparse output
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
static bool AllowLRDenseInputWithSparseOutputStorageType(const nnvm::NodeAttrs& attrs,
int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
/*!
* \brief Backward pass computing input gradient using forward inputs
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
static bool BackwardUseInStorageType(const nnvm::NodeAttrs& attrs,
int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
template<typename xpu, typename OP>
static void Compute(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] != kNullOp) {
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
});
});
}
}
template<typename xpu, typename OP>
static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] != kNullOp) {
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
});
});
}
}
template<typename xpu, typename OP>
static void ComputeEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 1);
if (req[0] == kNullOp) return;
const auto lhs_stype = inputs[0].storage_type();
const auto out_stype = outputs[0].storage_type();
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if ((common::ContainsOnlyStorage(inputs, kRowSparseStorage))
&& (out_stype == kRowSparseStorage || out_stype == kDefaultStorage)) {
// rsp, rsp -> rsp
// rsp, rsp -> dns
const int rsp_input_idx = lhs_stype == kRowSparseStorage ? 0 : 1;
MSHADOW_IDX_TYPE_SWITCH(inputs[rsp_input_idx].aux_type(rowsparse::kIdx), IType, {
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
RspRspOp<DType, IType, OP>(
s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0], false, false, false, false);
});
});
} else if (common::ContainsOnlyStorage(inputs, kCSRStorage) && out_stype == kCSRStorage) {
// csr, csr -> csr
MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(csr::kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(csr::kIndPtr), CType, {
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
CsrCsrOp<DType, IType, CType, OP>(
s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0]);
});
});
});
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
}
/*! \brief ComputeEx allowing dense lvalue and/or rvalue */
template<typename xpu, typename OP, bool lhs_may_be_dense, bool rhs_may_be_dense>
static void ComputeDnsLRValueEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 1);
if (req[0] == kNullOp) return;
const auto lhs_stype = inputs[0].storage_type();
const auto rhs_stype = inputs[1].storage_type();
const auto out_stype = outputs[0].storage_type();
if ((out_stype == kRowSparseStorage || out_stype == kDefaultStorage) &&
((lhs_stype == kRowSparseStorage && rhs_stype == kRowSparseStorage) ||
(lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage)) &&
lhs_may_be_dense && rhs_may_be_dense) {
// rsp, rsp -> rsp
// rsp, rsp -> dns
// rsp, dns -> rsp
// dns, rsp -> rsp
// More than once dense not allowed (this will be checked in RspRspOp):
// rsp, dns -> dns <-- NOT ALLOWED
// dns, rsp -> dns <-- NOT ALLOWED
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), IType, {
RspRspOp<DType, IType, OP>(
s, attrs, ctx, inputs[0], inputs[1],
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, false);
});
});
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseNone(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BackwardUseNone_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseNoneWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
BackwardUseNone_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseNoneEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 1U); // output grad
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
const auto in_stype = inputs[0].storage_type();
const auto lhs_stype = outputs[0].storage_type();
const auto rhs_stype = outputs[1].storage_type();
// lhs grad
if (req[0] != kNullOp) {
if (in_stype == lhs_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
CHECK_EQ(outputs[0].storage_type(), in_stype);
// rsp -> rsp, _. op requires 0-input returns 0-output
DCHECK_LT(fabs(static_cast<float>(LOP::Map(0))), 1e-5f);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<LOP, Req>>(attrs, ctx, inputs,
req, {outputs[0]});
});
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
}
// rhs grad
if (req[1] != kNullOp) {
if (in_stype == rhs_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
CHECK_EQ(outputs[0].storage_type(), in_stype);
// rsp -> _, rsp. op requires 0-input returns 0-output
DCHECK_LT(fabs(static_cast<float>(ROP::Map(0))), 1e-5f);
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
UnaryOp::KernelComputeEx<xpu, BackwardUseNoneOp<ROP, Req>>(attrs, ctx, inputs,
req, {outputs[1]});
});
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
}
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseIn(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BackwardUseIn_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseInWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
BackwardUseIn_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<
typename xpu, typename LOP, typename ROP,
bool in0_ok_dense = false, bool in1_ok_dense = false, bool in2_ok_dense = false>
static inline void BackwardUseInEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace common;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
const auto lhs_grad_stype = outputs[0].storage_type();
const auto rhs_grad_stype = outputs[1].storage_type();
if (ContainsOnlyStorage(inputs, kRowSparseStorage) &&
(lhs_grad_stype == kDefaultStorage || lhs_grad_stype == kRowSparseStorage) &&
(rhs_grad_stype == kDefaultStorage || rhs_grad_stype == kRowSparseStorage)) {
// rsp, rsp, rsp -> [dns, rsp], [dns, rsp]
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
BackwardUseInEx_<xpu, LOP, ROP, DType, in0_ok_dense, in1_ok_dense, in2_ok_dense>(
attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
});
}
}
}; // class ElemwiseBinaryOp
/*! \brief Binary launch */
#define MXNET_OPERATOR_REGISTER_BINARY(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "first input") \
.add_argument("rhs", "NDArray-or-Symbol", "second input")
/*! \brief Binary launch, with FComputeEx for csr and rsp available */
#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", \
ElemwiseStorageType<2, 1, true, true, true>) \
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>) \
.set_attr<FResourceRequest>("FResourceRequest", /* For Sparse CSR */ \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};})
/*! \brief Binary launch, dense result
* FInferStorageType attr is not set using this macro.
* By default DefaultStorageType is used.
*/
#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", \
ElemwiseBinaryOp::SparseSparseWithDenseResult) \
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>)
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_