| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file cuda_utils.h |
| * \brief CUDA debugging utilities. |
| */ |
| #ifndef MXNET_COMMON_CUDA_UTILS_H_ |
| #define MXNET_COMMON_CUDA_UTILS_H_ |
| |
| #include <dmlc/logging.h> |
| #include <mshadow/base.h> |
| |
| /*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */ |
| #ifdef __JETBRAINS_IDE__ |
| #define __CUDACC__ 1 |
| #define __host__ |
| #define __device__ |
| #define __global__ |
| #define __forceinline__ |
| #define __shared__ |
| inline void __syncthreads() {} |
| inline void __threadfence_block() {} |
| template<class T> inline T __clz(const T val) { return val; } |
| struct __cuda_fake_struct { int x; int y; int z; }; |
| extern __cuda_fake_struct blockDim; |
| extern __cuda_fake_struct threadIdx; |
| extern __cuda_fake_struct blockIdx; |
| #endif |
| |
| #if MXNET_USE_CUDA |
| |
| #include <cuda_runtime.h> |
| #include <cublas_v2.h> |
| #include <curand.h> |
| |
| namespace mxnet { |
| namespace common { |
| /*! \brief common utils for cuda */ |
| namespace cuda { |
| /*! |
| * \brief Get string representation of cuBLAS errors. |
| * \param error The error. |
| * \return String representation. |
| */ |
| inline const char* CublasGetErrorString(cublasStatus_t error) { |
| switch (error) { |
| case CUBLAS_STATUS_SUCCESS: |
| return "CUBLAS_STATUS_SUCCESS"; |
| case CUBLAS_STATUS_NOT_INITIALIZED: |
| return "CUBLAS_STATUS_NOT_INITIALIZED"; |
| case CUBLAS_STATUS_ALLOC_FAILED: |
| return "CUBLAS_STATUS_ALLOC_FAILED"; |
| case CUBLAS_STATUS_INVALID_VALUE: |
| return "CUBLAS_STATUS_INVALID_VALUE"; |
| case CUBLAS_STATUS_ARCH_MISMATCH: |
| return "CUBLAS_STATUS_ARCH_MISMATCH"; |
| case CUBLAS_STATUS_MAPPING_ERROR: |
| return "CUBLAS_STATUS_MAPPING_ERROR"; |
| case CUBLAS_STATUS_EXECUTION_FAILED: |
| return "CUBLAS_STATUS_EXECUTION_FAILED"; |
| case CUBLAS_STATUS_INTERNAL_ERROR: |
| return "CUBLAS_STATUS_INTERNAL_ERROR"; |
| case CUBLAS_STATUS_NOT_SUPPORTED: |
| return "CUBLAS_STATUS_NOT_SUPPORTED"; |
| default: |
| break; |
| } |
| return "Unknown cuBLAS status"; |
| } |
| |
| /*! |
| * \brief Get string representation of cuRAND errors. |
| * \param status The status. |
| * \return String representation. |
| */ |
| inline const char* CurandGetErrorString(curandStatus_t status) { |
| switch (status) { |
| case CURAND_STATUS_SUCCESS: |
| return "CURAND_STATUS_SUCCESS"; |
| case CURAND_STATUS_VERSION_MISMATCH: |
| return "CURAND_STATUS_VERSION_MISMATCH"; |
| case CURAND_STATUS_NOT_INITIALIZED: |
| return "CURAND_STATUS_NOT_INITIALIZED"; |
| case CURAND_STATUS_ALLOCATION_FAILED: |
| return "CURAND_STATUS_ALLOCATION_FAILED"; |
| case CURAND_STATUS_TYPE_ERROR: |
| return "CURAND_STATUS_TYPE_ERROR"; |
| case CURAND_STATUS_OUT_OF_RANGE: |
| return "CURAND_STATUS_OUT_OF_RANGE"; |
| case CURAND_STATUS_LENGTH_NOT_MULTIPLE: |
| return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; |
| case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: |
| return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; |
| case CURAND_STATUS_LAUNCH_FAILURE: |
| return "CURAND_STATUS_LAUNCH_FAILURE"; |
| case CURAND_STATUS_PREEXISTING_FAILURE: |
| return "CURAND_STATUS_PREEXISTING_FAILURE"; |
| case CURAND_STATUS_INITIALIZATION_FAILED: |
| return "CURAND_STATUS_INITIALIZATION_FAILED"; |
| case CURAND_STATUS_ARCH_MISMATCH: |
| return "CURAND_STATUS_ARCH_MISMATCH"; |
| case CURAND_STATUS_INTERNAL_ERROR: |
| return "CURAND_STATUS_INTERNAL_ERROR"; |
| } |
| return "Unknown cuRAND status"; |
| } |
| |
| } // namespace cuda |
| } // namespace common |
| } // namespace mxnet |
| |
| /*! |
| * \brief Check CUDA error. |
| * \param msg Message to print if an error occured. |
| */ |
| #define CHECK_CUDA_ERROR(msg) \ |
| { \ |
| cudaError_t e = cudaGetLastError(); \ |
| CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected CUDA call. |
| * \param func Expression to call. |
| * |
| * It checks for CUDA errors after invocation of the expression. |
| */ |
| #define CUDA_CALL(func) \ |
| { \ |
| cudaError_t e = (func); \ |
| CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ |
| << "CUDA: " << cudaGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected cuBLAS call. |
| * \param func Expression to call. |
| * |
| * It checks for cuBLAS errors after invocation of the expression. |
| */ |
| #define CUBLAS_CALL(func) \ |
| { \ |
| cublasStatus_t e = (func); \ |
| CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ |
| << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ |
| } |
| |
| /*! |
| * \brief Protected cuRAND call. |
| * \param func Expression to call. |
| * |
| * It checks for cuRAND errors after invocation of the expression. |
| */ |
| #define CURAND_CALL(func) \ |
| { \ |
| curandStatus_t e = (func); \ |
| CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ |
| << "cuRAND: " << common::cuda::CurandGetErrorString(e); \ |
| } |
| |
| #endif // MXNET_USE_CUDA |
| |
| #if MXNET_USE_CUDNN |
| |
| #include <cudnn.h> |
| |
| #define CUDNN_CALL(func) \ |
| { \ |
| cudnnStatus_t e = (func); \ |
| CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ |
| } |
| |
| #endif // MXNET_USE_CUDNN |
| |
| // Overload atomicAdd to work for floats on all architectures |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 |
| // From CUDA Programming Guide |
| static inline __device__ void atomicAdd(double *address, double val) { |
| unsigned long long* address_as_ull = // NOLINT(*) |
| reinterpret_cast<unsigned long long*>(address); // NOLINT(*) |
| unsigned long long old = *address_as_ull; // NOLINT(*) |
| unsigned long long assumed; // NOLINT(*) |
| |
| do { |
| assumed = old; |
| old = atomicCAS(address_as_ull, assumed, |
| __double_as_longlong(val + |
| __longlong_as_double(assumed))); |
| |
| // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) |
| } while (assumed != old); |
| } |
| #endif |
| |
| // Overload atomicAdd for half precision |
| // Taken from: |
| // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh |
| #if defined(__CUDA_ARCH__) |
| static inline __device__ void atomicAdd(mshadow::half::half_t *address, |
| mshadow::half::half_t val) { |
| unsigned int *address_as_ui = |
| reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - |
| (reinterpret_cast<size_t>(address) & 2)); |
| unsigned int old = *address_as_ui; |
| unsigned int assumed; |
| |
| do { |
| assumed = old; |
| mshadow::half::half_t hsum; |
| hsum.half_ = |
| reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff); |
| hsum += val; |
| old = reinterpret_cast<size_t>(address) & 2 |
| ? (old & 0xffff) | (hsum.half_ << 16) |
| : (old & 0xffff0000) | hsum.half_; |
| old = atomicCAS(address_as_ui, assumed, old); |
| } while (assumed != old); |
| } |
| #endif |
| |
| #endif // MXNET_COMMON_CUDA_UTILS_H_ |