blob: 00007bd2e602b216c4d539869e254825a275366e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file matrix_op.cu
* \brief GPU Implementation of matrix operations
*/
#include <cub/cub.cuh>
#include "./matrix_op-inl.h"
#include "./elemwise_unary_op.h"
namespace mxnet {
namespace op {
/*!
* \brief Compute the number of elements of every row.
*/
struct SliceMarkCsrIndPtr {
/*!
* \brief
* \param i the i-th row of the output csr ndarray
* \param prefix_sum indptr array of the output csr ndarray
* \param in_idx indices array of the input csr ndarray
* \param in_indptr indptr array of the input csr ndarray
* \param begin_col starting indice
* \param end_col ending indice
*/
template <typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i,
RType* prefix_sum,
const IType* in_idx,
const RType* in_indptr,
const int begin_col,
const int end_col) {
if (i == 0) {
prefix_sum[0] = 0;
}
RType size = 0;
for (RType j = in_indptr[i]; j < in_indptr[i + 1]; j++) {
// indices of CSRNDArray are in ascending order per row
if (in_idx[j] >= end_col) {
break;
} else if (in_idx[j] >= begin_col) {
size++;
}
}
prefix_sum[i + 1] = size;
}
};
template <>
void SliceDimTwoCsrImpl<gpu>(const mxnet::TShape& begin,
const mxnet::TShape& end,
const OpContext& ctx,
const NDArray& in,
const NDArray& out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
Stream<gpu>* s = ctx.get_stream<gpu>();
nnvm::dim_t begin_row = begin[0], end_row = end[0];
nnvm::dim_t begin_col = begin[1], end_col = end[1];
nnvm::dim_t indptr_len = end_row - begin_row + 1;
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
// assume idx indptr share the same type
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
RType* in_indptr = in.aux_data(kIndPtr).dptr<RType>();
IType* in_idx = in.aux_data(kIdx).dptr<IType>();
DType* in_data = in.data().dptr<DType>();
RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
Kernel<SliceMarkCsrIndPtr, gpu>::Launch(
s, indptr_len - 1, out_indptr, in_idx, in_indptr + begin_row, begin_col, end_col);
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
out_indptr,
out_indptr,
indptr_len,
Stream<gpu>::GetStream(s));
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
d_temp_storage = workspace.dptr_;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
out_indptr,
out_indptr,
indptr_len,
Stream<gpu>::GetStream(s));
// retrieve nnr
RType nnr = 0;
CUDA_CALL(cudaMemcpyAsync(&nnr,
&out_indptr[indptr_len - 1],
sizeof(RType),
cudaMemcpyDeviceToHost,
mshadow::Stream<gpu>::GetStream(s)));
CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s)));
// returns zeros in csr format if nnr = 0
if (nnr == 0) {
out.set_aux_shape(kIdx, Shape1(0));
return;
}
out.CheckAndAllocAuxData(kIdx, Shape1(nnr));
out.CheckAndAllocData(Shape1(nnr));
IType* out_idx = out.aux_data(kIdx).dptr<IType>();
DType* out_data = out.data().dptr<DType>();
Kernel<SliceDimTwoCsrAssign, gpu>::Launch(s,
indptr_len - 1,
out_idx,
out_data,
out_indptr,
in_idx,
in_data,
in_indptr + begin_row,
begin_col,
end_col);
});
});
});
}
template <typename DType>
struct split_tensor_data {
static const int MaxSections = 128;
size_t num_sections;
DType* outputs[MaxSections];
size_t indices[MaxSections + 1];
DType* inputs[1];
};
template <bool split_last_axis, typename LType, typename DType>
__global__ void split_tensor_kernel(size_t input_size,
const split_tensor_data<DType> params,
size_t split_axis_size,
size_t tail_size,
size_t last_axis_size,
size_t blocks_last_axis) {
const int entries_per_load = sizeof(LType) / sizeof(DType);
const LType* in_aligned = reinterpret_cast<const LType*>(params.inputs[0]);
const size_t last_axis_size_aligned =
entries_per_load > 0 ? last_axis_size / entries_per_load : last_axis_size;
if (split_last_axis) {
size_t input_offset_leading = (blockIdx.x / blocks_last_axis) * last_axis_size_aligned;
size_t position_last_axis = (blockIdx.x % blocks_last_axis) * blockDim.x * entries_per_load +
params.indices[0] + threadIdx.x * entries_per_load;
if (position_last_axis < params.indices[params.num_sections]) {
size_t position_last_axis_aligned =
entries_per_load > 0 ? position_last_axis / entries_per_load : position_last_axis;
LType input_data = in_aligned[input_offset_leading + position_last_axis_aligned];
// Binary search to find section of each thread
size_t lower = 0;
size_t upper = params.num_sections - 1;
while (lower < upper) {
size_t mid = (lower + upper + 1) / 2;
if (position_last_axis >= params.indices[mid])
lower = mid;
else
upper = mid - 1;
}
size_t section = upper;
size_t section_size = params.indices[section + 1] - params.indices[section];
LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
size_t section_size_aligned =
entries_per_load > 0 ? section_size / entries_per_load : section_size;
size_t index_aligned = entries_per_load > 0 ? params.indices[section] / entries_per_load :
params.indices[section];
size_t output_offset_leading = (blockIdx.x / blocks_last_axis) * section_size_aligned;
size_t output_position = output_offset_leading + position_last_axis_aligned - index_aligned;
out_aligned[output_position] = input_data;
}
} else {
size_t split_axis_size_iter = params.indices[params.num_sections] - params.indices[0];
size_t blocks_per_leading_dim = (split_axis_size_iter * tail_size * blocks_last_axis);
// input offsets: leading (axes pre-split-axis), at split-axis, tail, and blocks_last_axis
size_t input_offset_leading = (blockIdx.x / blocks_per_leading_dim) * split_axis_size *
tail_size * last_axis_size_aligned;
size_t pos_in_split_axis =
(blockIdx.x / (tail_size * blocks_last_axis)) % split_axis_size_iter + params.indices[0];
size_t input_offset_split_axis = pos_in_split_axis * tail_size * last_axis_size_aligned;
size_t offset_tail = ((blockIdx.x / blocks_last_axis) % tail_size) * last_axis_size_aligned;
size_t input_offset = input_offset_leading + input_offset_split_axis + offset_tail +
(blockIdx.x % blocks_last_axis) * blockDim.x;
// Binary search to find section for this block
size_t lower = 0;
size_t upper = params.num_sections - 1;
while (lower < upper) {
size_t mid = (lower + upper + 1) / 2;
if (pos_in_split_axis >= params.indices[mid])
lower = mid;
else
upper = mid - 1;
}
size_t section = upper;
size_t section_size = params.indices[section + 1] - params.indices[section];
LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
// output offsets: leading (axes pre-split-axis), at split-axis,and blocks_last_axis
size_t output_offset_leading =
(blockIdx.x / blocks_per_leading_dim) * section_size * tail_size * last_axis_size_aligned;
size_t output_offset_split_axis =
((blockIdx.x % blocks_per_leading_dim) / blocks_last_axis -
((params.indices[section] - params.indices[0]) * tail_size)) *
last_axis_size_aligned;
size_t output_offset = output_offset_leading + output_offset_split_axis +
(blockIdx.x % blocks_last_axis) * blockDim.x;
if (threadIdx.x < last_axis_size_aligned) {
LType input_data = in_aligned[input_offset + threadIdx.x];
out_aligned[output_offset + threadIdx.x] = input_data;
}
}
}
template <typename DType>
int get_load_type_split(size_t last_axis_size,
bool splitting_last_axis,
size_t n_sections,
size_t* indices) {
using namespace mshadow;
int sections_largest_multiple = 8;
if (splitting_last_axis) {
for (size_t i = 0; i < n_sections; ++i) {
size_t size_section = indices[i + 1] - indices[i];
if (size_section * sizeof(DType) % 8)
sections_largest_multiple = std::min(sections_largest_multiple, 4);
if (size_section * sizeof(DType) % 4)
sections_largest_multiple = std::min(sections_largest_multiple, 2);
if (size_section * sizeof(DType) % 2)
sections_largest_multiple = std::min(sections_largest_multiple, 1);
}
}
if (last_axis_size * sizeof(DType) % 8 == 0 && sections_largest_multiple == 8) {
return kFloat64;
} else if (last_axis_size * sizeof(DType) % 4 == 0 && sections_largest_multiple >= 4) {
return kFloat32;
} else if (last_axis_size * sizeof(DType) % 2 == 0 && sections_largest_multiple >= 2) {
return kFloat16;
} else {
return kUint8;
}
}
inline void SplitOpForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim());
const TBlob& input_data = inputs[split_enum::kData];
int real_axis = param.axis;
if (real_axis < 0) {
real_axis += input_data.ndim();
}
size_t last_axis_size = input_data.shape_[inputs[0].ndim() - 1];
size_t split_axis_size = input_data.shape_[real_axis];
size_t tail_size = 1; // does not include last dim
for (int i = real_axis + 1; i < input_data.ndim() - 1; ++i) {
tail_size *= input_data.shape_[i];
}
if (last_axis_size < 128) {
// custom kernel will not be efficient with less than 128 elemnts in last axis
SplitOpForwardImpl<gpu>(attrs, ctx, inputs, req, outputs, real_axis);
} else {
Stream<gpu>* s = ctx.get_stream<gpu>();
CHECK_LT(real_axis, input_data.ndim());
const mxnet::TShape& ishape = input_data.shape_;
const mxnet::TShape split_pts =
(param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices;
std::vector<size_t> indices;
for (const auto& split_pos : split_pts) {
indices.push_back(split_pos);
}
if (param.sections == 0) {
indices.push_back(ishape[real_axis]);
}
size_t n_sections = indices.size() - 1;
bool splitting_last_axis = (real_axis == inputs[0].ndim() - 1);
for (size_t sections_processed = 0; sections_processed < n_sections;) {
size_t remaining_sections = n_sections - sections_processed;
MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, {
// set parameters
split_tensor_data<DType> params{};
params.num_sections = std::min<size_t>(remaining_sections, params.MaxSections);
params.inputs[0] = input_data.dptr<DType>();
for (size_t i = 0; i < params.num_sections; ++i) {
params.outputs[i] = outputs[sections_processed + i].dptr<DType>();
params.indices[i] = indices[sections_processed + i];
}
params.indices[params.num_sections] = indices[sections_processed + params.num_sections];
// load type: we need to check that last axis size is multiple of ltype
// and if splitting_last_axis, all section sizes as well
int ltype = get_load_type_split<DType>(
last_axis_size, splitting_last_axis, params.num_sections, params.indices);
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
CHECK_LE(sizeof(DType), sizeof(LType));
const size_t entries_per_load = sizeof(LType) / sizeof(DType);
size_t block_size = 32;
size_t max_threads_block = 256;
size_t last_axis_elements =
entries_per_load > 0 ? (last_axis_size / entries_per_load) : 0;
if (splitting_last_axis) {
// may not be possible to include whole axis if too many sections
last_axis_elements =
entries_per_load > 0 ?
((params.indices[params.num_sections] - params.indices[0]) / entries_per_load) :
0;
}
while (block_size < last_axis_elements && (block_size < max_threads_block)) {
block_size += 32;
}
size_t blocks_last_axis = (last_axis_elements + block_size - 1) / block_size;
size_t n_blocks = blocks_last_axis;
for (int i = 0; i < input_data.ndim() - 1; ++i) {
if (i == real_axis) {
// may not be possible to include all sections if too many
n_blocks *= (params.indices[params.num_sections] - params.indices[0]);
} else {
n_blocks *= input_data.shape_[i];
}
}
if (splitting_last_axis) {
split_tensor_kernel<true, LType>
<<<n_blocks, block_size, 0, s->stream_>>>(input_data.Size(),
params,
split_axis_size,
tail_size,
last_axis_size,
blocks_last_axis);
} else {
split_tensor_kernel<false, LType>
<<<n_blocks, block_size, 0, s->stream_>>>(input_data.Size(),
params,
split_axis_size,
tail_size,
last_axis_size,
blocks_last_axis);
}
});
sections_processed += params.num_sections;
});
}
}
}
NNVM_REGISTER_OP(Reshape).set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
NNVM_REGISTER_OP(Flatten).set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
NNVM_REGISTER_OP(transpose).set_attr<FCompute>("FCompute<gpu>", Transpose<gpu>);
NNVM_REGISTER_OP(expand_dims).set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
NNVM_REGISTER_OP(slice)
.set_attr<FCompute>("FCompute<gpu>", SliceOpForward<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SliceEx<gpu>);
NNVM_REGISTER_OP(_backward_slice).set_attr<FCompute>("FCompute<gpu>", SliceOpBackward<gpu>);
NNVM_REGISTER_OP(_slice_assign).set_attr<FCompute>("FCompute<gpu>", SliceAssignOpForward<gpu>);
NNVM_REGISTER_OP(_slice_assign_scalar)
.set_attr<FCompute>("FCompute<gpu>", SliceAssignScalarOpForward<gpu>);
NNVM_REGISTER_OP(slice_axis).set_attr<FCompute>("FCompute<gpu>", SliceAxis<gpu>);
NNVM_REGISTER_OP(_backward_slice_axis).set_attr<FCompute>("FCompute<gpu>", SliceAxisGrad_<gpu>);
NNVM_REGISTER_OP(slice_like).set_attr<FCompute>("FCompute<gpu>", SliceLikeForward<gpu>);
NNVM_REGISTER_OP(_backward_slice_like).set_attr<FCompute>("FCompute<gpu>", SliceLikeBackward<gpu>);
NNVM_REGISTER_OP(clip)
.set_attr<FCompute>("FCompute<gpu>", Clip<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ClipEx<gpu>);
NNVM_REGISTER_OP(_backward_clip).set_attr<FCompute>("FCompute<gpu>", ClipGrad_<gpu>);
NNVM_REGISTER_OP(repeat).set_attr<FCompute>("FCompute<gpu>", RepeatOpForward<gpu>);
NNVM_REGISTER_OP(_backward_repeat).set_attr<FCompute>("FCompute<gpu>", RepeatOpBackward<gpu>);
NNVM_REGISTER_OP(tile).set_attr<FCompute>("FCompute<gpu>", TileOpForward<gpu>);
NNVM_REGISTER_OP(_backward_tile).set_attr<FCompute>("FCompute<gpu>", TileOpBackward<gpu>);
NNVM_REGISTER_OP(reverse)
.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
[](const NodeAttrs&, const bool) { return false; })
.set_attr<FCompute>("FCompute<gpu>", ReverseOpForward<gpu>);
NNVM_REGISTER_OP(_backward_reverse)
.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
[](const NodeAttrs&, const bool) { return false; })
.set_attr<FCompute>("FCompute<gpu>", ReverseOpForward<gpu>);
NNVM_REGISTER_OP(stack).set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
NNVM_REGISTER_OP(_backward_stack).set_attr<FCompute>("FCompute<gpu>", StackOpBackward<gpu>);
NNVM_REGISTER_OP(squeeze).set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
NNVM_REGISTER_OP(_backward_squeeze)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
NNVM_REGISTER_OP(depth_to_space).set_attr<FCompute>("FCompute<gpu>", DepthToSpaceOpForward<gpu>);
NNVM_REGISTER_OP(space_to_depth).set_attr<FCompute>("FCompute<gpu>", SpaceToDepthOpForward<gpu>);
NNVM_REGISTER_OP(_split_v2)
// Incompatible due to Copy(xpu_tensor, cpu_tensor) in SplitOpForwardImpl
.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
[](const NodeAttrs&, const bool) { return false; })
.set_attr<FCompute>("FCompute<gpu>", SplitOpForwardGPU);
NNVM_REGISTER_OP(_split_v2_backward)
// Incompatible due to Copy(xpu_tensor, cpu_tensor) in SplitOpBackwardImpl
.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
[](const NodeAttrs&, const bool) { return false; })
.set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
} // namespace op
} // namespace mxnet