blob: c284f4a80ea136248d15279739a9d51cf26f8b6a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file np_norm-.cc
* \brief CPU registration of np.linalg.norm
*/
#include "./np_norm-inl.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(NumpyNormParam);
inline bool NumpyLpNormShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
if (!shape_is_known((*in_attrs)[0])) return false;
const NumpyNormParam& param = nnvm::get<NumpyNormParam>(attrs.parsed);
const int ndim = (*in_attrs)[0].ndim();
if ((!param.axis.has_value() && param.flag != 0 && ndim > 2) ||
(param.axis.has_value() && param.axis.value().ndim() > 2))
LOG(FATAL) << "Improper number of dimensions to norm.";
if (!param.axis.has_value()) {
if ((ndim == 0 && param.flag != 0) || // for scalar
(ndim == 1 && (param.flag == 2)) ||
(ndim >= 2 && (param.ord == 0 || param.ord > 2 || param.ord < -2))) {
LOG(FATAL) << "Invalid norm order for inputs.";
}
} else {
if ((param.axis.value().ndim() == 0 && param.flag != 0) || // for scalar
(param.axis.value().ndim() == 1 && (param.flag == 2)) ||
(param.axis.value().ndim() == 2 && (param.ord == 0 || param.ord > 2 || param.ord < -2))) {
LOG(FATAL) << "Invalid norm order for inputs.";
}
}
if (!param.keepdims && (*in_attrs)[0].ndim() == 1) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1));
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims, false));
}
return true;
}
inline bool NumpyMatrixNormShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyNormParam& param = nnvm::get<NumpyNormParam>(attrs.parsed);
const int ndim = (*in_attrs)[0].ndim();
auto shape = swapMatDims((*in_attrs)[0], param.axis.value());
if (param.axis.value().ndim() == 2) {
int batch_dim = 1;
int row_dim = (*in_attrs)[0][param.axis.value()[0]];
int col_dim = (*in_attrs)[0][param.axis.value()[1]];
TShape out_shape(ndim - (param.keepdims ? 0 : 2), 1);
for (int i = 0; i < ndim - 2; ++i) {
batch_dim *= shape[i];
}
if (param.keepdims) {
out_shape = (*in_attrs)[0];
out_shape[param.axis.value()[0]] = 1;
out_shape[param.axis.value()[1]] = 1;
} else {
for (int i = 0; i < ndim - 2; ++i) {
out_shape[i] = shape[i];
}
}
int svd_dim = row_dim < col_dim ? row_dim : col_dim;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
if (param.ord == 2 || param.ord == -2) {
SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape({ batch_dim, row_dim, row_dim })); // UT
SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape({ batch_dim, svd_dim })); // L
SHAPE_ASSIGN_CHECK(*out_attrs, 3, TShape({ batch_dim, row_dim, col_dim })); // V
} else {
TShape sum_shape = (*in_attrs)[0];
TShape mat_axis = param.axis.value();
int sum_dim = mat_axis[!(param.ord == 1 || param.ord == -1)];
TShape small(3, 1);
sum_shape[sum_dim] = 1;
small[0] = sum_shape.ProdShape(0, sum_dim);
small[2] = sum_shape.ProdShape(sum_dim + 1, sum_shape.ndim());
SHAPE_ASSIGN_CHECK(*out_attrs, 1, small); // sum
SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape({ 0, 0 })); // L
SHAPE_ASSIGN_CHECK(*out_attrs, 3, TShape({ 0, 0, 0 })); // V
}
} else {
LOG(FATAL) << "Invalid norm or ord arguments.";
}
return true;
}
inline void assign_svd_empty(mxnet::ShapeVector *out_attrs) {
SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape({ 0, 0, 0 })); // UT
SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape({ 0, 0 })); // L
SHAPE_ASSIGN_CHECK(*out_attrs, 3, TShape({ 0, 0, 0 })); // V
}
bool NumpyNormType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 4U);
int in_type = in_attrs->at(0);
int out_type;
if (!common::is_float(in_type)) {
out_type = in_type;
LOG(WARNING) << "WARNING: Integer input to norm. This will result in integer "
"output which is different from standard NumPy behavior and "
"breaks gradient compute in backward. Please cast the input "
"to floating point types first.";
} else {
out_type = in_type;
}
for (int i = 0; i < 4; ++i) {
TYPE_ASSIGN_CHECK(*out_attrs, i, out_type);
}
return out_attrs->at(0) != -1;
}
bool NumpyNormShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 4U); // reduced, UT, S, V
const NumpyNormParam& param = nnvm::get<NumpyNormParam>(attrs.parsed);
if (!param.axis.has_value()) {
if (param.flag == -2) {
int ndim = param.keepdims ? (*in_attrs)[0].ndim() : 0;
int sz = param.keepdims ? 1 : -1;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(ndim, sz));
assign_svd_empty(out_attrs);
return true;
}
if ((*in_attrs)[0].ndim() >= 2) {
TShape axis(2, 0);
axis[0] = (*in_attrs)[0].ndim() - 2;
axis[1] = (*in_attrs)[0].ndim() - 1;
const_cast<NumpyNormParam&>(param).axis = axis;
return NumpyMatrixNormShape(attrs, in_attrs, out_attrs);
} else {
TShape axis(1, (*in_attrs)[0].ndim() - 1);
const_cast<NumpyNormParam&>(param).axis = axis;
assign_svd_empty(out_attrs);
return NumpyLpNormShape(attrs, in_attrs, out_attrs);
}
} else {
TShape axis(param.axis.value().ndim(), 0);
for (int i = 0; i < param.axis.value().ndim(); ++i) {
axis[i] = param.axis.value()[i] < 0 ?
(*in_attrs)[0].ndim() + param.axis.value()[i] :
param.axis.value()[i];
}
const_cast<NumpyNormParam&>(param).axis = axis;
if (param.axis.value().ndim() == 2) {
return NumpyMatrixNormShape(attrs, in_attrs, out_attrs);
} else {
assign_svd_empty(out_attrs);
return NumpyLpNormShape(attrs, in_attrs, out_attrs);
}
}
}
TShape swapMatDims(const TShape &shape, const TShape &axis) {
TShape ret(shape.ndim(), 1);
int i, j = 0;
for (i = 0; i < shape.ndim(); ++i) {
if (i != axis[0] && i != axis[1]) {
ret[j++] = shape[i];
}
}
ret[j++] = shape[axis[0]];
ret[j] = shape[axis[1]];
return ret;
}
TShape inverseTranspose(const TShape &axes) {
TShape ret(axes.ndim(), 1);
for (int i = 0; i < axes.ndim(); ++i) {
ret[axes[i]] = i;
}
return ret;
}
} // namespace op
} // namespace mxnet