blob: 73eb4e1daf548e6cf061b891f2c922370d3165c7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file diag_op-inl.h
* \brief Function definition of the diag op
* \author Istvan Fehervari, Zhijingcheng Yu
*/
#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
#define MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
#include <dmlc/parameter.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "./broadcast_reduce_op.h"
namespace mxnet {
namespace op {
struct DiagParam : public dmlc::Parameter<DiagParam> {
int k;
int32_t axis1;
int32_t axis2;
DMLC_DECLARE_PARAMETER(DiagParam) {
DMLC_DECLARE_FIELD(k)
.set_default(0)
.describe("Diagonal in question. The default is 0. "
"Use k>0 for diagonals above the main diagonal, "
"and k<0 for diagonals below the main diagonal. "
"If input has shape (S0 S1) k must be between -S0 and S1");
DMLC_DECLARE_FIELD(axis1)
.set_default(0)
.describe("The first axis of the sub-arrays of interest. "
"Ignored when the input is a 1-D array.");
DMLC_DECLARE_FIELD(axis2)
.set_default(1)
.describe("The second axis of the sub-arrays of interest. "
"Ignored when the input is a 1-D array.");
}
};
inline mxnet::TShape DiagShapeImpl(const mxnet::TShape& ishape, const int k,
const int32_t axis1, const int32_t axis2) {
if (ishape.ndim() == 1) {
auto s = ishape[0] + std::abs(k);
return mxnet::TShape({s, s});
}
int32_t x1 = CheckAxis(axis1, ishape.ndim());
int32_t x2 = CheckAxis(axis2, ishape.ndim());
CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the same axis " << x1;
auto h = ishape[x1];
auto w = ishape[x2];
if (k > 0) {
w -= k;
} else if (k < 0) {
h += k;
}
auto s = std::min(h, w);
if (s < 0) {
s = -1;
}
if (x1 > x2) {
std::swap(x1, x2);
}
int32_t n_dim = ishape.ndim() - 1;
mxnet::TShape oshape(n_dim, -1);
// remove axis1 and axis2 and append the new axis to the end
uint32_t idx = 0;
for (int i = 0; i <= n_dim; ++i) {
if (i != x1 && i != x2) {
oshape[idx++] = ishape[i];
}
}
oshape[n_dim - 1] = s;
return oshape;
}
inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(ishape)) {
return false;
}
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
mxnet::TShape oshape = DiagShapeImpl(ishape,
param.k,
param.axis1,
param.axis2);
if (shape_is_none(oshape)) {
LOG(FATAL) << "Diagonal does not exist.";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return shape_is_known(out_attrs->at(0));
}
inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
return (*out_attrs)[0] != -1;
}
template<int ndim, int req, bool back>
struct diag {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
mshadow::Shape<ndim> oshape,
mshadow::Shape<ndim> ishape,
index_t stride, index_t offset,
index_t base) {
using namespace mxnet_op;
index_t idx = i / base;
index_t j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - idx * base);
if (back) {
KERNEL_ASSIGN(out[j], req, a[i]);
} else {
KERNEL_ASSIGN(out[i], req, a[j]);
}
}
};
template<int req, bool back>
struct diag_gen {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
mshadow::Shape<2> oshape, int k) {
using namespace mxnet_op;
auto j = unravel(i, oshape);
if (j[1] == (j[0] + k)) {
auto l = j[0] < j[1] ? j[0] : j[1];
if (back) {
KERNEL_ASSIGN(out[l], req, a[i]);
} else {
KERNEL_ASSIGN(out[i], req, a[l]);
}
} else if (!back) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
}
}
};
template<typename xpu, bool back>
void DiagOpProcess(const TBlob& in_data,
const TBlob& out_data,
const mxnet::TShape& ishape,
const mxnet::TShape& oshape,
index_t dsize,
const DiagParam& param,
mxnet_op::Stream<xpu> *s,
const std::vector<OpReqType>& req) {
using namespace mxnet_op;
using namespace mshadow;
if (ishape.ndim() > 1) {
// input : (leading + i, body + i, trailing)
uint32_t x1 = CheckAxis(param.axis1, ishape.ndim());
uint32_t x2 = CheckAxis(param.axis2, ishape.ndim());
uint32_t idim = ishape.ndim(), odim = oshape.ndim();
uint32_t minx = x1, maxx = x2;
if (minx > maxx) {
std::swap(minx, maxx);
}
// merges contiguous axes that are not separated
// by axis1 or axis2 since they can be directly
// mapped to the output and there is no need
// to distinguish them
// (After this the input will have no more than
// three axes, hence improving the rave and
// unravel efficiency)
index_t oleading = 1,
obody = 1,
otrailing = 1;
for (uint32_t i = 0; i < minx; ++i) {
oleading *= ishape[i];
}
for (uint32_t i = minx + 1; i < maxx; ++i) {
obody *= ishape[i];
}
for (uint32_t i = maxx + 1; i < idim; ++i) {
otrailing *= ishape[i];
}
index_t ileading = oleading,
ibody = obody * ishape[minx],
itrailing = otrailing * ishape[maxx];
index_t stride1 = itrailing * obody,
stride2 = otrailing;
// stride1 + stride2 is the stride for
// iterating over the diagonal in question
if (x1 == maxx) {
std::swap(stride1, stride2);
}
// the extra index offset introduced by k
index_t offset;
int k = param.k;
if (k > 0) {
offset = stride2 * k;
} else if (k < 0) {
offset = stride1 * -k;
} else {
offset = 0;
}
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (back && req[0] != kAddTo && req[0] != kNullOp) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}
if (ileading == 1) {
Kernel<diag<2, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(obody, otrailing),
Shape2(ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
} else {
Kernel<diag<3, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape3(oleading, obody, otrailing),
Shape3(ileading, ibody, itrailing),
stride1 + stride2, offset, oshape[odim - 1]);
}
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<diag_gen<req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
in_data.dptr<DType>(), Shape2(oshape[0], oshape[1]),
param.k);
});
});
}
}
template<typename xpu>
void DiagOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[0], kWriteTo);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const mxnet::TShape& ishape = inputs[0].shape_;
const mxnet::TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req);
}
template<typename xpu>
void DiagOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const mxnet::TShape& ishape = inputs[0].shape_;
const mxnet::TShape& oshape = outputs[0].shape_;
const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
DiagOpProcess<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_