blob: 8f8a38176729066c01962b270bc0149f46cfccfc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <string>
#if MXNET_USE_CUDA
#include <cuda_runtime.h>
#endif // MXNET_USE_CUDA
#include "broadcast_reduce-inl.h"
#include "elemwise_binary_broadcast_op.h"
#if MXNET_USE_CUDA
#include "../../common/cuda/rtc/vectorization-inl.h"
#include "../../common/cuda/rtc.h"
#endif // MXNET_USE_CUDA
namespace mxnet {
namespace op {
#if MXNET_USE_CUDA
struct binary_broadcast_params {
const void* inputs[2];
void* outputs[1];
index_t stride[2][broadcast::MAX_DIM];
index_t oshape[broadcast::MAX_DIM];
index_t size[2];
};
const char broadcast_kernel_fwd[] = R"code(
struct binary_broadcast_params {
const void* inputs[2];
void* outputs[1];
index_t stride[2][util::MAX_DIM];
index_t oshape[util::MAX_DIM];
index_t size[2];
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_broadcast_kernel(
const binary_broadcast_params param,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
const index_t M = num_aligned_elements * other_dim;
VectorizedLoader<InputType0, nvec, aligned> lloader(
reinterpret_cast<const InputType0*>(param.inputs[0]), param.size[0]);
VectorizedLoader<InputType1, nvec, aligned> rloader(
reinterpret_cast<const InputType1*>(param.inputs[1]), param.size[1]);
using IType0 = AccType<InputType0>;
using IType1 = AccType<InputType1>;
using OType = AccType<OutputType0>;
for (index_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < M;
idx += gridDim.x * blockDim.x) {
OutputType0 * current_output_pointer;
index_t output_size;
index_t output_idx;
if (aligned) {
// Simplified case
index_t lindex, rindex;
util::unravel_dot<ndim>(idx * nvec, param.oshape,
param.stride[0], param.stride[1],
&lindex, &rindex);
lloader.load(lindex / nvec, param.size[0]);
rloader.load(rindex / nvec, param.size[1]);
current_output_pointer = reinterpret_cast<OutputType0*>(param.outputs[0]);
output_size = N;
output_idx = idx;
} else {
const index_t row = idx / num_aligned_elements;
const index_t lead_dim_idx = idx - row * num_aligned_elements;
index_t lindex, rindex;
const index_t original_idx = max(lead_dim_idx * nvec - lloader.alignment(),
static_cast<index_t>(0)) +
row * lead_dim;
util::unravel_dot<ndim>(original_idx, param.oshape,
param.stride[0], param.stride[1],
&lindex, &rindex);
lloader.load((lindex + lloader.alignment()) / nvec, param.size[0]);
rloader.load((rindex + lloader.alignment()) / nvec, param.size[1]);
current_output_pointer = reinterpret_cast<OutputType0*>(param.outputs[0]) + row * lead_dim;
output_size = lead_dim;
output_idx = lead_dim_idx;
}
VectorizedStorer<OutputType0, nvec, aligned> storer(current_output_pointer, output_size);
if (req == OpReqType::kAddTo) {
storer.load(output_idx, output_size);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const auto temp = OP(IType0::from(lloader.separate()[i]),
IType1::from(rloader.separate()[i]));
if (req == OpReqType::kAddTo) {
const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
storer.separate()[i] = OType::to(temp2);
} else {
storer.separate()[i] = OType::to(temp);
}
}
storer.store(output_idx, output_size);
}
}
)code";
const char single_side_broadcast_kernel_fwd[] = R"code(
struct binary_broadcast_params {
const void* inputs[2];
void* outputs[1];
index_t stride[2][util::MAX_DIM];
index_t oshape[util::MAX_DIM];
index_t size[2];
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void single_side_binary_broadcast_kernel(
const binary_broadcast_params param,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
const index_t M = num_aligned_elements * other_dim;
constexpr int other_side = 1 - side;
VectorizedLoader<DType, nvec, aligned> lloader(
reinterpret_cast<const DType*>(param.inputs[side]), param.size[side]);
using IType = AccType<DType>;
using IType2 = AccType<DType2>;
using OType = AccType<OutputType0>;
for (index_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < M;
idx += gridDim.x * blockDim.x) {
index_t original_idx;
OutputType0 * current_output_pointer;
index_t output_size;
index_t output_idx;
if (aligned) {
// Simplified case
original_idx = idx * nvec;
const index_t lindex = util::unravel_dot<ndim>(original_idx, param.oshape,
param.stride[side]);
lloader.load(lindex / nvec, param.size[side]);
current_output_pointer = reinterpret_cast<OutputType0*>(param.outputs[0]);
output_size = N;
output_idx = idx;
} else {
const index_t row = idx / num_aligned_elements;
const index_t lead_dim_idx = idx - row * num_aligned_elements;
original_idx = lead_dim_idx * nvec -
lloader.alignment() + row * lead_dim;
const index_t original_idx_clamped = max(lead_dim_idx * nvec - lloader.alignment(),
static_cast<index_t>(0)) +
row * lead_dim;
const index_t lindex = util::unravel_dot<ndim>(original_idx_clamped, param.oshape,
param.stride[side]);
lloader.load((lindex + lloader.alignment()) / nvec, param.size[side]);
current_output_pointer = reinterpret_cast<OutputType0*>(param.outputs[0]) + row * lead_dim;
output_size = lead_dim;
output_idx = lead_dim_idx;
}
VectorizedStorer<OutputType0, nvec, aligned> storer(current_output_pointer, output_size);
if (req == OpReqType::kAddTo) {
storer.load(output_idx, output_size);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const index_t rindex = min(max(util::unravel_dot<ndim>(original_idx + i,
param.oshape,
param.stride[other_side]),
static_cast<index_t>(0)),
param.size[other_side] - 1);
const auto rinput = IType2::from(
reinterpret_cast<const DType2*>(param.inputs[other_side])
[rindex]);
typename OType::type temp;
if (side == 0) {
// Left side is vectorized
temp = OP(IType::from(lloader.separate()[i]),
rinput);
} else {
// Right side is vectorized
temp = OP(rinput,
IType::from(lloader.separate()[i]));
}
if (req == OpReqType::kAddTo) {
const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
storer.separate()[i] = OType::to(temp2);
} else {
storer.separate()[i] = OType::to(temp);
}
}
storer.store(output_idx, output_size);
}
}
)code";
namespace {
std::vector<index_t> calc_stride(const mxnet::TShape& shape, int ndim) {
CHECK_EQ(ndim, shape.ndim());
std::vector<index_t> stride(ndim);
index_t cumprod = 1;
for (int i = shape.ndim() - 1; i >= 0; --i) {
stride[i] = (shape[i] > 1) ? cumprod : 0;
cumprod *= shape[i];
}
return stride;
}
} // namespace
void BinaryBroadcastRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet::common::cuda::rtc;
if (outputs[0].shape_.Size() == 0U)
return;
if (req[0] == kNullOp)
return;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(
inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape);
// Pad the ndim
BROADCAST_NDIM_SWITCH(ndim, NDim, {
if (ndim != 0) {
ndim = NDim;
}
});
if (!ndim) {
ElemwiseBinaryRTCCompute{OP}(attrs, ctx, inputs, req, outputs); // NOLINT
} else {
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
const TBlob& lhs = inputs[0].reshape(new_lshape);
const TBlob& rhs = inputs[1].reshape(new_rshape);
const TBlob& output = outputs[0].reshape(new_oshape);
const auto& lstride = calc_stride(lhs.shape_, ndim);
const auto& rstride = calc_stride(rhs.shape_, ndim);
size_t output_type_size = common::mshadow_type_info(outputs[0].type_flag_).size;
const int nvec =
output_type_size <= sizeof(uint64_t) ? (sizeof(uint64_t) / output_type_size) : 1;
binary_broadcast_params params{};
params.inputs[0] = lhs.dptr_;
params.inputs[1] = rhs.dptr_;
params.outputs[0] = output.dptr_;
for (int i = 0; i < ndim; ++i) {
params.stride[0][i] = lstride[i];
params.stride[1][i] = rstride[i];
params.oshape[i] = new_oshape[i];
}
params.size[0] = lhs.shape_.Size();
params.size[1] = rhs.shape_.Size();
index_t lead_dim = 1;
for (int i = ndim - 1; i >= 0; --i) {
/* Find the first non-1 dimension
to check the alignment
*/
if (params.oshape[i] != 1) {
lead_dim = params.oshape[i];
break;
}
}
const index_t other_dim = output.shape_.Size() / lead_dim;
int first_different = -1;
int common_shape = 1;
for (int i = ndim - 1; i >= 0; --i) {
if (params.stride[0][i] == params.stride[1][i]) {
common_shape *= params.oshape[i];
} else {
first_different = i;
break;
}
}
int lead_input_num = 0;
std::string code = std::string("const OpReqType req = ") + util::to_string(req[0]) +
";\n"
"#define OP op::" +
OP +
"\n"
"const int ndim = " +
std::to_string(ndim) + ";\n";
if (common_shape != 1) {
VectorizedKernelRTCLauncher(code,
"binary_broadcast_kernel",
broadcast_kernel_fwd,
nvec,
lead_dim,
other_dim,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id,
lead_input_num);
} else {
if (params.stride[0][first_different] == 0) {
lead_input_num = 1;
code +=
"const int side = 1;\n"
"using DType = InputType1;\n"
"using DType2 = InputType0;\n";
} else {
code +=
"const int side = 0;\n"
"using DType = InputType0;\n"
"using DType2 = InputType1;\n";
}
VectorizedKernelRTCLauncher(code,
"single_side_binary_broadcast_kernel",
single_side_broadcast_kernel_fwd,
nvec,
lead_dim,
other_dim,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id,
lead_input_num);
}
}
}
void BinaryBroadcastRTCBackwardUseNone::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_,
outputs[1].shape_,
inputs[0].shape_,
&new_lshape,
&new_rshape,
&new_oshape);
if (!ndim) {
ElemwiseBinaryRTCBwdUseNone{LOP, ROP}(attrs, ctx, inputs, req, outputs); // NOLINT
} else {
Stream<gpu>* s = ctx.get_stream<gpu>();
const TBlob lhs = outputs[0].reshape(new_lshape);
const TBlob rhs = outputs[1].reshape(new_rshape);
const TBlob out = inputs[0].reshape(new_oshape);
BROADCAST_NDIM_SWITCH(ndim, NDim, {
// Request temporary storage
size_t workspace_size = new_oshape.Size();
Tensor<gpu, 1, char> workspace = ctx.requested[0].get_space_typed<gpu, 1, char>(
Shape1(workspace_size * sizeof(index_t)), s);
if (out.shape_.Size() != 0) {
broadcast::RTCReduce(ctx, lhs, req[0], workspace, out, "red::sum{}", NDim, LOP);
broadcast::RTCReduce(ctx, rhs, req[1], workspace, out, "red::sum{}", NDim, ROP);
} else {
using namespace common::cuda::rtc::util;
if (lhs.shape_.Size() != 0) {
cudaMemsetAsync(lhs.dptr_,
0,
lhs.shape_.Size() * common::mshadow_type_info(lhs.type_flag_).size,
Stream<gpu>::GetStream(s));
}
if (rhs.shape_.Size() != 0) {
cudaMemsetAsync(rhs.dptr_,
0,
rhs.shape_.Size() * common::mshadow_type_info(rhs.type_flag_).size,
Stream<gpu>::GetStream(s));
}
}
});
}
}
void BinaryBroadcastRTCBackwardUseIn::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
mxnet::TShape new_lshape, new_rshape, new_oshape;
const bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_,
outputs[1].shape_,
inputs[0].shape_,
&new_lshape,
&new_rshape,
&new_oshape) != 0;
if (!need_bc) {
ElemwiseBinaryRTCBwdUseIn{LOP, ROP}(attrs, ctx, inputs, req, outputs); // NOLINT
} else {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
using namespace mshadow;
Stream<gpu>* s = ctx.get_stream<gpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
size_t workspace_size_l = broadcast::ReduceWorkspaceSize(
s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size_r = broadcast::ReduceWorkspaceSize(
s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
if (req[0] != kNullOp) {
broadcast::RTCReduce(
ctx, lgrad, req[0], workspace, ograd, lhs, rhs, "red::sum{}", NDim, "mul", LOP);
}
if (req[1] != kNullOp) {
broadcast::RTCReduce(
ctx, rgrad, req[1], workspace, ograd, lhs, rhs, "red::sum{}", NDim, "mul", ROP);
}
});
}
}
#endif // MXNET_USE_CUDA
} // namespace op
} // namespace mxnet