blob: 76c3064db71a6877f67c45d8247171c8acb400cc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_RTC_H_
#define MXNET_RTC_H_
#include "./base.h"
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
#include <nvrtc.h>
#include <cuda.h>
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "./ndarray.h"
namespace mxnet {
namespace rtc {
/*! \brief Cuda runtime compile module. */
class CudaModule {
private:
/*! \brief Structure for holding internal info. */
struct Chunk {
/*!
* \brief Constructs cuda module.
* \param source cuda source code.
* \param exports export symbols before mangling.
*/
Chunk(const char* source,
const std::vector<std::string>& options,
const std::vector<std::string>& exports);
/*! \brief deconstrutor */
~Chunk();
/*!
* \brief Get handle to cuda kernel from loaded module
* \param mangled_name mangled kernel name
* \param ctx context to run kernel on
* \return loaded function handle
*/
CUfunction GetFunction(const std::string& mangled_name, const Context& ctx);
/*! \brief nvrtc program handle. */
nvrtcProgram prog_;
/*! \brief compiled cuda PTX */
char* ptx_;
/*! \brief lazily loaded cuda module */
std::unordered_map<int, CUmodule> mod_;
/*! \brief exported names */
std::unordered_set<std::string> exports_;
};
/*! \brief pointer to Chunk */
std::shared_ptr<Chunk> ptr_;
public:
/*! \brief cuda kernel argument descriptor */
struct ArgType {
/*! \brief whether argument is NDArray */
bool is_ndarray;
/*! \brief whether argument is constant (input) */
bool is_const;
/*! \brief data type of argument */
mshadow::TypeFlag dtype;
};
/*! \brief Cuda kernel */
class Kernel {
public:
/*! \brief Launch the kernel */
void Launch(const Context& ctx, const std::vector<dmlc::any>& args,
uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
uint32_t shared_mem);
/*! \brief kernel interface signature */
const std::vector<ArgType>& signature() { return signature_; }
private:
friend class CudaModule;
/*!
* \brief constructor
* \param mod module of this kernel
* \param mangled_name mangled kernel name
* \param signature kernel argument signature
*/
Kernel(const std::shared_ptr<Chunk>& mod,
const std::string& mangled_name,
const std::vector<ArgType>& signature);
/*! \brief mangled kernel name */
std::string mangled_name_;
/*! \brief kernel argument signature */
std::vector<ArgType> signature_;
/*! \brief module of this kernel */
std::shared_ptr<Chunk> mod_;
/*! \brief cached kernel function on each device */
std::unordered_map<int, CUfunction> func_;
};
/*!
* \brief CudaModule constructor
* \param source cuda source code.
* \param exports export symbols before mangling.
*/
CudaModule(const char* source,
const std::vector<std::string>& options,
const std::vector<std::string>& exports)
: ptr_(std::make_shared<Chunk>(source, options, exports)) {}
/*!
* \brief Get cuda kernal from module by name
* \param name kernel name
* \param signature kernel signature
* \return shared pointer to cuda kernel
*/
std::shared_ptr<Kernel> GetKernel(const std::string& name,
const std::vector<ArgType>& signature);
};
} // namespace rtc
} // namespace mxnet
#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
#endif // MXNET_RTC_H_