blob: 5525b9209fc140a41f7265c783c5ae59ee04ae70 [file] [log] [blame]
/*
* Copyright (c) 2005-2019, NumPy Developers.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.
*
* * Neither the name of the NumPy Developers nor the names of any
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*!
* \file np_einsum_op-inl.h
* \brief Function definition of numpy-compatible einsum operator
* modified by Haozheng Fan(@hzfan) from:
* https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/einsum.c.src
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
#include <mxnet/operator_util.h>
#include <string>
#include <vector>
#include <algorithm>
#include "./np_tensordot_op-inl.h"
#include "./np_einsum_path_op-inl.h"
#include "../../common/static_array.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
#define NPY_MAXDIMS 16
#define NPY_MAXARGS 16
inline TShape get_stride(const TShape& shape) {
int ndim = shape.ndim(), prod = 1;
TShape stride = TShape(ndim, -1);
for (int i = ndim - 1; i >= 0; i--) {
stride[i] = shape[i] > 1 ? prod : 0;
prod = prod * shape[i];
}
return stride;
}
inline TShape pad(const TShape& shape, int odim) {
int ndim = shape.ndim();
CHECK_GE(odim, ndim);
TShape ret(odim, 1);
for (int idim = 0; idim < ndim; ++idim) {
ret[idim] = shape[idim];
}
return ret;
}
/*
* Parses the subscripts for one operand into an output of 'ndim'
* labels. The resulting 'op_labels' array will have:
* - the ASCII code of the label for the first occurrence of a label;
* - the (negative) offset to the first occurrence of the label for
* repeated labels;
* - zero for broadcast dimensions, if subscripts has an ellipsis.
* For example:
* - subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2]
* - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99]
*/
inline int parse_operand_subscripts(const char* subscripts,
int length,
int ndim,
int iop,
char* op_labels,
char* label_counts,
int* min_label,
int* max_label) {
using namespace mxnet_op;
int i;
int idim = 0;
int ellipsis = -1;
/* Process all labels for this operand */
for (i = 0; i < length; ++i) {
int label = subscripts[i];
/* A proper label for an axis. */
if (label > 0 && isalpha(label)) {
/* Check we don't exceed the operator dimensions. */
CHECK(idim < ndim) << "einstein sum subscripts string contains "
<< "too many subscripts for operand " << iop;
op_labels[idim++] = label;
if (label < *min_label) {
*min_label = label;
}
if (label > *max_label) {
*max_label = label;
}
label_counts[label]++;
} else if (label == '.') {
/* The beginning of the ellipsis. */
/* Check it's a proper ellipsis. */
CHECK(
!(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
<< "einstein sum subscripts string contains a "
<< "'.' that is not part of an ellipsis ('...') "
<< "in operand " << iop;
ellipsis = idim;
} else {
CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
<< "' in einstein sum "
<< "subscripts string, subscripts must "
<< "be letters";
}
}
/* No ellipsis found, labels must match dimensions exactly. */
if (ellipsis == -1) {
CHECK(idim == ndim) << "operand has more dimensions than subscripts "
<< "given in einstein sum, but no '...' ellipsis "
<< "provided to broadcast the extra dimensions.";
} else if (idim < ndim) {
/* Ellipsis found, may have to add broadcast dimensions. */
/* Move labels after ellipsis to the end. */
for (i = 0; i < idim - ellipsis; ++i) {
op_labels[ndim - i - 1] = op_labels[idim - i - 1];
}
/* Set all broadcast dimensions to zero. */
for (i = 0; i < ndim - idim; ++i) {
op_labels[ellipsis + i] = 0;
}
}
/*
* Find any labels duplicated for this operand, and turn them
* into negative offsets to the axis to merge with.
*
* In C, the char type may be signed or unsigned, but with
* twos complement arithmetic the char is ok either way here, and
* later where it matters the char is cast to a signed char.
*/
for (idim = 0; idim < ndim - 1; ++idim) {
int label = op_labels[idim];
/* If it is a proper label, find any duplicates of it. */
if (label > 0) {
/* Search for the next matching label. */
char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
while (next != nullptr) {
/* The offset from next to op_labels[idim] (negative). */
*next = static_cast<char>((op_labels + idim) - next);
/* Search for the next matching label. */
next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
}
}
}
return 0;
}
/*
* Parses the subscripts for the output operand into an output that
* includes 'ndim_broadcast' unlabeled dimensions, and returns the total
* number of output dimensions, or -1 if there is an error. Similarly
* to parse_operand_subscripts, the 'out_labels' array will have, for
* each dimension:
* - the ASCII code of the corresponding label;
* - zero for broadcast dimensions, if subscripts has an ellipsis.
*/
inline int parse_output_subscripts(const char* subscripts,
int length,
int ndim_broadcast,
const char* label_counts,
char* out_labels) {
using namespace mxnet_op;
int i, bdim;
int ndim = 0;
int ellipsis = 0;
/* Process all the output labels. */
for (i = 0; i < length; ++i) {
int label = subscripts[i];
/* A proper label for an axis. */
if (label > 0 && isalpha(label)) {
/* Check that it doesn't occur again. */
CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
<< "einstein sum subscripts string includes "
<< "output subscript '" << static_cast<char>(label) << "' multiple times";
/* Check that it was used in the inputs. */
CHECK(label_counts[label] != 0)
<< "einstein sum subscripts string included "
<< "output subscript '" << static_cast<char>(label) << "' which never appeared "
<< "in an input";
/* Check that there is room in out_labels for this label. */
CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
<< "too many subscripts in the output";
out_labels[ndim++] = label;
} else if (label == '.') {
/* The beginning of the ellipsis. */
/* Check it is a proper ellipsis. */
CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
<< "einstein sum subscripts string "
<< "contains a '.' that is not part of "
<< "an ellipsis ('...') in the output";
/* Check there is room in out_labels for broadcast dims. */
CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts string contains "
<< "too many subscripts in the output";
ellipsis = 1;
for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
out_labels[ndim++] = 0;
}
} else {
CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
<< "' in einstein sum "
<< "subscripts string, subscripts must "
<< "be letters";
}
}
/* If no ellipsis was found there should be no broadcast dimensions. */
CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions than subscripts "
<< "given in einstein sum, but no '...' ellipsis "
<< "provided to broadcast the extra dimensions.";
return ndim;
}
inline void get_combined_dims_view(const TBlob& op,
int iop,
char* labels,
TShape* newshape,
TShape* newstride) {
using namespace mxnet_op;
int idim, ndim, icombine, combineoffset;
int icombinemap[NPY_MAXDIMS];
int newdim;
const TShape& shape = op.shape_;
TShape stride = get_stride(shape);
ndim = op.shape_.ndim();
newdim = newshape->ndim();
/* Initialize the dimensions and strides to zero */
for (idim = 0; idim < newdim; ++idim) {
(*newshape)[idim] = 0;
(*newstride)[idim] = 0;
}
/* Copy the dimensions and strides, except when collapsing */
icombine = 0;
for (idim = 0; idim < ndim; ++idim) {
/*
* The char type may be either signed or unsigned, we
* need it to be signed here.
*/
int label = (signed char)labels[idim];
/* If this label says to merge axes, get the actual label */
if (label < 0) {
combineoffset = label;
label = labels[idim + label];
} else {
combineoffset = 0;
if (icombine != idim) {
labels[icombine] = labels[idim];
}
icombinemap[idim] = icombine;
}
/* If the label is 0, it's an unlabeled broadcast dimension */
if (label == 0) {
(*newshape)[icombine] = shape[idim];
(*newstride)[icombine] = stride[idim];
} else {
/* Update the combined axis dimensions and strides */
int i = icombinemap[idim + combineoffset];
CHECK(!(combineoffset < 0 && (*newshape)[i] != 0 && (*newshape)[i] != shape[idim]))
<< "dimensions in operand " << iop << " for collapsing index '" << label
<< "' don't match (" << static_cast<int>((*newshape)[i]) << " != " << shape[idim] << ")";
(*newshape)[i] = shape[idim];
(*newstride)[i] += stride[idim];
}
/* If the label didn't say to combine axes, increment dest i */
if (combineoffset == 0) {
icombine++;
}
}
}
inline static int prepare_op_axes(int ndim,
int iop,
char* labels,
int* axes,
int ndim_iter,
char* iter_labels) {
using namespace mxnet_op;
int i, label, ibroadcast;
ibroadcast = ndim - 1;
for (i = ndim_iter - 1; i >= 0; --i) {
label = iter_labels[i];
/*
* If it's an unlabeled broadcast dimension, choose
* the next broadcast dimension from the operand.
*/
if (label == 0) {
while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
--ibroadcast;
}
/*
* If we used up all the operand broadcast dimensions,
* extend it with a "newaxis"
*/
if (ibroadcast < 0) {
axes[i] = -1;
} else {
/* Otherwise map to the broadcast axis */
axes[i] = ibroadcast;
--ibroadcast;
}
} else {
/* It's a labeled dimension, find the matching one */
char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
/* If the op doesn't have the label, broadcast it */
if (match == nullptr) {
axes[i] = -1;
} else {
/* Otherwise use it */
axes[i] = match - labels;
}
}
}
return 0;
}
struct NumpyEinsumParam : public dmlc::Parameter<NumpyEinsumParam> {
int num_args;
int optimize;
std::string subscripts;
DMLC_DECLARE_PARAMETER(NumpyEinsumParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1).describe("Number of input arrays.");
DMLC_DECLARE_FIELD(subscripts)
.set_default("")
.describe(
"Specifies the subscripts for summation as comma separated list"
" of subscript labels. An implicit (classical Einstein summation) calculation"
" is performed unless the explicit indicator '->' is included as well as"
" subscript labels of the precise output form.");
DMLC_DECLARE_FIELD(optimize).set_default(0);
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s, optimize_s, subscripts_s;
num_args_s << num_args;
optimize_s << optimize;
subscripts_s << subscripts;
(*dict)["num_args"] = num_args_s.str();
(*dict)["optimize"] = optimize_s.str();
(*dict)["subscripts"] = subscripts_s.str();
}
};
class EinsumOp {
public:
int num_args;
int optimize;
std::string subscripts;
std::shared_ptr<NDArray> tempspace;
std::vector<Step> paths;
explicit EinsumOp(int num_args, int optimize, std::string subscripts) {
this->num_args = num_args;
this->optimize = optimize;
this->subscripts = subscripts;
}
bool operator==(const EinsumOp& other) const {
return this->num_args == other.num_args && !this->subscripts.compare(other.subscripts) &&
this->optimize == other.optimize;
}
}; // class EinsumOp
template <int dimension, int req, bool back, typename AType>
struct numpy_einsum {
template <typename DType>
MSHADOW_XINLINE static void Map(
index_t i,
DType* out,
common::StaticArray<DType*, NPY_MAXARGS> op,
mshadow::Shape<dimension> oshape,
common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> ostride,
mshadow::Shape<dimension> reduceshape,
common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> rstride,
int nop,
int iop0,
const DType* out_grad) {
using namespace mxnet_op;
mshadow::Shape<dimension> oidx = unravel(i, oshape);
i = back ? dot(oidx, ostride[iop0]) : i;
if (req == kWriteTo) {
out[i] = (DType)0;
}
for (int rdim = 0; rdim < dimension; ++rdim) {
if (reduceshape[rdim] == 0) {
return;
}
}
mshadow::Shape<dimension> ridx = unravel(0, reduceshape);
AType sum = 0;
do {
AType tmp =
back ? static_cast<AType>(out_grad[dot(oidx, ostride[nop]) + dot(ridx, rstride[nop])])
: (AType)1;
for (int iop = 0; iop < nop; ++iop) {
if (iop != iop0) {
index_t k = dot(oidx, ostride[iop]) + dot(ridx, rstride[iop]);
tmp = tmp * static_cast<AType>(op[iop][k]);
}
}
sum = sum + tmp;
} while (inc(&ridx, reduceshape));
out[i] = out[i] + static_cast<DType>(sum);
}
};
template <typename xpu, bool back>
inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const char* subscripts,
int nop,
const OpContext& ctx) {
using namespace mxnet_op;
/* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
CHECK(nop < NPY_MAXARGS) << "too many operands provided to einstein sum function";
CHECK(nop >= 1) << "not enough operands provided to einstein sum function";
/* Step 1: Parse the subscripts string into label_counts and op_labels */
int iop, idim, min_label = 127, max_label = 0;
char label_counts[128], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
memset(label_counts, 0, sizeof(label_counts));
for (iop = 0; iop < nop; ++iop) {
int length = static_cast<int>(strcspn(subscripts, ",-"));
CHECK(!(iop == nop - 1 && subscripts[length] == ','))
<< "more operands provided to einstein sum function "
<< "than specified in the subscripts string";
CHECK(!(iop < nop - 1 && subscripts[length] != ','))
<< "fewer operands provided to einstein sum function "
<< "than specified in the subscripts string";
CHECK_GE(parse_operand_subscripts(subscripts,
length,
inputs[iop + back].shape_.ndim(),
iop,
op_labels[iop],
label_counts,
&min_label,
&max_label),
0);
/* Move subscripts to the start of the labels for the next op */
subscripts += length;
if (iop < nop - 1) {
subscripts++;
}
}
/*
* Find the number of broadcast dimensions, which is the maximum
* number of labels == 0 in an op_labels array.
*/
int ndim_broadcast = 0;
for (iop = 0; iop < nop; ++iop) {
int count_zeros = 0;
int ndim;
char* labels = op_labels[iop];
ndim = inputs[iop + back].shape_.ndim();
for (idim = 0; idim < ndim; ++idim) {
if (labels[idim] == 0) {
++count_zeros;
}
}
if (count_zeros > ndim_broadcast) {
ndim_broadcast = count_zeros;
}
}
/*
* If there is no output signature, fill output_labels and ndim_output
* using each label that appeared once, in alphabetical order.
*/
int label, ndim_output;
char output_labels[NPY_MAXDIMS];
if (subscripts[0] == '\0') {
/* If no output was specified, always broadcast left, as usual. */
for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
output_labels[ndim_output] = 0;
}
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] == 1) {
CHECK(ndim_output < NPY_MAXDIMS) << "einstein sum subscript string has too many "
<< "distinct labels";
output_labels[ndim_output++] = label;
}
}
} else {
CHECK(subscripts[0] == '-' && subscripts[1] == '>') << "einstein sum subscript string does not "
<< "contain proper '->' output specified";
subscripts += 2;
/* Parse the output subscript string. */
ndim_output = parse_output_subscripts(
subscripts, strlen(subscripts), ndim_broadcast, label_counts, output_labels);
CHECK_GE(ndim_output, 0);
}
/*
* Step 2:
* Process all the input ops, combining dimensions into their
* diagonal where specified.
*/
std::vector<TShape> opshape(nop), opstride_true(nop);
for (iop = 0; iop < nop; ++iop) {
char* labels = op_labels[iop];
int combine, ndim;
ndim = inputs[iop + back].shape_.ndim();
/*
* Check whether any dimensions need to be combined
*
* The char type may be either signed or unsigned, we
* need it to be signed here.
*/
combine = 0;
for (idim = 0; idim < ndim; ++idim) {
if ((signed char)labels[idim] < 0) {
combine++;
}
}
/* If any dimensions are combined, create a view which combines them */
if (combine) {
TShape tshape(ndim - combine, -1);
TShape tstride(ndim - combine, -1);
get_combined_dims_view(inputs[iop + back], iop, labels, &tshape, &tstride);
opshape[iop] = tshape;
opstride_true[iop] = tstride;
} else {
/* No combining needed */
opshape[iop] = inputs[iop + back].shape_;
opstride_true[iop] = get_stride(opshape[iop]);
}
}
/*
* Step 3:
* Set up the labels for the iterator (output + combined labels).
* Can just share the output_labels memory, because iter_labels
* is output_labels with some more labels appended.
*/
char* iter_labels = output_labels;
int ndim_iter = ndim_output;
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output) == nullptr) {
CHECK(ndim_iter < NPY_MAXDIMS) << "too many subscripts in einsum";
iter_labels[ndim_iter++] = label;
}
}
/* Step 4: Set up the op_axes for the iterator */
TShape itershape(ndim_iter, -1);
std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_;
TShape ostride_true = get_stride(oshape);
TShape reduceshape;
std::vector<TShape> remainshape(nop);
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int* op_axes[NPY_MAXARGS];
for (iop = 0; iop < nop; ++iop) {
op_axes[iop] = op_axes_arrays[iop];
CHECK_GE(prepare_op_axes(
opshape[iop].ndim(), iop, op_labels[iop], op_axes[iop], ndim_iter, iter_labels),
0);
for (idim = 0; idim < ndim_iter; idim++) {
if (op_axes[iop][idim] != -1) {
iterstride[iop][idim] = opstride_true[iop][op_axes[iop][idim]];
if (itershape[idim] != -1) {
if (itershape[idim] == 1) {
itershape[idim] = opshape[iop][op_axes[iop][idim]];
}
} else {
itershape[idim] = opshape[iop][op_axes[iop][idim]];
}
}
}
}
for (idim = 0; idim < ndim_output; ++idim) {
iterstride[nop][idim] = ostride_true[idim];
}
reduceshape = TShape(ndim_iter - ndim_output, 0);
for (idim = ndim_output; idim < ndim_iter; ++idim) {
reduceshape[idim - ndim_output] = itershape[idim];
}
for (iop = 0; iop < nop; iop++) {
std::vector<size_t> rsh;
for (idim = 0; idim < ndim_iter; idim++) {
if (op_axes_arrays[iop][idim] == -1 ||
itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]]) {
rsh.push_back(itershape[idim]);
}
}
remainshape[iop] = TShape(rsh.begin(), rsh.end());
}
// exclude the 0-dim case
if (ndim_iter == 0) {
ndim_iter = 1;
}
itershape = pad(itershape, ndim_iter);
for (iop = 0; iop <= nop; ++iop) {
iterstride[iop] = pad(iterstride[iop], ndim_iter);
}
oshape = pad(oshape, ndim_iter);
reduceshape = pad(reduceshape, ndim_iter);
for (iop = 0; iop < nop; ++iop) {
opshape[iop] = pad(opshape[iop], ndim_iter);
remainshape[iop] = pad(remainshape[iop], ndim_iter);
}
if (!back) {
if (oshape.Size() == 0) {
return;
}
const TBlob& out_data = outputs[0];
MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
for (iop = 0; iop < nop; ++iop) {
op[iop] = inputs[iop].dptr<DType>();
}
MXNET_ASSIGN_REQ_SWITCH(
req[0], req_type, {MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> ostride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> rstride_arr;
for (iop = 0; iop < nop; ++iop) {
mshadow::Shape<dimension> otmp, rtmp;
for (idim = 0; idim < dimension; ++idim) {
otmp[idim] = idim < ndim_output ? iterstride[iop][idim] : 1;
rtmp[idim] =
idim < dimension - ndim_output ? iterstride[iop][idim + ndim_output] : 1;
}
ostride_arr[iop] = otmp;
rstride_arr[iop] = rtmp;
}
Kernel<numpy_einsum<dimension, req_type, 0, AType>, xpu>::Launch(
ctx.get_stream<xpu>(),
oshape.Size(),
out_data.dptr<DType>(),
op,
oshape.get<dimension>(),
ostride_arr,
reduceshape.get<dimension>(),
rstride_arr,
nop,
-1,
reinterpret_cast<DType*>(NULL));
})})
})
} else {
if (oshape.Size() == 0) {
for (iop = 0; iop < nop; ++iop) {
const TBlob& out_data = outputs[iop];
if (opshape[iop].Size() > 0) {
MSHADOW_TYPE_SWITCH(
out_data.type_flag_, DType, {MXNET_ASSIGN_REQ_SWITCH(req[iop], req_type, {
if (req_type == kWriteTo) {
out_data.FlatTo1D<xpu, DType>(ctx.get_stream<xpu>()) = 0;
}
})})
}
}
return;
}
for (int i = 0; i < nop; ++i) {
const TBlob& out_data = outputs[i];
const TBlob& out_grad = inputs[0];
std::vector<TShape> opstride(nop + 1, TShape(ndim_iter, 0));
std::vector<TShape> remainstride(nop + 1, TShape(ndim_iter, 0));
for (iop = 0; iop <= nop; ++iop) {
int j = 0;
for (idim = 0; idim < ndim_iter; ++idim) {
if (op_axes_arrays[i][idim] == -1 ||
(iop != nop && opshape[i][op_axes_arrays[i][idim]] == 1 &&
op_axes_arrays[iop][idim] != -1 && opshape[iop][op_axes_arrays[iop][idim]] != 1)) {
remainstride[iop][j++] = iterstride[iop][idim];
} else {
opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim];
}
}
}
MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
for (iop = 0; iop < nop; ++iop) {
op[iop] = inputs[iop + back].dptr<DType>();
}
MXNET_ASSIGN_REQ_SWITCH(
req[i], req_type, {MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> opstride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> remainstride_arr;
for (iop = 0; iop <= nop; ++iop) {
opstride_arr[iop] = opstride[iop].get<dimension>();
remainstride_arr[iop] = remainstride[iop].get<dimension>();
}
Kernel<numpy_einsum<dimension, req_type, 1, AType>, xpu>::Launch(
ctx.get_stream<xpu>(),
opshape[i].Size(),
out_data.dptr<DType>(),
op,
opshape[i].get<dimension>(),
opstride_arr,
remainshape[i].get<dimension>(),
remainstride_arr,
nop,
i,
out_grad.dptr<DType>());
})})
})
}
}
}
template <typename xpu>
inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
EinsumOp& state = state_ptr.get_state<EinsumOp>();
int num_args = state.num_args;
int optimize = state.optimize;
const char* subscripts = state.subscripts.c_str();
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), num_args);
CHECK_EQ(outputs.size(), 1U);
if (optimize == 0) {
NumpyEinsumProcess<xpu, 0>(inputs, req, outputs, subscripts, num_args, ctx);
return;
}
std::vector<Step>& paths = state.paths;
std::vector<std::vector<int> > pos;
std::string string_repr;
paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
int paths_len = paths.size();
size_t temp_space_size = 0, max_temp_space_size = 0;
std::vector<TBlob> operands(inputs), tmp_operands, temp_space_vec(paths_len - 1);
for (int i = 0; i + 1 < paths_len; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
state.tempspace.reset<NDArray>(new NDArray(
TShape(Shape1(temp_space_size)), ctx.run_ctx.ctx, false, outputs[0].type_flag_));
Tensor<xpu, 1, DType> temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
size_t begin = max_temp_space_size;
for (int i = 0; i < paths_len - 1; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_space_vec[i] = tblob.reshape(paths[i].oshape);
begin = begin + paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
tmp_operands.clear();
// We remove inds from right to left
for (const int& p : paths[i].contract_inds) {
tmp_operands.push_back(operands[p]);
operands.erase(operands.begin() + p);
}
bool handle_out = (i == paths_len - 1);
// Call tensordot if still possible
if (paths[i].do_blas) {
// Contract!
if (paths[i].do_einsum || handle_out) {
TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
max_temp_space.FlatTo1D<xpu, DType>(s) = 0;
max_temp_space = max_temp_space.reshape(paths[i].tshape);
size_t tensordot_tempspace_size =
TensordotWorkspaceSize<xpu>(paths[i].left_pos,
paths[i].right_pos,
tmp_operands[0],
tmp_operands[1],
max_temp_space,
std::vector<OpReqType>{OpReqType::kWriteTo});
Tensor<xpu, 1, char> tensordot_tempspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tensordot_tempspace_size), s);
TensordotImpl<xpu>(paths[i].left_pos,
paths[i].right_pos,
ctx,
tmp_operands[0],
tmp_operands[1],
max_temp_space,
std::vector<OpReqType>{OpReqType::kWriteTo},
tensordot_tempspace);
NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{max_temp_space},
handle_out ? req : std::vector<OpReqType>{OpReqType::kWriteTo},
handle_out ? outputs : std::vector<TBlob>{temp_space_vec[i]},
paths[i].blas2einsum_str.c_str(),
1,
ctx);
} else {
size_t tensordot_tempspace_size =
TensordotWorkspaceSize<xpu>(paths[i].left_pos,
paths[i].right_pos,
tmp_operands[0],
tmp_operands[1],
temp_space_vec[i],
std::vector<OpReqType>{OpReqType::kWriteTo});
Tensor<xpu, 1, char> tensordot_tempspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tensordot_tempspace_size), s);
TensordotImpl<xpu>(paths[i].left_pos,
paths[i].right_pos,
ctx,
tmp_operands[0],
tmp_operands[1],
temp_space_vec[i],
std::vector<OpReqType>{OpReqType::kWriteTo},
tensordot_tempspace);
}
} else {
NumpyEinsumProcess<xpu, 0>(tmp_operands,
handle_out ? req : std::vector<OpReqType>{OpReqType::kWriteTo},
handle_out ? outputs : std::vector<TBlob>{temp_space_vec[i]},
paths[i].einsum_str.c_str(),
tmp_operands.size(),
ctx);
}
if (!handle_out) {
operands.push_back(temp_space_vec[i]);
}
}
});
}
template <typename xpu>
inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow_op;
const EinsumOp& state = state_ptr.get_state<EinsumOp>();
int num_args = state.num_args;
int optimize = state.optimize;
const char* subscripts = state.subscripts.c_str();
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 1 + num_args);
CHECK_EQ(outputs.size(), num_args);
if (optimize == 0) {
NumpyEinsumProcess<xpu, 1>(inputs, req, outputs, subscripts, num_args, ctx);
return;
}
// calculate temporary space size for temp_grad
const std::vector<Step>& paths = state.paths;
int paths_len = paths.size();
size_t temp_space_size = 0, max_temp_space_size = 0;
for (int i = 0; i < paths_len - 1; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
// replay the forward process
std::vector<std::vector<int> > op_idx(paths_len + 1);
for (int i = 0; i <= paths_len; ++i) {
if (i == 0) {
op_idx[i].reserve(num_args);
for (int j = 0; j < num_args; ++j) {
op_idx[i].push_back(j + 1);
}
} else {
op_idx[i] = op_idx[i - 1];
// We remove inds from right to left
for (const int& p : paths[i - 1].contract_inds) {
op_idx[i].erase(op_idx[i].begin() + p);
}
op_idx[i].push_back(-static_cast<int>(i - 1));
}
}
// calculate temporary space size for tensordot
size_t tensordot_max_tempspace_size = 0;
size_t begin_tensordot_tempspace = 0;
std::vector<TBlob> temp_inputs, temp_outputs;
std::vector<OpReqType> temp_req;
std::vector<size_t> tensordot_tempspace_size;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
for (int i = 0; i < paths_len; ++i) {
temp_inputs.clear();
temp_outputs.clear();
temp_req.clear();
bool handle_out = (i == paths_len - 1);
if (handle_out) {
temp_inputs.push_back(inputs[0]);
} else {
temp_inputs.push_back(
TBlob(reinterpret_cast<DType*>(NULL), paths[i].oshape, xpu::kDevMask));
}
for (auto p : paths[i].contract_inds) {
int idx = op_idx[i][p];
if (idx >= 1) {
temp_inputs.push_back(inputs[idx]);
temp_outputs.push_back(outputs[idx - 1]);
temp_req.push_back(req[idx - 1]);
} else {
temp_inputs.push_back(
TBlob(reinterpret_cast<DType*>(NULL), paths[-idx].oshape, xpu::kDevMask));
temp_outputs.push_back(
TBlob(reinterpret_cast<DType*>(NULL), paths[-idx].oshape, xpu::kDevMask));
temp_req.push_back(OpReqType::kWriteTo);
}
}
size_t cur_tensordot_tempspace_size = 0;
if (paths[i].do_blas) {
if (paths[i].do_einsum) {
cur_tensordot_tempspace_size = TensordotBackwardWorkspaceSize<xpu>(
paths[i].left_pos,
paths[i].right_pos,
TBlob(reinterpret_cast<DType*>(NULL), paths[i].tshape, xpu::kDevMask),
temp_inputs[1],
temp_inputs[2],
temp_outputs[0],
temp_outputs[1],
temp_req);
} else {
cur_tensordot_tempspace_size = TensordotBackwardWorkspaceSize<xpu>(paths[i].left_pos,
paths[i].right_pos,
temp_inputs[0],
temp_inputs[1],
temp_inputs[2],
temp_outputs[0],
temp_outputs[1],
temp_req);
}
}
tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size);
tensordot_max_tempspace_size =
std::max(tensordot_max_tempspace_size, cur_tensordot_tempspace_size);
}
begin_tensordot_tempspace = temp_space_size;
temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType);
});
// allocate temporary space and propagate
std::vector<TBlob> temp_grad(paths_len - 1), temp_data(paths_len - 1);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
// allocate temporary space for gradients of intermediate results
Tensor<xpu, 1, DType> temp_space =
ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(temp_space_size), s);
size_t begin = max_temp_space_size;
for (int i = 0; i + 1 < paths_len; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_grad[i] = tblob.reshape(paths[i].oshape);
begin = begin + paths[i].oshape.Size();
}
// reinterprete ndarray for intermediate results
Tensor<xpu, 1, DType> ndarray_space = state.tempspace->data().FlatTo1D<xpu, DType>();
begin = max_temp_space_size;
for (int i = 0; i + 1 < paths_len; ++i) {
TBlob tblob = TBlob(ndarray_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_data[i] = tblob.reshape(paths[i].oshape);
begin = begin + paths[i].oshape.Size();
}
// go through the paths in the reversed order
for (int i = paths_len - 1; i >= 0; i--) {
temp_inputs.clear();
temp_outputs.clear();
temp_req.clear();
bool handle_out = (i == paths_len - 1);
if (handle_out) {
temp_inputs.push_back(inputs[0]);
} else {
temp_inputs.push_back(temp_grad[i]);
}
for (auto p : paths[i].contract_inds) {
int idx = op_idx[i][p];
if (idx >= 1) {
temp_inputs.push_back(inputs[idx]);
temp_outputs.push_back(outputs[idx - 1]);
temp_req.push_back(req[idx - 1]);
} else {
temp_inputs.push_back(temp_data[-idx]);
temp_outputs.push_back(temp_grad[-idx]);
temp_req.push_back(OpReqType::kWriteTo);
}
}
if (paths[i].do_blas) {
CHECK_EQ(temp_inputs.size(), 3U);
CHECK_EQ(temp_outputs.size(), 2U);
CHECK_EQ(temp_req.size(), 2U);
Tensor<xpu, 1, DType> tensordot_tempspace =
temp_space.Slice(begin_tensordot_tempspace, temp_space_size);
Tensor<xpu, 1, char> char_tempspace =
Tensor<xpu, 1, char>(reinterpret_cast<char*>(tensordot_tempspace.dptr_),
Shape1(tensordot_tempspace_size[i]),
tensordot_tempspace.stream_);
if (paths[i].do_einsum) {
TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
max_temp_space = max_temp_space.reshape(paths[i].tshape);
NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{temp_inputs[0]},
std::vector<OpReqType>{kWriteTo},
std::vector<TBlob>{max_temp_space},
paths[i].einsum2blas_str.c_str(),
1,
ctx);
TensordotBackwardImpl<xpu>(paths[i].left_pos,
paths[i].right_pos,
ctx,
max_temp_space,
temp_inputs[1],
temp_inputs[2],
temp_outputs[0],
temp_outputs[1],
temp_req,
char_tempspace);
} else {
TensordotBackwardImpl<xpu>(paths[i].left_pos,
paths[i].right_pos,
ctx,
temp_inputs[0],
temp_inputs[1],
temp_inputs[2],
temp_outputs[0],
temp_outputs[1],
temp_req,
char_tempspace);
}
} else {
NumpyEinsumProcess<xpu, 1>(temp_inputs,
temp_req,
temp_outputs,
paths[i].einsum_str.c_str(),
temp_outputs.size(),
ctx);
}
}
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_