blob: be0d27d1e3348318a8faffb2fa1c0e00e70df7c5 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file elementwise_binary_broadcast_op.h
* \brief Function definition of elementwise unary operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_
#include <mxnet/operator_util.h>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "./elemwise_binary_op.h"
#include "../operator_common.h"
#include "broadcast_reduce-inl.h"
namespace mxnet {
namespace op {
inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
TShape& lhs = (*in_attrs)[0];
TShape& rhs = (*in_attrs)[1];
// avoid pre-mature shape inference.
if (lhs.ndim() == 0 || rhs.ndim() == 0) return false;
if (lhs == rhs) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, lhs);
return true;
}
TShape out(std::max(lhs.ndim(), rhs.ndim()));
index_t bl = out.ndim() - lhs.ndim();
index_t br = out.ndim() - rhs.ndim();
for (index_t i = 0; i < out.ndim(); ++i) {
index_t l = 1, r = 1;
if (i >= bl) l = lhs[i-bl];
if (i >= br) r = rhs[i-br];
if (l != r) {
if (l == 0 || r == 0) {
out[i] = 0;
} else {
CHECK(l == 1 || r == 1)
<< "operands could not be broadcast together with shapes " << lhs << " " << rhs;
out[i] = std::max(l, r);
}
} else {
out[i] = l;
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
return true;
}
#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...) \
if (ndim <= 2) { \
const int NDim = 2; \
{__VA_ARGS__} \
} else if (ndim <= 4) { \
const int NDim = 4; \
{__VA_ARGS__} \
} else if (ndim <= broadcast::MAX_DIM) { \
const int NDim = broadcast::MAX_DIM; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "NDim too large "; \
}
inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshape,
const TShape& oshape, TShape *new_lshape,
TShape *new_rshape, TShape *new_oshape) {
if (lshape == rshape) return 0;
index_t odim = std::max<index_t>(oshape.ndim(), broadcast::MAX_DIM);
*new_lshape = TShape(odim);
*new_rshape = TShape(odim);
*new_oshape = TShape(odim);
index_t bl = oshape.ndim() - lshape.ndim();
index_t br = oshape.ndim() - rshape.ndim();
index_t j = 0, lprod = 1, rprod = 1, oprod = 1;
for (index_t i = 0; i < oshape.ndim(); ++i) {
index_t l = 1, r = 1, o = oshape[i];
if (i >= bl) l = lshape[i-bl];
if (i >= br) r = rshape[i-br];
if ((lprod != rprod || l != r) &&
lprod*l > 1 && rprod*r > 1) {
(*new_lshape)[j] = lprod;
(*new_rshape)[j] = rprod;
(*new_oshape)[j] = oprod;
lprod = rprod = oprod = 1; ++j;
}
lprod *= l;
rprod *= r;
oprod *= o;
}
if (lprod > 1 || rprod > 1) {
(*new_lshape)[j] = lprod;
(*new_rshape)[j] = rprod;
(*new_oshape)[j] = oprod;
++j;
}
if (j <= broadcast::MAX_DIM) {
BROADCAST_NDIM_SWITCH(j, NDim, {
new_lshape->assign(&(*new_lshape)[0], &(*new_lshape)[NDim]);
new_rshape->assign(&(*new_rshape)[0], &(*new_rshape)[NDim]);
new_oshape->assign(&(*new_oshape)[0], &(*new_oshape)[NDim]);
});
} else {
LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape;
}
return j;
}
template<typename xpu, typename OP>
void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace broadcast;
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
BinaryCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
BinaryBroadcastComputeImpl<NDim, DType, OP>(s, req[0], inputs[0].reshape(new_lshape),
inputs[1].reshape(new_rshape), outputs[0].reshape(new_oshape));
});
});
}
}
template<typename xpu, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace broadcast;
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!ndim) {
BinaryBackwardUseNone<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob lhs = outputs[0].reshape(new_lshape);
const TBlob rhs = outputs[1].reshape(new_rshape);
const TBlob out = inputs[0].reshape(new_oshape);
BROADCAST_NDIM_SWITCH(ndim, NDim, {
// Request temporary storage
size_t workspace_size_l = ReduceWorkspaceSize<NDim, DType>(s, lhs, req[0], out);
size_t workspace_size_r = ReduceWorkspaceSize<NDim, DType>(s, rhs, req[1], out);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, NDim, DType, LOP>(s, lhs, req[0], workspace, out);
Reduce<red::sum, NDim, DType, ROP>(s, rhs, req[1], workspace, out);
});
});
}
}
template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const TShape& new_lshape,
const TShape& new_rshape,
const TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
size_t workspace_size_l = ReduceWorkspaceSize<ndim, DType>(s, lgrad, req[0], ograd, lhs, rhs);
size_t workspace_size_r = ReduceWorkspaceSize<ndim, DType>(s, rgrad, req[1], ograd, lhs, rhs);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, mshadow::op::mul, LOP>(s, lgrad, req[0], workspace,
ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, mshadow::op::mul, ROP>(s, rgrad, req[1], workspace,
ograd, lhs, rhs);
}
template<typename xpu, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TShape new_lshape, new_rshape, new_oshape;
bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
if (!need_bc) {
BinaryBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImpl<xpu, NDim, DType, LOP, ROP>(
ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
});
});
}
}
#define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_