blob: 56b29f1911bd747119e854e25f6a893239262574 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file elemwise_binary_broadcast_op.h
* \brief Function definition of elementwise binary broadcast operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "./elemwise_binary_op.h"
#include "../operator_common.h"
#include "broadcast_reduce-inl.h"
namespace mxnet {
namespace op {
static inline bool BinaryBroadcastShapeCommon(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
mxnet::TShape& lhs = (*in_attrs)[0];
mxnet::TShape& rhs = (*in_attrs)[1];
// avoid pre-mature shape inference.
if (!mxnet::ndim_is_known(lhs) || !mxnet::ndim_is_known(rhs))
return false;
if (lhs == rhs) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, lhs);
return shape_is_known(lhs);
}
mxnet::TShape out(std::max(lhs.ndim(), rhs.ndim()), -1);
const int bl = out.ndim() - lhs.ndim();
const int br = out.ndim() - rhs.ndim();
for (int i = 0; i < out.ndim(); ++i) {
dim_t l = 1, r = 1;
if (i >= bl)
l = lhs[i - bl];
if (i >= br)
r = rhs[i - br];
if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r))
continue;
if (l != r) {
// Make it compatible with NumPy.
// For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can broadcast to (2, 0, 3).
CHECK(l == 1 || r == 1) << "operands could not be broadcast together with shapes " << lhs
<< " " << rhs;
out[i] = (l == 1 ? r : l);
} else {
out[i] = l;
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
return shape_is_known(lhs) && shape_is_known(rhs) && shape_is_known(out);
}
inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
return BinaryBroadcastShapeCommon(attrs, in_attrs, out_attrs);
}
inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int lhs_stype = in_attrs->at(0);
const int rhs_stype = in_attrs->at(1);
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
#if MXNET_USE_ONEDNN == 1
if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
#else
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
#endif // MXNET_USE_ONEDNN == 1
}
if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) {
dispatched =
storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int lhs_stype = in_attrs->at(0);
const int rhs_stype = in_attrs->at(1);
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
#if MXNET_USE_ONEDNN == 1
if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
#else
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
#endif // MXNET_USE_ONEDNN == 1
}
if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...) \
if (ndim <= 2) { \
const int NDim = 2; \
{ __VA_ARGS__ } \
} else if (ndim <= 4) { \
const int NDim = 4; \
{ __VA_ARGS__ } \
} else if (ndim <= broadcast::MAX_DIM) { \
const int NDim = broadcast::MAX_DIM; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "NDim too large "; \
}
inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape,
const mxnet::TShape& rshape,
const mxnet::TShape& oshape,
mxnet::TShape* new_lshape,
mxnet::TShape* new_rshape,
mxnet::TShape* new_oshape) {
if (lshape == rshape)
return 0;
const int odim = std::max(oshape.ndim(), broadcast::MAX_DIM);
*new_lshape = mxnet::TShape(odim, 1);
*new_rshape = mxnet::TShape(odim, 1);
*new_oshape = mxnet::TShape(odim, 1);
int bl = oshape.ndim() - lshape.ndim();
int br = oshape.ndim() - rshape.ndim();
int j = 0;
index_t lprod = 1, rprod = 1, oprod = 1;
for (int i = 0; i < oshape.ndim(); ++i) {
index_t l = 1, r = 1, o = oshape[i];
if (i >= bl)
l = lshape[i - bl];
if (i >= br)
r = rshape[i - br];
if ((lprod != rprod || l != r) && lprod * l > 1 && rprod * r > 1) {
(*new_lshape)[j] = lprod;
(*new_rshape)[j] = rprod;
(*new_oshape)[j] = oprod;
lprod = rprod = oprod = 1;
++j;
}
lprod *= l;
rprod *= r;
oprod *= o;
}
if (lprod > 1 || rprod > 1) {
(*new_lshape)[j] = lprod;
(*new_rshape)[j] = rprod;
(*new_oshape)[j] = oprod;
++j;
}
if (j <= broadcast::MAX_DIM) {
BROADCAST_NDIM_SWITCH(j, NDim, {
new_lshape->assign(new_lshape->begin(), new_lshape->begin() + NDim);
new_rshape->assign(new_rshape->begin(), new_rshape->begin() + NDim);
new_oshape->assign(new_oshape->begin(), new_oshape->begin() + NDim);
});
} else {
LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape;
}
return j;
}
template <typename xpu, typename OP>
void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (outputs[0].shape_.Size() == 0U)
return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(
inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::ComputeInt<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp)
return;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
if (outputs[0].type_flag_ == mshadow::kBool) {
LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type";
}
MXNET_INT_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::template LaunchEx(
s,
new_oshape.Size(),
req[0],
lstride,
rstride,
oshape,
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].dptr<DType>());
});
});
}
}
template <typename xpu, typename OP>
void BinaryBroadcastIntComputeWithBool(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (outputs[0].shape_.Size() == 0U)
return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(
inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::ComputeIntWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp)
return;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::template LaunchEx(
s,
new_oshape.Size(),
req[0],
lstride,
rstride,
oshape,
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].dptr<DType>());
});
});
}
}
template <typename xpu, typename OP>
void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (outputs[0].shape_.Size() == 0U)
return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(
inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp)
return;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
if (outputs[0].type_flag_ == mshadow::kBool) {
LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type";
}
MSHADOW_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
broadcast::BinaryBroadcastComputeImpl<NDim, DType, OP>(s,
req[0],
inputs[0].reshape(new_lshape),
inputs[1].reshape(new_rshape),
outputs[0].reshape(new_oshape));
});
});
}
}
#if MXNET_USE_CUDA
struct BinaryBroadcastRTCCompute {
std::string OP;
void operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);
};
struct BinaryBroadcastRTCBackwardUseNone {
std::string LOP;
std::string ROP;
void operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);
};
struct BinaryBroadcastRTCBackwardUseIn {
std::string LOP;
std::string ROP;
void operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);
};
#endif // MXNET_USE_CUDA
template <typename xpu, typename OP>
void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (outputs[0].shape_.Size() == 0U)
return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(
inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::ComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp)
return;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::template LaunchEx(
s,
new_oshape.Size(),
req[0],
lstride,
rstride,
oshape,
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].dptr<DType>());
});
});
}
}
template <typename xpu, typename OP>
void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (outputs[0].shape_.Size() == 0U)
return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];
int ndim = BinaryBroadcastShapeCompact(
lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
ElemwiseBinaryOp::ComputeLogic<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
if (req[0] == kNullOp)
return;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(lhs.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(rhs.type_flag_, EType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::template LaunchEx(
s,
new_oshape.Size(),
req[0],
lstride,
rstride,
oshape,
lhs.dptr<DType>(),
rhs.dptr<EType>(),
out.dptr<bool>());
});
});
});
}
}
template <typename xpu, typename OP>
void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
const NDArray& csr,
const NDArray& dns,
const OpReqType req,
const NDArray& output) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK(req != kAddTo && req != kWriteInplace);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
bool col_vec;
if (dns.shape().ndim() == 1) {
col_vec = false;
} else {
col_vec = (dns.shape()[0] == csr.shape()[0]) ? true : false;
}
if (csr.storage_initialized()) {
const nnvm::dim_t nnz = csr.storage_shape()[0];
const nnvm::dim_t num_rows = output.shape()[0];
output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
// broadcast_mul/div between csr and a scalar case
if ((dns.shape().ndim() == 2 && dns.shape()[0] == 1 && dns.shape()[1] == 1) ||
(dns.shape().ndim() == 1 && dns.shape()[0] == 1)) {
Kernel<csr_dns_csr_broadcast_kernel<req_type, OP, false>, xpu>::Launch(
s,
nnz,
csr.data().dptr<DType>(),
dns.data().dptr<DType>(),
output.data().dptr<DType>(),
nnz);
} else {
// broadcast_mul/div between csr and column vector
if (col_vec) {
Kernel<csr_dns_csr_broadcast_kernel<req_type, OP, true>, xpu>::Launch(
s,
num_rows,
csr.data().dptr<DType>(),
csr.aux_data(kIdx).dptr<CType>(),
csr.aux_data(kIndPtr).dptr<RType>(),
dns.data().dptr<DType>(),
output.data().dptr<DType>());
// broadcast_mul/div between csr and row vector
} else {
Kernel<csr_dns_csr_broadcast_kernel<req_type, OP, false>, xpu>::Launch(
s,
num_rows,
csr.data().dptr<DType>(),
csr.aux_data(kIdx).dptr<CType>(),
csr.aux_data(kIndPtr).dptr<RType>(),
dns.data().dptr<DType>(),
output.data().dptr<DType>());
}
}
Copy(output.aux_data(kIdx).FlatTo1D<xpu, CType>(),
csr.aux_data(kIdx).FlatTo1D<xpu, CType>(),
s);
Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
s);
});
});
});
});
// If input csr is an empty matrix, fill zeros and return
} else {
FillZerosCsrImpl(s, output);
return;
}
}
template <typename xpu, typename OP>
void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx,
const NDArray& csr,
const NDArray& dns,
const OpReqType req,
const NDArray& output,
const mxnet::TShape& new_csrshape,
const mxnet::TShape& new_dnsshape,
const mxnet::TShape& new_oshape,
const int ndim,
const bool reverse) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK(req == kWriteTo) << "Only kWriteTo supported for broadcast(csr, dns) = dns";
const bool legal_op =
std::is_same<OP, mshadow_op::plus>::value || std::is_same<OP, mshadow_op::minus>::value;
CHECK(legal_op) << "Only add/sub are supported for broadcast(csr, dns) = dns";
CHECK_EQ(csr.shape()[0], output.shape()[0]);
CHECK_EQ(csr.shape()[1], output.shape()[1]);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const nnvm::dim_t num_rows = output.shape()[0];
const nnvm::dim_t num_cols = output.shape()[1];
const TBlob& csr_data = csr.data();
const TBlob& csr_indices = csr.aux_data(kIdx);
const TBlob& csr_indptr = csr.aux_data(kIndPtr);
TBlob dns_data = dns.data();
TBlob out_data = output.data();
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
Shape<NDim> oshape = new_oshape.get<NDim>();
Shape<NDim> lstride = calc_stride(new_csrshape.get<NDim>());
Shape<NDim> rstride = calc_stride(new_dnsshape.get<NDim>());
if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
Kernel<binary_broadcast_kernel<NDim, mshadow_op::plus>, xpu>::template LaunchEx(
s,
new_oshape.Size(),
req,
lstride,
rstride,
oshape,
DType(0),
dns_data.dptr<DType>(),
out_data.dptr<DType>());
} else {
Kernel<binary_broadcast_kernel<NDim, OP>, xpu>::template LaunchEx(s,
new_oshape.Size(),
req,
lstride,
rstride,
oshape,
DType(0),
dns_data.dptr<DType>(),
out_data.dptr<DType>());
}
});
});
if (csr.storage_initialized()) {
MSHADOW_TYPE_SWITCH(csr.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIdx), CType, {
MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIndPtr), RType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
Kernel<csr_dns_map_kernel<req_type, mshadow_op::minus, true>, xpu>::Launch(
s,
num_rows,
csr_data.dptr<DType>(),
csr_indices.dptr<CType>(),
csr_indptr.dptr<RType>(),
out_data.dptr<DType>(),
num_rows,
num_cols);
} else {
Kernel<csr_dns_map_kernel<req_type, mshadow_op::plus>, xpu>::Launch(
s,
num_rows,
csr_data.dptr<DType>(),
csr_indices.dptr<CType>(),
csr_indptr.dptr<RType>(),
out_data.dptr<DType>(),
num_rows,
num_cols);
}
});
});
});
});
}
}
template <typename xpu, typename OP>
void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2)
<< "input dense matrix should have less than or equal to 2 dimensions";
if (req[0] == kNullOp)
return;
const NDArray& lhs = inputs[0];
const NDArray& rhs = inputs[1];
const NDArray& out = outputs[0];
const auto lhs_stype = lhs.storage_type();
const auto rhs_stype = rhs.storage_type();
const auto out_stype = out.storage_type();
// If the input is a matrix with the same shape, should be elemwise
if ((rhs.shape().ndim() != 1U) && (rhs.shape()[0] != 1) && (rhs.shape()[1] != 1)) {
if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) {
const bool supported_op = std::is_same<OP, mshadow_op::mul>::value;
CHECK(supported_op)
<< "Please use elemwise_div for division between csr and dense of the same shape";
ElemwiseBinaryOp::DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, rhs, lhs, req[0], out, true);
}
} else {
// broadcast(CSR, Dense(1D)) = CSR
if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) {
BinaryBroadcastCsrDnsCsrImpl<xpu, OP>(ctx, lhs, rhs, req[0], out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
}
template <typename xpu, typename OP>
void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2)
<< "input dense matrix should have less than or equal to 2 dimensions";
if (req[0] == kNullOp)
return;
const NDArray& lhs = inputs[0];
const NDArray& rhs = inputs[1];
const NDArray& out = outputs[0];
const auto lhs_stype = lhs.storage_type();
const auto rhs_stype = rhs.storage_type();
const auto out_stype = out.storage_type();
bool reverse = (lhs_stype == kDefaultStorage);
const NDArray& dns = (reverse) ? lhs : rhs;
const NDArray& csr = (reverse) ? rhs : lhs;
mxnet::TShape new_csrshape, new_dnsshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(
csr.shape(), dns.shape(), out.shape(), &new_csrshape, &new_dnsshape, &new_oshape);
if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
out_stype == kDefaultStorage) {
// If the input is a matrix with the same shape, should be elemwise
if (!ndim) {
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
ElemwiseBinaryOp::DnsCsrDnsOp<OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], !reverse);
} else {
// broadcast(CSR, Dense(1D)) = CSR
BinaryBroadcastCsrDnsDnsImpl<xpu, OP>(
ctx, csr, dns, req[0], out, new_csrshape, new_dnsshape, new_oshape, ndim, reverse);
}
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
template <typename xpu, typename LOP, typename ROP>
inline typename std::enable_if<std::is_same<xpu, cpu>::value, void>::type
BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace broadcast;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_,
outputs[1].shape_,
inputs[0].shape_,
&new_lshape,
&new_rshape,
&new_oshape);
if (!ndim) {
ElemwiseBinaryOp::BackwardUseNone<cpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Stream<cpu>* s = ctx.get_stream<cpu>();
const TBlob lhs = outputs[0].reshape(new_lshape);
const TBlob rhs = outputs[1].reshape(new_rshape);
const TBlob out = inputs[0].reshape(new_oshape);
BROADCAST_NDIM_SWITCH(ndim, NDim, {
// Request temporary storage
size_t workspace_size = new_oshape.Size();
Tensor<cpu, 1, char> workspace = ctx.requested[0].get_space_typed<cpu, 1, char>(
Shape1(workspace_size * sizeof(index_t)), s);
ReduceWithExtraMem<red::sum, NDim, DType, LOP>(s, lhs, req[0], workspace, out);
ReduceWithExtraMem<red::sum, NDim, DType, ROP>(s, rhs, req[1], workspace, out);
});
});
}
}
template <typename xpu, int ndim, typename DType, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mshadow::Tensor<xpu, 1, char>& workspace,
const mxnet::TShape& new_lshape,
const mxnet::TShape& new_rshape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu>* s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
if (ograd.Size() != 0) {
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(
s, lgrad, req[0], workspace, ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(
s, rgrad, req[1], workspace, ograd, lhs, rhs);
}
}
template <typename xpu, int ndim, typename DType, typename LOP, typename ROP>
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_lshape,
const mxnet::TShape& new_rshape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu>* s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
size_t workspace_size_l =
ReduceWorkspaceSize(s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size_r =
ReduceWorkspaceSize(s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(
s, lgrad, req[0], workspace, ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(
s, rgrad, req[1], workspace, ograd, lhs, rhs);
}
template <typename xpu, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
mxnet::TShape new_lshape, new_rshape, new_oshape;
const bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_,
outputs[1].shape_,
inputs[0].shape_,
&new_lshape,
&new_rshape,
&new_oshape) != 0;
if (!need_bc) {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImpl<xpu, NDim, DType, LOP, ROP>(
ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
});
});
}
}
#if MXNET_USE_ONEDNN == 1
template <dnnl::algorithm alg>
void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
#endif // MXNET_USE_ONEDNN == 1
#define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_