| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file la_op.h |
| * \brief Function definition of Operators for advanced linear algebra. |
| */ |
| #ifndef MXNET_OPERATOR_TENSOR_LA_OP_H_ |
| #define MXNET_OPERATOR_TENSOR_LA_OP_H_ |
| |
| #include <mxnet/operator_util.h> |
| #include <mxnet/imperative.h> |
| #include <vector> |
| #include <algorithm> |
| #include <string> |
| #include "../mshadow_op.h" |
| #include "../mxnet_op.h" |
| #include "../operator_common.h" |
| #include "../elemwise_op_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| // Parameters for general matrix-matrix multiply-accumulate (mac) |
| struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> { |
| bool transpose_a, transpose_b; |
| double alpha, beta; |
| int axis; |
| DMLC_DECLARE_PARAMETER(LaMatrixMacParam) { |
| DMLC_DECLARE_FIELD(transpose_a) |
| .set_default(false) |
| .describe("Multiply with transposed of first input (A)."); |
| DMLC_DECLARE_FIELD(transpose_b) |
| .set_default(false) |
| .describe("Multiply with transposed of second input (B)."); |
| DMLC_DECLARE_FIELD(alpha).set_default(1.0).describe("Scalar factor multiplied with A*B."); |
| DMLC_DECLARE_FIELD(beta).set_default(1.0).describe("Scalar factor multiplied with C."); |
| DMLC_DECLARE_FIELD(axis).set_default(-2).describe("Axis corresponding to the matrix rows."); |
| } |
| }; |
| |
| // Parameters for general matrix-matrix multiply |
| struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> { |
| bool transpose_a, transpose_b; |
| double alpha; |
| int axis; |
| DMLC_DECLARE_PARAMETER(LaMatrixMultParam) { |
| DMLC_DECLARE_FIELD(transpose_a) |
| .set_default(false) |
| .describe("Multiply with transposed of first input (A)."); |
| DMLC_DECLARE_FIELD(transpose_b) |
| .set_default(false) |
| .describe("Multiply with transposed of second input (B)."); |
| DMLC_DECLARE_FIELD(alpha).set_default(1.0).describe("Scalar factor multiplied with A*B."); |
| DMLC_DECLARE_FIELD(axis).set_default(-2).describe( |
| "Axis corresponding to the matrix row indices."); |
| } |
| }; |
| |
| // Parameters for Cholesky factorization and matrix inversion |
| struct LaCholeskyParam : public dmlc::Parameter<LaCholeskyParam> { |
| bool lower; |
| DMLC_DECLARE_PARAMETER(LaCholeskyParam) { |
| DMLC_DECLARE_FIELD(lower).set_default(true).describe( |
| "True if the triangular matrix is lower triangular, false if it is upper triangular."); |
| } |
| void SetAttrDict(std::unordered_map<std::string, std::string>* dict) { |
| std::ostringstream lower_s; |
| lower_s << lower; |
| (*dict)["lower"] = lower_s.str(); |
| } |
| }; |
| |
| // Parameters for matrix-matrix multiplication where one is a triangular matrix. |
| struct LaTriangMatrixMultParam : public dmlc::Parameter<LaTriangMatrixMultParam> { |
| bool transpose; |
| bool rightside; |
| bool lower; |
| double alpha; |
| DMLC_DECLARE_PARAMETER(LaTriangMatrixMultParam) { |
| DMLC_DECLARE_FIELD(transpose).set_default(false).describe( |
| "Use transposed of the triangular matrix"); |
| DMLC_DECLARE_FIELD(rightside).set_default(false).describe( |
| "Multiply triangular matrix from the right to non-triangular one."); |
| DMLC_DECLARE_FIELD(lower).set_default(true).describe( |
| "True if the triangular matrix is lower triangular, false if it is upper triangular."); |
| DMLC_DECLARE_FIELD(alpha).set_default(1.0).describe( |
| "Scalar factor to be applied to the result."); |
| } |
| }; |
| |
| // Parameters for syrk |
| struct LaSyrkParam : public dmlc::Parameter<LaSyrkParam> { |
| bool transpose; |
| double alpha; |
| DMLC_DECLARE_PARAMETER(LaSyrkParam) { |
| DMLC_DECLARE_FIELD(transpose).set_default(false).describe("Use transpose of input matrix."); |
| DMLC_DECLARE_FIELD(alpha).set_default(1.0).describe( |
| "Scalar factor to be applied to the result."); |
| } |
| }; |
| |
| // Parameters for diag extraction/creation. |
| struct LaDiagParam : public dmlc::Parameter<LaDiagParam> { |
| int offset; |
| DMLC_DECLARE_PARAMETER(LaDiagParam) { |
| DMLC_DECLARE_FIELD(offset).set_default(0).describe( |
| "Offset of the diagonal versus the main diagonal. 0 corresponds to the main " |
| "diagonal, a negative/positive value to diagonals below/above the main diagonal."); |
| } |
| }; |
| |
| // Parameters for trian extraction/creation. |
| struct LaTrianParam : public dmlc::Parameter<LaTrianParam> { |
| int offset; |
| bool lower; |
| DMLC_DECLARE_PARAMETER(LaTrianParam) { |
| DMLC_DECLARE_FIELD(offset).set_default(0).describe( |
| "Offset of the diagonal versus the main diagonal. 0 corresponds to the main " |
| "diagonal, a negative/positive value to diagonals below/above the main diagonal."); |
| DMLC_DECLARE_FIELD(lower).set_default(true).describe( |
| "Refer to the lower triangular matrix if lower=true, refer to the upper otherwise." |
| " Only relevant when offset=0"); |
| } |
| }; |
| |
| // Common function for shape inference for matrix mult and matrix mac. |
| inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_GE(in_attrs->size(), 2); |
| CHECK_EQ(out_attrs->size(), 1); |
| bool transpose_a(false), transpose_b(false); |
| int axis_param(-2); |
| if (in_attrs->size() == 2) { |
| // Matrix-Matrix mult |
| transpose_a = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_a; |
| transpose_b = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_b; |
| axis_param = nnvm::get<LaMatrixMultParam>(attrs.parsed).axis; |
| } else { |
| // Matrix-Matrix mac |
| transpose_a = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_a; |
| transpose_b = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_b; |
| axis_param = nnvm::get<LaMatrixMacParam>(attrs.parsed).axis; |
| } |
| if ((*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == (*in_attrs)[1].ndim()) { |
| // Forward shape inference. |
| const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param); |
| CHECK(axis >= 0 && axis < ndim - 1) << "Invalid row axis (" << axis_param << ")"; |
| std::vector<int> oshape(ndim); |
| for (int i = 0; i < ndim - 1; ++i) { |
| if (i != axis) { |
| // Both inputs must have same shape except for row/col dimensions. |
| CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i]) |
| << "Shapes of inputs 0, 1 must be the same, except on row/col axis"; |
| } |
| oshape[i] = (*in_attrs)[0][i]; |
| } |
| CHECK_EQ((transpose_a ? (*in_attrs)[0][axis] : (*in_attrs)[0][ndim - 1]), |
| (transpose_b ? (*in_attrs)[1][ndim - 1] : (*in_attrs)[1][axis])) |
| << "Incompatible matrix dimensions for multiplication"; |
| oshape[axis] = (transpose_a ? (*in_attrs)[0][ndim - 1] : (*in_attrs)[0][axis]); |
| oshape[ndim - 1] = (transpose_b ? (*in_attrs)[1][axis] : (*in_attrs)[1][ndim - 1]); |
| mxnet::TShape tshape(oshape.begin(), oshape.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); |
| if (in_attrs->size() > 2) { |
| // Infer/check shape of third operand of a mac. |
| SHAPE_ASSIGN_CHECK(*in_attrs, 2, tshape); |
| } |
| return true; |
| } |
| // Can't do backward inference of shapes for this operator. |
| return false; |
| } |
| |
| inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed); |
| CHECK_EQ(in_attrs->size(), 2); |
| CHECK_EQ(out_attrs->size(), 1); |
| if ((*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == (*in_attrs)[1].ndim()) { |
| // Forward shape inference. |
| const int ndim((*in_attrs)[0].ndim()); |
| CHECK_EQ((*in_attrs)[0][ndim - 2], (*in_attrs)[0][ndim - 1]) |
| << "First operand must be a tensor of square matrices"; |
| std::vector<int> oshape(ndim); |
| for (int i = 0; i < ndim - 2; ++i) { |
| // Must have same shape except for last two dimensions. |
| CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i]) |
| << "Shapes of inputs 0, 1 must be the same, except on last two dimensions"; |
| oshape[i] = (*in_attrs)[0][i]; |
| } |
| if (param.rightside) { |
| // We compute B * A where A is the first and B the second input. |
| CHECK_EQ((*in_attrs)[0][ndim - 2], (*in_attrs)[1][ndim - 1]) |
| << "Incompatible matrix dimensions for multiplication"; |
| oshape[ndim - 2] = (*in_attrs)[1][ndim - 2]; |
| oshape[ndim - 1] = (param.transpose ? (*in_attrs)[0][ndim - 2] : (*in_attrs)[0][ndim - 1]); |
| } else { |
| // We compute A * B where A is the first and B the second input. |
| CHECK_EQ((*in_attrs)[1][ndim - 2], (*in_attrs)[0][ndim - 1]) |
| << "Incompatible matrix dimensions for multiplication"; |
| oshape[ndim - 2] = (param.transpose ? (*in_attrs)[0][ndim - 1] : (*in_attrs)[0][ndim - 2]); |
| oshape[ndim - 1] = (*in_attrs)[1][ndim - 1]; |
| } |
| mxnet::TShape tshape(oshape.begin(), oshape.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); |
| return true; |
| } |
| if ((*out_attrs)[0].ndim() >= 2) { |
| // Backward shape inference. |
| const int odim((*out_attrs)[0].ndim()); |
| std::vector<int> ishape1(odim), ishape2(odim); |
| for (int i = 0; i < odim - 2; ++i) { |
| ishape1[i] = ishape2[i] = (*out_attrs)[0][i]; |
| } |
| if (param.rightside) { |
| // We compute B * A where A is the first and B the second input. |
| ishape2[odim - 2] = (*out_attrs)[0][odim - 2]; |
| ishape1[odim - 2] = ishape1[odim - 1] = ishape2[odim - 1] = (*out_attrs)[0][odim - 1]; |
| } else { |
| // We compute A * B where A is the first and B the second input. |
| ishape2[odim - 1] = (*out_attrs)[0][odim - 1]; |
| ishape1[odim - 2] = ishape1[odim - 1] = ishape2[odim - 2] = (*out_attrs)[0][odim - 2]; |
| } |
| mxnet::TShape tshape1(ishape1.begin(), ishape1.end()); |
| SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape1); |
| mxnet::TShape tshape2(ishape2.begin(), ishape2.end()); |
| SHAPE_ASSIGN_CHECK(*in_attrs, 1, tshape2); |
| return true; |
| } |
| return false; |
| } |
| |
| template <int dim> |
| inline bool LaReduceShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| // Shape for reduction of the dim lowest dimensions to a scalar. |
| // Can only deduct in forward direction. |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 1); |
| const int ndim((*in_attrs)[0].ndim()); |
| if (ndim < dim) { |
| return false; |
| } |
| std::vector<int> oshape(std::max(1, ndim - dim)); |
| oshape[0] = 1; |
| for (int i = 0; i < ndim - dim; ++i) { |
| oshape[i] = (*in_attrs)[0][i]; |
| } |
| // Will reduce all matrices/vectors to a scalar. |
| mxnet::TShape tshape(oshape.begin(), oshape.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); |
| return true; |
| } |
| |
| template <bool diag, bool extract> |
| inline bool LaDiagTrianShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 1); |
| const int ndim((*in_attrs)[0].ndim()); |
| // Only infer in forward direction |
| if (ndim == 0) { |
| return false; |
| } |
| const int offset = (diag ? nnvm::get<LaDiagParam>(attrs.parsed).offset : |
| nnvm::get<LaTrianParam>(attrs.parsed).offset); |
| std::vector<int> oshape(extract ? ndim - 1 : ndim + 1); |
| for (int i = 0; i < ndim - 1; ++i) { |
| oshape[i] = (*in_attrs)[0][i]; |
| } |
| if (extract) { |
| CHECK_GE(ndim, 2) << "Input operand must be a tensor of matrices"; |
| CHECK_EQ((*in_attrs)[0][ndim - 2], (*in_attrs)[0][ndim - 1]) |
| << "Input operand must be a tensor of square matrices"; |
| const int n((*in_attrs)[0][ndim - 1] - abs(offset)); |
| CHECK_GT(n, 0) << "Illegal offset " << offset |
| << " for diag/trian extraction of matrix with dimension " << ndim; |
| oshape[ndim - 2] = (diag ? n : (n * (n + 1)) / 2); |
| } else if (diag) { |
| oshape[ndim] = oshape[ndim - 1] = (*in_attrs)[0][ndim - 1] + abs(offset); |
| } else { |
| const int n((*in_attrs)[0][ndim - 1]); |
| const int m(std::floor(0.5 + (std::sqrt(8 * n + 1) - 1.0) * 0.5)); |
| CHECK_EQ((m * (m + 1)) / 2, n) |
| << "Input tensor of maketrian has an invalid dimension for the last axis."; |
| oshape[ndim] = oshape[ndim - 1] = m + abs(offset); |
| } |
| mxnet::TShape tshape(oshape.begin(), oshape.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); |
| return true; |
| } |
| |
| // Shape inference function for linalg_syrk |
| inline bool LaSyrkShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 1); |
| const mxnet::TShape& in_attr = (*in_attrs)[0]; |
| bool transpose = nnvm::get<LaSyrkParam>(attrs.parsed).transpose; |
| const int ndim = in_attr.ndim(); |
| if (ndim >= 2) { |
| // Forward shape inference. |
| std::vector<int> oshape(ndim); |
| for (int i = 0; i < ndim - 2; ++i) { |
| oshape[i] = in_attr[i]; |
| } |
| oshape[ndim - 2] = (transpose ? in_attr[ndim - 1] : in_attr[ndim - 2]); |
| oshape[ndim - 1] = oshape[ndim - 2]; |
| mxnet::TShape tshape(oshape.begin(), oshape.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); |
| return true; |
| } |
| // Can't do backward inference of shapes for this operator. |
| return false; |
| } |
| |
| // Shape inference function for linalg_gelqf |
| // Inputs: A. Outputs: Q, L |
| inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 2); |
| const mxnet::TShape& in_a = (*in_attrs)[0]; |
| const mxnet::TShape& out_q = (*out_attrs)[0]; |
| const mxnet::TShape& out_l = (*out_attrs)[1]; |
| if (in_a.ndim() >= 2) { |
| // Forward shape inference. |
| const int ndim(in_a.ndim()); |
| CHECK_LE(in_a[ndim - 2], in_a[ndim - 1]) |
| << "Input A shape wrong: Last dimension must be >= than second to last"; |
| // Q must have same shape as A |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a); |
| std::vector<int> oshape_l(ndim); |
| for (int i = 0; i < ndim - 1; ++i) { |
| oshape_l[i] = in_a[i]; |
| } |
| oshape_l[ndim - 1] = in_a[ndim - 2]; |
| mxnet::TShape tshape_l(oshape_l.begin(), oshape_l.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 1, tshape_l); |
| return true; |
| } |
| if (out_q.ndim() >= 2 && out_q.ndim() == out_l.ndim()) { |
| // Backward shape inference. |
| const int ndim(out_q.ndim()); |
| for (int i = 0; i < ndim - 1; ++i) { |
| CHECK_EQ(out_q[i], out_l[i]) << "Outputs Q, L must have same dimensions except for last"; |
| } |
| CHECK_LE(out_q[ndim - 2], out_q[ndim - 1]) |
| << "Output Q shape wrong: Last dimension must be >= than second to last"; |
| CHECK_EQ(out_l[ndim - 2], out_l[ndim - 1]) |
| << "Output L shape wrong: Last two dimensions must be equal"; |
| SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_q); |
| return true; |
| } |
| return false; |
| } |
| |
| // Shape inference function for linalg_inverse |
| // Inputs: A. Outputs: inverse(A) |
| inline bool InverseShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 1); |
| const mxnet::TShape& in = (*in_attrs)[0]; |
| if (!ndim_is_known(in)) |
| return false; |
| const int ndim(in.ndim()); |
| CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; |
| CHECK_EQ(in[ndim - 2], in[ndim - 1]) << "Input A's last two dimension must be equal"; |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, in); |
| return shape_is_known(in); |
| } |
| |
| // Shape inference function for det functions in linalg |
| template <int onum> |
| inline bool DetShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), onum + 2); |
| const mxnet::TShape& in = (*in_attrs)[0]; |
| if (!ndim_is_known(in)) |
| return false; |
| const int ndim(in.ndim()); |
| CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; |
| CHECK_EQ(in[ndim - 2], in[ndim - 1]) << "Input A's last two dimension must be equal"; |
| mxnet::TShape out; |
| if (ndim == 2) { |
| if (Imperative::Get()->is_np_shape() || in.Size() == 0U) { |
| out = mxnet::TShape(0, 1); |
| } else { |
| out = mxnet::TShape(1, 1); |
| } |
| } else { |
| out = mxnet::TShape(in.begin(), in.end() - 2); |
| } |
| for (int i = 0; i < onum; ++i) { |
| SHAPE_ASSIGN_CHECK(*out_attrs, i, out); /* sign or det or logdet */ |
| } |
| SHAPE_ASSIGN_CHECK(*out_attrs, onum, in); /* LU */ |
| SHAPE_ASSIGN_CHECK(*out_attrs, onum + 1, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */ |
| return shape_is_known(in); |
| } |
| |
| // Type inference function for det functions in linalg |
| template <int onum> |
| inline bool DetType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_type, |
| std::vector<int>* out_type) { |
| using namespace mshadow; |
| CHECK_EQ(in_type->size(), 1); |
| CHECK_EQ(out_type->size(), onum + 2); |
| const int dtype = (*in_type)[0]; |
| if (dtype == -1) |
| return false; |
| CHECK(dtype == kFloat32 || dtype == kFloat64) |
| << "This operation only supports 32-bit and 64-bit floating point"; |
| for (int i = 0; i < onum; ++i) { |
| TYPE_ASSIGN_CHECK(*out_type, i, dtype); /* sign or det or logdet */ |
| } |
| TYPE_ASSIGN_CHECK(*out_type, onum, dtype); /* LU */ |
| TYPE_ASSIGN_CHECK(*out_type, onum + 1, index_type_flag); /* pivot */ |
| return true; |
| } |
| |
| // Shape inference function for linalg_syevd |
| // Inputs: A. Outputs: U, L |
| inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 2); |
| const mxnet::TShape& in_a = (*in_attrs)[0]; |
| const mxnet::TShape& out_u = (*out_attrs)[0]; |
| const mxnet::TShape& out_l = (*out_attrs)[1]; |
| if (in_a.ndim() >= 2) { |
| // Forward shape inference. |
| const int ndim(in_a.ndim()); |
| CHECK_EQ(in_a[ndim - 2], in_a[ndim - 1]) |
| << "Input A shape wrong: Last two dimensions must be equal"; |
| // U must have same shape as A |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a); |
| std::vector<int> oshape_l(ndim - 1); |
| for (int i = 0; i < ndim - 1; ++i) { |
| oshape_l[i] = in_a[i]; |
| } |
| mxnet::TShape tshape_l(oshape_l.begin(), oshape_l.end()); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 1, tshape_l); |
| return true; |
| } |
| if (out_u.ndim() >= 2 && out_u.ndim() == out_l.ndim() + 1) { |
| // Backward shape inference. |
| const int ndim(out_u.ndim()); |
| for (int i = 0; i < ndim - 1; ++i) { |
| CHECK_EQ(out_u[i], out_l[i]) << "Outputs U, L must have same dimensions except for last"; |
| } |
| CHECK_EQ(out_u[ndim - 2], out_u[ndim - 1]) |
| << "Output U shape wrong: Last two dimensions must be equal"; |
| // A must have same shape as U |
| SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_u); |
| return true; |
| } |
| return false; |
| } |
| |
| // Flattener for following adaptors. |
| template <typename xpu, int dim, typename DType> |
| mshadow::Tensor<xpu, dim, DType> LaOpFlatten(const TBlob& blob, |
| mshadow::Stream<xpu>* s, |
| int axis = -2) { |
| if (axis < 0) { |
| axis = blob.ndim() + axis; |
| } |
| if (axis >= blob.ndim() - 2) { |
| // Leave highest axis, collapse rest. |
| return blob.FlatToKD<xpu, dim, DType>(s); |
| } |
| // Collapse ranges [0,axis-1] and [axis+1,ndim-2]. |
| CHECK_EQ(dim, 4); |
| mxnet::TShape shape(dim, -1); |
| shape[0] = 1; |
| for (int i = 0; i < axis; ++i) { |
| shape[0] *= blob.shape_[i]; |
| } |
| shape[1] = blob.shape_[axis]; |
| shape[2] = 1; |
| for (int i = axis + 1; i < blob.ndim() - 1; ++i) { |
| shape[2] *= blob.shape_[i]; |
| } |
| shape[3] = blob.shape_[blob.ndim() - 1]; |
| return blob.get_with_shape<xpu, dim, DType>(shape.get<dim>(), s); |
| } |
| |
| // Adapters for calling the various operators with appropriate signatures. |
| |
| template <typename xpu, typename DType, int idim, int odim, int inum, int onum, typename laop> |
| struct LaOpCaller { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| CHECK(false) << "no specialized LaOpCaller defined for template parameters"; |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 1, 2, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[1], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[1], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[1], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[2], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[1], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[2], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[1], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 4, 1, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[1], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[2], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[3], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[1], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[2], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[3], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[1], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, int idim, int odim, typename laop> |
| struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| int axis = -2) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(LaOpFlatten<xpu, idim + 1, DType>(inputs[0], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[1], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[2], s, axis), |
| LaOpFlatten<xpu, idim + 1, DType>(inputs[3], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[0], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[1], s, axis), |
| LaOpFlatten<xpu, odim + 1, DType>(outputs[2], s, axis), |
| ctx, |
| attrs); |
| } |
| }; |
| |
| template <typename xpu, int idim, int odim, int inum, int onum, typename laop> |
| void LaOpForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| CHECK_EQ(inputs.size(), inum); |
| CHECK_EQ(outputs.size(), onum); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, attrs, ctx); |
| }); |
| } |
| |
| template <typename xpu, int idim, int odim, int inum, int onum, typename laop> |
| void LaOpBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| CHECK_EQ(inputs.size(), inum); |
| CHECK_EQ(outputs.size(), onum); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| std::vector<TBlob> tspace(outputs); |
| for (int i = 0; i < onum; ++i) { |
| if (req[i] == kAddTo) { |
| tspace[i].dptr_ = |
| ctx.requested[0].get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_; |
| } |
| } |
| LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, tspace, attrs, ctx); |
| for (int i = 0; i < onum; ++i) { |
| if (req[i] == kAddTo) { |
| Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s); |
| out += tspace[i].FlatTo1D<xpu, OType>(s); |
| } |
| } |
| }); |
| } |
| |
| template <typename xpu, int idim, int odim, int inum, int onum, typename laop> |
| void LaOpGemmForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| CHECK_EQ(inputs.size(), inum); |
| CHECK_EQ(outputs.size(), onum); |
| const int axis(inputs.size() == 2 ? nnvm::get<LaMatrixMultParam>(attrs.parsed).axis : |
| nnvm::get<LaMatrixMacParam>(attrs.parsed).axis); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| if (axis == -2 || axis == inputs[0].ndim() - 2) { |
| LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, attrs, ctx); |
| } else { |
| LaOpCaller<xpu, OType, idim + 1, odim + 1, inum, onum, laop>::op( |
| inputs, outputs, attrs, ctx, axis); |
| } |
| }); |
| } |
| |
| template <typename xpu, int idim, int odim, int inum, int onum, typename laop> |
| void LaOpGemmBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| CHECK_EQ(inputs.size(), inum); |
| CHECK_EQ(outputs.size(), onum); |
| const int axis(inputs.size() == 3 ? nnvm::get<LaMatrixMultParam>(attrs.parsed).axis : |
| nnvm::get<LaMatrixMacParam>(attrs.parsed).axis); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| std::vector<TBlob> tspace(outputs); |
| for (int i = 0; i < onum; ++i) { |
| if (req[i] == kAddTo) { |
| tspace[i].dptr_ = |
| ctx.requested[0].get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_; |
| } |
| } |
| if (axis == -2 || axis == inputs[0].ndim() - 2) { |
| LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, attrs, ctx); |
| } else { |
| LaOpCaller<xpu, OType, idim + 1, odim + 1, inum, onum, laop>::op( |
| inputs, outputs, attrs, ctx, axis); |
| } |
| for (int i = 0; i < onum; ++i) { |
| if (req[i] == kAddTo) { |
| Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s); |
| out += tspace[i].FlatTo1D<xpu, OType>(s); |
| } |
| } |
| }); |
| } |
| |
| // Specific wrapper for syevd (cannot use the default ones, because A, U have |
| // different dimensionality than L |
| |
| // (A) => (U, L) |
| template <typename xpu, typename laop> |
| void LaOpForwSyevd(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| CHECK_EQ(inputs.size(), 1); |
| CHECK_EQ(outputs.size(), 2); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(inputs[0].FlatToKD<xpu, 3, OType>(s), |
| outputs[0].FlatToKD<xpu, 3, OType>(s), |
| outputs[1].FlatToKD<xpu, 2, OType>(s), |
| ctx, |
| attrs); |
| }); |
| } |
| |
| // (dU, dL, U, L) => (dA) |
| template <typename xpu, typename laop> |
| void LaOpBackwSyevd(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| CHECK_EQ(inputs.size(), 4); |
| CHECK_EQ(outputs.size(), 1); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| std::vector<TBlob> tspace(outputs); |
| if (req[0] == kAddTo) { |
| tspace[0].dptr_ = |
| ctx.requested[0].get_space_typed<xpu, 1, OType>(Shape1(outputs[0].Size()), s).dptr_; |
| } |
| laop::op(inputs[0].FlatToKD<xpu, 3, OType>(s), |
| inputs[1].FlatToKD<xpu, 2, OType>(s), |
| inputs[2].FlatToKD<xpu, 3, OType>(s), |
| inputs[3].FlatToKD<xpu, 2, OType>(s), |
| tspace[0].FlatToKD<xpu, 3, OType>(s), |
| ctx, |
| attrs); |
| if (req[0] == kAddTo) { |
| Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s); |
| out += tspace[0].FlatTo1D<xpu, OType>(s); |
| } |
| }); |
| } |
| |
| template <typename xpu, typename DType, int onum, typename laop, typename IndexT> |
| struct LaOpDetForwardCaller { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx) { |
| CHECK(false) << "no specialized LaOpDetForward defined for template parameters"; |
| } |
| }; |
| template <typename xpu, typename DType, typename laop, typename IndexT> |
| struct LaOpDetForwardCaller<xpu, DType, 1, laop, IndexT> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(inputs[0].FlatToKD<xpu, 3, DType>(s), |
| outputs[0].FlatToKD<xpu, 1, DType>(s), |
| outputs[1].FlatToKD<xpu, 3, DType>(s), |
| outputs[2].FlatToKD<xpu, 2, IndexT>(s), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, typename laop, typename IndexT> |
| struct LaOpDetForwardCaller<xpu, DType, 2, laop, IndexT> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(inputs[0].FlatToKD<xpu, 3, DType>(s), |
| outputs[0].FlatToKD<xpu, 1, DType>(s), |
| outputs[1].FlatToKD<xpu, 1, DType>(s), |
| outputs[2].FlatToKD<xpu, 3, DType>(s), |
| outputs[3].FlatToKD<xpu, 2, IndexT>(s), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, int onum, typename laop> |
| void LaOpDetForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using IndexT = lapack_index_t; |
| CHECK_EQ(inputs.size(), 1); |
| CHECK_EQ(outputs.size(), onum + 2); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| LaOpDetForwardCaller<xpu, OType, onum, laop, IndexT>::op(inputs, outputs, attrs, ctx); |
| }); |
| } |
| |
| template <typename xpu, typename DType, int onum, typename laop, typename IndexT> |
| struct LaOpDetBackwardCaller { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx) { |
| CHECK(false) << "no specialized LaOpDetBackward defined for template parameters"; |
| } |
| }; |
| template <typename xpu, typename DType, typename laop, typename IndexT> |
| struct LaOpDetBackwardCaller<xpu, DType, 1, laop, IndexT> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(inputs[0].FlatToKD<xpu, 1, DType>(s), |
| inputs[1].FlatToKD<xpu, 1, DType>(s), |
| inputs[2].FlatToKD<xpu, 3, DType>(s), |
| inputs[3].FlatToKD<xpu, 2, IndexT>(s), |
| outputs[0].FlatToKD<xpu, 3, DType>(s), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, typename DType, typename laop, typename IndexT> |
| struct LaOpDetBackwardCaller<xpu, DType, 2, laop, IndexT> { |
| static void op(const std::vector<TBlob>& inputs, |
| const std::vector<TBlob>& outputs, |
| const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx) { |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| laop::op(inputs[0].FlatToKD<xpu, 1, DType>(s), |
| inputs[1].FlatToKD<xpu, 1, DType>(s), |
| inputs[2].FlatToKD<xpu, 1, DType>(s), |
| inputs[3].FlatToKD<xpu, 3, DType>(s), |
| inputs[4].FlatToKD<xpu, 2, IndexT>(s), |
| outputs[0].FlatToKD<xpu, 3, DType>(s), |
| ctx, |
| attrs); |
| } |
| }; |
| template <typename xpu, int onum, typename laop> |
| void LaOpDetBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using IndexT = lapack_index_t; |
| if (outputs[0].shape_.Size() == 0U) { |
| return; |
| } |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| CHECK_EQ(inputs.size(), onum + 3); |
| CHECK_EQ(outputs.size(), 1); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { |
| std::vector<TBlob> tspace(outputs); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| if (req[i] == kAddTo) { |
| tspace[i].dptr_ = |
| ctx.requested[0].get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_; |
| } |
| } |
| LaOpDetBackwardCaller<xpu, OType, onum, laop, IndexT>::op(inputs, tspace, attrs, ctx); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| if (req[i] == kAddTo) { |
| Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s); |
| out += tspace[i].FlatTo1D<xpu, OType>(s); |
| } |
| } |
| }); |
| } |
| |
| // Only transfer ddet and outputs to gradient |
| template <int onum> |
| struct ReduceDetGrad { |
| const char* op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n, |
| const std::vector<nnvm::NodeEntry>& ograds) { |
| std::vector<nnvm::NodeEntry> heads; |
| heads.push_back(ograds[onum - 1]); |
| uint32_t n_out = n->num_outputs(); |
| for (uint32_t i = 0; i < n_out; ++i) { |
| heads.emplace_back(nnvm::NodeEntry{n, i, 0}); |
| } |
| return MakeGradNode(op_name, n, heads, n->attrs.dict); |
| } |
| }; |
| |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_OPERATOR_TENSOR_LA_OP_H_ |