blob: afdb49b0f564d16a302afeed32ad082be45f29ea [file] [log] [blame]
/*
* Copyright (c) 2005-2019, NumPy Developers.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.
*
* * Neither the name of the NumPy Developers nor the names of any
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*!
* \file np_einsum_op.cc
* \brief CPU Implementation of numpy-compatible einsum
*/
#include "./np_einsum_op-inl.h"
#include <cstdlib>
#include <cstring>
namespace mxnet {
namespace op {
inline std::vector<std::string> _parse_einsum_input(std::string subscripts,
const mxnet::ShapeVector& shapes) {
const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
std::bitset<MAXAXIS> einsum_symbols_set;
for (const char& c : einsum_symbols) {
einsum_symbols_set.set(c);
}
CHECK_NE(shapes.size(), 0U) << "No input operands";
auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
subscripts.erase(end_pos, subscripts.end());
// Ensure all characters are valid
for (const char& c : subscripts) {
if (c == '.' || c == ',' || c == '-' || c == '>') {
continue;
}
CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
}
// Check for proper "->"
if (subscripts.find('-') != std::string::npos || subscripts.find('>') != std::string::npos) {
bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 ||
std::count(subscripts.begin(), subscripts.end(), '>') > 1);
CHECK(!invalid && _count_substring(subscripts, "->") == 1)
<< "Subscripts can only contain one '->'.";
}
// Parse ellipses
if (subscripts.find('.') != std::string::npos) {
std::string used = subscripts;
used.erase(
std::remove_if(used.begin(),
used.end(),
[](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
used.end());
std::bitset<MAXAXIS> used_set = str2set(used);
std::string ellipse_inds = "";
for (const char& c : einsum_symbols) {
if (!used_set.test(static_cast<int>(c))) {
ellipse_inds.append(1, c);
}
}
int longest = 0;
std::string input_tmp, output_sub;
std::vector<std::string> split_subscripts;
bool out_sub;
if (subscripts.find("->") != std::string::npos) {
std::vector<std::string> tmp = split(subscripts, "->");
input_tmp = tmp[0];
output_sub = tmp[1];
split_subscripts = split(input_tmp, ",");
out_sub = true;
} else {
split_subscripts = split(subscripts, ",");
out_sub = false;
}
size_t size_split_subscripts = split_subscripts.size();
subscripts = "";
for (size_t i = 0; i < size_split_subscripts; ++i) {
const std::string& sub = split_subscripts[i];
if (sub.find('.') != std::string::npos) {
CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid Ellipses";
CHECK_EQ(_count_substring(sub, "..."), 1) << "Invalid Ellipses";
// Take into account numerical values
int ellipse_count = 0;
if (shapes[i].ndim() == 0) {
ellipse_count = 0;
} else {
ellipse_count = std::max(shapes[i].ndim(), 1);
ellipse_count -= sub.length() - 3;
}
if (ellipse_count > longest) {
longest = ellipse_count;
}
CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
if (ellipse_count == 0) {
split_subscripts[i].erase(sub.find("..."), 3);
} else {
std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
}
}
subscripts += split_subscripts[i];
if (i + 1 < size_split_subscripts) {
subscripts += ",";
}
}
std::string out_ellipse;
if (longest == 0) {
out_ellipse = "";
} else {
out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
}
if (out_sub) {
output_sub.replace(output_sub.find("..."), 3, out_ellipse);
subscripts += "->" + output_sub;
} else {
// Special care for outputless ellipses
std::bitset<MAXAXIS> out_ellipse_set = str2set(out_ellipse);
std::string tmp_subscripts = subscripts, output_subscript = "";
size_t len_tmp_subscripts = tmp_subscripts.length();
std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
for (size_t i = 0; i < len_tmp_subscripts; ++i) {
const char& c = tmp_subscripts[i];
if (c == ',') {
continue;
}
CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
if ((i == 0 || tmp_subscripts[i - 1] != c) &&
(i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
!out_ellipse_set.test(c)) {
output_subscript.append(1, c);
}
}
subscripts += "->" + out_ellipse + output_subscript;
}
}
// Build output string if does not exist
std::vector<std::string> ret(2);
if (subscripts.find("->") != std::string::npos) {
ret = split(subscripts, "->");
} else {
ret[0] = subscripts;
ret[1] = "";
// Build output subscripts
std::string tmp_subscripts = subscripts;
size_t len_tmp_subscripts = tmp_subscripts.length();
std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
for (size_t i = 0; i < len_tmp_subscripts; ++i) {
const char& c = tmp_subscripts[i];
if (c == ',') {
continue;
}
CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
if ((i == 0 || tmp_subscripts[i - 1] != c) &&
(i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
ret[1].append(1, c);
}
}
}
// Make sure output subscripts are in the input
std::bitset<MAXAXIS> input_subscripts_set = str2set(ret[0]);
for (const char& c : ret[1]) {
CHECK(input_subscripts_set.test(c))
<< "Output character " << c << " did not appear in the input";
}
// Make sure number operands is equivalent to the number of terms
CHECK_EQ(std::count(ret[0].begin(), ret[0].end(), ',') + 1, shapes.size())
<< "Number of einsum subscripts must be equal to the "
<< "number of operands.";
return ret;
}
bool NumpyEinsumShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const NumpyEinsumParam& param = nnvm::get<NumpyEinsumParam>(attrs.parsed);
const std::string& subscripts = param.subscripts;
int num_args = param.num_args;
CHECK_EQ(in_attrs->size(), num_args);
CHECK_EQ(out_attrs->size(), 1U);
for (int i = 0; i < num_args; i++) {
if (!shape_is_known(in_attrs->at(i))) {
return false;
}
}
// Parsing
std::vector<std::string> parsed_subscripts = _parse_einsum_input(subscripts, *in_attrs);
// Build a few useful list and sets
std::vector<std::string> input_list = split(parsed_subscripts[0], ",");
size_t isize = input_list.size();
// Get length of each unique dimension and ensure all dimensions are correct
dim_t dimension_dict[MAXAXIS];
memset(dimension_dict, -1, sizeof(dimension_dict));
for (size_t i = 0; i < isize; ++i) {
const std::string& term = input_list[i];
const TShape& sh = in_attrs->at(i);
CHECK_EQ(sh.ndim(), term.length())
<< "Einstein sum subscript " << input_list[i] << " does not contain the "
<< "correct number of indices for operand " << i << ".";
size_t len_term = term.length();
for (size_t j = 0; j < len_term; ++j) {
dim_t dim = sh[j];
const char& c = term[j];
if (dimension_dict[static_cast<int>(c)] != -1) {
// For broadcasting cases we always want the largest dim size
if (dimension_dict[static_cast<int>(c)] == 1) {
dimension_dict[static_cast<int>(c)] = dim;
}
CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
<< "Size of label '" << c << "' for operand " << i << " ("
<< dimension_dict[static_cast<int>(c)] << ") does not match previous terms (" << dim
<< ").";
} else {
dimension_dict[static_cast<int>(c)] = dim;
}
}
}
// Get oshape
const std::string& output_str = parsed_subscripts[1];
size_t odim = output_str.size();
TShape oshape(odim, -1);
for (size_t i = 0; i < odim; ++i) {
oshape[i] = dimension_dict[static_cast<int>(output_str[i])];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
size_t lim = static_cast<size_t>(std::numeric_limits<index_t>::max());
for (int i = 0; i < num_args; ++i) {
CHECK_LE(in_attrs->at(i).Size(), lim)
<< "Size of operand " << i << " exceeds the maximum index."
<< " Try setting `USE_INT64_TENSOR_SIZE`.";
}
CHECK_LE(oshape.Size(), lim) << "Size of output"
<< " exceeds the maximum index."
<< " Try setting `USE_INT64_TENSOR_SIZE`.";
return shape_is_known(oshape);
}
OpStatePtr CreateEinsumState(const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
const std::vector<int>& in_types) {
const NumpyEinsumParam& param = dmlc::get<NumpyEinsumParam>(attrs.parsed);
return OpStatePtr::Create<EinsumOp>(param.num_args, param.optimize, param.subscripts);
}
DMLC_REGISTER_PARAMETER(NumpyEinsumParam);
NNVM_REGISTER_OP(_npi_einsum)
.describe(R"doc()doc" ADD_FILELINE)
.set_attr_parser(ParamParser<NumpyEinsumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
const NumpyEinsumParam& param = dmlc::get<NumpyEinsumParam>(attrs.parsed);
return static_cast<uint32_t>(param.num_args);
})
.set_num_outputs(1)
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const nnvm::NodeAttrs& attrs) {
int num_args =
dmlc::get<NumpyEinsumParam>(attrs.parsed).num_args;
std::vector<std::string> ret;
ret.reserve(num_args);
for (int i = 0; i < num_args; i++) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyEinsumShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FCreateOpState>("FCreateOpState", CreateEinsumState)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>(1,
ResourceRequest::kTempSpace);
})
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", NumpyEinsumForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_einsum"})
.add_argument("data", "NDArray-or-Symbol[]", "List of eimsum operands")
.add_arguments(NumpyEinsumParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_npi_einsum)
.set_attr_parser(ParamParser<NumpyEinsumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
const NumpyEinsumParam& param = dmlc::get<NumpyEinsumParam>(attrs.parsed);
return static_cast<uint32_t>(param.num_args + 1);
})
.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
const NumpyEinsumParam& param = dmlc::get<NumpyEinsumParam>(attrs.parsed);
return static_cast<uint32_t>(param.num_args);
})
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>(1,
ResourceRequest::kTempSpace);
})
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", NumpyEinsumBackward<cpu>);
} // namespace op
} // namespace mxnet