| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file utils.h |
| * \brief Common CUDA utilities. |
| */ |
| #ifndef MXNET_COMMON_CUDA_UTILS_H_ |
| #define MXNET_COMMON_CUDA_UTILS_H_ |
| |
| #include <dmlc/logging.h> |
| #include <dmlc/parameter.h> |
| #include <dmlc/optional.h> |
| #include <mshadow/base.h> |
| #include <mxnet/libinfo.h> |
| |
| /*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */ |
| #ifdef __JETBRAINS_IDE__ |
| #define __CUDACC__ 1 |
| #define __host__ |
| #define __device__ |
| #define __global__ |
| #define __forceinline__ |
| #define __shared__ |
| inline void __syncthreads() {} |
| inline void __threadfence_block() {} |
| template <class T> |
| inline T __clz(const T val) { |
| return val; |
| } |
| struct __cuda_fake_struct { |
| int x; |
| int y; |
| int z; |
| }; |
| extern __cuda_fake_struct blockDim; |
| extern __cuda_fake_struct threadIdx; |
| extern __cuda_fake_struct blockIdx; |
| #endif |
| |
| #define QUOTE(x) #x |
| #define QUOTEVALUE(x) QUOTE(x) |
| |
| #if MXNET_USE_CUDA |
| |
| #include <cuda_runtime.h> |
| #include <cublas_v2.h> |
| #include <curand.h> |
| #if MXNET_USE_NVML |
| #include <nvml.h> |
| #endif // MXNET_USE_NVML |
| |
| #include <vector> |
| |
| #define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \ |
| static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \ |
| QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \ |
| QUOTEVALUE(min_version) " or later.") |
| |
| /*! |
| * \brief When compiling a __device__ function, check that the architecture is >= Kepler (3.0) |
| * Note that __CUDA_ARCH__ is not defined outside of a __device__ function |
| */ |
| #ifdef __CUDACC__ |
| inline __device__ bool __is_supported_cuda_architecture() { |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300 |
| #error "Fermi and earlier GPU architectures are not supported (architecture versions less than 3.0)" |
| return false; |
| #else |
| return true; |
| #endif // __CUDA_ARCH__ < 300 |
| } |
| #endif // __CUDACC__ |
| |
| /*! |
| * \brief Check CUDA error. |
| * \param msg Message to print if an error occured. |
| */ |
| #define CHECK_CUDA_ERROR(msg) \ |
| { \ |
| cudaError_t e = cudaGetLastError(); \ |
| CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected CUDA call. |
| * \param func Expression to call. |
| * |
| * It checks for CUDA errors after invocation of the expression. |
| */ |
| #define CUDA_CALL(func) \ |
| { \ |
| cudaError_t e = (func); \ |
| CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected cuBLAS call. |
| * \param func Expression to call. |
| * |
| * It checks for cuBLAS errors after invocation of the expression. |
| */ |
| #define CUBLAS_CALL(func) \ |
| { \ |
| cublasStatus_t e = (func); \ |
| CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ |
| << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected cuSolver call. |
| * \param func Expression to call. |
| * |
| * It checks for cuSolver errors after invocation of the expression. |
| */ |
| #define CUSOLVER_CALL(func) \ |
| { \ |
| cusolverStatus_t e = (func); \ |
| CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ |
| << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected cuRAND call. |
| * \param func Expression to call. |
| * |
| * It checks for cuRAND errors after invocation of the expression. |
| */ |
| #define CURAND_CALL(func) \ |
| { \ |
| curandStatus_t e = (func); \ |
| CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ |
| << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected NVRTC call. |
| * \param func Expression to call. |
| * |
| * It checks for NVRTC errors after invocation of the expression. |
| */ |
| #define NVRTC_CALL(x) \ |
| { \ |
| nvrtcResult result = x; \ |
| CHECK_EQ(result, NVRTC_SUCCESS) << #x " failed with error " << nvrtcGetErrorString(result); \ |
| } |
| |
| /*! |
| * \brief Protected CUDA driver call. |
| * \param func Expression to call. |
| * |
| * It checks for CUDA driver errors after invocation of the expression. |
| */ |
| #define CUDA_DRIVER_CALL(func) \ |
| { \ |
| CUresult e = (func); \ |
| if (e != CUDA_SUCCESS) { \ |
| char const* err_msg = nullptr; \ |
| if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ |
| LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ |
| } else { \ |
| LOG(FATAL) << "CUDA Driver: " << e << " " << err_msg; \ |
| } \ |
| } \ |
| } |
| |
| #if MXNET_USE_NVML |
| /*! |
| * \brief Protected NVML call. |
| * \param func Expression to call. |
| * |
| * It checks for NVML errors after invocation of the expression. |
| */ |
| #define NVML_CALL(func) \ |
| { \ |
| nvmlReturn_t result = (func); \ |
| CHECK_EQ(result, NVML_SUCCESS) << #func " failed with error " << nvmlErrorString(result); \ |
| } |
| #endif // MXNET_USE_NVML |
| |
| #if !defined(_MSC_VER) |
| #define CUDA_UNROLL _Pragma("unroll") |
| #define CUDA_NOUNROLL _Pragma("nounroll") |
| #else |
| #define CUDA_UNROLL |
| #define CUDA_NOUNROLL |
| #endif |
| |
| namespace mxnet { |
| namespace common { |
| /*! \brief common utils for cuda */ |
| namespace cuda { |
| /*! |
| * \brief Converts between C++ datatypes and enums/constants needed by cuBLAS. |
| */ |
| template <typename DType> |
| struct CublasType; |
| |
| // With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own |
| // datatype cublasDataType_t. The older cudaDataType_t values could be |
| // included below, but since this class was introduced to support the cuBLAS v8 |
| // call cublasGemmEx(), burdening the class with the legacy type values |
| // was not needed. |
| |
| template <> |
| struct CublasType<float> { |
| static const int kFlag = mshadow::kFloat32; |
| #if CUDA_VERSION >= 8000 |
| static const cudaDataType_t kCudaFlag = CUDA_R_32F; |
| #endif |
| typedef float ScaleType; |
| static const float one; |
| static const float zero; |
| }; |
| template <> |
| struct CublasType<double> { |
| static const int kFlag = mshadow::kFloat64; |
| #if CUDA_VERSION >= 8000 |
| static const cudaDataType_t kCudaFlag = CUDA_R_64F; |
| #endif |
| typedef double ScaleType; |
| static const double one; |
| static const double zero; |
| }; |
| template <> |
| struct CublasType<mshadow::half::half_t> { |
| static const int kFlag = mshadow::kFloat16; |
| #if CUDA_VERSION >= 8000 |
| static const cudaDataType_t kCudaFlag = CUDA_R_16F; |
| #endif |
| typedef float ScaleType; |
| static const mshadow::half::half_t one; |
| static const mshadow::half::half_t zero; |
| }; |
| template <> |
| struct CublasType<uint8_t> { |
| static const int kFlag = mshadow::kUint8; |
| #if CUDA_VERSION >= 8000 |
| static const cudaDataType_t kCudaFlag = CUDA_R_8I; |
| #endif |
| typedef uint8_t ScaleType; |
| static const uint8_t one = 1; |
| static const uint8_t zero = 0; |
| }; |
| template <> |
| struct CublasType<int32_t> { |
| static const int kFlag = mshadow::kInt32; |
| #if CUDA_VERSION >= 8000 |
| static const cudaDataType_t kCudaFlag = CUDA_R_32I; |
| #endif |
| typedef int32_t ScaleType; |
| static const int32_t one = 1; |
| static const int32_t zero = 0; |
| }; |
| |
| /*! |
| * \brief Get string representation of cuBLAS errors. |
| * \param error The error. |
| * \return String representation. |
| */ |
| inline const char* CublasGetErrorString(cublasStatus_t error) { |
| switch (error) { |
| case CUBLAS_STATUS_SUCCESS: |
| return "CUBLAS_STATUS_SUCCESS"; |
| case CUBLAS_STATUS_NOT_INITIALIZED: |
| return "CUBLAS_STATUS_NOT_INITIALIZED"; |
| case CUBLAS_STATUS_ALLOC_FAILED: |
| return "CUBLAS_STATUS_ALLOC_FAILED"; |
| case CUBLAS_STATUS_INVALID_VALUE: |
| return "CUBLAS_STATUS_INVALID_VALUE"; |
| case CUBLAS_STATUS_ARCH_MISMATCH: |
| return "CUBLAS_STATUS_ARCH_MISMATCH"; |
| case CUBLAS_STATUS_MAPPING_ERROR: |
| return "CUBLAS_STATUS_MAPPING_ERROR"; |
| case CUBLAS_STATUS_EXECUTION_FAILED: |
| return "CUBLAS_STATUS_EXECUTION_FAILED"; |
| case CUBLAS_STATUS_INTERNAL_ERROR: |
| return "CUBLAS_STATUS_INTERNAL_ERROR"; |
| case CUBLAS_STATUS_NOT_SUPPORTED: |
| return "CUBLAS_STATUS_NOT_SUPPORTED"; |
| default: |
| break; |
| } |
| return "Unknown cuBLAS status"; |
| } |
| |
| #if CUDA_VERSION >= 8000 |
| /*! |
| * \brief Create the proper constant for indicating cuBLAS transposition, if desired. |
| * \param transpose Whether transposition should be performed. |
| * \return the yes/no transposition-indicating constant. |
| */ |
| inline cublasOperation_t CublasTransposeOp(bool transpose) { |
| return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; |
| } |
| #endif |
| |
| /*! |
| * \brief Get string representation of cuSOLVER errors. |
| * \param error The error. |
| * \return String representation. |
| */ |
| inline const char* CusolverGetErrorString(cusolverStatus_t error) { |
| switch (error) { |
| case CUSOLVER_STATUS_SUCCESS: |
| return "CUSOLVER_STATUS_SUCCESS"; |
| case CUSOLVER_STATUS_NOT_INITIALIZED: |
| return "CUSOLVER_STATUS_NOT_INITIALIZED"; |
| case CUSOLVER_STATUS_ALLOC_FAILED: |
| return "CUSOLVER_STATUS_ALLOC_FAILED"; |
| case CUSOLVER_STATUS_INVALID_VALUE: |
| return "CUSOLVER_STATUS_INVALID_VALUE"; |
| case CUSOLVER_STATUS_ARCH_MISMATCH: |
| return "CUSOLVER_STATUS_ARCH_MISMATCH"; |
| case CUSOLVER_STATUS_EXECUTION_FAILED: |
| return "CUSOLVER_STATUS_EXECUTION_FAILED"; |
| case CUSOLVER_STATUS_INTERNAL_ERROR: |
| return "CUSOLVER_STATUS_INTERNAL_ERROR"; |
| case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: |
| return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; |
| default: |
| break; |
| } |
| return "Unknown cuSOLVER status"; |
| } |
| |
| /*! |
| * \brief Get string representation of cuRAND errors. |
| * \param status The status. |
| * \return String representation. |
| */ |
| inline const char* CurandGetErrorString(curandStatus_t status) { |
| switch (status) { |
| case CURAND_STATUS_SUCCESS: |
| return "CURAND_STATUS_SUCCESS"; |
| case CURAND_STATUS_VERSION_MISMATCH: |
| return "CURAND_STATUS_VERSION_MISMATCH"; |
| case CURAND_STATUS_NOT_INITIALIZED: |
| return "CURAND_STATUS_NOT_INITIALIZED"; |
| case CURAND_STATUS_ALLOCATION_FAILED: |
| return "CURAND_STATUS_ALLOCATION_FAILED"; |
| case CURAND_STATUS_TYPE_ERROR: |
| return "CURAND_STATUS_TYPE_ERROR"; |
| case CURAND_STATUS_OUT_OF_RANGE: |
| return "CURAND_STATUS_OUT_OF_RANGE"; |
| case CURAND_STATUS_LENGTH_NOT_MULTIPLE: |
| return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; |
| case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: |
| return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; |
| case CURAND_STATUS_LAUNCH_FAILURE: |
| return "CURAND_STATUS_LAUNCH_FAILURE"; |
| case CURAND_STATUS_PREEXISTING_FAILURE: |
| return "CURAND_STATUS_PREEXISTING_FAILURE"; |
| case CURAND_STATUS_INITIALIZATION_FAILED: |
| return "CURAND_STATUS_INITIALIZATION_FAILED"; |
| case CURAND_STATUS_ARCH_MISMATCH: |
| return "CURAND_STATUS_ARCH_MISMATCH"; |
| case CURAND_STATUS_INTERNAL_ERROR: |
| return "CURAND_STATUS_INTERNAL_ERROR"; |
| } |
| return "Unknown cuRAND status"; |
| } |
| |
| template <typename DType> |
| inline DType __device__ CudaMax(DType a, DType b) { |
| return a > b ? a : b; |
| } |
| |
| template <typename DType> |
| inline DType __device__ CudaMin(DType a, DType b) { |
| return a < b ? a : b; |
| } |
| |
| class DeviceStore { |
| public: |
| /*! \brief default constructor- only optionally restores previous device */ |
| explicit DeviceStore(int requested_device = -1, bool restore = true) |
| : restore_device_(-1), current_device_(requested_device), restore_(restore) { |
| if (restore_) |
| CUDA_CALL(cudaGetDevice(&restore_device_)); |
| if (requested_device != restore_device_) { |
| SetDevice(requested_device); |
| } |
| } |
| |
| ~DeviceStore() { |
| if (restore_ && current_device_ != restore_device_ && current_device_ != -1 && |
| restore_device_ != -1) |
| CUDA_CALL(cudaSetDevice(restore_device_)); |
| } |
| |
| void SetDevice(int device) { |
| if (device != -1) { |
| CUDA_CALL(cudaSetDevice(device)); |
| current_device_ = device; |
| } |
| } |
| |
| private: |
| int restore_device_; |
| int current_device_; |
| bool restore_; |
| }; |
| |
| /*! |
| * \brief Get the largest datatype suitable to read |
| * requested number of bytes. |
| * |
| * \input Number of bytes to be read |
| * \return mshadow representation of type that could |
| * be used for reading |
| */ |
| int get_load_type(size_t N); |
| |
| /*! |
| * \brief Determine how many rows in a 2D matrix should a block |
| * of threads handle based on the row size and the number |
| * of threads in a block. |
| * \param row_size Size of the row expressed in the number of reads required to fully |
| * load it. For example, if the row has N elements, but each thread |
| * reads 2 elements with a single read, row_size should be N / 2. |
| * \param num_threads_per_block Number of threads in a block. |
| * \return the number of rows that should be handled by a single block. |
| */ |
| int get_rows_per_block(size_t row_size, int num_threads_per_block); |
| |
| } // namespace cuda |
| } // namespace common |
| } // namespace mxnet |
| |
| /*! \brief Maximum number of GPUs */ |
| constexpr size_t kMaxNumGpus = 64; |
| |
| // The implementations below assume that accesses of 32-bit ints are inherently atomic and |
| // can be read/written by multiple threads without locks. The values held should be < 2^31. |
| |
| /*! |
| * \brief Return an attribute GPU `device_id`. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \param cached_values An array of attributes for already-looked-up GPUs. |
| * \param attr The attribute, by number. |
| * \param attr_name A string representation of the attribute, for error messages. |
| * \return the gpu's attribute value. |
| */ |
| inline int cudaAttributeLookup(int device_id, |
| std::vector<int32_t>* cached_values, |
| cudaDeviceAttr attr, |
| const char* attr_name) { |
| if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) { |
| LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id; |
| } else if ((*cached_values)[device_id] < 0) { |
| int temp = -1; |
| CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id)); |
| (*cached_values)[device_id] = static_cast<int32_t>(temp); |
| } |
| return (*cached_values)[device_id]; |
| } |
| |
| /*! |
| * \brief Determine major version number of the gpu's cuda compute architecture. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return the major version number of the gpu's cuda compute architecture. |
| */ |
| inline int ComputeCapabilityMajor(int device_id) { |
| static std::vector<int32_t> capability_major(kMaxNumGpus, -1); |
| return cudaAttributeLookup( |
| device_id, &capability_major, cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor"); |
| } |
| |
| /*! |
| * \brief Determine minor version number of the gpu's cuda compute architecture. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return the minor version number of the gpu's cuda compute architecture. |
| */ |
| inline int ComputeCapabilityMinor(int device_id) { |
| static std::vector<int32_t> capability_minor(kMaxNumGpus, -1); |
| return cudaAttributeLookup( |
| device_id, &capability_minor, cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor"); |
| } |
| |
| /*! |
| * \brief Return the integer SM architecture (e.g. Volta = 70). |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return the gpu's cuda compute architecture as an int. |
| */ |
| inline int SMArch(int device_id) { |
| auto major = ComputeCapabilityMajor(device_id); |
| auto minor = ComputeCapabilityMinor(device_id); |
| return 10 * major + minor; |
| } |
| |
| /*! |
| * \brief Return the number of streaming multiprocessors of GPU `device_id`. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return the gpu's count of streaming multiprocessors. |
| */ |
| inline int MultiprocessorCount(int device_id) { |
| static std::vector<int32_t> sm_counts(kMaxNumGpus, -1); |
| return cudaAttributeLookup( |
| device_id, &sm_counts, cudaDevAttrMultiProcessorCount, "MultiprocessorCount"); |
| } |
| |
| /*! |
| * \brief Return the shared memory size in bytes of each of the GPU's streaming multiprocessors. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return the shared memory size per streaming multiprocessor. |
| */ |
| inline int MaxSharedMemoryPerMultiprocessor(int device_id) { |
| static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1); |
| return cudaAttributeLookup(device_id, |
| &max_smem_per_mutiprocessor, |
| cudaDevAttrMaxSharedMemoryPerMultiprocessor, |
| "MaxSharedMemoryPerMultiprocessor"); |
| } |
| |
| /*! |
| * \brief Return whether the GPU `device_id` supports cooperative-group kernel launching. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return the gpu's ability to run cooperative-group kernels. |
| */ |
| inline bool SupportsCooperativeLaunch(int device_id) { |
| static std::vector<int32_t> coop_launch(kMaxNumGpus, -1); |
| return cudaAttributeLookup( |
| device_id, &coop_launch, cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch"); |
| } |
| |
| /*! |
| * \brief Determine whether a cuda-capable gpu's architecture supports float16 math. |
| * Assume not if device_id is negative. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return whether the gpu's architecture supports float16 math. |
| */ |
| inline bool SupportsFloat16Compute(int device_id) { |
| if (device_id < 0) { |
| return false; |
| } else { |
| // Kepler and most Maxwell GPUs do not support fp16 compute |
| int computeCapabilityMajor = ComputeCapabilityMajor(device_id); |
| return (computeCapabilityMajor > 5) || |
| (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3); |
| } |
| } |
| |
| /*! |
| * \brief Determine whether a cuda-capable gpu's architecture supports Tensor Core math. |
| * Assume not if device_id is negative. |
| * \param device_id The device index of the cuda-capable gpu of interest. |
| * \return whether the gpu's architecture supports Tensor Core math. |
| */ |
| inline bool SupportsTensorCore(int device_id) { |
| // Volta (sm_70) supports TensorCore algos |
| return device_id >= 0 && ComputeCapabilityMajor(device_id) >= 7; |
| } |
| |
| // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE |
| #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true |
| |
| /*! |
| * \brief Returns global policy for TensorCore algo use. |
| * \return whether to allow TensorCore algo (if not specified by the Operator locally). |
| */ |
| inline bool GetEnvAllowTensorCore() { |
| // Since these statics are in the '.h' file, they will exist and will be set |
| // separately in each compilation unit. Not ideal, but cleaner than creating a |
| // cuda_utils.cc solely to have a single instance and initialization. |
| static bool allow_tensor_core = false; |
| static bool is_set = false; |
| if (!is_set) { |
| // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal. |
| bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT; |
| allow_tensor_core = |
| dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", dmlc::optional<bool>(default_value)).value(); |
| is_set = true; |
| } |
| return allow_tensor_core; |
| } |
| |
| // The policy if the user hasn't set the environment variable |
| // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION |
| #define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT false |
| |
| /*! |
| * \brief Returns global policy for TensorCore implicit type casting |
| */ |
| inline bool GetEnvAllowTensorCoreConversion() { |
| // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be |
| // legal. |
| bool default_value = MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT; |
| return dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION", |
| dmlc::optional<bool>(default_value)) |
| .value(); |
| } |
| |
| #if CUDA_VERSION >= 9000 |
| // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous. |
| inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) { |
| auto handle_math_mode = CUBLAS_DEFAULT_MATH; |
| CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode)); |
| CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type)); |
| return handle_math_mode; |
| } |
| #endif |
| |
| #endif // MXNET_USE_CUDA |
| |
| #if MXNET_USE_CUDNN |
| |
| #include <cudnn.h> |
| |
| // Creating CUDNN_VERSION_AS_STRING as follows avoids a static_assert error message that shows |
| // the formula for CUDNN_VERSION, i.e. "1000 * 7 + 100 * 6 + 0" rather than number "7600". |
| static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10, |
| "CUDNN_VERSION_AS_STRING macro assumptions violated."); |
| #if CUDNN_PATCHLEVEL >= 10 |
| #define CUDNN_VERSION_AS_STRING \ |
| QUOTEVALUE(CUDNN_MAJOR) \ |
| QUOTEVALUE(CUDNN_MINOR) \ |
| QUOTEVALUE(CUDNN_PATCHLEVEL) |
| #else |
| #define CUDNN_VERSION_AS_STRING \ |
| QUOTEVALUE(CUDNN_MAJOR) \ |
| QUOTEVALUE(CUDNN_MINOR) \ |
| "0" QUOTEVALUE(CUDNN_PATCHLEVEL) |
| #endif |
| |
| #define STATIC_ASSERT_CUDNN_VERSION_GE(min_version) \ |
| static_assert( \ |
| CUDNN_VERSION >= min_version, \ |
| "Compiled-against cuDNN version " CUDNN_VERSION_AS_STRING \ |
| " is too old, please upgrade system to version " QUOTEVALUE(min_version) " or later.") |
| |
| #define CUDNN_CALL_S(f, s) \ |
| { \ |
| cudnnStatus_t unclash_cxx_e = (f); \ |
| if (unclash_cxx_e != CUDNN_STATUS_SUCCESS) \ |
| LOG(s) << "cuDNN: " << cudnnGetErrorString(unclash_cxx_e); \ |
| } |
| |
| #define CUDNN_CALL(f) CUDNN_CALL_S(f, FATAL) |
| #define CUDNN_CALL_NONFATAL(f) CUDNN_CALL_S(f, WARNING) |
| |
| #define CUTENSOR_CALL(func) \ |
| { \ |
| cutensorStatus_t e = (func); \ |
| CHECK_EQ(e, CUTENSOR_STATUS_SUCCESS) << "cuTensor: " << cutensorGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Return max number of perf structs cudnnFindConvolutionForwardAlgorithm() |
| * may want to populate. |
| * \param cudnn_handle cudnn handle needed to perform the inquiry. |
| * \return max number of perf structs cudnnFindConvolutionForwardAlgorithm() may |
| * want to populate. |
| */ |
| inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) { |
| STATIC_ASSERT_CUDNN_VERSION_GE(7000); |
| int max_algos = 0; |
| CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos)); |
| return max_algos; |
| } |
| |
| /*! |
| * \brief Return max number of perf structs cudnnFindConvolutionBackwardFilterAlgorithm() |
| * may want to populate. |
| * \param cudnn_handle cudnn handle needed to perform the inquiry. |
| * \return max number of perf structs cudnnFindConvolutionBackwardFilterAlgorithm() may |
| * want to populate. |
| */ |
| inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) { |
| STATIC_ASSERT_CUDNN_VERSION_GE(7000); |
| int max_algos = 0; |
| CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos)); |
| return max_algos; |
| } |
| |
| /*! |
| * \brief Return max number of perf structs cudnnFindConvolutionBackwardDataAlgorithm() |
| * may want to populate. |
| * \param cudnn_handle cudnn handle needed to perform the inquiry. |
| * \return max number of perf structs cudnnFindConvolutionBackwardDataAlgorithm() may |
| * want to populate. |
| */ |
| inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) { |
| STATIC_ASSERT_CUDNN_VERSION_GE(7000); |
| int max_algos = 0; |
| CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos)); |
| return max_algos; |
| } |
| |
| #endif // MXNET_USE_CUDNN |
| |
| // Overload atomicAdd to work for floats on all architectures |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 |
| // From CUDA Programming Guide |
| static inline __device__ void atomicAdd(double* address, double val) { |
| unsigned long long* address_as_ull = // NOLINT(*) |
| reinterpret_cast<unsigned long long*>(address); // NOLINT(*) |
| unsigned long long old = *address_as_ull; // NOLINT(*) |
| unsigned long long assumed; // NOLINT(*) |
| |
| do { |
| assumed = old; |
| old = atomicCAS( |
| address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); |
| |
| // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) |
| } while (assumed != old); |
| } |
| #endif |
| |
| // Overload atomicAdd for half precision |
| // Taken from: |
| // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh |
| #ifdef __CUDACC__ |
| static inline __device__ void atomicAdd(mshadow::half::half_t* address, mshadow::half::half_t val) { |
| unsigned int* address_as_ui = reinterpret_cast<unsigned int*>( |
| reinterpret_cast<char*>(address) - (reinterpret_cast<size_t>(address) & 2)); |
| unsigned int old = *address_as_ui; |
| unsigned int assumed; |
| |
| do { |
| assumed = old; |
| mshadow::half::half_t hsum; |
| hsum.half_ = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff); |
| hsum += val; |
| old = reinterpret_cast<size_t>(address) & 2 ? (old & 0xffff) | (hsum.half_ << 16) : |
| (old & 0xffff0000) | hsum.half_; |
| old = atomicCAS(address_as_ui, assumed, old); |
| } while (assumed != old); |
| } |
| |
| static inline __device__ void atomicAdd(uint8_t* address, uint8_t val) { |
| unsigned int* address_as_ui = (unsigned int*)(address - ((size_t)address & 0x3)); |
| unsigned int old = *address_as_ui; |
| unsigned int shift = (((size_t)address & 0x3) << 3); |
| unsigned int sum; |
| unsigned int assumed; |
| |
| do { |
| assumed = old; |
| sum = val + static_cast<uint8_t>((old >> shift) & 0xff); |
| old = (old & ~(0x000000ff << shift)) | (sum << shift); |
| old = atomicCAS(address_as_ui, assumed, old); |
| } while (assumed != old); |
| } |
| |
| static inline __device__ void atomicAdd(int8_t* address, int8_t val) { |
| unsigned int* address_as_ui = (unsigned int*)(address - ((size_t)address & 0x3)); |
| unsigned int old = *address_as_ui; |
| unsigned int shift = (((size_t)address & 0x3) << 3); |
| unsigned int sum; |
| unsigned int assumed; |
| |
| do { |
| assumed = old; |
| sum = val + static_cast<int8_t>((old >> shift) & 0xff); |
| old = (old & ~(0x000000ff << shift)) | (sum << shift); |
| old = atomicCAS(address_as_ui, assumed, old); |
| } while (assumed != old); |
| } |
| |
| // Overload atomicAdd to work for signed int64 on all architectures |
| static inline __device__ void atomicAdd(int64_t* address, int64_t val) { |
| atomicAdd(reinterpret_cast<unsigned long long*>(address), // NOLINT |
| static_cast<unsigned long long>(val)); // NOLINT |
| } |
| |
| template <typename DType> |
| __device__ inline DType ldg(const DType* address) { |
| #if __CUDA_ARCH__ >= 350 |
| return __ldg(address); |
| #else |
| return *address; |
| #endif |
| } |
| |
| namespace mxnet { |
| namespace common { |
| /*! \brief common utils for cuda */ |
| namespace cuda { |
| |
| static constexpr const int warp_size = 32; |
| |
| /*! \brief Reduction inside a warp. |
| * Template parameters: |
| * NVALUES - number of values to reduce (defaults to warp_size). |
| * \param value - values to be reduced. |
| * \param redfun - function used to perform reduction. |
| */ |
| template <int NVALUES = warp_size, typename OP, typename T> |
| __device__ inline T warp_reduce(T value, OP redfun) { |
| #pragma unroll |
| for (int i = warp_size / 2; i >= 1; i /= 2) { |
| if (NVALUES > i) |
| value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); |
| } |
| return value; |
| } |
| |
| template <typename OP, typename T> |
| __device__ inline T grouped_warp_allreduce(T value, OP redfun, const int group_size) { |
| for (int i = 1; i < group_size; i *= 2) { |
| value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); |
| } |
| return __shfl_sync(0xffffffff, value, 0, group_size); |
| } |
| |
| template <int NValues = warp_size, typename OP> |
| __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { |
| float v = static_cast<float>(value); |
| #pragma unroll |
| for (int i = warp_size / 2; i >= 1; i /= 2) { |
| if (NValues > i) |
| v = redfun(v, __shfl_down_sync(0xffffffff, v, i)); |
| } |
| return mshadow::half::half_t(v); |
| } |
| |
| /*! \brief Reduction inside a block, requires all threads in a block to participate. |
| * It uses a 2 step approach: |
| * - all warps in a block perform intermediate reduction |
| * - first warp reduces the intermediate results. |
| * Template parameters: |
| * NTHREADS - number of threads in a block. |
| * all_reduce - whether all threads need the result of the reduction. If set to |
| * true, then all threads return with the same value. If set to |
| * false, then only thread 0 has the valid result. Defaults to true. |
| * \param value - value from each thread to be reduced |
| * \param redfun - function used to perform reduction |
| */ |
| template <int NTHREADS, bool all_reduce = true, typename OP, typename T> |
| __device__ inline T reduce(const T& value, OP redfun) { |
| static_assert(NTHREADS <= warp_size * warp_size, "Number of threads too large for reduction"); |
| __shared__ T scratch[NTHREADS / warp_size]; |
| const int thread_idx_in_warp = threadIdx.x % warp_size; |
| const int warp_id = threadIdx.x / warp_size; |
| const T my_val = warp_reduce<warp_size>(value, redfun); |
| if (thread_idx_in_warp == 0) { |
| scratch[warp_id] = my_val; |
| } |
| __syncthreads(); |
| T ret = 0; |
| if (warp_id == 0) { |
| const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0; |
| const T my_val = warp_reduce<NTHREADS / warp_size>(prev_val, redfun); |
| if (all_reduce) { |
| scratch[threadIdx.x] = my_val; |
| } else { |
| ret = my_val; |
| } |
| } |
| // Necessary to synchronize in order to use this function again |
| // as the shared memory scratch space is reused between calls |
| __syncthreads(); |
| if (all_reduce) { |
| ret = scratch[0]; |
| __syncthreads(); |
| } |
| return ret; |
| } |
| |
| } // namespace cuda |
| } // namespace common |
| } // namespace mxnet |
| |
| #endif // __CUDACC__ |
| |
| #endif // MXNET_COMMON_CUDA_UTILS_H_ |