| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #ifndef MXNET_RTC_H_ |
| #define MXNET_RTC_H_ |
| #include "./base.h" |
| #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC |
| #include <nvrtc.h> |
| #include <cuda.h> |
| |
| #include <vector> |
| #include <string> |
| #include <memory> |
| #include <utility> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include "./ndarray.h" |
| |
| namespace mxnet { |
| namespace rtc { |
| |
| /*! \brief Cuda runtime compile module. */ |
| class CudaModule { |
| private: |
| /*! \brief Structure for holding internal info. */ |
| struct Chunk { |
| /*! |
| * \brief Constructs cuda module. |
| * \param source cuda source code. |
| * \param exports export symbols before mangling. |
| */ |
| Chunk(const char* source, |
| const std::vector<std::string>& options, |
| const std::vector<std::string>& exports); |
| /*! \brief deconstrutor */ |
| ~Chunk(); |
| /*! |
| * \brief Get handle to cuda kernel from loaded module |
| * \param mangled_name mangled kernel name |
| * \param ctx context to run kernel on |
| * \return loaded function handle |
| */ |
| CUfunction GetFunction(const std::string& mangled_name, const Context& ctx); |
| /*! \brief nvrtc program handle. */ |
| nvrtcProgram prog_; |
| /*! \brief compiled cuda PTX */ |
| char* ptx_; |
| /*! \brief lazily loaded cuda module */ |
| std::unordered_map<int, CUmodule> mod_; |
| /*! \brief exported names */ |
| std::unordered_set<std::string> exports_; |
| }; |
| /*! \brief pointer to Chunk */ |
| std::shared_ptr<Chunk> ptr_; |
| |
| public: |
| /*! \brief cuda kernel argument descriptor */ |
| struct ArgType { |
| /*! \brief whether argument is NDArray */ |
| bool is_ndarray; |
| /*! \brief whether argument is constant (input) */ |
| bool is_const; |
| /*! \brief data type of argument */ |
| mshadow::TypeFlag dtype; |
| }; |
| /*! \brief Cuda kernel */ |
| class Kernel { |
| public: |
| /*! \brief Launch the kernel */ |
| void Launch(const Context& ctx, const std::vector<dmlc::any>& args, |
| uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z, |
| uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z, |
| uint32_t shared_mem); |
| /*! \brief kernel interface signature */ |
| const std::vector<ArgType>& signature() { return signature_; } |
| |
| private: |
| friend class CudaModule; |
| /*! |
| * \brief constructor |
| * \param mod module of this kernel |
| * \param mangled_name mangled kernel name |
| * \param signature kernel argument signature |
| */ |
| Kernel(const std::shared_ptr<Chunk>& mod, |
| const std::string& mangled_name, |
| const std::vector<ArgType>& signature); |
| /*! \brief mangled kernel name */ |
| std::string mangled_name_; |
| /*! \brief kernel argument signature */ |
| std::vector<ArgType> signature_; |
| /*! \brief module of this kernel */ |
| std::shared_ptr<Chunk> mod_; |
| /*! \brief cached kernel function on each device */ |
| std::unordered_map<int, CUfunction> func_; |
| }; |
| /*! |
| * \brief CudaModule constructor |
| * \param source cuda source code. |
| * \param exports export symbols before mangling. |
| */ |
| CudaModule(const char* source, |
| const std::vector<std::string>& options, |
| const std::vector<std::string>& exports) |
| : ptr_(std::make_shared<Chunk>(source, options, exports)) {} |
| /*! |
| * \brief Get cuda kernal from module by name |
| * \param name kernel name |
| * \param signature kernel signature |
| * \return shared pointer to cuda kernel |
| */ |
| std::shared_ptr<Kernel> GetKernel(const std::string& name, |
| const std::vector<ArgType>& signature); |
| }; |
| |
| } // namespace rtc |
| } // namespace mxnet |
| |
| #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC |
| #endif // MXNET_RTC_H_ |