| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file elemwise_binary_op-inl.h |
| * \brief Function definition of elementwise binary operators |
| */ |
| #ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_INL_H_ |
| #define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_INL_H_ |
| |
| #include <vector> |
| #include <algorithm> |
| #include "./elemwise_binary_op.h" |
| #include "../mxnet_op.h" |
| #define WARP_SIZE 32 |
| #define WARP_SIZE_BITS 5 |
| |
| namespace mxnet { |
| namespace op { |
| |
| /*! \brief binary op handling for the following row sparse inputs/outputs |
| rsp, rsp -> rsp, |
| dns, rsp -> rsp, |
| rsp, dns -> rsp, |
| dns, rsp -> dns, |
| rsp, dns -> dns, |
| */ |
| template<typename OP> |
| void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s, |
| const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const NDArray &lhs, |
| const NDArray &rhs, |
| const OpReqType req, |
| const NDArray &output, |
| const bool lhs_may_be_dense, |
| const bool rhs_may_be_dense, |
| const bool allow_inplace, |
| const bool scatter) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| const NDArray& rsp = lhs.storage_type() == kRowSparseStorage ? lhs : rhs; |
| const bool is_dense_result = output.storage_type() == kDefaultStorage; |
| const bool lhs_is_dense = lhs.storage_type() == kDefaultStorage; |
| const bool rhs_is_dense = rhs.storage_type() == kDefaultStorage; |
| CHECK(!lhs_is_dense || lhs_may_be_dense) << "rvalue cannot be dense"; |
| CHECK(!rhs_is_dense || rhs_may_be_dense) << "rvalue cannot be dense"; |
| CHECK(!lhs_is_dense || !rhs_is_dense); |
| MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { |
| MSHADOW_TYPE_SWITCH(output.dtype(), DType, { |
| // Only one item at most may be dense (lhs, rhs or result) |
| if (rhs_is_dense) { |
| // For right-side dense, in order to have sparse output, lhs input zero should |
| // always output zero |
| CHECK(std::fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < 1e-4f); |
| CHECK(!is_dense_result); // Currently not handled |
| } |
| if (lhs_is_dense) { |
| // For left-side dense, in order to have sparse output, lhs input zero should |
| // always output zero |
| CHECK(std::fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < 1e-4f); |
| CHECK(!is_dense_result); // Currently not handled |
| } |
| |
| // Memory Estimation: This is (roughly) the number of result rows. We may still |
| // need to subtract the number of common rows |
| bool lhs_in_place = false, rhs_in_place = false; |
| const size_t num_rows_l = lhs_is_dense ? lhs.shape()[0] : |
| lhs.aux_shape(rowsparse::kIdx).Size(); |
| const size_t num_rows_r = rhs_is_dense ? rhs.shape()[0] : |
| rhs.aux_shape(rowsparse::kIdx).Size(); |
| if (is_dense_result) { |
| output.CheckAndAlloc(); |
| } else { |
| if (rhs_is_dense || scatter) { |
| output.CheckAndAlloc({mshadow::Shape1(num_rows_l)}); |
| } else if (lhs_is_dense) { |
| output.CheckAndAlloc({mshadow::Shape1(num_rows_r)}); |
| } else { |
| lhs_in_place = IsSameArray(lhs, output); |
| rhs_in_place = IsSameArray(rhs, output); |
| if (!lhs_in_place && !rhs_in_place) { |
| output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)}); |
| } else { |
| CHECK_EQ(allow_inplace, true); |
| CHECK_EQ(is_dense_result, false); |
| if (lhs_in_place) { |
| // For in-place, zero L-value must always be zero output |
| DCHECK(std::fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < DType(1e-3)); |
| } else { |
| // For in-place, zero R-value must always be zero output |
| DCHECK(std::fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < DType(1e-3)); |
| } |
| } |
| } |
| } |
| |
| // Indices |
| const Tensor<cpu, 1, IType> indices_l = lhs_is_dense ? |
| Tensor<cpu, 1, IType>() : |
| lhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s); |
| const Tensor<cpu, 1, IType> indices_r = rhs_is_dense ? |
| Tensor<cpu, 1, IType>() : |
| rhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s); |
| Tensor<cpu, 1, IType> indices_out = is_dense_result ? |
| Tensor<cpu, 1, IType>() : |
| output.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s); |
| |
| // Data |
| // TODO(cjolivier01): Change to get_with_shape() calls |
| const Tensor<cpu, 2, DType> data_l = AsRowise2D<DType>(s, lhs.data()); |
| const Tensor<cpu, 2, DType> data_r = AsRowise2D<DType>(s, rhs.data()); |
| Tensor<cpu, 2, DType> out = AsRowise2D<DType>(s, output.data()); |
| |
| size_t iter_l = 0; |
| size_t iter_r = 0; |
| size_t iter_out = 0; |
| int32_t num_common_rows = 0; |
| |
| if (is_dense_result) { |
| if (!num_rows_l && !num_rows_r) { |
| const size_t all_rows = static_cast<size_t>(lhs.shape()[0]); |
| iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out); |
| } |
| } |
| |
| while (iter_l < num_rows_l && iter_r < num_rows_r) { |
| IType idx_l = lhs_is_dense ? indices_r[iter_r] : indices_l[iter_l]; |
| IType idx_r = rhs_is_dense ? idx_l : indices_r[iter_r]; |
| if (lhs_in_place) { |
| while (idx_r < idx_l && ++iter_r < num_rows_r) { |
| idx_r = indices_r[iter_r]; |
| } |
| if (iter_r >= num_rows_r) { |
| break; |
| } |
| } else if (rhs_in_place) { |
| while (idx_l < idx_r && ++iter_l < num_rows_l) { |
| idx_l = indices_l[iter_l]; |
| } |
| if (iter_l >= num_rows_l) { |
| break; |
| } |
| } |
| if (is_dense_result) { |
| iter_out = FillDense<DType, OP>(s, idx_l, idx_r, req, &out, iter_out); |
| DCHECK_EQ(iter_out, static_cast<size_t>(std::min(idx_l, idx_r))); |
| } |
| if (idx_l == idx_r) { |
| // Same row |
| if (!is_dense_result) { |
| indices_out[iter_out] = idx_l; |
| } |
| Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l]; |
| Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r]; |
| DCHECK_EQ(lvalue.shape_.Size(), rvalue.shape_.Size()); |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch( |
| s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_, rvalue.dptr_); |
| }); |
| num_common_rows++; |
| } else if (idx_l < idx_r) { |
| // Left only |
| if (!is_dense_result) { |
| indices_out[iter_out] = idx_l; |
| } |
| Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l]; |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch( |
| s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_); |
| }); |
| } else { |
| // Right only |
| if (scatter) { |
| ++iter_r; |
| continue; // skip '++iter_out' below |
| } |
| if (!is_dense_result) { |
| indices_out[iter_out] = idx_r; |
| } |
| Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r]; |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch( |
| s, rvalue.shape_.Size(), out[iter_out].dptr_, rvalue.dptr_); |
| }); |
| } |
| ++iter_out; |
| } |
| // Evaluate the remaining rows beyond the l and r value row intersetion |
| while (iter_l < num_rows_l && !lhs_is_dense && !rhs_in_place) { |
| if (!is_dense_result) { |
| indices_out[iter_out] = indices_l[iter_l]; |
| } else { |
| const IType idx_l = indices_l[iter_l]; |
| iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_l, req, &out, iter_out); |
| } |
| Tensor<cpu, 1, DType> lvalue = data_l[iter_l++]; |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch( |
| s, lvalue.shape_.Size(), out[iter_out++].dptr_, lvalue.dptr_); |
| }); |
| } |
| while (iter_r < num_rows_r && !rhs_is_dense && !lhs_in_place && !scatter) { |
| if (!is_dense_result) { |
| indices_out[iter_out] = indices_r[iter_r]; |
| } else { |
| const IType idx_r = indices_r[iter_r]; |
| iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_r, req, &out, iter_out); |
| } |
| Tensor<cpu, 1, DType> rvalue = data_r[iter_r++]; |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch( |
| s, rvalue.shape_.Size(), out[iter_out++].dptr_, rvalue.dptr_); |
| }); |
| } |
| if (is_dense_result) { |
| const size_t all_rows = static_cast<size_t>(lhs.shape()[0]); |
| iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out); |
| } else { |
| if (lhs_in_place) { |
| CHECK_LE(iter_out, num_rows_l); |
| } |
| if (rhs_in_place) { |
| CHECK_LE(iter_out, num_rows_r); |
| } |
| DCHECK_LE(iter_out, num_rows_l + num_rows_r); // Make sure that we didn't overrun |
| mxnet::TShape new_shape = output.aux_shape(rowsparse::kIdx); |
| CHECK_LE(iter_out, new_shape.Size()); |
| if (!rhs_is_dense && !lhs_is_dense && !lhs_in_place && !rhs_in_place && !scatter) { |
| // Reduce the first-dimension size by the number of common rows |
| new_shape[0] -= num_common_rows; |
| output.set_aux_shape(rowsparse::kIdx, new_shape); |
| } |
| } |
| }); |
| }); |
| } |
| |
| /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */ |
| template<typename OP> |
| void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<gpu> *s, |
| const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const NDArray &lhs, |
| const NDArray &rhs, |
| const OpReqType req, |
| const NDArray &output) { |
| LOG(FATAL) << "GPU not supported for CsrCsrOp"; |
| } |
| |
| /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */ |
| template<typename OP> |
| void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s, |
| const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const NDArray &lhs, |
| const NDArray &rhs, |
| const OpReqType req, |
| const NDArray &output) { |
| using namespace mshadow; |
| using namespace mxnet_op; |
| using namespace mshadow::expr; |
| |
| const auto nr_rows = static_cast<size_t>(lhs.shape()[0]); |
| if (!nr_rows) { |
| return; |
| } |
| CHECK_EQ(lhs.aux_shape(csr::kIndPtr).Size(), nr_rows + 1); |
| const size_t nr_cols = lhs.shape().Size() / nr_rows; |
| |
| CHECK_EQ(lhs.shape().Size(), rhs.shape().Size()); |
| |
| const bool same_lhs_rhs = IsSameArray(lhs, rhs); |
| |
| const size_t lhs_nnz = lhs.storage_shape().Size(); |
| const size_t rhs_nnz = rhs.storage_shape().Size(); |
| |
| const size_t output_nnz_guess = same_lhs_rhs ? lhs_nnz : lhs_nnz + rhs_nnz; |
| |
| output.CheckAndAlloc({mshadow::Shape1(lhs.shape()[0] + 1), |
| mshadow::Shape1(std::min(output_nnz_guess, lhs.shape().Size()))}); |
| DCHECK_EQ(output.aux_shape(csr::kIndPtr), lhs.aux_shape(csr::kIndPtr)); |
| |
| MSHADOW_IDX_TYPE_SWITCH(lhs.aux_type(csr::kIdx), IType, { |
| MSHADOW_IDX_TYPE_SWITCH(lhs.aux_type(csr::kIndPtr), CType, { |
| MSHADOW_TYPE_SWITCH(output.dtype(), DType, { |
| const size_t alloc_size = nr_cols * sizeof(IType) + 2 * nr_cols * sizeof(DType); |
| |
| Tensor<cpu, 1, uint8_t> workspace = |
| ctx.requested[ResourceRequestType::kTempSpace].get_space_typed<cpu, 1, uint8_t>( |
| mshadow::Shape1(alloc_size), s); |
| |
| // Allocate temp space and partition into three tensors |
| mshadow::Tensor<cpu, 1, IType> next(reinterpret_cast<IType *>(workspace.dptr_), |
| Shape1(nr_cols)); |
| mshadow::Tensor<cpu, 1, DType> lhs_row(reinterpret_cast<DType *>( |
| workspace.dptr_ + nr_cols * sizeof(IType)), |
| Shape1(nr_cols)); |
| mshadow::Tensor<cpu, 1, DType> rhs_row; |
| |
| OpBase::FillDense<IType>(s, next.shape_.Size(), IType(-1), req, next.dptr_); |
| OpBase::FillDense<DType>(s, lhs_row.shape_.Size(), DType(0), req, lhs_row.dptr_); |
| |
| if (!same_lhs_rhs) { |
| rhs_row = Tensor<cpu, 1, DType>(lhs_row.dptr_ + nr_cols, Shape1(nr_cols)); |
| OpBase::FillDense<DType>(s, rhs_row.shape_.Size(), DType(0), req, rhs_row.dptr_); |
| } else { |
| rhs_row = lhs_row; |
| } |
| |
| // Column indices |
| const Tensor<cpu, 1, IType> col_indices_l = lhs.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s); |
| const Tensor<cpu, 1, IType> col_indices_r = rhs.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s); |
| Tensor<cpu, 1, IType> col_indices_out = output.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s); |
| |
| // Row pointers |
| const Tensor<cpu, 1, CType> row_ptr_l = lhs.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s); |
| const Tensor<cpu, 1, CType> row_ptr_r = rhs.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s); |
| Tensor<cpu, 1, CType> row_ptr_out = output.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s); |
| |
| Tensor<cpu, 1, DType> data_l = lhs.data().FlatTo1D<cpu, DType>(s); |
| Tensor<cpu, 1, DType> data_r = rhs.data().FlatTo1D<cpu, DType>(s); |
| Tensor<cpu, 1, DType> data_out = output.data().FlatTo1D<cpu, DType>(s); |
| |
| IType nnz = 0; |
| row_ptr_out[0] = 0; |
| |
| for (IType i = 0; i < static_cast<IType>(nr_rows); i++) { |
| IType head = -2; |
| IType length = 0; |
| |
| // add a row of A to lhs_row |
| const IType i_start_l = row_ptr_l[i]; |
| const IType i_end_l = row_ptr_l[i + 1]; |
| for (IType jj = i_start_l; jj < i_end_l; jj++) { |
| IType col = col_indices_l[jj]; |
| lhs_row[col] += data_l[jj]; |
| |
| if (next[col] == -1) { |
| next[col] = head; |
| head = col; |
| ++length; |
| } |
| } |
| |
| if (!same_lhs_rhs) { |
| // add a row of B to rhs_row |
| const IType i_start_r = row_ptr_r[i]; |
| const IType i_end_r = row_ptr_r[i + 1]; |
| for (IType jj = i_start_r; jj < i_end_r; jj++) { |
| const IType col = col_indices_r[jj]; |
| rhs_row[col] += data_r[jj]; |
| |
| if (next[col] == -1) { |
| next[col] = head; |
| head = col; |
| ++length; |
| } |
| } |
| } |
| |
| // scan through columns where A or B has |
| // contributed a non-zero entry |
| for (IType jj = 0; jj < length; jj++) { |
| const DType result = OP::Map(lhs_row[head], rhs_row[head]); |
| |
| if (result != 0) { |
| col_indices_out[nnz] = head; |
| data_out[nnz] = result; |
| ++nnz; |
| } |
| |
| const IType temp = head; |
| head = next[head]; |
| |
| next[temp] = -1; |
| lhs_row[temp] = 0; |
| if (!same_lhs_rhs) rhs_row[temp] = 0; |
| } |
| |
| row_ptr_out[i + 1] = nnz; |
| } |
| }); |
| }); |
| }); |
| } |
| |
| /*! |
| * \brief Kernel for performing elemwise op between dense and csr matrix |
| * \param i global thread id |
| * \param req type of request |
| * \param out output array |
| * \param dns_data data array of dense input |
| * \param csr_data data array of csr input |
| * \param csr_indices indices array of csr input |
| * \param csr_indptr indptr array of csr input |
| * \param num_rows number of rows of both inputs |
| * \param num_cols number of columns of both inputs |
| */ |
| template<int req, typename OP> |
| struct ElemwiseDnsCsrDnsKernel { |
| template<typename DType, typename IType, typename CType> |
| MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data, |
| const DType* csr_data, const IType* csr_indices, |
| const CType* csr_indptr, const nnvm::dim_t num_rows, |
| const nnvm::dim_t num_cols) { |
| if (i < num_rows) { |
| for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) { |
| KERNEL_ASSIGN(out[i * num_cols + csr_indices[j]], req, |
| OP::Map(dns_data[i * num_cols + csr_indices[j]], csr_data[j])); |
| } |
| } |
| } |
| }; |
| |
| /*! |
| * \brief Kernel for performing elemwise op between dense and csr matrix |
| * \param tid global thread id |
| * \param req type of request |
| * \param out output array |
| * \param dns_data data array of dense input |
| * \param csr_data data array of csr input |
| * \param csr_indices indices array of csr input |
| * \param csr_indptr indptr array of csr input |
| * \param num_rows number of rows of both inputs |
| * \param num_cols number of columns of both inputs |
| */ |
| template<int req, typename OP> |
| struct ElemwiseDnsCsrDnsWarpKernel { |
| template<typename DType, typename IType, typename CType> |
| MSHADOW_XINLINE static void Map(int tid, DType* out, DType* dns_data, |
| const DType* csr_data, const IType* csr_indices, |
| const CType* csr_indptr, const nnvm::dim_t num_rows, |
| const nnvm::dim_t num_cols) { |
| if (tid < WARP_SIZE * num_rows) { |
| const int row_id = tid >> WARP_SIZE_BITS; |
| const int warp_id = tid & (WARP_SIZE - 1); |
| for (int j = csr_indptr[row_id] + warp_id; j < csr_indptr[row_id+1]; j += WARP_SIZE) { |
| KERNEL_ASSIGN(out[row_id * num_cols + csr_indices[j]], req, |
| OP::Map(dns_data[row_id * num_cols + csr_indices[j]], csr_data[j])); |
| } |
| } |
| } |
| }; |
| |
| /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */ |
| template<typename OP> |
| void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<cpu> *s, |
| const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const NDArray &dns, |
| const NDArray &csr, |
| const OpReqType req, |
| const NDArray &output, |
| const bool reverse) { |
| using namespace mshadow; |
| using namespace mxnet_op; |
| CHECK_EQ(dns.storage_type(), kDefaultStorage); |
| CHECK_EQ(csr.storage_type(), kCSRStorage); |
| CHECK(req != kAddTo); |
| CHECK(req != kNullOp); |
| const bool supported_op = std::is_same<OP, mshadow_op::minus>::value || |
| std::is_same<OP, mshadow_op::plus>::value; |
| CHECK(supported_op == true); |
| const nnvm::dim_t num_csr_rows = csr.shape()[0]; |
| const nnvm::dim_t num_csr_cols = csr.shape()[1]; |
| TBlob csr_data = csr.data(); |
| TBlob csr_indices = csr.aux_data(csr::kIdx); |
| TBlob csr_indptr = csr.aux_data(csr::kIndPtr); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { |
| MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| if (reverse && std::is_same<OP, mshadow_op::minus>::value) { |
| mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, cpu>::Launch( |
| s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>()); |
| if (!csr.storage_initialized()) { return; } |
| mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, mshadow_op::plus>, cpu>::Launch( |
| s, num_csr_rows, output.data().dptr<DType>(), |
| output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(), |
| csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols); |
| } else { |
| if (req == kWriteTo) { |
| mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, cpu>::Launch( |
| s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>()); |
| } |
| if (!csr.storage_initialized()) { return; } |
| mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, OP>, cpu>::Launch( |
| s, num_csr_rows, output.data().dptr<DType>(), |
| output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(), |
| csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols); |
| } |
| }); |
| }); |
| }); |
| }); |
| } |
| |
| /*! |
| * \brief Kernel for performing elemwise op between dense and csr matrix |
| * \param i global thread id |
| * \param req type of request |
| * \param out output array |
| * \param dns_data data array of dense input |
| * \param csr_data data array of csr input |
| * \param csr_indices indices array of csr input |
| * \param csr_indptr indptr array of csr input |
| * \param num_rows number of rows of both inputs |
| * \param num_cols number of columns of both inputs |
| */ |
| template<int req, typename OP, bool reverse> |
| struct ElemwiseDnsCsrCsrKernel { |
| template<typename DType, typename IType, typename CType> |
| MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data, |
| const DType* csr_data, const IType* csr_indices, |
| const CType* csr_indptr, const nnvm::dim_t num_rows, |
| const nnvm::dim_t num_cols) { |
| if (i < num_rows) { |
| for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) { |
| KERNEL_ASSIGN(out[j], req, reverse ? |
| OP::Map(dns_data[i * num_cols + csr_indices[j]], csr_data[j]) : |
| OP::Map(csr_data[j], dns_data[i * num_cols + csr_indices[j]])); |
| } |
| } |
| } |
| }; |
| |
| /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */ |
| template<typename xpu, typename OP> |
| void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const NDArray &dns, |
| const NDArray &csr, |
| const OpReqType req, |
| const NDArray &output, |
| const bool reverse) { |
| using namespace mshadow; |
| using namespace mxnet_op; |
| using namespace csr; |
| CHECK_EQ(dns.storage_type(), kDefaultStorage); |
| CHECK_EQ(csr.storage_type(), kCSRStorage); |
| CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo"; |
| if (req == kNullOp) return; |
| const bool supported_op = std::is_same<OP, mshadow_op::mul>::value; |
| CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul"; |
| const nnvm::dim_t num_csr_rows = csr.shape()[0]; |
| const nnvm::dim_t num_csr_cols = csr.shape()[1]; |
| const nnvm::dim_t nnz = csr.storage_shape()[0]; |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| |
| output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)}); |
| if (csr.storage_initialized()) { |
| TBlob csr_data = csr.data(); |
| TBlob csr_indices = csr.aux_data(kIdx); |
| TBlob csr_indptr = csr.aux_data(kIndPtr); |
| MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { |
| MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| if (reverse) { |
| Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, true>, xpu>::Launch( |
| s, num_csr_rows, output.data().dptr<DType>(), dns.data().dptr<DType>(), |
| csr_data.dptr<DType>(), csr_indices.dptr<IType>(), csr_indptr.dptr<CType>(), |
| num_csr_rows, num_csr_cols); |
| } else { |
| Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, false>, xpu>::Launch( |
| s, num_csr_rows, output.data().dptr<DType>(), dns.data().dptr<DType>(), |
| csr_data.dptr<DType>(), csr_indices.dptr<IType>(), csr_indptr.dptr<CType>(), |
| num_csr_rows, num_csr_cols); |
| } |
| Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(), |
| csr.aux_data(kIdx).FlatTo1D<xpu, IType>(), s); |
| Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), |
| csr.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), s); |
| }); |
| }); |
| }); |
| }); |
| } else { |
| FillZerosCsrImpl(s, output); |
| } |
| } |
| |
| /*! |
| * \brief Kernel for performing elemwise op between dense and rsp tensor |
| * \param i global thread id |
| * \param req type of request |
| * \param out output array |
| * \param dns_data data array of dense input |
| * \param rsp_data data array of rsp input |
| * \param rsp_indices indices array of rsp input |
| * \param num_rows number of rows of both inputs |
| * \param nz_rows number of non-zero rows of rsp tensor |
| * \param num_cols number of columns of both inputs |
| */ |
| template<int req, typename OP> |
| struct ElemwiseDnsRspDnsKernel { |
| template<typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data, |
| const DType* rsp_data, const IType* rsp_indices, |
| const nnvm::dim_t num_rows, const nnvm::dim_t nz_rows, |
| const nnvm::dim_t num_cols) { |
| if (i < nz_rows * num_cols) { |
| const nnvm::dim_t rsp_idx = i / num_cols; |
| const nnvm::dim_t dns_row = rsp_indices[rsp_idx]; |
| const nnvm::dim_t col = i % num_cols; |
| KERNEL_ASSIGN(out[dns_row * num_cols + col], req, |
| OP::Map(dns_data[dns_row * num_cols + col], |
| rsp_data[rsp_idx * num_cols + col])); |
| } |
| } |
| }; |
| |
| /*! \brief DNS -op- RSP binary operator for non-canonical NDArray */ |
| template<typename xpu, typename OP> |
| void ElemwiseBinaryOp::DnsRspDnsOp(mshadow::Stream<xpu> *s, |
| const nnvm::NodeAttrs &attrs, |
| const OpContext &ctx, |
| const NDArray &dns, |
| const NDArray &rsp, |
| const OpReqType req, |
| const NDArray &output, |
| const bool reverse) { |
| using namespace mshadow; |
| using namespace mxnet_op; |
| CHECK(dns.storage_type() == kDefaultStorage || dns.storage_type() == kRowSparseStorage); |
| CHECK_EQ(rsp.storage_type(), kRowSparseStorage); |
| CHECK_EQ(output.data().Size(), dns.data().Size()); |
| CHECK(req != kAddTo); |
| if (req == kNullOp) return; |
| const bool supported_op = std::is_same<OP, mshadow_op::minus>::value || |
| std::is_same<OP, mshadow_op::plus>::value; |
| CHECK(supported_op == true) << |
| "Only plus and minus supported now for elemwise operation between default and rsp matrices"; |
| const nnvm::dim_t num_rows = dns.shape()[0]; |
| const nnvm::dim_t num_cols = dns.data().Size() / num_rows; |
| const nnvm::dim_t nz_rows = rsp.aux_shape(rowsparse::kIdx).Size(); |
| TBlob rsp_data = rsp.data(); |
| TBlob rsp_indices = rsp.aux_data(rowsparse::kIdx); |
| |
| MSHADOW_TYPE_SWITCH(rsp_data.type_flag_, DType, { |
| MSHADOW_IDX_TYPE_SWITCH(rsp_indices.type_flag_, IType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, Req, { |
| if (reverse && std::is_same<OP, mshadow_op::minus>::value) { |
| mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, xpu>::Launch( |
| s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>()); |
| if (rsp.storage_initialized()) { |
| mxnet_op::Kernel<ElemwiseDnsRspDnsKernel<Req, mshadow_op::plus>, xpu>::Launch( |
| s, nz_rows * num_cols, output.data().dptr<DType>(), |
| output.data().dptr<DType>(), rsp_data.dptr<DType>(), rsp_indices.dptr<IType>(), |
| num_rows, nz_rows, num_cols); |
| } |
| } else { |
| if (req == kWriteTo) { |
| mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch( |
| s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>()); |
| } |
| if (rsp.storage_initialized()) { |
| mxnet_op::Kernel<ElemwiseDnsRspDnsKernel<Req, OP>, xpu>::Launch( |
| s, nz_rows * num_cols, output.data().dptr<DType>(), |
| output.data().dptr<DType>(), rsp_data.dptr<DType>(), rsp_indices.dptr<IType>(), |
| num_rows, nz_rows, num_cols); |
| } |
| } |
| }); |
| }); |
| }); |
| } |
| |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_INL_H_ |