blob: 9c1d8b17fdead38952ecd5896b3c7cf2c0d808e9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file elemwise_binary_op.h
* \brief Function definition of elementwise binary operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <vector>
#include <string>
#include <utility>
#include <typeinfo>
#include <algorithm>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "../../engine/openmp.h"
#include "elemwise_unary_op.h"
#include "../../common/utils.h"
#include "./init_op.h"
namespace mxnet {
namespace op {
/*! Gather binary operator functions into ElemwiseBinaryOp class */
class ElemwiseBinaryOp : public OpBase {
public:
/*! \brief For sparse, assume missing rvalue is 0 */
template<typename OP, int Req>
struct MissingRValueOp {
typedef OP Operation;
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs) {
KERNEL_ASSIGN(out[i], Req, OP::Map(lhs[i], DType(0)));
}
};
/*! \brief For sparse, assume missing lvalue is 0 */
template<typename OP, int Req>
struct MissingLValueOp {
typedef OP Operation;
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *rhs) {
KERNEL_ASSIGN(out[i], Req, OP::Map(DType(0), rhs[i]));
}
};
private:
/*!
* \brief CSR operation requires temp space
*/
enum ResourceRequestType {
kTempSpace
};
/*!
* \brief Fill contiguous dense output rows with value computed from 0 lhs and 0 rhs input
* CPU-Only version
*/
template<typename DType, typename OP, typename xpu>
static inline size_t FillDense(mshadow::Stream<xpu> *s,
const size_t idx_l,
const size_t idx_r,
const OpReqType req,
mshadow::Tensor<xpu, 2, DType> *out,
const size_t iter_out) {
const int index_out_min = static_cast<int>(std::min(idx_l, idx_r));
if (static_cast<size_t>(index_out_min) > iter_out) {
const DType zero_input_val = OP::Map(DType(0), DType(0));
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int i = static_cast<int>(iter_out); i < index_out_min; ++i) {
Fill<false>(s, (*out)[i], req, zero_input_val);
}
}
return static_cast<size_t>(index_out_min); // MSVC wants OMP loops to always use 'int'
}
static inline bool IsSameArray(const NDArray& a1, const NDArray& a2) {
return a1.var() == a2.var();
}
/*! \brief Minimum of three */
static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) {
return a < b ? (a < c ? a : c) : (b < c ? b : c);
}
template<typename xpu, typename LOP, typename ROP, typename DType>
static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const int size = static_cast<int>((outputs[0].Size() + DataType<DType>::kLanes - 1)
/ DataType<DType>::kLanes);
const DType *ograd_dptr = inputs[0].dptr<DType>();
if (std::is_same<LOP, mshadow_op::identity>::value && req[0] == kWriteInplace) {
CHECK_EQ(ograd_dptr, outputs[0].dptr<DType>());
} else if (req[0] != kNullOp) {
DType *lgrad_dptr = outputs[0].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr);
});
}
if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == kWriteInplace) {
CHECK_EQ(ograd_dptr, outputs[1].dptr<DType>());
} else if (req[1] != kNullOp) {
DType *rgrad_dptr = outputs[1].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr);
});
}
}
template<typename xpu, typename LOP, typename ROP, typename DType>
static void BackwardUseIn_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
DCHECK_EQ(outputs.size(), 2U);
DCHECK_EQ(inputs.size(), 3U);
mxnet_op::Stream<xpu> *s = ctx.get_stream<xpu>();
const DType *ograd_dptr = inputs[0].dptr<DType>();
const DType *lhs_dptr = inputs[1].dptr<DType>();
const DType *rhs_dptr = inputs[2].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
const int size = static_cast<int>(
(outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
/ mxnet_op::DataType<DType>::kLanes);
DType * lgrad_dptr = outputs[0].dptr<DType>();
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<LOP>, Req>, xpu>::Launch(
s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
const int size = static_cast<int>(
(outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
/ mxnet_op::DataType<DType>::kLanes);
DType * rgrad_dptr = outputs[1].dptr<DType>();
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<ROP>, Req>, xpu>::Launch(
s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
}
template<
typename xpu,
typename LOP,
typename ROP,
bool in0_ok_dense = false,
bool in1_ok_dense = false,
bool in2_ok_dense = false,
typename BackupCompute>
static inline void RspRspOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs,
BackupCompute backup_compute) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// lhs grad
if (req[0] != kNullOp) {
// RspRspOp can handle dense outputs so long as OP(0, 0) == 0
RspRspOp<LOP>(
s, attrs, ctx, inputs[1], inputs[2], req[0], outputs[0],
false, false, false, false);
// lhs in-place
RspRspOp<op::mshadow_op::mul>(
s, attrs, ctx, outputs[0], inputs[0], req[0], outputs[0],
false, false, true, false);
}
// rhs grad
if (req[1] != kNullOp) {
RspRspOp<ROP>(
s, attrs, ctx, inputs[1], inputs[2], req[1], outputs[1],
false, false, false, false);
// rhs in-place
RspRspOp<op::mshadow_op::mul>(
s, attrs, ctx, inputs[0], outputs[1], req[1], outputs[1],
false, false, true, false);
}
}
template<typename xpu, typename LOP, typename ROP>
static inline void DnsCsrCsrOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const bool supported_ops = std::is_same<mshadow_op::right, LOP>::value &&
std::is_same<mshadow_op::left, ROP>::value;
CHECK(supported_ops)
<< "Only backward for mul is supported (LOP should be right, ROP should be left)";
const NDArray& out_grad = inputs[0];
const NDArray& lhs_in = inputs[1];
const NDArray& rhs_in = inputs[2];
const NDArray& lhs_grad = outputs[0];
const NDArray& rhs_grad = outputs[1];
const bool reverse = (outputs[0].storage_type() == kCSRStorage);
if (reverse) {
DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, rhs_in, req[0], lhs_grad, false);
Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(), lhs_in.data()}, {req[1]},
{rhs_grad.data()});
} else {
DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, lhs_in, req[1], rhs_grad, false);
Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(), rhs_in.data()}, {req[0]},
{lhs_grad.data()});
}
}
public:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
template<typename OP>
static void RspRspOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
bool lhs_may_be_dense,
bool rhs_may_be_dense,
bool allow_inplace,
bool scatter);
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
template<typename OP>
static void RspRspOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
bool lhs_may_be_dense,
bool rhs_may_be_dense,
bool allow_inplace,
bool scatter);
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void DnsCsrDnsOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void DnsCsrDnsOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
public:
/*!
* \brief Rsp-op-Rsp operation which produces a dense result
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
static bool SparseSparseWithDenseResult(const nnvm::NodeAttrs& attrs,
int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
/*!
* \brief Allow one of the binary inputs to be dense and still produce a sparse output.
* Typically used for sparse * dense = sparse.
* Note: for csr, it dispatches to fallback other than csr, csr -> csr
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
static bool PreferSparseStorageType(const nnvm::NodeAttrs& attrs,
int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace common;
CHECK_EQ(in_attrs->size(), 2U) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), 1U) << " in operator " << attrs.name;
const auto& lhs_stype = in_attrs->at(0);
const auto& rhs_stype = in_attrs->at(1);
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
DispatchMode::kFComputeEx;
if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && ContainsOnlyStorage(*in_attrs, kRowSparseStorage)) {
// rsp, rsp -> rsp
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched && ContainsOnlyStorage(*in_attrs, kCSRStorage)) {
// csr, csr -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched &&
((lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage))) {
// rsp, dns -> rsp
// dns, rsp -> rsp
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched &&
((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
// csr, dns -> csr
// dns, csr -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
/*!
* \brief Allow one of the inputs to be dense and produce a dense output,
* for rsp inputs only support when both inputs are rsp type.
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
template<bool cpu_only, bool rsp, bool csr>
static bool PreferDenseStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace common;
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
const auto lhs_stype = (*in_attrs)[0];
const auto rhs_stype = (*in_attrs)[1];
bool dispatched = false;
const bool invalid_ctx = cpu_only && dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
DispatchMode::kFComputeEx;
if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
// dns, dns ... -> dns
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && rsp && ContainsOnlyStorage(*in_attrs, kRowSparseStorage)) {
// rsp, rsp, ... -> rsp
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && csr && ContainsOnlyStorage(*in_attrs, kCSRStorage)) {
// csr, csr, ... -> csr
dispatched = storage_type_assign(out_attrs, kCSRStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched && ((lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage) ||
(lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage))) {
// dense, csr -> dense / csr, dense -> dense
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && ((lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage) ||
(lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage))) {
// dense, rsp -> dense / rsp, dense -> dense
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
}
return true;
}
/*!
* \brief Backward pass computing input gradient using forward inputs
* \param attrs Attributes
* \param dev_mask Device mask
* \param dispatch_mode Dispatch Mode
* \param in_attrs Input storage attributes
* \param out_attrs Output storage attributes
* \return true if handled
*/
static bool BackwardUseInStorageType(const nnvm::NodeAttrs& attrs,
int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
template<typename xpu, typename OP>
static void Compute(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] != kNullOp) {
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
}
});
});
}
}
template<typename xpu, typename OP>
static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] != kNullOp) {
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
}
});
});
}
}
template<typename xpu, typename OP>
static void ComputeEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace common;
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 1);
if (req[0] == kNullOp) return;
const auto lhs_stype = inputs[0].storage_type();
const auto rhs_stype = inputs[1].storage_type();
const auto out_stype = outputs[0].storage_type();
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if ((ContainsOnlyStorage(inputs, kRowSparseStorage)) &&
(out_stype == kRowSparseStorage || out_stype == kDefaultStorage)) {
// rsp, rsp -> rsp
// rsp, rsp -> dns
RspRspOp<OP>(
s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0], false, false, false, false);
} else if (ContainsOnlyStorage(inputs, kCSRStorage) && out_stype == kCSRStorage) {
// csr, csr -> csr
CsrCsrOp<OP>(s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0]);
} else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
out_stype == kDefaultStorage) {
const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1];
const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
const bool reverse = (lhs_stype == kCSRStorage);
DnsCsrDnsOp<OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], reverse);
} else if (((lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage)) &&
out_stype == kDefaultStorage) {
const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1];
const bool reverse = (lhs_stype == kRowSparseStorage);
const NDArray& rsp = (reverse)? inputs[0] : inputs[1];
DnsRspDnsOp<xpu, OP>(s, attrs, ctx, dns, rsp, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
/*! \brief ComputeEx allowing dense lvalue and/or rvalue */
template<typename xpu, typename OP, bool lhs_may_be_dense, bool rhs_may_be_dense>
static void ComputeDnsLRValueEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 1);
if (req[0] == kNullOp) return;
const auto lhs_stype = inputs[0].storage_type();
const auto rhs_stype = inputs[1].storage_type();
const auto out_stype = outputs[0].storage_type();
if ((out_stype == kRowSparseStorage || out_stype == kDefaultStorage) &&
((lhs_stype == kRowSparseStorage && rhs_stype == kRowSparseStorage) ||
(lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage)) &&
lhs_may_be_dense && rhs_may_be_dense) {
// rsp, rsp -> rsp
// rsp, rsp -> dns
// rsp, dns -> rsp
// dns, rsp -> rsp
// More than once dense not allowed (this will be checked in RspRspOp):
// rsp, dns -> dns <-- NOT ALLOWED
// dns, rsp -> dns <-- NOT ALLOWED
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
RspRspOp<OP>(
s, attrs, ctx, inputs[0], inputs[1],
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, false);
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
out_stype == kCSRStorage) {
const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1];
const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
const bool reverse = (lhs_stype == kCSRStorage);
DnsCsrCsrOp<xpu, OP>(attrs, ctx, dns, csr, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseNone(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BackwardUseNone_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseNoneWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
BackwardUseNone_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseNoneEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 1U); // output grad
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
const auto in_stype = inputs[0].storage_type();
const auto lhs_stype = outputs[0].storage_type();
const auto rhs_stype = outputs[1].storage_type();
// lhs grad
if (req[0] != kNullOp) {
if (in_stype == lhs_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
CHECK_EQ(outputs[0].storage_type(), in_stype);
// rsp -> rsp, _. op requires 0-input returns 0-output
DCHECK_LT(std::fabs(static_cast<float>(LOP::Map(0))), 1e-5f);
UnaryOp::ComputeEx<xpu, LOP>(attrs, ctx, inputs, req, {outputs[0]});
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
// rhs grad
if (req[1] != kNullOp) {
if (in_stype == rhs_stype && (in_stype == kRowSparseStorage || in_stype == kCSRStorage)) {
CHECK_EQ(outputs[0].storage_type(), in_stype);
// rsp -> _, rsp. op requires 0-input returns 0-output
DCHECK_LT(std::fabs(static_cast<float>(ROP::Map(0))), 1e-5f);
UnaryOp::ComputeEx<xpu, ROP>(attrs, ctx, inputs, req, {outputs[1]});
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseIn(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BackwardUseIn_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu, typename LOP, typename ROP>
static inline void BackwardUseInWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
BackwardUseIn_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<
typename xpu, typename LOP, typename ROP,
bool in0_ok_dense = false, bool in1_ok_dense = false, bool in2_ok_dense = false>
static inline void BackwardUseInEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace common;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
const auto out_grad_stype = inputs[0].storage_type();
const auto lhs_grad_stype = outputs[0].storage_type();
const auto rhs_grad_stype = outputs[1].storage_type();
if (ContainsOnlyStorage(inputs, kRowSparseStorage) &&
(lhs_grad_stype == kDefaultStorage || lhs_grad_stype == kRowSparseStorage) &&
(rhs_grad_stype == kDefaultStorage || rhs_grad_stype == kRowSparseStorage)) {
// rsp, rsp, rsp -> [dns, rsp], [dns, rsp]
RspRspOpBackward<xpu, LOP, ROP, in0_ok_dense, in1_ok_dense, in2_ok_dense>(
attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
}
if (((lhs_grad_stype == kDefaultStorage && rhs_grad_stype == kCSRStorage) ||
(lhs_grad_stype == kCSRStorage && rhs_grad_stype == kDefaultStorage)) &&
out_grad_stype == kDefaultStorage) {
// dns, csr, dns -> [csr, dns] / csr, dns, dns -> [dns, csr]
DnsCsrCsrOpBackward<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
}
}
}; // class ElemwiseBinaryOp
/*! \brief Binary launch */
#define MXNET_OPERATOR_REGISTER_BINARY(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "first input") \
.add_argument("rhs", "NDArray-or-Symbol", "second input")
/*! \brief Binary launch, with FComputeEx for csr and rsp available */
#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", \
ElemwiseStorageType<2, 1, true, true, true>) \
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>) \
.set_attr<FResourceRequest>("FResourceRequest", /* For Sparse CSR */ \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};})
/*! \brief Binary launch, with FComputeEx for csr and rsp available.
when inputs contain both sparse and dense, sparse output is preferred. */
#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_PS(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", \
ElemwiseBinaryOp::PreferSparseStorageType) \
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>) \
.set_attr<FResourceRequest>("FResourceRequest", /* For Sparse CSR */ \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};})
/*! \brief Binary launch, dense result
* FInferStorageType attr is not set using this macro.
* By default DefaultStorageType is used.
*/
#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", \
ElemwiseBinaryOp::SparseSparseWithDenseResult) \
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>)
/*! \brief Binary launch, with FComputeEx for prefer dense */
#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_PD(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", \
ElemwiseBinaryOp::PreferDenseStorageType<true, true, true>) \
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>) \
.set_attr<FResourceRequest>("FResourceRequest", /* For Sparse CSR */ \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};})
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_