blob: dac690fa70ecbc72d968b1bbb8157184c1ac1421 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file dot-inl.h
* \brief Function definition of matrix dot operator
*/
#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_H_
#define MXNET_OPERATOR_TENSOR_DOT_INL_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include <algorithm>
#include <utility>
#include <type_traits>
#include "./util/tensor_util-inl.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "./init_op.h"
#include "../mxnet_op.h"
#ifdef __CUDACC__
#include "./dot-inl.cuh"
#endif // __CUDACC__
namespace mxnet {
namespace op {
enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
enum DotOut { out = 0, out_min, out_max };
struct DotParam : public dmlc::Parameter<DotParam> {
bool transpose_a;
bool transpose_b;
dmlc::optional<int> forward_stype;
DMLC_DECLARE_PARAMETER(DotParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(transpose_b)
.describe("If true then transpose the second input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(forward_stype)
.describe(
"The desired storage type of the forward output given by user, if the"
"combination of input storage types and this hint does not match"
"any implemented ones, the dot operator will perform fallback operation"
"and still produce an output of the desired storage type.")
.add_enum("default", kDefaultStorage)
.add_enum("row_sparse", kRowSparseStorage)
.add_enum("csr", kCSRStorage)
.set_default(dmlc::optional<int>());
}
bool operator==(const DotParam& other) const {
return this->transpose_a == other.transpose_a && this->transpose_b == other.transpose_b &&
this->forward_stype == other.forward_stype;
}
std::string ForwardStype2String(int forward_stype) {
switch (forward_stype) {
case kDefaultStorage:
return "default";
case kRowSparseStorage:
return "row_sparse";
case kCSRStorage:
return "csr";
default:
LOG(FATAL) << "Unknown forward stype enum " << forward_stype;
}
LOG(FATAL) << "should not reach here ";
return "";
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream transpose_a_s, transpose_b_s, forward_stype_s;
transpose_a_s << transpose_a;
transpose_b_s << transpose_b;
forward_stype_s << forward_stype;
(*dict)["transpose_a"] = transpose_a_s.str();
(*dict)["transpose_b"] = transpose_b_s.str();
if (forward_stype.has_value()) {
(*dict)["forward_stype"] = ForwardStype2String(forward_stype.value());
} else {
(*dict)["forward_stype"] = forward_stype_s.str();
}
}
};
template <typename xpu>
void DotForward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::lhs].type_flag_)
<< "Binary function only support input/output with the same type";
CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::rhs].type_flag_)
<< "Binary function only support input/output with the same type";
CHECK(outputs[DotOut::out].type_flag_ == kFloat32 ||
outputs[DotOut::out].type_flag_ == kFloat64 ||
(outputs[DotOut::out].type_flag_ == kFloat16 &&
ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
MSHADOW_REAL_TYPE_SWITCH(outputs[DotOut::out].type_flag_, DType, {
// VectorDot() with fp16 is not supported in mshadow. Dispatch to dot() instead.
if (inputs[DotIn::lhs].ndim() == 1 && inputs[DotIn::rhs].ndim() == 1 &&
inputs[DotIn::lhs].type_flag_ != kFloat16) {
CHECK_NE(req[DotOut::out], kAddTo) << "AddTo not yet supported";
Tensor<xpu, 1, DType> out = outputs[DotOut::out].get<xpu, 1, DType>(s);
VectorDot(
out, inputs[DotIn::lhs].get<xpu, 1, DType>(s), inputs[DotIn::rhs].get<xpu, 1, DType>(s));
} else {
index_t ma, na, mb, nb, m, n;
if (param.transpose_a) {
ma = inputs[DotIn::lhs].size(0);
na = inputs[DotIn::lhs].Size() / ma;
m = na;
} else {
na = inputs[DotIn::lhs].size(inputs[DotIn::lhs].ndim() - 1);
ma = inputs[DotIn::lhs].Size() / na;
m = ma;
}
if (param.transpose_b) {
nb = inputs[DotIn::rhs].size(inputs[DotIn::rhs].ndim() - 1);
mb = inputs[DotIn::rhs].Size() / nb;
n = mb;
} else {
mb = inputs[DotIn::rhs].size(0);
nb = inputs[DotIn::rhs].Size() / mb;
n = nb;
}
Tensor<xpu, 2, DType> input0 =
inputs[DotIn::lhs].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
Tensor<xpu, 2, DType> input1 =
inputs[DotIn::rhs].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
Tensor<xpu, 2, DType> out =
outputs[DotOut::out].get_with_shape<xpu, 2, DType>(Shape2(m, n), s);
if (param.transpose_a && param.transpose_b) {
ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0.T(), input1.T()));
} else if (!param.transpose_a && param.transpose_b) {
ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0, input1.T()));
} else if (param.transpose_a && !param.transpose_b) {
ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0.T(), input1));
} else {
ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0, input1));
}
}
});
}
template <typename xpu>
void DotBackward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_NE(req[0], kWriteInplace);
CHECK_NE(req[1], kWriteInplace);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) {
Tensor<xpu, 1, DType> mout_grad = inputs[0].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> mlhs_data = inputs[1].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> mrhs_data = inputs[2].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> mlhs_grad = outputs[0].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> mrhs_grad = outputs[1].get<xpu, 1, DType>(s);
ASSIGN_DISPATCH(mrhs_grad, req[1], broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data);
ASSIGN_DISPATCH(mlhs_grad, req[0], broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data);
} else {
int ma, na, mb, nb, m, n;
if (param.transpose_a) {
ma = outputs[0].size(0);
na = outputs[0].Size() / ma;
m = na;
} else {
na = outputs[0].size(outputs[0].ndim() - 1);
ma = outputs[0].Size() / na;
m = ma;
}
if (param.transpose_b) {
nb = outputs[1].size(outputs[1].ndim() - 1);
mb = outputs[1].Size() / nb;
n = mb;
} else {
mb = outputs[1].size(0);
nb = outputs[1].Size() / mb;
n = nb;
}
Tensor<xpu, 2, DType> mout_grad = inputs[0].get_with_shape<xpu, 2, DType>(Shape2(m, n), s);
Tensor<xpu, 2, DType> mlhs_data = inputs[1].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
Tensor<xpu, 2, DType> mrhs_data = inputs[2].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
Tensor<xpu, 2, DType> mlhs_grad = outputs[0].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
Tensor<xpu, 2, DType> mrhs_grad = outputs[1].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
if (param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x.T, y.T)
// dy = dot(x, dz).T = dot(dz.T, x.T)
// dx = dot(dz, y).T = dot(y.T, dz.T)
ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T()));
ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T()));
} else if (!param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x, y.T)
// dy = dot(x.T, dz).T = dot(dz.T, x)
// dx = dot(dz, y)
ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data));
ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data));
} else if (param.transpose_a && !param.transpose_b) {
// Gradient of z = dot(x.T, y)
// dy = dot(x, dz)
// dx = dot(dz, y.T).T = dot(y, dz.T)
ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad));
ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T()));
} else {
// Gradient of z = dot(x, y)
// dy = dot(x.T, dz)
// dx = dot(dz, y.T)
ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad));
ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T()));
}
}
});
}
inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
// csr has many zero columns, so the result of dot(csr.T, matrix) should be
// rsp
const auto& lhs_stype = in_attrs->at(0);
const auto& rhs_stype = in_attrs->at(1);
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
bool only_lhs_transpose = param.transpose_a && !param.transpose_b;
bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
bool hint_has_value = param.forward_stype.has_value();
NDArrayStorageType target_stype =
hint_has_value ? static_cast<NDArrayStorageType>(param.forward_stype.value()) :
kUndefinedStorage;
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) {
// dns, dns -> dns
target_stype = hint_has_value ? target_stype : kDefaultStorage;
if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype,
kDefaultStorage,
dispatch_mode,
#if MXNET_USE_ONEDNN == 1
DispatchMode::kFComputeEx);
#else
DispatchMode::kFCompute);
#endif
}
}
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
// csr.T, rsp/dns -> rsp
target_stype = hint_has_value ? target_stype : kRowSparseStorage;
if (target_stype == kRowSparseStorage) {
dispatched = storage_type_assign(
&out_stype, kRowSparseStorage, dispatch_mode, DispatchMode::kFComputeEx);
// csr.T, rsp/dns -> dns
} else if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
}
if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a &&
!param.transpose_b) {
// csr, rsp/dns -> dns
target_stype = hint_has_value ? target_stype : kDefaultStorage;
if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a) {
target_stype = hint_has_value ? target_stype : kCSRStorage;
if (dev_mask == mshadow::cpu::kDevMask) {
// dns, csr -> csr on CPU
if (target_stype == kCSRStorage && !param.transpose_b) {
dispatched =
storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx);
// dns, csr/csr.T -> dns on CPU
} else if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
// dns, csr/csr.T -> dns on GPU
} else if (dev_mask == mshadow::gpu::kDevMask) {
if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
}
}
if (!dispatched) {
target_stype = (target_stype == kUndefinedStorage) ? kDefaultStorage : target_stype;
dispatched = storage_type_assign(
&out_stype, target_stype, dispatch_mode, DispatchMode::kFComputeFallback);
}
return dispatched;
}
inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 2U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
const auto& ograd_stype = in_attrs->at(0);
const auto& lhs_stype = in_attrs->at(1);
const auto& rhs_stype = in_attrs->at(2);
const bool no_transpose = !param.transpose_a && !param.transpose_b;
auto& lhs_grad_stype = out_attrs->at(0);
auto& rhs_grad_stype = out_attrs->at(1);
bool dispatched = false;
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage &&
ograd_stype == kDefaultStorage) {
if (type_assign(&lhs_grad_stype, kDefaultStorage) &&
type_assign(&rhs_grad_stype, kDefaultStorage)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFCompute);
dispatched = true;
}
}
if (!dispatched && no_transpose && lhs_stype == kCSRStorage &&
(ograd_stype == kRowSparseStorage || ograd_stype == kDefaultStorage)) {
// backward: csr.T, rsp/dns -> rsp, dns.T, rsp/dns -> dns
if (type_assign(&rhs_grad_stype, kRowSparseStorage) &&
type_assign(&lhs_grad_stype, kDefaultStorage)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
dispatched = true;
}
}
if (!dispatched && param.transpose_a && !param.transpose_b && lhs_stype == kCSRStorage &&
(ograd_stype == kRowSparseStorage || ograd_stype == kDefaultStorage)) {
// backward: csr, rsp/dns -> dns, dns, rsp/dns -> dns
if (type_assign(&rhs_grad_stype, kDefaultStorage) &&
type_assign(&lhs_grad_stype, kDefaultStorage)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
dispatched = true;
}
}
if (!dispatched && !param.transpose_a && lhs_stype == kDefaultStorage &&
rhs_stype == kCSRStorage && ograd_stype == kDefaultStorage) {
if (type_assign(&lhs_grad_stype, kDefaultStorage) &&
type_assign(&rhs_grad_stype, kDefaultStorage)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
dispatched = true;
}
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
/*!
* \brief CPU Kernel of dot(csr, dns1) = dns2
* Parallelization by row blocks
*/
struct DotCsrDnsDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows)
return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows);
for (dim_t j = seg_start; j < seg_end; ++j) {
if (indptr_l[j] == indptr_l[j + 1])
continue;
const dim_t offset_out = j * num_cols;
for (IType k = indptr_l[j]; k < indptr_l[j + 1]; ++k) {
const DType val = data_l[k];
const dim_t offset_r = col_idx_l[k] * num_cols;
for (dim_t l = 0; l < num_cols; ++l) {
out[offset_out + l] += data_r[offset_r + l] * val;
}
}
}
}
};
/*!
* \brief CPU Kernel of dot(csr.T(), dns1) = dns2
* Parallelization by row blocks
*/
struct DotCsrTransDnsDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows)
return;
const dim_t seg_end = (i + 1) * seg_len;
for (dim_t j = 0; j < num_rows_l; ++j) {
if (indptr_l[j] == indptr_l[j + 1])
continue;
const dim_t offset_r = j * num_cols;
for (IType k = indptr_l[j]; k < indptr_l[j + 1]; ++k) {
const CType col_idx = col_idx_l[k];
if (col_idx < seg_start || col_idx >= seg_end)
continue;
const dim_t offset_out = col_idx * num_cols;
const DType val = data_l[k];
for (dim_t l = 0; l < num_cols; ++l) {
out[offset_out + l] += data_r[offset_r + l] * val;
}
}
}
}
};
/*!
* \brief CPU Kernel of dot(csr.T(), dns) = rsp
* Parallelization by row blocks which evenly partition the non-zero rows.
*/
struct DotCsrTransDnsRspByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template <typename DType, typename IType, typename CType, typename RType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
nnvm::dim_t* row_flg_sum,
RType* row_idx,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t nnr,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= nnr)
return;
const dim_t seg_end = (i + 1) * seg_len;
const dim_t col_start = row_idx[seg_start];
const dim_t col_end = seg_end >= nnr ? (row_idx[nnr - 1] + 1) : row_idx[seg_end];
for (dim_t j = 0; j < num_rows_l; ++j) {
if (indptr_l[j] == indptr_l[j + 1])
continue;
const dim_t offset_r = j * num_cols;
for (IType k = indptr_l[j]; k < indptr_l[j + 1]; ++k) {
const CType col_idx = col_idx_l[k];
if (col_idx < col_start || col_idx >= col_end)
continue;
const nnvm::dim_t rsp_row = row_flg_sum[col_idx] - 1;
const nnvm::dim_t offset_out = rsp_row * num_cols;
const DType val = data_l[k];
for (dim_t l = 0; l < num_cols; ++l) {
out[offset_out + l] += data_r[offset_r + l] * val;
}
}
}
}
};
/*!
* \brief CPU Kernel of dot(csr, rsp) = dns
* Parallelization by row blocks
*/
struct DotCsrRspDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param nnr_r storage_shape[0] of the rsp
* \param num_rows dns.shape[0]
* \param num_cols dns.shape[1]
*/
template <typename DType, typename IType, typename CType, typename RType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const RType* row_idx_r,
const nnvm::dim_t nnr_r,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols,
const nnvm::dim_t seg_len) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows)
return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows);
for (dim_t j = seg_start; j < seg_end; ++j) {
if (indptr_l[j] == indptr_l[j + 1])
continue;
const dim_t offset_out = j * num_cols;
// Use binary search to find the lower_bound of val in row_idx array
const RType* first = row_idx_r;
const RType* last = row_idx_r + nnr_r;
const CType val = col_idx_l[indptr_l[j]];
const RType* it;
int count = last - first, step;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (*it < val) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
const RType* row_idx_ptr = first;
// end of binary search
if (row_idx_ptr == row_idx_r + nnr_r || *row_idx_ptr > col_idx_l[indptr_l[j + 1] - 1])
continue;
for (IType k = indptr_l[j]; k < indptr_l[j + 1] && row_idx_ptr != row_idx_r + nnr_r;) {
if (col_idx_l[k] == *row_idx_ptr) {
const dim_t offset_r = (row_idx_ptr - row_idx_r) * num_cols;
for (dim_t l = 0; l < num_cols; ++l) {
out[offset_out + l] += data_l[k] * data_r[offset_r + l];
}
++k;
++row_idx_ptr;
} else if (col_idx_l[k] < *row_idx_ptr) {
++k;
} else {
++row_idx_ptr;
}
}
}
}
};
/*!
* \brief CPU Kernel of dot(csr.T(), rsp1) = rsp2, with row_idx marked for non-zero rows
* Parallelization by row blocks
*/
struct DotCsrTransRspRspByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param num_rows_l number of rows of lhs matrix
* \param nnr_r number of non-zero rows of rhs matrix
* \param num_rows number of rows of out matrix
* \param num_cols number of cols of out matrix
*/
template <typename DType, typename IType, typename CType, typename RType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
RType* row_idx_out,
const DType* data_l,
const IType* indptr_l,
const CType* col_idx_l,
const DType* data_r,
const RType* row_idx_r,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t nnr_r,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols,
const nnvm::dim_t seg_len) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows)
return;
const dim_t seg_end = (i + 1) * seg_len;
for (dim_t rid = 0; rid < nnr_r; ++rid) {
const RType j = row_idx_r[rid];
if (indptr_l[j] == indptr_l[j + 1])
continue;
const dim_t offset_r = rid * num_cols;
for (IType k = indptr_l[j]; k < indptr_l[j + 1]; ++k) {
const CType col_idx = col_idx_l[k];
if (col_idx < seg_start || col_idx >= seg_end)
continue;
row_idx_out[col_idx] = 1; // mark nonzero row as 1
const dim_t offset_out = col_idx * num_cols;
for (dim_t l = 0; l < num_cols; ++l) {
out[offset_out + l] += data_r[offset_r + l] * data_l[k];
}
}
}
}
};
/*!
* \brief CPU Kernel of PopulateCsrForNNC
* Parallelization by individual rows
* Populates the indptr and indices array
* based on number of non zero columns
*/
struct PopulateCsrForNNC {
/*!
* \brief
* \param i the i-th thread
* \param nnc_idx all non zero column indexes
* \param indptr_out indptr array for output
* \param col_idx_out column indices for output
* \param nnc number of non zero columns in the output
* \param num_rows_l number of rows in lhs
*/
template <typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
const CType* nnc_idx,
IType* indptr_out,
CType* col_idx_out,
const nnvm::dim_t nnc,
const nnvm::dim_t num_rows_l) {
const CType start_idx = i * nnc;
nnvm::dim_t cur = 0;
indptr_out[i] = start_idx;
if (static_cast<nnvm::dim_t>(i) == (num_rows_l - 1))
indptr_out[i + 1] = indptr_out[i] + nnc;
for (IType idx = start_idx; idx < (start_idx + nnc); idx++) {
col_idx_out[idx] = nnc_idx[cur++];
}
}
};
/*!
* \brief CPU Impl of dot(dns, csr) = csr
*/
struct DotDnsCsrCsrByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param num_rows_r number of rows in rhs
* \param num_rows_l number of rows in lhs
* \param num_cols number of columns in output
* \param nnc number of non zero columns
*/
template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const IType* indptr_r,
const CType* col_idx_r,
const DType* data_r,
const nnvm::dim_t seg_len,
const IType num_rows_r,
const IType num_rows_l,
const nnvm::dim_t num_cols,
const nnvm::dim_t nnc,
const CType* prefix_sum) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l)
return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = seg_start; j < seg_end; j++) {
for (dim_t k = 0; k < num_rows_r; k++) {
const dim_t working_idx = j * num_rows_r + k;
const DType val = data_l[working_idx];
if (indptr_r[k] == indptr_r[k + 1])
continue;
const dim_t row_start = j * nnc;
for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) {
dim_t cur_col_idx_r = col_idx_r[cur];
const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1;
out[out_idx] += val * data_r[cur];
}
}
}
}
};
/*!
* \brief CPU Kernel of dot(dns1, csr) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param out output matrix
* \param data_l data of lhs
* \param data_r values of csr
* \param indptr_r row offsets of csr
* \param col_idx_r column indices of csr
* \param seg_len workload of this thread
* \param num_rows_l number of rows in lhs
* \param num_cols_l number of columns in lhs
* \param num_rows_r number of rows in rhs
* \param num_cols_r number of columns in rhs
*/
template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l)
return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j + 1])
continue;
for (IType k = indptr_r[j]; k < indptr_r[j + 1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r * num_cols_r + col_idx] += data_l[r * num_cols_l + j] * val;
}
}
}
}
};
/*!
* \brief CPU Kernel of dot(dns1, csr.T) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrTransDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param out output matrix
* \param data_l data of lhs
* \param data_r values of csr
* \param indptr_r row offsets of csr
* \param col_idx_r column indices of csr
* \param seg_len workload of this thread
* \param num_rows_l number of rows in lhs
* \param num_cols_l number of columns in lhs
* \param num_rows_r number of rows in rhs
* \param num_cols_r number of columns in rhs
*/
template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l)
return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j + 1])
continue;
for (IType k = indptr_r[j]; k < indptr_r[j + 1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r * num_rows_r + j] += data_l[r * num_cols_l + col_idx] * val;
}
}
}
}
};
/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
*/
inline void DotCsrDnsDnsImpl(const OpContext& ctx,
const cpu& cpu_dev,
const NDArray& lhs,
const TBlob& rhs,
const OpReqType req,
const bool trans_lhs,
TBlob* ret) {
if (kNullOp == req)
return;
CHECK_EQ(lhs.storage_type(), kCSRStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!lhs.storage_initialized()) {
Fill(s, *ret, req, 0);
return;
}
using nnvm::dim_t;
const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob& data_r = rhs;
const TBlob data_out = *ret;
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
dim_t num_threads;
if (kWriteTo == req) {
num_threads = data_out.Size();
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_threads, data_out.dptr<DType>());
}
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
bool dynamic = false;
const dim_t large_matrix_threshold = 1024 * 10;
if (data_out.shape_[0] > large_matrix_threshold) {
dynamic = true;
// each unit of work processes at least 1024 elements in the output
const dim_t unit_work_per_thread = 1024;
num_threads = data_out.Size() / unit_work_per_thread;
}
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (trans_lhs) {
mxnet_op::Kernel<DotCsrTransDnsDnsByRowBlocks, cpu>::Launch(s,
num_threads,
data_out.dptr<DType>(),
data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(),
data_r.dptr<DType>(),
seg_len,
lhs.shape()[0],
data_out.shape_[0],
data_out.shape_[1]);
} else {
if (dynamic) {
mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, cpu>::LaunchDynamic(s,
num_threads,
data_out.dptr<DType>(),
data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(),
data_r.dptr<DType>(),
seg_len,
data_out.shape_[0],
data_out.shape_[1]);
} else {
mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, cpu>::Launch(s,
num_threads,
data_out.dptr<DType>(),
data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(),
data_r.dptr<DType>(),
seg_len,
data_out.shape_[0],
data_out.shape_[1]);
}
}
});
});
});
}
/*!
* \brief CPU Impl of dot(csr.T, dns) = rsp
*/
inline void DotCsrDnsRspImpl(const OpContext& ctx,
const cpu& cpu_dev,
const NDArray& lhs,
const TBlob& rhs,
const OpReqType req,
const bool trans_lhs,
NDArray* ret) {
if (kNullOp == req)
return;
CHECK_EQ(lhs.storage_type(), kCSRStorage);
CHECK_EQ(ret->storage_type(), kRowSparseStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!lhs.storage_initialized()) {
FillZerosRspImpl(s, *ret);
return;
}
CHECK_EQ(req, kWriteTo);
using namespace mxnet_op;
using nnvm::dim_t;
const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob& data_r = rhs;
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(ret->aux_type(rowsparse::kIdx), RType, { // row idx type
const dim_t num_rows = lhs.shape()[1];
size_t workspace_size = num_rows * sizeof(dim_t);
mshadow::Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(mshadow::Shape1(workspace_size), s);
dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
// prefix sum array re-uses the row_flg array temp space
dim_t* prefix_sum = row_flg;
Kernel<set_zero, cpu>::Launch(s, num_rows, row_flg);
Kernel<MarkRowFlgKernel, cpu>::Launch(
s, lhs.aux_shape(csr::kIdx)[0], row_flg, col_idx_l.dptr<CType>());
prefix_sum[0] = row_flg[0];
for (nnvm::dim_t i = 1; i < num_rows; i++) {
prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
}
dim_t nnr = prefix_sum[num_rows - 1];
if (nnr == 0) {
FillZerosRspImpl(s, *ret);
return;
}
ret->CheckAndAlloc({mshadow::Shape1(nnr)});
const TBlob& data_out = ret->data();
const TBlob& row_idx = ret->aux_data(rowsparse::kIdx);
dim_t num_threads = data_out.Size();
mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, data_out.dptr<DType>());
RType* row_idx_out = row_idx.dptr<RType>();
mxnet_op::Kernel<FillRspRowIdxKernel, cpu>::Launch(
s, num_rows, row_idx_out, prefix_sum, num_rows);
num_threads = mxnet_op::get_num_threads<cpu>(nnr);
dim_t seg_len = (nnr + num_threads - 1) / num_threads;
if (trans_lhs) {
mxnet_op::Kernel<DotCsrTransDnsRspByRowBlocks, cpu>::Launch(s,
num_threads,
data_out.dptr<DType>(),
prefix_sum,
row_idx_out,
data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(),
data_r.dptr<DType>(),
seg_len,
lhs.shape()[0],
nnr,
ret->shape()[1]);
} else {
LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet.";
}
});
});
});
});
}
/*!
* \brief CPU Impl of dot(csr, rsp) = dns
*/
inline void DotCsrRspDnsImpl(const OpContext& ctx,
const cpu& cpu_dev,
const NDArray& lhs,
const NDArray& rhs,
const OpReqType req,
const bool trans_lhs,
TBlob* ret) {
if (kNullOp == req)
return;
// reuse csr dns implementation when storage_shape == shape for rhs
if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense
DotCsrDnsDnsImpl(ctx, cpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
return;
}
CHECK_EQ(lhs.storage_type(), kCSRStorage);
CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
if (kWriteTo == req) {
MSHADOW_SGL_DBL_TYPE_SWITCH(ret->type_flag_, DType, { // data type
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, ret->Size(), ret->dptr<DType>());
});
}
return;
}
using nnvm::dim_t;
const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob data_r = rhs.data();
const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx);
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // row idx type
dim_t num_threads;
if (kWriteTo == req) {
num_threads = ret->Size();
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_threads, ret->dptr<DType>());
}
num_threads = mxnet_op::get_num_threads<cpu>(ret->shape_[0]);
dim_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads;
if (trans_lhs) {
LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet";
} else {
mxnet_op::Kernel<DotCsrRspDnsByRowBlocks, cpu>::Launch(s,
num_threads,
ret->dptr<DType>(),
data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(),
data_r.dptr<DType>(),
row_idx_r.dptr<RType>(),
rhs.storage_shape()[0],
ret->shape_[0],
ret->shape_[1],
seg_len);
}
});
});
});
});
}
/*!
* \brief CPU Impl of dot(csr.T, rsp1) = rsp2
*/
inline void DotCsrRspRspImpl(const OpContext& ctx,
const cpu& cpu_dev,
const NDArray& lhs,
const NDArray& rhs,
const OpReqType req,
const bool trans_lhs,
NDArray* ret) {
if (kNullOp == req)
return;
// reuse csr dns implementation when storage_shape == shape for rhs
if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense
DotCsrDnsRspImpl(ctx, cpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
return;
}
CHECK_EQ(lhs.storage_type(), kCSRStorage);
CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
CHECK_EQ(ret->storage_type(), kRowSparseStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
FillZerosRspImpl(s, *ret);
return;
}
CHECK_EQ(req, kWriteTo);
using mxnet_op::set_zero;
using nnvm::dim_t;
const TBlob data_l = lhs.data();
const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
const TBlob data_r = rhs.data();
const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx);
// pre-allocate spaces for ret using the dense dimension size
if (ret->storage_type() == kRowSparseStorage) {
ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])});
}
const TBlob data_out = ret->data();
const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx);
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // row idx type
dim_t num_threads = data_out.Size();
mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, data_out.dptr<DType>());
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (trans_lhs) {
RType* row_idx = row_idx_out.dptr<RType>();
num_threads = row_idx_out.Size();
mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, row_idx);
mxnet_op::Kernel<DotCsrTransRspRspByRowBlocks, cpu>::Launch(s,
num_threads,
data_out.dptr<DType>(),
row_idx,
data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(),
data_r.dptr<DType>(),
row_idx_r.dptr<RType>(),
lhs.shape()[0],
rhs.storage_shape()[0],
ret->shape()[0],
ret->shape()[1],
seg_len);
dim_t nnr = 0;
nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
if (0 == nnr) {
FillZerosRspImpl(s, *ret);
return;
}
ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, DType>(s);
dim_t idx = 0;
for (index_t i = 0; i < ret->shape()[0]; ++i) {
if (row_idx[i] > 0) {
row_idx[idx] = i;
mshadow::Copy(rsp_data[idx], rsp_data[i], s);
++idx;
}
}
} else {
LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr, rsp) = rsp2 yet";
}
});
});
});
});
}
/*
* \brief Impl of dot(dns, csr) = csr
*/
inline void DotDnsCsrCsrImpl(const OpContext& ctx,
const cpu& cpu_dev,
const TBlob& lhs,
const NDArray& rhs,
const OpReqType req,
NDArray* ret) {
if (kNullOp == req)
return;
CHECK_EQ(req, kWriteTo);
CHECK_EQ(rhs.storage_type(), kCSRStorage);
using namespace mshadow;
using namespace mshadow::expr;
using nnvm::dim_t;
/* Initialize data structures */
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
const NDArray& out = *ret;
const TBlob data_l = lhs;
const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}
MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type
/* Allocate workspace */
CType num_cols_out = out.shape()[1];
CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size());
size_t workspace_size = 2 * num_cols_out * sizeof(CType);
Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_size), s);
CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
CType* prefix_sum = col_flg;
CType* nnc_idx = prefix_sum + num_cols_out;
/* Set the column flags for nnz columns */
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out, col_flg);
mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(
s, rhs_data_size, col_flg, col_idx_r.dptr<CType>());
/* 1. Calculate prefix sum from col flgs
* 2. Storage all non zero column indexes in nnc_idx
*/
CType cur = 0;
prefix_sum[0] = col_flg[0];
if (prefix_sum[0])
nnc_idx[cur++] = 0;
for (CType i = 1; i < num_cols_out; i++) {
prefix_sum[i] = prefix_sum[i - 1] + col_flg[i];
if (prefix_sum[i] > prefix_sum[i - 1])
nnc_idx[cur++] = i;
}
/* Allocate aux data for out */
IType num_rows_l = lhs.shape_[0];
dim_t nnc = prefix_sum[num_cols_out - 1];
dim_t nnz = nnc * num_rows_l;
out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1));
out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
/* Set csr indptr and index according to nnc_idx*/
IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>();
CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>();
DType* data_out = out.data().dptr<DType>();
mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch(
s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);
const dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l);
const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads;
IType num_rows_r = rhs.shape()[0];
mxnet_op::Kernel<DotDnsCsrCsrByRowBlocks, cpu>::Launch(s,
num_threads,
data_out,
data_l.dptr<DType>(),
indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(),
data_r.dptr<DType>(),
seg_len,
num_rows_r,
num_rows_l,
num_cols_out,
nnc,
prefix_sum);
});
});
});
}
/*
* \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
*/
inline void DotDnsCsrDnsImpl(const OpContext& ctx,
const cpu& cpu_dev,
const TBlob& dns,
const NDArray& rhs,
const OpReqType req,
NDArray* ret,
const bool transpose_b) {
if (req == kNullOp)
return;
CHECK_EQ(rhs.storage_type(), kCSRStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!rhs.storage_initialized()) {
Fill(s, ret->data(), req, 0);
return;
}
using nnvm::dim_t;
const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
const TBlob& data_l = dns;
const TBlob data_out = ret->data();
MSHADOW_REAL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type
dim_t num_threads;
if (req == kWriteTo || req == kWriteInplace) {
num_threads = data_out.Size();
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_threads, data_out.dptr<DType>());
}
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
// seg by output row
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (transpose_b) {
mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s,
num_threads,
data_out.dptr<DType>(),
data_l.dptr<DType>(),
data_r.dptr<DType>(),
indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(),
seg_len,
dns.shape_[0],
dns.shape_[1],
rhs.shape()[0],
rhs.shape()[1]);
} else {
mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s,
num_threads,
data_out.dptr<DType>(),
data_l.dptr<DType>(),
data_r.dptr<DType>(),
indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(),
seg_len,
dns.shape_[0],
dns.shape_[1],
rhs.shape()[0],
rhs.shape()[1]);
}
});
});
});
}
inline bool DotShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
if (!ndim_is_known(lshape) || !ndim_is_known(rshape))
return false;
CHECK_GT(lshape.ndim(), 0) << "scalar tensor is not supported by this operator.";
CHECK_GT(rshape.ndim(), 0) << "scalar tensor is not supported by this operator.";
if (lshape.ndim() == 1 && rshape.ndim() == 1) {
CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors";
CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1));
} else {
bool Ta = param.transpose_a, Tb = param.transpose_b;
mxnet::TShape L[2], R[2];
if (Ta) {
L[0] = mshadow::Shape1(lshape[0]);
L[1] = lshape.ndim() > 1 ? mxnet::TShape(&lshape[1], lshape.end()) : mxnet::TShape(1, 1);
} else {
L[0] = lshape.ndim() > 1 ? mxnet::TShape(&lshape[0], &lshape[lshape.ndim() - 1]) :
mxnet::TShape(1, 1);
L[1] = mshadow::Shape1(lshape[lshape.ndim() - 1]);
}
if (Tb) {
R[0] = rshape.ndim() > 1 ? mxnet::TShape(&rshape[0], &rshape[rshape.ndim() - 1]) :
mxnet::TShape(1, 1);
R[1] = mshadow::Shape1(rshape[rshape.ndim() - 1]);
} else {
R[0] = mshadow::Shape1(rshape[0]);
R[1] = rshape.ndim() > 1 ? mxnet::TShape(&rshape[1], rshape.end()) : mxnet::TShape(1, 1);
}
if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) {
CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) << "dot shape error: " << lshape << " X " << rshape;
}
std::vector<index_t> buf;
if (lshape.ndim() > 1)
buf.insert(buf.end(), &L[Ta][0], L[Ta].end());
if (rshape.ndim() > 1)
buf.insert(buf.end(), &R[!Tb][0], R[!Tb].end());
mxnet::TShape oshape(buf.begin(), buf.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
}
// return true if output shape is fully inferred
return shape_is_known((*out_attrs)[0]);
}
#if MXNET_USE_ONEDNN == 1
void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
#endif
template <typename xpu>
void DotForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
#if MXNET_USE_ONEDNN == 1
if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
if (std::is_same<xpu, cpu>::value) {
DotForwardExDNNL(attrs, ctx, inputs, req, outputs);
} else {
FallBackCompute(DotForward_<gpu>, attrs, ctx, inputs, req, outputs);
}
return;
}
#endif
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_EQ(inputs[DotIn::lhs].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs";
CHECK_EQ(inputs[DotIn::rhs].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs";
auto lhs_stype = inputs[DotIn::lhs].storage_type();
auto rhs_stype = inputs[DotIn::rhs].storage_type();
auto out_stype = outputs[DotOut::out].storage_type();
if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage &&
!param.transpose_b) {
TBlob ret = outputs[DotOut::out].data();
DotCsrDnsDnsImpl(ctx,
xpu(),
inputs[DotIn::lhs],
inputs[DotIn::rhs].data(),
req[DotOut::out],
param.transpose_a,
&ret);
} else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage &&
out_stype == kDefaultStorage && !param.transpose_b) {
TBlob ret = outputs[DotOut::out].data();
DotCsrRspDnsImpl(ctx,
xpu(),
inputs[DotIn::lhs],
inputs[DotIn::rhs],
req[DotOut::out],
param.transpose_a,
&ret);
} else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage &&
out_stype == kRowSparseStorage && !param.transpose_b) {
NDArray out = outputs[DotOut::out];
DotCsrDnsRspImpl(ctx,
xpu(),
inputs[DotIn::lhs],
inputs[DotIn::rhs].data(),
req[DotOut::out],
param.transpose_a,
&out);
} else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage &&
out_stype == kRowSparseStorage && !param.transpose_b) {
NDArray ret = outputs[DotOut::out];
DotCsrRspRspImpl(ctx,
xpu(),
inputs[DotIn::lhs],
inputs[DotIn::rhs],
req[DotOut::out],
param.transpose_a,
&ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && out_stype == kCSRStorage &&
!(param.transpose_a || param.transpose_b)) {
NDArray ret = outputs[DotOut::out];
DotDnsCsrCsrImpl(
ctx, xpu(), inputs[DotIn::lhs].data(), inputs[DotIn::rhs], req[DotOut::out], &ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
out_stype == kDefaultStorage && !(param.transpose_a)) {
NDArray ret = outputs[DotOut::out];
DotDnsCsrDnsImpl(ctx,
xpu(),
inputs[DotIn::lhs].data(),
inputs[DotIn::rhs],
req[DotOut::out],
&ret,
param.transpose_b);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
template <typename xpu>
void DotBackwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req.size(), 2U);
CHECK(!(req[0] != kNullOp && outputs[0].storage_type() == kCSRStorage))
<< "sparse dot does not support computing the gradient of csr";
CHECK(!(req[1] != kNullOp && outputs[1].storage_type() == kCSRStorage))
<< "sparse dot does not support computing the gradient of csr";
CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace";
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs";
CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs";
const auto ograd_stype = inputs[0].storage_type();
const auto lhs_stype = inputs[1].storage_type();
const auto rhs_stype = inputs[2].storage_type();
const auto grad_rhs_stype = outputs[1].storage_type();
if (ograd_stype == kDefaultStorage // ograd dns format
&& lhs_stype == kCSRStorage // csr input lhs of the op
&& grad_rhs_stype == kDefaultStorage && !param.transpose_b) { // grad(rhs) dns format
TBlob ret = outputs[1].data();
DotCsrDnsDnsImpl(ctx, xpu(), inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret);
} else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage &&
grad_rhs_stype == kRowSparseStorage && !param.transpose_b) {
NDArray ret = outputs[1];
DotCsrDnsRspImpl(ctx, xpu(), inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret);
} else if (ograd_stype == kRowSparseStorage && lhs_stype == kCSRStorage &&
grad_rhs_stype == kRowSparseStorage && !param.transpose_b) {
NDArray ret = outputs[1];
DotCsrRspRspImpl(ctx, xpu(), inputs[1], inputs[0], req[1], !param.transpose_a, &ret);
} else if (ograd_stype == kRowSparseStorage && lhs_stype == kCSRStorage &&
grad_rhs_stype == kDefaultStorage && !param.transpose_b) {
TBlob ret = outputs[1].data();
DotCsrRspDnsImpl(ctx, xpu(), inputs[1], inputs[0], req[1], !param.transpose_a, &ret);
} else if (ograd_stype == kDefaultStorage && // ograd dns format
lhs_stype == kDefaultStorage && // lhs dns format
rhs_stype == kCSRStorage && !param.transpose_a) {
NDArray ret = outputs[0];
DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[2], req[0], &ret, !param.transpose_b);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
template <typename xpu>
void BatchDotForward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
if (req[DotOut::out] == kNullOp)
return;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::lhs].type_flag_)
<< "Binary function only support input/output with the same type";
CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::rhs].type_flag_)
<< "Binary function only support input/output with the same type";
CHECK(outputs[DotOut::out].type_flag_ == kFloat32 ||
outputs[DotOut::out].type_flag_ == kFloat64 ||
(outputs[DotOut::out].type_flag_ == kFloat16 &&
ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
MSHADOW_REAL_TYPE_SWITCH(outputs[DotOut::out].type_flag_, DType, {
int ndim = outputs[DotOut::out].ndim();
if (outputs[DotOut::out].shape_.Size() == 0 || inputs[DotIn::lhs].shape_.Size() == 0 ||
inputs[DotIn::rhs].shape_.Size() == 0) {
if (outputs[DotOut::out].shape_.Size() != 0 && req[DotOut::out] != kAddTo) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, outputs[DotOut::out].shape_.Size(), outputs[DotOut::out].dptr<DType>());
}
return;
}
size_t batch_size = outputs[DotOut::out].shape_.ProdShape(0, ndim - 2);
mshadow::Tensor<xpu, 3, DType> out = outputs[DotOut::out].get_with_shape<xpu, 3, DType>(
Shape3(batch_size,
outputs[DotOut::out].shape_[ndim - 2],
outputs[DotOut::out].shape_[ndim - 1]),
s);
mshadow::Tensor<xpu, 3, DType> mlhs = inputs[DotIn::lhs].get_with_shape<xpu, 3, DType>(
Shape3(
batch_size, inputs[DotIn::lhs].shape_[ndim - 2], inputs[DotIn::lhs].shape_[ndim - 1]),
s);
mshadow::Tensor<xpu, 3, DType> mrhs = inputs[DotIn::rhs].get_with_shape<xpu, 3, DType>(
Shape3(
batch_size, inputs[DotIn::rhs].shape_[ndim - 2], inputs[DotIn::rhs].shape_[ndim - 1]),
s);
mshadow::Tensor<xpu, 1, DType*> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType*>(mshadow::Shape1(3 * out.size(0)), s);
if (param.transpose_a && param.transpose_b) {
mshadow::BatchGEMM<true, true>(out,
mlhs,
mrhs,
(DType)1.0f,
(kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else if (!param.transpose_a && param.transpose_b) {
mshadow::BatchGEMM<false, true>(out,
mlhs,
mrhs,
(DType)1.0f,
(kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else if (param.transpose_a && !param.transpose_b) {
mshadow::BatchGEMM<true, false>(out,
mlhs,
mrhs,
(DType)1.0f,
(kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else {
mshadow::BatchGEMM<false, false>(out,
mlhs,
mrhs,
(DType)1.0f,
(kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
workspace);
}
});
}
template <typename ParamType>
inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
// return false if lhs and rhs both have fully unknown shape
if (!ndim_is_known(lshape) || !ndim_is_known(rshape))
return false;
if (lshape.ndim() >= 3 && rshape.ndim() >= 3 && lshape.ndim() == rshape.ndim()) {
int ndim = lshape.ndim();
// only partially infer shape if last dim of lhs and second dim of rhs is known
bool last_dim_known = dim_size_is_known(lshape, ndim - 1);
bool second_dim_known = dim_size_is_known(rshape, ndim - 2);
if (!last_dim_known || !second_dim_known)
return false;
for (int i = 0; i < ndim - 2; i++) {
CHECK_EQ(lshape[i], rshape[i])
<< "batch_dot shape error (the leading batch dimensions must be equal): " << lshape
<< " X " << rshape << " trans_a=" << param.transpose_a
<< " trans_b=" << param.transpose_b;
}
dim_t out_m = param.transpose_a ? lshape[ndim - 1] : lshape[ndim - 2];
dim_t lshape_k = param.transpose_a ? lshape[ndim - 2] : lshape[ndim - 1];
dim_t out_n = param.transpose_b ? rshape[ndim - 2] : rshape[ndim - 1];
dim_t rshape_k = param.transpose_b ? rshape[ndim - 1] : rshape[ndim - 2];
CHECK_EQ(lshape_k, rshape_k) << "batch_dot shape error (shape mismatch): " << lshape << " X "
<< rshape << " trans_a=" << param.transpose_a
<< " trans_b=" << param.transpose_b;
std::vector<dim_t> out_shape_vec;
for (int i = 0; i < ndim - 2; i++) {
out_shape_vec.push_back(lshape[i]);
}
out_shape_vec.push_back(out_m);
out_shape_vec.push_back(out_n);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(out_shape_vec));
} else {
LOG(FATAL) << "batch_dot currently only support N-D*N-D array (N >= 3)" << lshape << " v.s. "
<< rshape;
}
// return true if output shape is fully inferred
return shape_is_known((*out_attrs)[0]);
}
} // namespace op
} // namespace mxnet
namespace std {
template <>
struct hash<mxnet::op::DotParam> {
size_t operator()(const mxnet::op::DotParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.transpose_a);
ret = dmlc::HashCombine(ret, val.transpose_b);
ret = dmlc::HashCombine(ret, val.forward_stype);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_