blob: 36ad611e5f792cbaa31eb9f35e904d8b31b88b79 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "elemwise_binary_scalar_op.h"
#if MXNET_USE_CUDA
#include "../../common/cuda/rtc/vectorization-inl.h"
#include "../../common/cuda/rtc.h"
#endif // MXNET_USE_CUDA
namespace mxnet {
namespace op {
#if MXNET_USE_CUDA
struct binary_scalar_kernel_params {
const void* inputs[2];
void* outputs[1];
double scalar;
};
const char binary_scalar_kernel_fwd[] = R"code(
struct binary_scalar_kernel_params {
const void *inputs[2];
void *outputs[1];
double scalar;
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_scalar_kernel(const binary_scalar_kernel_params params,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
using type_util::mixed_type;
VectorizedLoader<InputType0, nvec, aligned> loader(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedStorer<OutputType0, nvec, aligned> storer(
reinterpret_cast<OutputType0*>(params.outputs[0]), N);
using IType = AccType<InputType0>;
using OType = AccType<OutputType0>;
const index_t M = num_aligned_elements;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
if (req == OpReqType::kAddTo) {
storer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const auto input = IType::from(loader.separate()[i]);
// enables returning different type
const auto temp = OP(input,
static_cast<mixed_type<typename IType::type,
typename OType::type>>(params.scalar));
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
storer.separate()[i] = OType::to(temp2);
} else {
storer.separate()[i] = OType::to(temp);
}
}
storer.store(tid, N);
}
}
)code";
void BinaryScalarRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet::common::cuda::rtc;
if (req[0] == kNullOp)
return;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
const double alpha = param.scalar;
const std::string code = std::string("const OpReqType req = ") + util::to_string(req[0]) + ";\n" +
"#define OP op::" + OP + "\n";
const int nvec = common::mshadow_type_info(outputs[0].type_flag_).size == 8 ? 2 : 4;
const index_t size = outputs[0].Size();
binary_scalar_kernel_params params = {{inputs[0].dptr_, nullptr}, {outputs[0].dptr_}, alpha};
VectorizedKernelRTCLauncher(code,
"binary_scalar_kernel",
binary_scalar_kernel_fwd,
nvec,
size,
1,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id);
}
void BinaryScalarRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (req[0] == kNullOp) {
return;
}
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
InitStorageGeometry<1, 1>(attrs, inputs, outputs);
CHECK_NE(outputs[0].storage_type(), kDefaultStorage)
<< "This function works only for sparse types.";
CHECK_EQ(inputs[0].storage_type(), outputs[0].storage_type())
<< "The storage type of both inputs and outputs needs to be the same.";
AllocateGeometry(&outputs[0], req[0], &inputs[0]);
CopyGeometryBlobs<gpu>(ctx.get_stream<gpu>(), &outputs[0], req[0], inputs[0]);
outputs[0].CheckAndAllocData(inputs[0].storage_shape());
if (inputs[0].storage_shape().Size()) {
std::vector<TBlob> in_blobs, out_blobs;
in_blobs.reserve(inputs.size());
out_blobs.reserve(outputs.size());
for (auto& input : inputs) {
in_blobs.emplace_back(input.data());
}
for (auto& output : outputs) {
out_blobs.emplace_back(output.data());
}
this->operator()(attrs, ctx, in_blobs, req, out_blobs);
}
}
const char binary_scalar_kernel_bwd[] = R"code(
struct binary_scalar_kernel_params {
const void *inputs[2];
void *outputs[1];
double scalar;
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_scalar_kernel_bwd(const binary_scalar_kernel_params params,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
using type_util::mixed_type;
VectorizedLoader<InputType0, nvec, aligned> ograd_loader(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedLoader<InputType1, nvec, aligned> input_loader(
reinterpret_cast<const InputType1*>(params.inputs[1]), N);
VectorizedStorer<OutputType0, nvec, aligned> storer(
reinterpret_cast<OutputType0*>(params.outputs[0]), N);
using GType = AccType<InputType0>;
using IType = AccType<InputType1>;
using OType = AccType<OutputType0>;
const index_t M = num_aligned_elements;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
ograd_loader.load(tid, N);
input_loader.load(tid, N);
if (req == OpReqType::kAddTo) {
storer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const auto ograd = GType::from(ograd_loader.separate()[i]);
const auto input = IType::from(input_loader.separate()[i]);
// enables returning different type
const auto temp = op::mul(ograd,
OP(input,
static_cast<mixed_type<typename IType::type,
typename OType::type>>(params.scalar)));
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
storer.separate()[i] = OType::to(temp2);
} else {
storer.separate()[i] = OType::to(temp);
}
}
storer.store(tid, N);
}
}
)code";
void BinaryScalarRTCBackward::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet::common::cuda::rtc;
if (req[0] == kNullOp)
return;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
const double alpha = param.scalar;
const std::string code = std::string("const OpReqType req = ") + util::to_string(req[0]) +
";\n"
"#define OP op::" +
OP + "\n";
const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4;
const index_t size = outputs[0].Size();
binary_scalar_kernel_params params = {
{inputs[0].dptr_, inputs[1].dptr_}, {outputs[0].dptr_}, alpha};
VectorizedKernelRTCLauncher(code,
"binary_scalar_kernel_bwd",
binary_scalar_kernel_bwd,
nvec,
size,
1,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id);
}
#endif // MXNET_USE_CUDA
} // namespace op
} // namespace mxnet