blob: b13906af662433107ee90e9289972957fc4393a0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file broadcast_reduce_op.h
* \brief Function definition of broadcast and reduce operators
*/
#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <utility>
#include <algorithm>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "./elemwise_binary_broadcast_op.h"
#include "../mxnet_op.h"
namespace mxnet {
namespace op {
struct ReduceAxesParam : public dmlc::Parameter<ReduceAxesParam> {
dmlc::optional<mxnet::TShape> axis;
bool keepdims;
bool exclude;
DMLC_DECLARE_PARAMETER(ReduceAxesParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<mxnet::TShape>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Negative values means indexing from right to left.)code");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
}
};
struct NormParam : public dmlc::Parameter<NormParam> {
int ord;
dmlc::optional<mxnet::TShape> axis;
bool keepdims;
DMLC_DECLARE_PARAMETER(NormParam) {
DMLC_DECLARE_FIELD(ord).set_default(2)
.describe("Order of the norm. Currently ord=1 and ord=2 is supported.");
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<mxnet::TShape>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices,
and the matrix norms of these matrices are computed.)code");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axis is left "
"in the result as dimension with size one.");
}
};
struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
dmlc::optional<int> axis;
bool keepdims;
DMLC_DECLARE_PARAMETER(ReduceAxisParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>())
.describe("The axis along which to perform the reduction. "
"Negative values means indexing from right to left. "
"``Requires axis to be set as int, because global reduction "
"is not supported yet.``");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axis is left "
"in the result as dimension with size one.");
}
};
enum PickOpMode {kWrap, kClip};
struct PickParam : public dmlc::Parameter<PickParam> {
dmlc::optional<int> axis;
int mode;
bool keepdims;
DMLC_DECLARE_PARAMETER(PickParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
.describe("int or None. The axis to picking the elements. "
"Negative values means indexing from right to left. "
"If is `None`, the elements in the index w.r.t the "
"flattened input will be picked.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If true, the axis where we pick the elements is left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(mode)
.add_enum("wrap", kWrap)
.add_enum("clip", kClip)
.set_default(kClip)
.describe("Specify how out-of-bound indices behave. Default is \"clip\"."
" \"clip\" means clip to the range. So, if all indices mentioned are too large,"
" they are replaced by the index that addresses the last element along an axis. "
" \"wrap\" means to wrap around.");
}
};
struct BroadcastAxesParam : public dmlc::Parameter<BroadcastAxesParam> {
mxnet::TShape axis;
mxnet::TShape size;
DMLC_DECLARE_PARAMETER(BroadcastAxesParam) {
DMLC_DECLARE_FIELD(axis).set_default(mxnet::TShape())
.describe("The axes to perform the broadcasting.");
DMLC_DECLARE_FIELD(size).set_default(mxnet::TShape())
.describe("Target sizes of the broadcasting axes.");
}
};
struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> {
mxnet::TShape shape;
DMLC_DECLARE_PARAMETER(BroadcastToParam) {
DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape())
.describe("The shape of the desired array."
" We can set the dim to zero if it's same as the original."
" E.g `A = broadcast_to(B, shape=(10, 0, 0))` "
"has the same meaning as `A = broadcast_axis(B, axis=0, size=10)`.");
}
};
struct BroadcastLikeParam : public dmlc::Parameter<BroadcastLikeParam> {
dmlc::optional<mxnet::TShape> lhs_axes;
dmlc::optional<mxnet::TShape> rhs_axes;
DMLC_DECLARE_PARAMETER(BroadcastLikeParam) {
DMLC_DECLARE_FIELD(lhs_axes).set_default(dmlc::optional<mxnet::TShape>())
.describe("Axes to perform broadcast on in the first input array");
DMLC_DECLARE_FIELD(rhs_axes).set_default(dmlc::optional<mxnet::TShape>())
.describe("Axes to copy from the second input array");
}
};
inline int CheckAxis(int axis, int ndim) {
CHECK(axis < ndim && axis >= -ndim)
<< "axis " << axis << " exceeds the input dimension of " << ndim;
return (axis + ndim)%ndim;
}
inline mxnet::TShape AxisShapeCompact(mxnet::TShape shape, int *axis, bool allow_2d) {
int ndim = static_cast<int>(shape.ndim());
index_t leading = 1, trailing = 1, M = shape[*axis];
for (int i = 0; i < *axis; ++i) leading *= shape[i];
for (int i = *axis + 1; i < ndim; ++i) trailing *= shape[i];
if (allow_2d && trailing == 1) {
*axis = 1;
return mshadow::Shape2(leading, M);
}
if (allow_2d && leading == 1) {
*axis = 0;
return mshadow::Shape2(M, trailing);
}
*axis = 1;
return mshadow::Shape3(leading, M, trailing);
}
inline mxnet::TShape ReduceAxisShapeImpl(const mxnet::TShape& ishape,
const dmlc::optional<int>& axis,
bool keepdims) {
if (!axis || ishape.ndim() == 1) {
if (keepdims) {
return mxnet::TShape(ishape.ndim());
}
return mshadow::Shape1(1);
}
int new_axis = CheckAxis(axis.value(), ishape.ndim());
if (keepdims) {
mxnet::TShape oshape = ishape;
oshape[new_axis] = 1;
return oshape;
}
mxnet::TShape oshape(ishape.ndim() - 1);
for (int i = 0; i < new_axis; ++i) oshape[i] = ishape[i];
for (int i = new_axis+1; i < static_cast<int>(ishape.ndim()); ++i) {
oshape[i-1] = ishape[i];
}
return oshape;
}
inline bool ReduceAxisShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxisShapeImpl(ishape, param.axis, param.keepdims));
return true;
}
inline mxnet::TShape ReduceAxesShapeImpl(const mxnet::TShape& ishape,
const dmlc::optional<mxnet::TShape>& axis,
bool keepdims, bool exclude) {
// if axis doesn't have value, treat it same mxnet::TShape().
if (!axis.has_value() || axis.value().ndim() == 0) {
if (keepdims) {
return mxnet::TShape(ishape.ndim());
} else {
return mxnet::TShape(1);
}
}
// axis has value
mxnet::TShape axes(axis.value());
for (index_t i = 0; i < axes.ndim(); i++) {
if (axes[i] < 0) {
axes[i] += ishape.ndim();
}
}
std::sort(axes.begin(), axes.end());
for (index_t i = 1; i < axes.ndim(); i++) {
CHECK_LT(axes[i-1], axes[i])
<< "Reduction axes have duplicates "
<< axes;
}
CHECK_LT(axes[axes.ndim()-1], ishape.ndim())
<< "Reduction axis " << axes[axes.ndim()-1]
<< " Exceeds input dimensions " << ishape;
CHECK_GE(axes[0], 0)
<< "Reduction axis " << axis.value()
<< " Exceeds input dimensions " << ishape;
mxnet::TShape oshape;
if (keepdims) {
oshape = mxnet::TShape(ishape);
} else if (exclude) {
oshape = mxnet::TShape(axes.ndim());
} else {
oshape = mxnet::TShape(std::max<index_t>(1, ishape.ndim() - axes.ndim()));
}
if (keepdims && exclude) {
for (index_t i = 0, j = 0; i < ishape.ndim(); ++i) {
if (j < axes.ndim() && i == axes[j]) {
++j;
continue;
}
oshape[i] = 1;
}
} else if (keepdims) {
for (index_t i = 0; i < axes.ndim(); ++i) {
oshape[axes[i]] = 1;
}
} else if (exclude) {
for (index_t i = 0; i < axes.ndim(); ++i) {
oshape[i] = ishape[axes[i]];
}
} else {
for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) {
if (j < axes.ndim() && i == axes[j]) {
++j;
continue;
}
oshape[k++] = ishape[i];
}
}
return oshape;
}
inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 0) return false;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxesShapeImpl((*in_attrs)[0], param.axis,
param.keepdims, param.exclude));
return true;
}
inline bool NormShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 0) return false;
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxesShapeImpl((*in_attrs)[0], param.axis,
param.keepdims, false));
return true;
}
inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 0) return false;
const BroadcastAxesParam& param = nnvm::get<BroadcastAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim() , param.size.ndim());
mxnet::TShape &ishape = (*in_attrs)[0];
mxnet::TShape oshape = ishape;
for (index_t i = 0; i < param.axis.ndim(); ++i) {
CHECK_EQ(oshape[param.axis[i]], 1U) << "Broadcasting axis must have size 1";
oshape[param.axis[i]] = param.size[i];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}
inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
CHECK_EQ(ishape.ndim(), param.shape.ndim())
<< "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape;
mxnet::TShape oshape = param.shape;
for (index_t i = 0; i < ishape.ndim(); ++i) {
if (oshape[i] != 0) {
CHECK(ishape[i] == oshape[i] || ishape[i] == 1)
<< "Array cannot be broadcasted from " << ishape << " to " << param.shape;
} else {
oshape[i] = ishape[i];
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}
inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& lhs_shape = (*in_attrs)[0];
mxnet::TShape& rhs_shape = (*in_attrs)[1];
if ((lhs_shape.ndim() == 0) || (lhs_shape.ndim() == 0)) {
return false;
}
const BroadcastLikeParam& param = nnvm::get<BroadcastLikeParam>(attrs.parsed);
mxnet::TShape oshape;
// lhs or rhs or both params were not specified
if (!param.lhs_axes.has_value() || !param.rhs_axes.has_value()) {
CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
<< "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape;
oshape = mxnet::TShape(rhs_shape);
for (index_t i = 0; i < lhs_shape.ndim(); ++i) {
if (rhs_shape[i] != 0) {
CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
<< "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
} else {
oshape[i] = lhs_shape[i];
}
}
} else {
auto lhs_axes = param.lhs_axes.value();
auto rhs_axes = param.rhs_axes.value();
CHECK(rhs_axes.ndim() == lhs_axes.ndim())
<< "Input_axis and other_axis size does not match";
CHECK(lhs_axes.ndim() > 0)
<< "Empty axes tuple is not allowed";
oshape = mxnet::TShape(lhs_shape);
for (index_t i = 0; i < lhs_axes.ndim(); ++i) {
auto copyfrom = lhs_axes[i];
if (copyfrom < 0) {
copyfrom = lhs_shape.ndim() + copyfrom;
}
CHECK(copyfrom >= 0 && copyfrom < oshape.ndim())
<< "Invalid dimension specified in lhs_axes: " << lhs_axes[i];
auto copyto = rhs_axes[i];
if (copyto < 0) {
copyto = rhs_shape.ndim() + copyto;
}
CHECK(copyto >= 0 && copyto < rhs_shape.ndim())
<< "Invalid dimension specified in rhs_axes: " << rhs_axes[i];
CHECK(lhs_shape[copyfrom] == 1) << "Input axis " << lhs_axes[i]
<< " at dimension " << i << " cannot be broadcasted to " << rhs_shape[copyto];
oshape[copyfrom] = rhs_shape[copyto];
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}
inline void BroadcastReduceShapeCompact(const mxnet::TShape& big, const mxnet::TShape& small,
mxnet::TShape *new_big, mxnet::TShape *new_small) {
index_t idim = std::max<index_t>(big.ndim(), MXNET_SPECIAL_MAX_NDIM);
*new_big = mxnet::TShape(idim);
*new_small = mxnet::TShape(idim);
index_t j = 0;
if (small.Size() == 1) {
(*new_big)[j++] = big.Size();
} else {
index_t bprod = 1, sprod = 1;
for (index_t i = 0, k = 0; i < big.ndim(); ++i) {
bool red_axis = big[i] != small[i];
if ((red_axis && sprod > 1) || (!red_axis && bprod != sprod)) {
(*new_big)[j] = bprod;
(*new_small)[j] = sprod;
bprod = sprod = 1; ++j;
}
bprod *= big[i];
if (red_axis) {
++k;
} else {
sprod *= big[i];
}
}
if (bprod > 1 || sprod > 1) {
(*new_big)[j] = bprod;
(*new_small)[j] = sprod;
++j;
}
}
if (j <= 2) {
new_small->assign(&(*new_small)[0], &(*new_small)[2]);
new_big->assign(&(*new_big)[0], &(*new_big)[2]);
} else if (j <= MXNET_SPECIAL_MAX_NDIM) {
new_small->assign(&(*new_small)[0], &(*new_small)[MXNET_SPECIAL_MAX_NDIM]);
new_big->assign(&(*new_big)[0], &(*new_big)[MXNET_SPECIAL_MAX_NDIM]);
} else {
LOG(FATAL) << "Too many reduction axes from " << big << " to " << small;
}
}
// infer storage function for sum(csr) and mean(csr)
inline bool ReduceAxesOpForwardStorage(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
const int in_stype = in_attrs->at(0);
int& out_stype = out_attrs->at(0);
// sum and reduce only supports for CPU for now.
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex =
invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
bool dispatched = false;
if (!dispatched && in_stype == kDefaultStorage) {
// When input is dense output storage is set as dense and dispatched to
// dense operator
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape();
if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 &&
(axis[0] == 0 || axis[0] == 1) && !param.keepdims && !param.exclude) {
// If input is csr and axis is 0 or 1, and neither of keepdims or exclude
// are set, dipsatch to sparse operator and output storage is set as dense
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
dispatch_ex);
}
if (!dispatched) {
// If input is csr, but keepdims or exclude is set or summing along a axis
// different from 0 or 1
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
template<typename xpu, typename reducer>
void SearchAxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
if (!param.axis) LOG(FATAL) << "Global reduction not supported yet";
int axis = CheckAxis(param.axis.value(), inputs[0].shape_.ndim());
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, false);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> out = outputs[0].get_with_shape<xpu, 2, DType>(
Shape2(shape[0], shape[2]), s);
Tensor<xpu, 3, DType> in = inputs[0].get_with_shape<xpu, 3, DType>(
shape.get<3>(), s);
CHECK(req[0] != kAddTo) << "AddTo is not supported";
ASSIGN_DISPATCH(out, req[0], (reduce_with_axis<reducer, true>(in, 1)));
});
}
template<typename xpu, typename reducer, bool normalize = false,
typename OP = op::mshadow_op::identity>
void ReduceAxesComputeImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& small) {
using namespace mshadow;
using namespace mshadow::expr;
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
size_t workspace_size = broadcast::ReduceWorkspaceSize<NDim, DType>(
s, out_data.shape_, req[0], in_data.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
broadcast::Reduce<reducer, NDim, DType, OP>(
s, out_data, req[0], workspace, in_data);
if (normalize) {
auto out = out_data.FlatTo2D<xpu, DType>(s);
out /= scalar<DType>(src_shape.Size()/dst_shape.Size());
}
});
});
}
template<typename xpu, typename reducer, bool normalize = false,
typename OP = op::mshadow_op::identity>
void ReduceAxesCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
mxnet::TShape small;
if (param.keepdims) {
small = outputs[0].shape_;
} else {
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude);
}
ReduceAxesComputeImpl<xpu, reducer, normalize, OP>(ctx, inputs, req, outputs, small);
}
template <typename red_op, int req, int axis>
struct ReduceCsrKernel;
template <typename red_op, int req>
/* \brief The number of columns are divided equally among the number of threads
* available.
* Each thread gets a subset of columns. It iterates through all rows for the
* subset of columns.
* In each iteration, it tries to do a binary search for the first column
* index between in_idx[in_indptr[row]] in_idx[in_indptr[row+1]]. After we find
* an index that is equal to the first column or close to the first column,
* it does a linear search for the rest of the indices and adds their data
* to the intermediate sum. At the end of iteration through all
* rows we have the sum along the axis for the subset of columns.
*/
struct ReduceCsrKernel<red_op, req, 0> {
template <typename RType, typename IType, typename DType>
MSHADOW_XINLINE static void Map(int j, DType* out_data,
const RType* in_indptr, const IType* in_idx,
const DType* in_data,
DType* sum,
DType* residual,
RType num_rows,
IType num_cols,
const nnvm::dim_t seg_len) {
const IType seg_start = j * seg_len;
if (seg_start >= num_cols) return;
const IType seg_end = std::min(seg_start + seg_len, num_cols);
for (RType row = 0; row < num_rows; ++row) {
// row specific seg starts
IType row_seg_start = seg_start;
IType row_seg_end = seg_end;
// Cache starting and ending indptr values for the row
IType row_indptr_start = in_indptr[row];
IType row_indptr_end = in_indptr[row + 1] - 1;
if (row_indptr_start == (row_indptr_end + 1)) continue;
// If row_seg_start is less than the first index for the row, move the
// row_seg_start forward
while (row_seg_start < in_idx[row_indptr_start] &&
row_seg_start < row_seg_end) {
row_seg_start++;
}
// If row_seg_start is greater than last index for the row, move on to
// the next row
if (row_seg_start > in_idx[row_indptr_end]) continue;
// Do binary search for row_seg_start between in_idx[in_indptr[i]] and
// in_idx[in_indptr[i + 1]]
IType start = row_indptr_start;
IType end = row_indptr_end;
// Initialize mid with the first indice of the row
IType mid = start;
while (start <= end) {
mid = start + (end - start) / 2;
if (in_idx[mid] == row_seg_start) {
break;
} else if (in_idx[mid] < row_seg_start) {
start = mid + 1;
} else {
end = mid - 1;
}
}
// At this point we have a in_idx[mid] which is close to row_seg_start
// Safety check to make sure mid is a valid indptr value
if (mid < row_indptr_start || mid > row_indptr_end)
mid = row_indptr_start;
// Linear search for nnzs for column subset between row_seg_start
// and row_seg_end
for (IType col = row_seg_start;
col < row_seg_end && mid <= row_indptr_end;) {
if (col == in_idx[mid]) {
red_op::Reduce(sum[col], in_data[mid], residual[col]);
mid++;
col++;
} else if (in_idx[mid] < col) {
mid++;
} else {
col++;
}
}
}
for (IType col = seg_start; col < seg_end; col++) {
KERNEL_ASSIGN(out_data[col], req, sum[col]);
}
}
};
template <typename red_op, int req>
struct ReduceCsrKernel<red_op, req, 1> {
template <typename RType, typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const RType* in_indptr,
const DType* in_data) {
DType sum, residual;
red_op::SetInitValue(sum, residual);
for (RType k = in_indptr[i]; k < in_indptr[i + 1]; k++) {
red_op::Reduce(sum, in_data[k], residual);
}
KERNEL_ASSIGN(out_data[i], req, sum);
}
};
template <typename xpu, typename red_op, bool normalize = false>
void ReduceCsrImpl(mshadow::Stream<xpu>* s, const OpContext& ctx,
const NDArray& input, const OpReqType req,
NDArray* output, const mxnet::TShape reduce_axis) {
if (req == kNullOp) return;
int64_t out_data_size = 0;
if (reduce_axis[0] == 0) {
out_data_size = input.shape()[1];
} else {
out_data_size = input.shape()[0];
}
// only dense output storage type is supported
CHECK_EQ(output->storage_type(), kDefaultStorage);
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace csr;
using nnvm::dim_t;
if (req == kWriteTo || req == kWriteInplace) {
MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, out_data_size,
output->data().dptr<DType>());
})
}
if (!input.storage_initialized()) {
return;
}
if (0 == reduce_axis[0]) {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(input.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const RType* in_indptr = input.aux_data(kIndPtr).dptr<RType>();
const IType* in_idx = input.aux_data(kIdx).dptr<IType>();
const DType* in_data = input.data().dptr<DType>();
const RType num_rows = input.shape()[0];
const IType num_cols = input.shape()[1];
dim_t num_threads = mxnet_op::get_num_threads<xpu>(16);
dim_t seg_len = (out_data_size + num_threads - 1) / num_threads;
mshadow::Tensor<xpu, 1, DType> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType>(
Shape1(2 * out_data_size), s);
mshadow::Tensor<xpu, 1, DType> sum(
reinterpret_cast<DType*>(workspace.dptr_),
Shape1(out_data_size));
mshadow::Tensor<xpu, 1, DType> residual(
reinterpret_cast<DType*>(workspace.dptr_ +
out_data_size),
Shape1(out_data_size));
Kernel<set_zero, xpu>::Launch(s, out_data_size, sum.dptr_);
Kernel<set_zero, xpu>::Launch(s, out_data_size, residual.dptr_);
Kernel<ReduceCsrKernel<red_op, req_type, 0>, xpu>::Launch(
s, num_threads, output->data().dptr<DType>(), in_indptr, in_idx,
in_data, sum.dptr_, residual.dptr_, num_rows, num_cols,
seg_len);
if (normalize) {
mxnet_op::Kernel<
mxnet_op::op_with_req<op::mshadow_op::div, req_type>,
xpu>::Launch(s, out_data_size, output->data().dptr<DType>(),
output->data().dptr<DType>(), DType(num_rows));
}
});
});
});
});
} else if (1 == reduce_axis[0]) {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(input.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(input.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const RType* in_indptr = input.aux_data(kIndPtr).dptr<RType>();
const DType* in_data = input.data().dptr<DType>();
const IType num_cols = input.shape()[1];
Kernel<ReduceCsrKernel<red_op, req_type, 1>, xpu>::Launch(
s, out_data_size, output->data().dptr<DType>(), in_indptr,
in_data);
if (normalize) {
mxnet_op::Kernel<
mxnet_op::op_with_req<op::mshadow_op::div, req_type>,
xpu>::Launch(s, out_data_size, output->data().dptr<DType>(),
output->data().dptr<DType>(), DType(num_cols));
}
});
});
});
});
}
}
/*! \brief If normalize is true, the mean should be computed instead of sum */
template <typename xpu, typename red_op, bool normalize = false>
void ReduceCsr(const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu>* s, const OpContext& ctx,
const NDArray& input, const OpReqType req, NDArray* output) {
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
CHECK(param.axis.has_value());
const mxnet::TShape axis = param.axis.value();
CHECK_EQ(axis.ndim(), 1U) << "sum(csr)/mean(csr) only supports axis 0 or 1";
CHECK(axis[0] == 0 || axis[0] == 1)
<< "sum(csr)/mean(csr) only support axis 0 or 1";
CHECK(!param.keepdims) << "keepdims not supported for sparse";
CHECK(!param.exclude) << "exclude not supported for sparse";
ReduceCsrImpl<xpu, red_op, normalize>(s, ctx, input, req, output, axis);
}
template <typename xpu, typename reducer, bool normalize = false>
void ReduceAxesOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType istype = inputs[0].storage_type();
if (istype == kCSRStorage) {
NDArray output = outputs[0];
ReduceCsr<xpu, mshadow::red::sum, normalize>(attrs, s, ctx, inputs[0],
req[0], &output);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
template<typename xpu, typename OP, bool normalize = false>
void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx,
const mxnet::TShape &small,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (dst_shape.ndim() == 2) {
Tensor<xpu, 2, DType> igrad =
outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
Tensor<xpu, 2, DType> ograd =
inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, DType> data =
inputs[1].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
Tensor<xpu, 2, DType> out =
inputs[2].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
ASSIGN_DISPATCH(igrad, req[0],
broadcast_to(ograd, src_shape)*F<OP>(data, broadcast_to(out, src_shape)));
if (normalize) igrad /= scalar<DType>(src_shape.Size()/dst_shape.Size());
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, DType> igrad =
outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> ograd =
inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> data =
inputs[1].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> out =
inputs[2].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
ASSIGN_DISPATCH(igrad, req[0],
broadcast_to(ograd, src_shape)*F<OP>(data, broadcast_to(out, src_shape)));
if (normalize) igrad /= scalar<DType>(src_shape.Size()/dst_shape.Size());
}
});
}
// works when shape inference of output is given
template<typename xpu, typename OP, bool normalize = false>
void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
mxnet::TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, param.exclude);
}
ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, outputs);
}
template<typename xpu>
inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& small) {
using namespace mshadow;
using namespace mshadow::expr;
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (dst_shape.ndim() == 2) {
Tensor<xpu, 2, DType> out =
outputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, DType> data =
inputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
ASSIGN_DISPATCH(out, req[0], broadcast_to(data, dst_shape));
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, DType> out =
outputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> data =
inputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
ASSIGN_DISPATCH(out, req[0], broadcast_to(data, dst_shape));
}
});
}
template<typename xpu>
inline void BroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
BroadcastComputeImpl<xpu>(attrs, ctx, inputs, req, outputs, inputs[0].shape_);
}
template<typename xpu, bool normalize = false>
inline void ReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
mxnet::TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, param.exclude);
}
BroadcastComputeImpl<xpu>(attrs, ctx, inputs, req, outputs, small);
if (normalize) {
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s);
igrad /= scalar<DType>(outputs[0].Size()/inputs[0].Size());
});
}
}
template<typename PType>
inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
PType param;
param.Init(attrs->dict);
attrs->parsed = std::move(param);
}
struct ReduceGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode(
op_name, n,
ograds, {n->inputs[0], nnvm::NodeEntry{n, 0, 0}},
n->attrs.dict);
}
};
inline bool LpNormStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const int in_stype = in_attrs->at(0);
int& out_stype = out_attrs->at(0);
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
bool dispatched = false;
// l2 norm on a particular axis only supports cpu
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex =
invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
if (!dispatched && in_stype == kDefaultStorage) {
// dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
if (param.ord == 2) {
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape();
if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) &&
axis.ndim() == 0 && param.ord == 2) {
// l2 norm: rsp/csr, axis = () -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims &&
(axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
// l2 norm: csr, axis = 0/1 -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
dispatch_ex);
}
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
/*! \brief compute square on each element and sum reducer */
struct sq_sum {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
dst += src * src;
}
/*! \brief do stable reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
DType y = src * src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
// This won't be called in backward.
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};
template<typename xpu>
void L2NormComputeImpl(mshadow::Stream<xpu> *s,
const TBlob& input,
const OpReqType req,
const TBlob& output) {
MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
// assign_req switch exits immediately for null req
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mshadow::Tensor<xpu, 1, DType> out = output.get_with_shape<xpu, 1, DType>(
mshadow::Shape1(output.shape_.Size()), s);
mshadow::Tensor<xpu, 1, DType> in = input.get_with_shape<xpu, 1, DType>(
mshadow::Shape1(input.shape_.Size()), s);
mshadow::VectorDot(out, in, in);
DType* out_data = output.dptr<DType>();
using namespace mxnet_op;
Kernel<op_with_req<mshadow_op::square_root, Req>, xpu>::Launch(
s, output.Size(), out_data, out_data);
});
});
}
template<typename xpu>
void SqRootForL2(const OpContext& ctx, OpReqType req, const TBlob &output) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
DType* out_data = output.dptr<DType>();
using namespace mxnet_op;
Kernel<op_with_req<mshadow_op::square_root, Req>, xpu>::Launch(
s, output.Size(), out_data, out_data);
});
});
}
template<typename xpu>
void LpNormCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
CHECK(param.ord == 1 || param.ord == 2) << "norm only supports ord=1 and ord=2";
if (req[0] == kNullOp) return;
mxnet::TShape small;
if (param.keepdims) {
small = outputs[0].shape_;
} else {
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
}
if (param.ord == 1) {
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
} else if (param.ord == 2) {
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
ctx, inputs, req, outputs, small);
}
}
template<typename xpu>
void LpNormGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
if (req[0] == kNullOp) return;
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
mxnet::TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false);
}
if (param.ord == 1) {
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (dst_shape.ndim() == 2) {
Tensor<xpu, 2, DType> ograd =
inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, DType> igrad =
outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
Tensor<xpu, 2, DType> data =
inputs[1].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
ASSIGN_DISPATCH(igrad, req[0],
broadcast_to(ograd, src_shape)*F<mshadow_op::sign>(data));
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, DType> igrad =
outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> ograd =
inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> data =
inputs[1].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
ASSIGN_DISPATCH(igrad, req[0],
broadcast_to(ograd, src_shape)*F<mshadow_op::sign>(data));
}
});
} else if (param.ord == 2) {
ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, inputs,
req, outputs);
}
}
template<typename xpu>
void L2NormComputeSparseImpl(mshadow::Stream<xpu> *s,
const NDArray& input,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp) return;
// input is zeros
if (!input.storage_initialized()) {
// Add zeros. No op.
if (req == kAddTo) return;
Fill<false>(s, output, req, 0);
} else {
L2NormComputeImpl(s, input.data(), req, output);
}
}
template<typename xpu>
void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
/*! \brief index element from array along axes */
template<int ndim, bool clip = true>
struct pick {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
const IType *idx, index_t M, int stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
index_t j = static_cast<index_t>(idx[i]);
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
} else {
j = j % M;
j += (j < 0) ? M : 0;
}
j = ravel(unravel(i, sshape), bshape) + j*stride;
out[i] = a[j];
}
};
/*! \brief index element from array along axes */
template<int ndim, bool clip = true>
struct pick_grad {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd,
const IType *idx, index_t M, int stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
index_t j = static_cast<index_t>(idx[i]);
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
} else {
j = j % M;
j += (j < 0) ? M : 0;
}
j = ravel(unravel(i, sshape), bshape) + j*stride;
igrad[j] += ograd[i];
}
};
inline bool PickOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
const PickParam& param = nnvm::get<PickParam>(attrs.parsed);
if (!param.axis) LOG(FATAL)
<< "axis=None is not supported by pick yet. Must specify an axis.";
mxnet::TShape oshape = ReduceAxisShapeImpl(ishape, param.axis, param.keepdims);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
if (!(*in_attrs)[1].ndim()) return false;
if ((*in_attrs)[1].ndim() == ishape.ndim()) {
SHAPE_ASSIGN_CHECK(*in_attrs, 1,
ReduceAxisShapeImpl(ishape, param.axis, true));
} else {
SHAPE_ASSIGN_CHECK(*in_attrs, 1,
ReduceAxisShapeImpl(ishape, param.axis, false));
}
return true;
}
inline bool PickOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE((*in_attrs)[1], -1) << "Index type must be set for pick operator";
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
return (*out_attrs)[0] != -1;
}
template<typename xpu>
void PickOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(req[0], kWriteTo);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const PickParam& param = nnvm::get<PickParam>(attrs.parsed);
const mxnet::TShape& ishape = inputs[0].shape_;
index_t axis = CheckAxis(param.axis.value(), ishape.ndim());
int leading = 1, trailing = 1, M = ishape[axis];
for (index_t i = 0; i < axis; ++i) leading *= ishape[i];
for (index_t i = axis+1; i < ishape.ndim(); ++i) trailing *= ishape[i];
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index type
if (param.mode == kWrap) {
if (trailing == 1) {
Kernel<pick<2, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick<3, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
} else {
if (trailing == 1) {
Kernel<pick<2, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick<3, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
}
});
});
}
template<typename xpu>
void PickOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const PickParam& param = nnvm::get<PickParam>(attrs.parsed);
const mxnet::TShape& ishape = outputs[0].shape_;
const index_t axis = CheckAxis(param.axis.value(), ishape.ndim());
int leading = 1, trailing = 1, M = ishape[axis];
for (index_t i = 0; i < axis; ++i) leading *= ishape[i];
for (index_t i = axis+1; i < ishape.ndim(); ++i) trailing *= ishape[i];
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index type
if (req[0] != kAddTo) outputs[0].FlatTo1D<xpu, DType>(s) = 0;
if (param.mode == kWrap) {
if (trailing == 1) {
Kernel<pick_grad<2, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick_grad<3, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
} else {
if (trailing == 1) {
Kernel<pick_grad<2, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick_grad<3, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
}
});
});
}
#define MXNET_OPERATOR_REGISTER_REDUCE_AXIS(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<ReduceAxisParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ReduceAxisShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.add_argument("data", "NDArray-or-Symbol", "The input") \
.add_arguments(ReduceAxisParam::__FIELDS__())
#define MXNET_OPERATOR_REGISTER_REDUCE(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(AxesParamParser<ReduceAxesParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ReduceAxesShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.add_argument("data", "NDArray-or-Symbol", "The input") \
.add_arguments(ReduceAxesParam::__FIELDS__())
#define MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(name) \
NNVM_REGISTER_OP(name) \
.set_num_outputs(1) \
.set_attr_parser(AxesParamParser<ReduceAxesParam>) \
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#define MXNET_OPERATOR_REGISTER_BROADCAST(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<nnvm::FGradient>("FGradient", \
[](const nnvm::NodePtr& n, \
const std::vector<nnvm::NodeEntry>& ograds) { \
return MakeNonlossGradNode("_broadcast_backward", n, ograds, {}, \
{{"keepdims", "true"}}); \
}) \
.add_argument("data", "NDArray-or-Symbol", "The input")
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_