blob: 0ab24899042dc7d2daf83ab037eb7366f80fec86 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file control_flow.h
* \brief Function definitions of operators for controlling flow
*/
#ifndef MXNET_OPERATOR_TENSOR_CONTROL_FLOW_OP_H_
#define MXNET_OPERATOR_TENSOR_CONTROL_FLOW_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
/*! \brief Choose elements from x or y depending on condition.
* The condition, x, and y have the same shape.
* The returned array is formed by elements from x or y
* depending on the elements of condition.
*/
template<int req>
struct where {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
const DType* x, const DType* y) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i]? x[i] : y[i]));
}
};
/*! \brief Choose elements from x or y depending on condition
* The condition is a vector whose size is the same as the
* x's first dim size.
* The returned array is formed by rows from x or y depending on
* the condition's elements.
*/
template<int req>
struct where_batch {
// DType is the output data type
// CType is the condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
const DType* x, const DType* y, int M) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i]));
}
};
/*!
* \brief Template for calculating grad[x] and grad[y].
* template argument req is OpReqType; negate indicates
* whether the output is grad_x (negate=true)
* or grad_y (negate=false).
*/
template<int req, bool negate>
struct where_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
const DType* grad_in,
const CType* cond) {
KERNEL_ASSIGN(grad_out[i], req,
((0 == cond[i])^negate)? grad_in[i] : static_cast<DType>(0));
}
};
/*!
* \brief Template for calculating grad[x] and grad[y].
* template argument req is OpReqType; negate indicates
* whether the output is grad_x (negate=true)
* or grad_y (negate=false).
* The condition is a vector whose size is the same as the
* x's first dim size.
*/
template<int req, bool negate>
struct where_batch_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
const DType* grad_in,
const CType* cond, int M) {
KERNEL_ASSIGN(grad_out[i], req,
((0 == cond[i/M])^negate)? grad_in[i] : static_cast<DType>(0));
}
};
inline bool WhereOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
TShape tshape((*in_attrs)[1]);
if (!shape_assign(&tshape, (*in_attrs)[2])) return false;
if (!shape_assign(&tshape, (*out_attrs)[0])) return false;
SHAPE_ASSIGN_CHECK(*in_attrs, 1, tshape);
SHAPE_ASSIGN_CHECK(*in_attrs, 2, tshape);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
if ((*in_attrs)[0].ndim() == tshape.ndim()) {
if (!shape_assign(&tshape, (*in_attrs)[0])) return false;
SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape);
return true;
} else if ((*in_attrs)[0].ndim() == 1) {
return (*in_attrs)[0].Size() == tshape[0];
}
return false;
}
inline bool WhereOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
int dtype = -1;
if (!type_assign(&dtype, (*in_attrs)[1])) return false;
if (!type_assign(&dtype, (*in_attrs)[2])) return false;
if (!type_assign(&dtype, (*out_attrs)[0])) return false;
if (-1 == dtype) return false;
TYPE_ASSIGN_CHECK(*in_attrs, 1, dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 2, dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype);
return true;
}
template<typename xpu>
void WhereOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
using namespace mxnet_op;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& cond = inputs[0];
const TBlob& x = inputs[1];
const TBlob& y = inputs[2];
const TBlob& out = outputs[0];
if (out.Size() == 0) return;
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(cond.type_flag_, CType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (cond.shape_ == x.shape_) {
Kernel<where<req_type>, xpu>::Launch(s, out.Size(), out.dptr<DType>(),
cond.dptr<CType>(), x.dptr<DType>(),
y.dptr<DType>());
} else {
Kernel<where_batch<req_type>, xpu>::Launch(s, out.Size(), out.dptr<DType>(),
cond.dptr<CType>(), x.dptr<DType>(),
y.dptr<DType>(), x.Size()/cond.Size());
}
});
});
});
}
/*!
* \brief Compute the gradient of the loss function
* with respect to condition, x, and y. The gradient
* with respect to condition is always 0. The gradient
* with respect to x and y depends on the corresponding
* elements in the condition.
* The inputs are gradient with respect to the output
* of the operator, condition, x, and y.
* The outputs are gradients with respect to
* condition, x, and y.
*/
template<typename xpu>
void WhereOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(req.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
using namespace mxnet_op;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& grad_in = inputs[0];
const TBlob& cond = inputs[1];
const TBlob& grad_x = outputs[0];
const TBlob& grad_y = outputs[1];
if (grad_in.Size() == 0) return;
MSHADOW_TYPE_SWITCH(grad_in.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(cond.type_flag_, CType, {
bool same_shape = (cond.shape_ == grad_in.shape_);
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type_x, {
if (same_shape) {
Kernel<where_backward<req_type_x, true>, xpu>::Launch(s, grad_in.Size(),
grad_x.dptr<DType>(), grad_in.dptr<DType>(), cond.dptr<CType>());
} else {
Kernel<where_batch_backward<req_type_x, true>, xpu>::Launch(s, grad_in.Size(),
grad_x.dptr<DType>(), grad_in.dptr<DType>(), cond.dptr<CType>(),
grad_in.Size()/cond.Size());
}
});
MXNET_ASSIGN_REQ_SWITCH(req[1], req_type_y, {
if (same_shape) {
Kernel<where_backward<req_type_y, false>, xpu>::Launch(s, grad_in.Size(),
grad_y.dptr<DType>(), grad_in.dptr<DType>(), cond.dptr<CType>());
} else {
Kernel<where_batch_backward<req_type_y, false>, xpu>::Launch(s, grad_in.Size(),
grad_y.dptr<DType>(), grad_in.dptr<DType>(), cond.dptr<CType>(),
grad_in.Size()/cond.Size());
}
});
});
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_CONTROL_FLOW_OP_H_