blob: e808e11215bfa390c2c4cfc95c8f96f07612410c [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file mxrtc.cc
* \brief Wrapper for NVRTC
* \author Junyuan Xie
*/
#include <mxnet/mxrtc.h>
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
namespace mxnet {
const char MXRtc::str_type[] = "float";
std::unordered_map<std::string, char*> MXRtc::kernel_registry;
MXRtc::MXRtc(const std::string& name,
std::vector<std::pair<std::string, NDArray> > const& input,
std::vector<std::pair<std::string, NDArray> > const& output,
const std::string& kernel) {
name_ = name;
num_input_ = input.size();
num_output_ = output.size();
code_ = decorate(name, input, output, kernel);
if (MXRtc::kernel_registry.find(code_) != MXRtc::kernel_registry.end()) {
ptx_ = MXRtc::kernel_registry[code_];
} else {
ptx_ = compile(name, code_);
}
}
void MXRtc::push(std::vector<NDArray> const& input,
std::vector<NDArray> const& output,
unsigned int grid_dim_X,
unsigned int grid_dim_Y,
unsigned int grid_dim_Z,
unsigned int block_dim_X,
unsigned int block_dim_Y,
unsigned int block_dim_Z) {
CHECK_EQ(num_input_, input.size());
CHECK_EQ(num_output_, output.size());
CHECK(output.size());
cudaError_enum err;
CUfunction func;
int dev_id = output[0].ctx().dev_id;
if (func_.find(dev_id) != func_.end()) {
func = func_[dev_id];
} else {
CUmodule module;
CHECK_EQ(err = cuModuleLoadDataEx(&module, ptx_, 0, 0, 0), CUDA_SUCCESS)
<< "CudaError: " << err;
CHECK_EQ(err = cuModuleGetFunction(&func, module, name_.c_str()), CUDA_SUCCESS)
<< "CudaError: " << err;
module_[dev_id] = module;
func_[dev_id] = func;
}
auto op = [this, func, input, output,
grid_dim_X, grid_dim_Y, grid_dim_Z,
block_dim_X, block_dim_Y, block_dim_Z](RunContext rctx) {
std::vector<float*> float_args;
for (auto& i : input) float_args.push_back(static_cast<float*>(i.data().dptr_));
for (auto& i : output) float_args.push_back(static_cast<float*>(i.data().dptr_));
std::vector<void*> args;
for (auto& i : float_args) args.push_back(&i);
cudaError_enum err;
cudaError_t cuerr;
CHECK_EQ(err = cuLaunchKernel(func,
grid_dim_X, grid_dim_Y, grid_dim_Z,
block_dim_X, block_dim_Y, block_dim_Z,
0, rctx.get_stream<mshadow::gpu>()->stream_,
args.data(), 0), CUDA_SUCCESS) << "CudaError: " << err;
CHECK_EQ(cuerr = cudaStreamSynchronize(rctx.get_stream<mshadow::gpu>()->stream_),
cudaSuccess) << "CudaError: " << cuerr;
};
std::vector<Engine::VarHandle> var_in, var_out;
for (auto& i : input) var_in.push_back(i.var());
for (auto& i : output) var_out.push_back(i.var());
Engine::Get()->PushSync(op, output[0].ctx(), var_in, var_out,
FnProperty::kNormal, 0, PROFILER_MESSAGE("MXRtc"));
}
std::string MXRtc::decorate(const std::string& name,
std::vector<std::pair<std::string, NDArray> > const& input,
std::vector<std::pair<std::string, NDArray> > const& output,
const std::string kernel) {
std::string source;
source = source + "\nextern \"C\" __global__ void " + name + "(";
for (auto &i : input) {
source = source + "const " + str_type + "* " + i.first + ",";
}
for (auto &i : output) {
source = source + str_type + "* " + i.first + ",";
}
source.pop_back();
source = source + ") {\n";
for (auto &i : input) {
source = source + "const int " + i.first + "_ndim = " +
std::to_string(i.second.shape().ndim()) + ";\n";
source = source + "const int " + i.first + "_dims[] = {";
for (index_t j = 0; j < i.second.shape().ndim(); ++j) {
source = source + std::to_string(i.second.shape()[j]) + ",";
}
source.pop_back();
source = source + "};\n";
}
for (auto &i : output) {
source = source + "const int " + i.first + "_ndim = " +
std::to_string(i.second.shape().ndim()) + ";\n";
source = source + "const int " + i.first + "_dims[] = {";
for (index_t j = 0; j < i.second.shape().ndim(); ++j) {
source = source + std::to_string(i.second.shape()[j]) + ",";
}
source.pop_back();
source = source + "};\n";
}
source = source + kernel + "\n}\n";
return source;
}
char* MXRtc::compile(const std::string& name, const std::string& code) {
nvrtcProgram prog;
CHECK_EQ(nvrtcCreateProgram(&prog,
code.c_str(),
(name+".cu").c_str(),
0,
NULL,
NULL), NVRTC_SUCCESS);
nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, NULL);
size_t log_size;
CHECK_EQ(nvrtcGetProgramLogSize(prog, &log_size), NVRTC_SUCCESS);
char *log = new char[log_size];
CHECK_EQ(nvrtcGetProgramLog(prog, log), NVRTC_SUCCESS);
CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
CHECK_EQ(nvrtcGetPTXSize(prog, &ptx_size), NVRTC_SUCCESS);
char *ptx = new char[ptx_size];
CHECK_EQ(nvrtcGetPTX(prog, ptx), NVRTC_SUCCESS);
CHECK_EQ(nvrtcDestroyProgram(&prog), NVRTC_SUCCESS);
return ptx;
}
} // namespace mxnet
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))