blob: 5d6c49ff88829c25bd9b127471f67242c64a3f4c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015-2017 by Contributors
* \file broadcast_reduce-inl.cuh
* \brief CUDA implementations for binary broadcast and reduce
* \author Antti-Pekka Hynninen
*/
#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
using namespace mshadow::cuda;
template<int ndim, typename DType, typename OP, int unroll>
__launch_bounds__(kMaxThreadsPerBlock)
__global__ void binary_broadcast_kernel(const int N, const bool addto,
const DType* __restrict lhs,
const DType* __restrict rhs, DType *out,
const Shape<ndim> lstride, const Shape<ndim> rstride,
const Shape<ndim> oshape) {
for (int idx = blockIdx.x * blockDim.x * unroll + threadIdx.x; idx < N;
idx += blockDim.x * gridDim.x * unroll)
{
int j[unroll];
int k[unroll];
DType val[unroll];
#pragma unroll
for (int i=0;i < unroll;i++) {
unravel_dot(idx + i*blockDim.x, oshape, lstride, rstride, &j[i], &k[i]);
val[i] = OP::Map(lhs[j[i]], rhs[k[i]]);
}
#pragma unroll
for (int i=0;i < unroll;i++) {
if (idx + i*blockDim.x < N) assign(&out[idx + i*blockDim.x], addto, val[i]);
}
}
}
template<int ndim, typename DType, typename OP>
void BinaryBroadcastComputeImpl(Stream<gpu> *s, const OpReqType req,
const TBlob& lhs, const TBlob& rhs, const TBlob& out) {
if (req == kNullOp) return;
cudaStream_t stream = Stream<gpu>::GetStream(s);
int N = out.shape_.Size();
const int warpSize = 32;
const int unroll = 2;
int nthread = std::min(kMaxThreadsPerBlock, ((N + warpSize - 1)/warpSize)*warpSize );
int ngrid = std::min(kBaseGridNum, (N + nthread*unroll - 1) / (nthread*unroll));
Shape<ndim> lstride = calc_stride(lhs.shape_.get<ndim>());
Shape<ndim> rstride = calc_stride(rhs.shape_.get<ndim>());
binary_broadcast_kernel<ndim, DType, OP, unroll><<<ngrid, nthread, 0, stream>>>(
N, req == kAddTo, lhs.dptr<DType>(), rhs.dptr<DType>(), out.dptr<DType>(), lstride, rstride,
out.shape_.get<ndim>());
}
const int nthread_reduce = kMaxThreadsPerBlock;
template<typename Reducer, int ndim, typename DType, typename OP, int unroll>
__launch_bounds__(nthread_reduce)
__global__ void reduce_kernel(const int N, const int M, const bool addto,
const DType* __restrict big, DType *small,
const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
const Shape<ndim> big_shape, const Shape<ndim> big_stride,
const int Mnext, const bool do_transpose) {
extern __shared__ char shTileChar[];
DType* shTile = (DType*)(shTileChar);
const int tid = threadIdx.x + threadIdx.y*blockDim.x;
const int bx = (do_transpose) ? blockDim.y : blockDim.x;
const int by = (do_transpose) ? blockDim.x : blockDim.y;
const int tidx = (do_transpose) ? tid / by : threadIdx.x;
const int tidy = (do_transpose) ? tid % by : threadIdx.y;
for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
// This TB handles M range [Mstart, ...., Mend - 1]
const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
int idx = idx0 + tidx;
Shape<ndim> coord = unravel(idx, small_shape);
int idx_big0 = ravel(coord, big_shape0);
DType val, residual;
Reducer::SetInitValue(val, residual);
if (idx < N) {
for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
int idx_big[unroll];
#pragma unroll
for (int u=0;u < unroll;u++) {
idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride);
}
DType tmp[unroll];
#pragma unroll
for (int u=0;u < unroll;u++) {
if (k + u*by < Mend) {
tmp[u] = OP::Map(big[idx_big[u]]);
}
}
#pragma unroll
for (int u=0;u < unroll;u++) {
if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
}
}
}
// Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
if (by > 1) {
// Fix bx to avoid bank conflicts. Assumes warpSize number of banks
const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
const int it0 = tidx + tidy*fbx;
shTile[it0 * 2] = val;
shTile[it0 * 2 + 1] = residual;
__syncthreads();
for (int t=1;t < by;t <<= 1) {
DType tmp, tmp_residual;
Reducer::SetInitValue(tmp, tmp_residual);
if (tidy + t < by) {
tmp = shTile[(it0 + t*fbx) * 2];
tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
}
__syncthreads();
Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
__syncthreads();
}
if (idx < N && tidy == 0) {
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, val);
}
}
}
}
}
template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2, int unroll>
__launch_bounds__(nthread_reduce)
__global__ void reduce_kernel(const int N, const int M, const bool addto,
const DType* __restrict big, const DType* __restrict lhs,
const DType* __restrict rhs, DType *small,
const Shape<ndim> big_shape0, const Shape<ndim> lhs_shape0,
const Shape<ndim> rhs_shape0, const Shape<ndim> small_shape,
const Shape<ndim> big_shape, const Shape<ndim> lhs_shape,
const Shape<ndim> rhs_shape, const Shape<ndim> big_stride,
const Shape<ndim> lhs_stride, const Shape<ndim> rhs_stride,
const int Mnext, const bool do_transpose) {
extern __shared__ char shTileChar[];
DType* shTile = (DType*)(shTileChar);
const int tid = threadIdx.x + threadIdx.y*blockDim.x;
const int bx = (do_transpose) ? blockDim.y : blockDim.x;
const int by = (do_transpose) ? blockDim.x : blockDim.y;
const int tidx = (do_transpose) ? tid / by : threadIdx.x;
const int tidy = (do_transpose) ? tid % by : threadIdx.y;
for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
// This TB handles M range [Mstart, ...., Mend - 1]
const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
int idx = idx0 + tidx;
Shape<ndim> coord = unravel(idx, small_shape);
int idx_big0 = ravel(coord, big_shape0);
int idx_lhs0 = ravel(coord, lhs_shape0);
int idx_rhs0 = ravel(coord, rhs_shape0);
DType val, residual;
Reducer::SetInitValue(val, residual);
if (idx < N) {
for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
int idx_big[unroll];
int idx_lhs[unroll];
int idx_rhs[unroll];
#pragma unroll
for (int u=0;u < unroll;u++) {
idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride);
idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride);
idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride);
}
DType tmp[unroll];
#pragma unroll
for (int u=0;u < unroll;u++) {
if (k + u*by < Mend) {
tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]]));
}
}
#pragma unroll
for (int u=0;u < unroll;u++) {
if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
}
}
}
// Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
if (by > 1) {
// Fix bx to avoid bank conflicts. Assumes warpSize number of banks
const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
const int it0 = tidx + tidy*fbx;
shTile[it0 * 2] = val;
shTile[it0 * 2 + 1] = residual;
__syncthreads();
for (int t=1;t < by;t <<= 1) {
DType tmp, tmp_residual;
Reducer::SetInitValue(tmp, tmp_residual);
if (tidy + t < by) {
tmp = shTile[(it0 + t*fbx) * 2];
tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
}
__syncthreads();
Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
__syncthreads();
}
if (idx < N && tidy == 0) {
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, val);
}
}
}
}
}
// Simple reduction of lines when M is small
template<typename Reducer, typename DType>
__launch_bounds__(kMaxThreadsPerBlock)
__global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
const int small_in_stride, const DType* __restrict small_in, DType *small_out) {
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
DType val, residual;
Reducer::SetInitValue(val, residual);
for (int k = 0; k < M; k++) {
Reducer::Reduce(val, small_in[idx + k*small_in_stride], residual);
}
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small_out[idx], addto, val);
}
}
}
template<typename Reducer, int ndim, typename DType, typename OP>
__global__ void reduce_kernel_M1(const int N, const bool addto,
const DType* __restrict big, DType *small, const Shape<ndim> bshape,
const Shape<ndim> sshape) {
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, sshape);
int j = ravel(coord, bshape);
DType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP::Map(big[j]), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}
}
template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
__global__ void reduce_kernel_M1(const int N, const bool addto,
const DType* __restrict big,
const DType* __restrict lhs,
const DType* __restrict rhs,
DType *small,
const Shape<ndim> big_shape,
const Shape<ndim> lhs_shape,
const Shape<ndim> rhs_shape,
const Shape<ndim> small_shape) {
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, small_shape);
int idx_big = ravel(coord, big_shape);
int idx_lhs = ravel(coord, lhs_shape);
int idx_rhs = ravel(coord, rhs_shape);
DType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}
}
// Returns the stride with which the fastest dimension is moving.
// Used to detect memory access scatter.
template<int ndim>
MSHADOW_XINLINE int fastest_stride(const Shape<ndim>& small, const Shape<ndim>& big,
const Shape<ndim>& big_stride) {
for (int i = ndim-1; i >= 0; --i) {
if (big[i] != 1) {
return (small[i] == big[i]) ? 1 : big_stride[i];
}
}
return 1;
}
// Returns a/b integer division rounded up
template<typename Type>
Type ceil_idiv(const Type a, const Type b) {
return (a + b - 1)/b;
}
// Configuration for ReduceImpl()
template<int ndim>
struct ReduceImplConfig {
static const int warpSize = 32;
static const int unroll_reduce = 2;
static const int maxLoopPerTB = 64;
int N;
int M;
int Mnext;
struct {
dim3 blockDim;
dim3 gridDim;
int shMemSize;
bool do_transpose;
} kernel_1;
struct {
int blockSize;
int gridSize;
} kernel_2;
size_t workspace_size;
Shape<ndim> rshape, rstride;
Shape<ndim> lhs_shape, lhs_stride;
Shape<ndim> rhs_shape, rhs_stride;
};
static inline uint64_t calc_num_load(const int X, const int Y, const int* strides) {
const int warpSize = ReduceImplConfig<1>::warpSize;
// Number of full warps
uint64_t num_full_warp = X / warpSize;
// Length of the partial warp i.e. number of threads that are performing loads
uint64_t len_part_warp = X % warpSize;
uint64_t num_load_full = (std::min(warpSize, strides[0]) +
std::min(warpSize, strides[1]) +
std::min(warpSize, strides[2]))*num_full_warp;
uint64_t num_load_part =
(std::min(len_part_warp, ceil_idiv<uint64_t>(len_part_warp*strides[0], warpSize)) +
std::min(len_part_warp, ceil_idiv<uint64_t>(len_part_warp*strides[1], warpSize)) +
std::min(len_part_warp, ceil_idiv<uint64_t>(len_part_warp*strides[2], warpSize)))*
(len_part_warp != 0);
uint64_t num_load = (num_load_full + num_load_part)*(uint64_t)Y;
return num_load;
}
template<int ndim, typename DType>
ReduceImplConfig<ndim> ConfigureReduceImpl(const mxnet::TShape& small, const mxnet::TShape& big, const mxnet::TShape* lhs,
const mxnet::TShape* rhs) {
ReduceImplConfig<ndim> config;
diff(small.get<ndim>(), big.get<ndim>(), &config.rshape, &config.rstride);
config.N = small.Size();
config.M = config.rshape.Size();
bool multiOp = false;
if (lhs != NULL) {
CHECK_NOTNULL(rhs);
diff(small.get<ndim>(), lhs->get<ndim>(), &config.lhs_shape,
&config.lhs_stride);
diff(small.get<ndim>(), rhs->get<ndim>(), &config.rhs_shape,
&config.rhs_stride);
multiOp = true;
}
config.workspace_size = 0;
if (config.M == 1) {
config.kernel_1.blockDim.x = kMaxThreadsPerBlock;
config.kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum,
(config.N + config.kernel_1.blockDim.x - 1)/config.kernel_1.blockDim.x);
} else {
int reduce_strides[3];
reduce_strides[0] = fastest_stride(small.get<ndim>(), big.get<ndim>(),
big.get<ndim>());
reduce_strides[1] = (multiOp) ? fastest_stride(small.get<ndim>(),
lhs->get<ndim>(), lhs->get<ndim>()) : 1;
reduce_strides[2] = (multiOp) ? fastest_stride(small.get<ndim>(),
rhs->get<ndim>(), rhs->get<ndim>()) : 1;
int reduce_strides_transp[3];
reduce_strides_transp[0] = fastest_stride(small.get<ndim>(), config.rshape,
config.rstride);
reduce_strides_transp[1] = (multiOp) ?
fastest_stride(small.get<ndim>(), config.lhs_shape, config.lhs_stride) : 1;
reduce_strides_transp[2] = (multiOp) ?
fastest_stride(small.get<ndim>(), config.rhs_shape, config.rhs_stride) : 1;
uint64_t num_load = calc_num_load(config.N, config.M, reduce_strides);
uint64_t num_load_transp = calc_num_load(config.M, config.N, reduce_strides_transp);
config.Mnext = 1;
config.kernel_1.do_transpose = (num_load > num_load_transp);
config.kernel_1.blockDim.x = 0;
config.kernel_1.blockDim.y = 0;
if (config.kernel_1.do_transpose) {
// Fastest thread ID goes through M
// Loop over N has step size config.kernel_1.blockDim.y
if (config.N < 8) {
config.kernel_1.blockDim.y = 1;
} else if (config.N < 256) {
config.kernel_1.blockDim.y = 4;
} else {
if (config.M < 8) {
config.kernel_1.blockDim.x = 1;
} else if (config.M < 256) {
config.kernel_1.blockDim.x = 4;
} else {
config.kernel_1.blockDim.x = config.warpSize;
}
}
} else {
// Fastest thread ID goes through N
// Loop over M has step size config.kernel_1.blockDim.y
if (config.M < 8) {
config.kernel_1.blockDim.y = 1;
} else if (config.M < 256) {
config.kernel_1.blockDim.y = 4;
} else {
if (config.N < 8) {
config.kernel_1.blockDim.x = 1;
} else if (config.N < 256) {
config.kernel_1.blockDim.x = 4;
} else {
config.kernel_1.blockDim.x = config.warpSize;
}
}
}
if (config.kernel_1.blockDim.x == 0 && config.kernel_1.blockDim.y == 0) {
LOG(FATAL) << "Unable to set blockDim";
} else if (config.kernel_1.blockDim.x == 0) {
config.kernel_1.blockDim.x = nthread_reduce / config.kernel_1.blockDim.y;
} else if (config.kernel_1.blockDim.y == 0) {
config.kernel_1.blockDim.y = nthread_reduce / config.kernel_1.blockDim.x;
}
if (config.kernel_1.do_transpose) {
// Fastest thread ID goes through M
config.kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum,
ceil_idiv<unsigned int>(config.N, config.kernel_1.blockDim.y));
config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext);
int by = config.kernel_1.blockDim.y;
if (config.kernel_1.blockDim.y % config.warpSize == 0) {
// Fix shared memory bank conflict
by++;
}
config.kernel_1.shMemSize = (config.kernel_1.blockDim.x > 1) ?
config.kernel_1.blockDim.x*by*sizeof(DType) * 2 : 0;
// Maximum number of times we want TB to loop in M
// Max size of M-block each TB can handle
int maxMblock = config.kernel_1.blockDim.x*config.maxLoopPerTB;
config.Mnext = (config.M + maxMblock - 1) / maxMblock;
} else {
// Fastest thread ID goes through N
config.kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum,
ceil_idiv<unsigned int>(config.N, config.kernel_1.blockDim.x));
config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext);
config.kernel_1.shMemSize = (config.kernel_1.blockDim.y > 1) ?
config.kernel_1.blockDim.x*config.kernel_1.blockDim.y*sizeof(DType) * 2 : 0;
// Maximum number of times we want TB to loop in M
// Max size of M-block each TB can handle
int maxMblock = config.kernel_1.blockDim.y*config.maxLoopPerTB;
config.Mnext = (config.M + maxMblock - 1) / maxMblock;
}
if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
config.workspace_size += config.N*config.Mnext*sizeof(DType);
// Set gridDim.y to Mnext
config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext);
}
if (config.Mnext > 1) {
config.kernel_2.blockSize = kMaxThreadsPerBlock;
config.kernel_2.gridSize = std::min((int)kBaseGridNum,
(config.N + config.kernel_2.blockSize - 1)/config.kernel_2.blockSize );
}
}
return config;
}
#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
if (do_unroll) { \
const int unrollVar = unrollAmount; \
{__VA_ARGS__} \
} else { \
const int unrollVar = 1; \
{__VA_ARGS__} \
}
template<typename Reducer, int ndim, typename DType, typename OP>
void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
const TBlob& big, const Tensor<gpu, 1, char>& workspace,
const ReduceImplConfig<ndim>& config) {
if (config.M == 1) {
reduce_kernel_M1<Reducer, ndim, DType, OP>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
config.N, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(),
small.shape_.get<ndim>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
} else {
DType* small_dptr = small.dptr<DType>();
bool addto = (req == kAddTo);
if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
addto = false;
// Check that the workspace is contigiuous
CHECK_EQ(workspace.CheckContiguous(), true);
// Check that we have enough storage
CHECK_GE(workspace.size(0), config.workspace_size);
}
const int by = (config.kernel_1.do_transpose) ?
config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce );
KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig<ndim>::unroll_reduce, UNROLL, {
reduce_kernel<Reducer, ndim, DType, OP, UNROLL>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
small.shape_.get<ndim>(), config.rshape, config.rstride, config.Mnext,
config.kernel_1.do_transpose);
});
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
if (config.Mnext > 1) {
reduce_lines_kernel<Reducer, DType>
<<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
(config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
}
}
}
template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs,
const OpReqType req, const TBlob& big, const Tensor<gpu, 1, char>& workspace,
const ReduceImplConfig<ndim>& config) {
if (config.M == 1) {
reduce_kernel_M1<Reducer, ndim, DType, OP1, OP2>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
config.N, req == kAddTo, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
small.dptr<DType>(), big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
rhs.shape_.get<ndim>(), small.shape_.get<ndim>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
} else {
DType* small_dptr = small.dptr<DType>();
bool addto = (req == kAddTo);
if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
addto = false;
// Check that the workspace is contigiuous
CHECK_EQ(workspace.CheckContiguous(), true);
// Check that we have enough storage
CHECK_GE(workspace.size(0), config.workspace_size);
}
const int by = (config.kernel_1.do_transpose) ?
config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce );
KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig<ndim>::unroll_reduce, UNROLL, {
reduce_kernel<Reducer, ndim, DType, OP1, OP2, UNROLL>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
config.N, config.M, addto, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
small_dptr, big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
rhs.shape_.get<ndim>(), small.shape_.get<ndim>(), config.rshape, config.lhs_shape,
config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext,
config.kernel_1.do_transpose);
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
});
if (config.Mnext > 1) {
reduce_lines_kernel<Reducer, DType>
<<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
(config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
}
}
}
#undef KERNEL_UNROLL_SWITCH
template<typename Reducer, int ndim, typename DType, typename OP>
void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
if (req == kNullOp) return;
cudaStream_t stream = Stream<gpu>::GetStream(s);
ReduceImplConfig<ndim> config =
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
ReduceImpl<Reducer, ndim, DType, OP>(stream, small, req, big, workspace, config);
}
template <typename Reducer, int ndim, typename DType, typename OP>
void ReduceWithExtraMem(Stream<gpu>* s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {};
template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big,
const TBlob& lhs, const TBlob& rhs) {
if (req == kNullOp) return;
cudaStream_t stream = Stream<gpu>::GetStream(s);
ReduceImplConfig<ndim> config =
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_);
ReduceImpl<Reducer, ndim, DType, OP1, OP2>(stream, small, lhs, rhs, req, big, workspace, config);
}
template<int ndim, typename DType>
size_t ReduceWorkspaceSize(Stream<gpu> *s, const mxnet::TShape& small, const OpReqType req,
const mxnet::TShape& big) {
if (req == kNullOp) return 0;
ReduceImplConfig<ndim> config = ConfigureReduceImpl<ndim, DType>(small, big, NULL, NULL);
return config.workspace_size;
}
template<int ndim, typename DType>
size_t ReduceWorkspaceSize(Stream<gpu> *s, const mxnet::TShape& small, const OpReqType req,
const mxnet::TShape& big, const mxnet::TShape& lhs, const mxnet::TShape& rhs) {
if (req == kNullOp) return 0;
ReduceImplConfig<ndim> config = ConfigureReduceImpl<ndim, DType>(small, big, &lhs, &rhs);
return config.workspace_size;
}
#endif //MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_