| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file codegen_cuda.cc |
| */ |
| |
| #include "codegen_cuda.h" |
| |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/index_map.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <cmath> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "../../tir/transforms/ir_utils.h" |
| #include "literal/cuda_half_t.h" |
| #include "ptx.h" |
| |
| namespace tvm { |
| namespace codegen { |
| |
| CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } |
| |
| void CodeGenCUDA::Init(bool output_ssa) { |
| CodeGenC::Init(output_ssa); |
| vid_global_barrier_state_ = name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state); |
| vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect"); |
| ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); |
| } |
| |
| void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; } |
| |
| class ThreadIdxExtractor : public tir::StmtVisitor { |
| private: |
| void VisitStmt_(const AttrStmtNode* op) final { |
| if (op->attr_key == tir::attr::thread_extent) { |
| IterVar iv = Downcast<IterVar>(op->node); |
| if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { |
| threadIdx_x_ext = op->value; |
| } |
| if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") { |
| threadIdx_y_ext = op->value; |
| } |
| if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") { |
| threadIdx_z_ext = op->value; |
| } |
| } |
| StmtVisitor::VisitStmt_(op); |
| } |
| |
| public: |
| PrimExpr threadIdx_x_ext = Integer(1); |
| PrimExpr threadIdx_y_ext = Integer(1); |
| PrimExpr threadIdx_z_ext = Integer(1); |
| }; |
| |
| void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { |
| ThreadIdxExtractor extractor; |
| extractor(f->body); |
| arith::Analyzer analyzer; |
| PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * |
| extractor.threadIdx_z_ext); |
| if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) { |
| if (threadIdx_ext_int->value == 1) { |
| // unable to extract the number of threads per block, hence directly return |
| return; |
| } |
| os << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; |
| } |
| } |
| |
| std::string CodeGenCUDA::Finish() { |
| if (enable_fp16_) { |
| decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; |
| decl_stream << "#include <cuda_fp16.h>\n"; |
| decl_stream << "__device__ half max" |
| << "(half a, half b)\n" |
| << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; |
| decl_stream << "__device__ half min(half a, half b)\n" |
| << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; |
| decl_stream << "#else\n"; |
| decl_stream << _cuda_half_t_def; |
| decl_stream << "#endif\n\n"; |
| decl_stream << _cuda_half_util; |
| } |
| |
| if (enable_bf16_) { |
| decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n"; |
| decl_stream << "#include <cuda_bf16.h>\n"; |
| decl_stream << "__device__ nv_bfloat16 max" |
| << "(nv_bfloat16 a, nv_bfloat16 b)\n" |
| << "{\n return __hgt(a, b) ? a : b;\n}\n"; |
| decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n" |
| << "{\n return __hlt(a, b) ? a : b;\n}\n"; |
| decl_stream << "#endif\n\n"; |
| decl_stream << _cuda_bfloat16_util; |
| } |
| |
| if (enable_fp8_) { |
| decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; |
| decl_stream << "#include <cuda_fp8.h>\n"; |
| decl_stream << "#endif\n\n"; |
| } |
| |
| if (enable_warp_shuffle_) { |
| decl_stream << _cuda_warp_intrinsic_util; |
| } |
| |
| if (enable_int8_) { |
| decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n"; |
| decl_stream << "#include <sm_61_intrinsics.h>\n"; |
| decl_stream << "#endif\n"; |
| } |
| |
| if (need_math_constants_h_) { |
| decl_stream << "#include <math_constants.h>\n"; |
| } |
| |
| if (need_mma_h_) { |
| decl_stream << "#include <mma.h>\n"; |
| } |
| |
| if (need_cast_smem_ptr_to_int_) { |
| decl_stream << "__forceinline__ __device__ unsigned int\n"; |
| decl_stream << "cast_smem_ptr_to_int(const void* const smem_ptr)\n"; |
| decl_stream << "{\n"; |
| decl_stream << " unsigned int smem_int;\n"; |
| decl_stream << " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; " |
| "cvt.u32.u64 %0, smem_int; }\"\n"; |
| decl_stream << " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n"; |
| decl_stream << " return smem_int;\n"; |
| decl_stream << "}\n"; |
| } |
| |
| decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n"; |
| decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n"; |
| decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n"; |
| decl_stream << "#else\n"; |
| decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; |
| decl_stream << "#endif\n"; |
| |
| decl_stream << "\n#ifdef _WIN32\n"; |
| decl_stream << " using uint = unsigned int;\n"; |
| decl_stream << " using uchar = unsigned char;\n"; |
| decl_stream << " using ushort = unsigned short;\n"; |
| decl_stream << " using int64_t = long long;\n"; |
| decl_stream << " using uint64_t = unsigned long long;\n"; |
| decl_stream << "#else\n"; |
| decl_stream << " #define uint unsigned int\n"; |
| decl_stream << " #define uchar unsigned char\n"; |
| decl_stream << " #define ushort unsigned short\n"; |
| decl_stream << " #define int64_t long long\n"; |
| decl_stream << " #define uint64_t unsigned long long\n"; |
| decl_stream << "#endif\n"; |
| |
| return CodeGenC::Finish(); |
| } |
| |
| void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { |
| ICHECK(is_const_int(op->min, 0)); |
| if (op->kind == tir::ForKind::kUnrolled) { |
| PrintIndent(); |
| stream << "#pragma unroll\n"; |
| } |
| CodeGenC::VisitStmt_(op); |
| } |
| |
| void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { |
| ICHECK(!var_idmap_.count(iv->var.get())); |
| var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); |
| } |
| |
| void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) |
| int lanes = t.lanes(); |
| if (t.is_handle()) { |
| ICHECK(t.is_scalar()) << "do not yet support vector types"; |
| os << "void*"; |
| return; |
| } |
| |
| if (t.is_void()) { |
| os << "void"; |
| return; |
| } |
| |
| bool fail = false; |
| if (t.is_float()) { |
| switch (t.bits()) { |
| case 16: |
| enable_fp16_ = true; |
| if (t.is_scalar()) { |
| os << "half"; |
| } else if (lanes <= 8) { |
| // Emit CUDA code to access fp16 vector elements. |
| // |
| // half4 is stored as uint2 |
| // |
| // h4.x is emitted as *(half2*)(&(u2.x)).x |
| // h4.y is emitted as *(half2*)(&(u2.x)).y |
| // h4.z is emitted as *(half2*)(&(u2.y)).x |
| // h4.w is emitted as *(half2*)(&(u2.y)).y |
| // |
| ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; |
| os << "uint" << lanes / 2; |
| } else { |
| fail = true; |
| } |
| break; |
| case 32: |
| if (lanes <= 4) { |
| os << "float"; |
| } else if (lanes <= 8) { |
| // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8. |
| // |
| // float8 is stored as ulonglong4 |
| // |
| // f8.v1 is emitted as *(float2*)(&(ul4.x)).x |
| // f8.v2 is emitted as *(float2*)(&(ul4.x)).y |
| // |
| ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; |
| os << "ulonglong" << lanes / 2; |
| } else { |
| fail = true; |
| } |
| break; |
| case 64: |
| os << "double"; |
| break; |
| default: |
| fail = true; |
| break; |
| } |
| if (!fail && (t.is_scalar() || t.bits() == 16)) return; |
| if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; |
| if (!fail && (lanes >= 2 && lanes <= 4)) { |
| os << lanes; |
| return; |
| } |
| } else if (t.is_bfloat16()) { |
| enable_bf16_ = true; |
| if (t.is_scalar()) { |
| os << "nv_bfloat16"; |
| } else if (lanes <= 8) { |
| ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; |
| os << "uint" << lanes / 2; |
| } else { |
| fail = true; |
| } |
| if (!fail) return; |
| } else if (t.is_float8()) { |
| if (t.is_scalar()) { |
| os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char |
| } else if (lanes == 2) { |
| os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short |
| } else if (lanes == 4) { |
| os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int |
| } else { |
| fail = true; |
| } |
| if (!fail) return; |
| } else if (t == DataType::Bool()) { |
| os << "bool"; |
| return; |
| } else if (t.is_vector_bool()) { |
| // CUDA does not support bool vectors. |
| // Use ushort vectors to represent instead. |
| int n = t.lanes(); |
| if (n <= 4) { |
| os << "ushort" << n; |
| return; |
| } |
| } else if (t.is_uint() || t.is_int()) { |
| if (t.is_uint()) { |
| os << "u"; |
| } |
| switch (t.bits()) { |
| case 1: { |
| if (t.is_scalar()) { |
| os << "int"; |
| return; |
| } else if (t.lanes() == 8) { |
| os << "int8_t"; |
| return; |
| } else if (t.lanes() == 16) { |
| os << "int16_t"; |
| return; |
| } else if (t.lanes() == 32) { |
| os << "int"; |
| return; |
| } else { |
| LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; |
| } |
| } |
| case 4: { |
| if (t.is_scalar()) { |
| os << "int"; |
| return; |
| } else if (t.lanes() == 4) { |
| os << "int16_t"; |
| return; |
| } else if (t.lanes() == 8) { |
| // directly 8 4-bit int in integer. |
| os << "int"; |
| return; |
| } else if (t.lanes() == 16) { |
| os << "int2"; |
| return; |
| } else if (t.lanes() == 32) { |
| os << "int4"; |
| return; |
| } else if (t.lanes() == 64) { |
| os << "int8"; |
| return; |
| } else { |
| LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; |
| } |
| } |
| case 8: { |
| if (t.lanes() == 4) { |
| // directly 4 8 bit int in integer. |
| enable_int8_ = true; |
| |
| // We use int for int8x4 instead of char4 because using char4 is |
| // likely to produce extra instructions to pack four int8 elements |
| // into 32-bit data. |
| os << "int"; |
| return; |
| } else if (t.lanes() == 8) { |
| enable_int8_ = true; |
| os << "int2"; |
| return; |
| } else if (t.lanes() == 16) { |
| enable_int8_ = true; |
| os << "int4"; |
| return; |
| } else if (!t.is_uint() && t.is_scalar()) { |
| os << "signed char"; |
| break; |
| } else { |
| os << "char"; |
| break; |
| } |
| } |
| case 16: { |
| if (t.is_scalar()) { |
| os << "short"; |
| } else if (t.lanes() <= 4) { |
| os << "short" << lanes; |
| } else if (t.lanes() <= 8) { |
| // Emit CUDA code to access int16 vector elements. |
| // |
| // short4 is stored as int2 |
| // |
| // s4.x is emitted as *(short2*)(&(i2.x)).x |
| // s4.y is emitted as *(short2*)(&(i2.x)).y |
| // s4.z is emitted as *(short2*)(&(i2.y)).x |
| // s4.w is emitted as *(short2*)(&(i2.y)).y |
| // |
| ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4"; |
| os << "int" << t.lanes() / 2; |
| } else { |
| fail = true; |
| } |
| if (!fail) { |
| return; |
| } |
| break; |
| } |
| case 32: { |
| if (t.is_scalar()) { |
| os << "int"; |
| } else if (t.lanes() <= 4) { |
| os << "int" << t.lanes(); |
| } else if (t.lanes() <= 8) { |
| // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. |
| // |
| // int8 is stored as longlong4 |
| // |
| // i8.v1 is emitted as *(int2*)(&(l4.x)).x |
| // i8.v2 is emitted as *(int2*)(&(l4.x)).y |
| // |
| ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4"; |
| os << "longlong" << lanes / 2; |
| } else { |
| fail = true; |
| } |
| if (!fail) { |
| return; |
| } |
| break; |
| } |
| case 64: { |
| if (t.is_scalar()) { |
| os << "int64_t"; |
| } else if (t.lanes() == 2) { |
| os << "longlong2"; |
| } else if (t.lanes() == 3) { |
| os << "longlong3"; |
| } else if (t.lanes() == 4) { |
| os << "longlong4"; |
| } |
| return; |
| } |
| default: |
| fail = true; |
| break; |
| } |
| if (!fail && lanes == 1) { |
| return; |
| } |
| if (!fail && (lanes >= 2 && lanes <= 4)) { |
| os << lanes; |
| return; |
| } |
| } |
| LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; |
| } |
| |
| void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { |
| os << "make_"; |
| PrintType(t, os); |
| } |
| |
| void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, |
| std::ostream& os) { // NOLINT(*) |
| // Delcare the result. |
| std::string sret = name_supply_->FreshName("_"); |
| this->PrintIndent(); |
| this->PrintType(t, stream); |
| stream << ' ' << sret << ";\n"; |
| int ssa_scope = BeginScope(); |
| { |
| // Unpack into individual ops. |
| std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); |
| std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); |
| |
| for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { |
| std::ostringstream value_temp; |
| if (isalpha(op[0])) { |
| value_temp << op << "("; |
| PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); |
| value_temp << ", "; |
| PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); |
| value_temp << ")"; |
| } else { |
| value_temp << "("; |
| PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); |
| value_temp << op; |
| PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); |
| value_temp << ")"; |
| } |
| PrintVecElemStore(sret, t, i, value_temp.str()); |
| } |
| } |
| EndScope(ssa_scope); |
| os << sret; |
| } |
| |
| void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, |
| std::ostream& os) { // NOLINT(*) |
| if (t.is_scalar()) { |
| os << vec; |
| return; |
| } |
| |
| static const char access[] = {'x', 'y', 'z', 'w'}; |
| ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); |
| if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
| std::string type_name = t.is_int() ? "char" : "unsigned char"; |
| if (t.lanes() == 2 || t.lanes() == 3) { |
| os << vec << "." << access[i % t.lanes()]; |
| } else { |
| std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); |
| os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; |
| } |
| } else if (t.is_float16()) { |
| os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
| } else if (t.is_bfloat16()) { |
| os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
| } else if (t.lanes() > 4 && t.lanes() <= 8) { |
| std::string type_name; |
| if (t.bits() == 16) { |
| if (t.is_int()) { |
| type_name = "short"; |
| } else if (t.is_uint()) { |
| type_name = "ushort"; |
| } |
| } else if (t.bits() == 32) { |
| if (t.is_int()) { |
| type_name = "int"; |
| } else if (t.is_uint()) { |
| type_name = "uint"; |
| } else if (t.is_float()) { |
| type_name = "float"; |
| } |
| } |
| ICHECK(!type_name.empty()); |
| os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
| } else { |
| os << vec << "." << access[i]; |
| } |
| } |
| |
| void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, |
| const std::string& value) { |
| this->PrintIndent(); |
| static const char access[] = {'x', 'y', 'z', 'w'}; |
| ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); |
| if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
| if (t.lanes() == 2 || t.lanes() == 3) { |
| stream << vec << '.' << access[i % t.lanes()] << "=" |
| << "(" << value << ");\n"; |
| } else { |
| std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); |
| stream << ac << "="; |
| // Do not read the first undef lane. |
| if (i != 0) { |
| stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; |
| } |
| stream << "(" << value << " << " << i % 4 * 8 << ");\n"; |
| } |
| } else if (t.is_float16()) { |
| stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " |
| << value << ";\n"; |
| } else if (t.is_bfloat16()) { |
| stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] |
| << " = " << value << ";\n"; |
| } else if (t.lanes() > 4 && t.lanes() <= 8) { |
| std::string type_name; |
| if (t.bits() == 16) { |
| if (t.is_int()) { |
| type_name = "short"; |
| } else if (t.is_uint()) { |
| type_name = "ushort"; |
| } |
| } else if (t.bits() == 32) { |
| if (t.is_int()) { |
| type_name = "int"; |
| } else if (t.is_uint()) { |
| type_name = "uint"; |
| } else if (t.is_float()) { |
| type_name = "float"; |
| } |
| } |
| ICHECK(!type_name.empty()); |
| stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" |
| << access[i % 2] << " = " << value << ";\n"; |
| } else { |
| stream << vec << "." << access[i] << " = " << value << ";\n"; |
| } |
| } |
| |
| void CodeGenCUDA::PrintStorageSync(const CallNode* op) { |
| const std::string& sync = op->args[0].as<StringImmNode>()->value; |
| if (sync == "warp") { |
| // DO nothing. |
| } else if (sync == "shared" || sync == "shared.dyn") { |
| this->PrintIndent(); |
| this->stream << "__syncthreads();\n"; |
| } else if (sync == "global") { |
| if (!need_global_barrier_) { |
| need_global_barrier_ = true; |
| this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_ |
| << ";\n"; |
| } |
| // global synchronizer |
| std::string is_load = PrintExpr(op->args[1]); |
| std::string num_blocks = PrintExpr(op->args[2]); |
| this->PrintIndent(); |
| // In theory only threadfence is needed |
| // but we observed problems with only threadfence |
| this->stream << "__threadfence_system();\n"; |
| this->PrintIndent(); |
| this->stream << "if (" << is_load << ") {\n"; |
| int wb = this->BeginScope(); |
| this->PrintIndent(); |
| this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; |
| this->PrintIndent(); |
| std::string ptr = name_supply_->FreshName("pf"); |
| this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; |
| this->PrintIndent(); |
| this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; |
| this->PrintIndent(); |
| this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; |
| this->EndScope(wb); |
| this->PrintIndent(); |
| this->stream << "}\n"; |
| this->PrintIndent(); |
| this->stream << "__syncthreads();\n"; |
| } |
| } |
| |
| void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) |
| ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " |
| "all global arrays as input instead"; |
| if (scope == "shared") { |
| os << "__shared__ "; |
| } else if (scope == "shared.dyn") { |
| os << "extern __shared__ "; |
| } |
| } |
| |
| std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType target) { |
| if (from == target) return value; |
| std::ostringstream os; |
| os << "(("; |
| this->PrintType(target, os); |
| os << ")"; |
| if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { |
| os << "("; |
| if (target.is_uint()) { |
| os << "u"; |
| } |
| os << "int)"; |
| } |
| os << value << ")"; |
| return os.str(); |
| } |
| |
| void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { |
| DataType from_ty = op->value.dtype(); |
| DataType target_ty = op->dtype; |
| ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); |
| |
| // Emit simple C-style type conversion. |
| if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); |
| |
| // We could emit make_float4 like calls, but the emitted code looks |
| // too compact to read. Emit this as vectorized unary ops. |
| std::string sret = name_supply_->FreshName("_"); |
| this->PrintIndent(); |
| this->PrintType(target_ty, stream); |
| stream << ' ' << sret << ";\n"; |
| { |
| std::string src = SSAGetID(PrintExpr(op->value), from_ty); |
| for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { |
| std::ostringstream val; |
| val << "("; |
| PrintType(target_ty.element_of(), val); |
| val << ")("; |
| PrintVecElemLoad(src, from_ty, i, val); |
| val << ")"; |
| PrintVecElemStore(sret, target_ty, i, val.str()); |
| } |
| } |
| os << sret; |
| } |
| |
| void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, |
| bool skip_first_arg, std::ostream& os) { // NOLINT(*) |
| DataType ret_dtype = GetRuntimeDataType(ret_type); |
| if (ret_dtype.is_vector()) { |
| // |
| // Emit an unsupported vector call |
| // |
| // v = intrin_f((float4*)A[0], (float4*)B[0]) |
| // |
| // as |
| // |
| // float4 __ret; |
| // { |
| // float4 __arg0 = ((float4*)A)[0]; |
| // float4 __arg1 = ((float4*)B)[0]; |
| // __ret.x = intrin_f(__arg0.x, __arg1.x); |
| // __ret.y = intrin_f(__arg0.y, __arg1.y); |
| // __ret.z = intrin_f(__arg0.z, __arg1.z); |
| // __ret.w = intrin_f(__arg0.w, __arg1.w); |
| // } |
| // v = __ret; |
| // |
| // Declare the result vector. |
| std::string sret = name_supply_->FreshName("_"); |
| this->PrintIndent(); |
| this->PrintType(ret_dtype, stream); |
| stream << ' ' << sret << ";\n"; |
| { |
| // Load arguments. |
| std::vector<std::string> sargs; |
| size_t arg_begin = static_cast<size_t>(skip_first_arg); |
| for (size_t i = arg_begin; i < args.size(); ++i) { |
| std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); |
| sargs.push_back(std::move(val)); |
| } |
| |
| // Emit a scalar call for each lane. |
| for (int i = 0; i < ret_dtype.lanes(); ++i) { |
| std::ostringstream scall; |
| scall << global_symbol << "("; |
| for (size_t j = 0; j < sargs.size(); ++j) { |
| if (j > 0) scall << ", "; |
| PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); |
| } |
| scall << ")"; |
| PrintVecElemStore(sret, ret_dtype, i, scall.str()); |
| } |
| } |
| os << sret; |
| } else { |
| CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); |
| } |
| } |
| |
| void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { |
| if (auto opt_call_opt = op->op.as<Op>()) { |
| Op call_op = opt_call_opt.value(); |
| // This is only for backward compatibility with __shfl_{up/down}. |
| // A macro will be used to replace *_sync calls to legacy ones. |
| if (op_need_warp_shuffle_.get(call_op, false)) { |
| enable_warp_shuffle_ = true; |
| } |
| } |
| |
| if (op->op.same_as(builtin::tvm_fill_fragment())) { |
| need_mma_h_ = true; |
| ICHECK_EQ(op->args.size(), 6U); |
| os << "nvcuda::wmma::fill_fragment("; |
| this->PrintExpr(op->args[0], os); |
| os << "["; |
| this->PrintExpr(op->args[4], os); |
| os << "], "; |
| this->PrintExpr(op->args[5], os); |
| os << ")"; |
| } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { |
| need_mma_h_ = true; |
| ICHECK_EQ(op->args.size(), 8U); |
| os << "nvcuda::wmma::load_matrix_sync("; |
| this->PrintExpr(op->args[0], os); |
| os << "["; |
| this->PrintExpr(op->args[4], os); |
| os << "], "; |
| this->PrintExpr(op->args[5], os); |
| os << ", "; |
| this->PrintExpr(op->args[6], os); |
| os << ")"; |
| } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { |
| need_mma_h_ = true; |
| ICHECK_EQ(op->args.size(), 8U); |
| os << "nvcuda::wmma::store_matrix_sync("; |
| this->PrintExpr(op->args[5], os); |
| os << ", "; |
| this->PrintExpr(op->args[0], os); |
| os << "["; |
| this->PrintExpr(op->args[4], os); |
| os << "], "; |
| this->PrintExpr(op->args[6], os); |
| if (const StringImmNode* str = op->args[7].as<StringImmNode>()) { |
| os << ", nvcuda::wmma::mem_" << str->value; |
| } else { |
| LOG(FATAL) << "Invalid parameters"; |
| } |
| os << ")"; |
| } else if (op->op.same_as(builtin::tvm_mma_sync())) { |
| need_mma_h_ = true; |
| ICHECK_EQ(op->args.size(), 8U); |
| os << "nvcuda::wmma::mma_sync("; |
| for (int i = 0; i < 4; ++i) { |
| this->PrintExpr(op->args[i * 2], os); |
| os << "["; |
| this->PrintExpr(op->args[i * 2 + 1], os); |
| os << "]" << ((i < 3) ? ", " : ")"); |
| } |
| } else if (op->op.same_as(builtin::tvm_bmma_sync())) { |
| need_mma_h_ = true; |
| ICHECK_EQ(op->args.size(), 8U); |
| os << "nvcuda::wmma::bmma_sync("; |
| for (int i = 0; i < 4; ++i) { |
| this->PrintExpr(op->args[i * 2], os); |
| os << "["; |
| this->PrintExpr(op->args[i * 2 + 1], os); |
| os << "]" << ((i < 3) ? ", " : ")"); |
| } |
| } else if (op->op.same_as(builtin::ptx_mma())) { |
| // arg 0: shape: mXnXkX |
| // arg 1: A layout: row/col |
| // arg 2: B layout: row/col |
| // arg 3: A precision: fp16, fp64, ... |
| // arg 4: B precision: fp16, fp64, ... |
| // arg 5: C precision: fp32, fp64, ... |
| // arg 6: A multiplicand |
| // arg 7: A multiplicand index |
| // arg 8: B multiplicand |
| // arg 9: B multiplicand index |
| // arg 10: C accumulator |
| // arg 11: C accumulator index |
| // arg 12: saturate |
| // arg 13: (optional) 1-bit operator (xor or and) |
| ICHECK(op->args.size() == 13U || op->args.size() == 14U); |
| std::string shape = Downcast<StringImm>(op->args[0])->value; |
| std::string A_layout = Downcast<StringImm>(op->args[1])->value; |
| std::string B_layout = Downcast<StringImm>(op->args[2])->value; |
| std::string A_dtype = Downcast<StringImm>(op->args[3])->value; |
| std::string B_dtype = Downcast<StringImm>(op->args[4])->value; |
| std::string C_dtype = Downcast<StringImm>(op->args[5])->value; |
| std::string a_ref = this->PrintExpr(op->args[6]); |
| std::string a_bias = this->PrintExpr(op->args[7]); |
| std::string b_ref = this->PrintExpr(op->args[8]); |
| std::string b_bias = this->PrintExpr(op->args[9]); |
| std::string c_ref = this->PrintExpr(op->args[10]); |
| std::string c_bias = this->PrintExpr(op->args[11]); |
| bool saturate = Downcast<Bool>(op->args[12])->value; |
| std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : ""; |
| std::string asm_code = |
| PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, |
| b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); |
| |
| this->stream << asm_code; |
| } else if (op->op.same_as(builtin::ptx_mma_sp())) { |
| // arg 0: shape: mXnXkX |
| // arg 1: A layout: row/col |
| // arg 2: B layout: row/col |
| // arg 3: A precision: fp16, fp32, ... |
| // arg 4: B precision: fp16, fp32, ... |
| // arg 5: C precision: fp16, fp32, ... |
| // arg 6: A multiplicand pointer |
| // arg 7: A multiplicand index |
| // arg 8: B multiplicand pointer |
| // arg 9: B multiplicand index |
| // arg 10: C accumulator pointer |
| // arg 11: C accumulator index |
| // arg 12: metadata |
| // arg 13: metadata index |
| // arg 14: sparse_selector |
| // arg 15: saturate |
| ICHECK_EQ(op->args.size(), 16U); |
| std::string shape = Downcast<StringImm>(op->args[0])->value; |
| std::string A_layout = Downcast<StringImm>(op->args[1])->value; |
| std::string B_layout = Downcast<StringImm>(op->args[2])->value; |
| std::string A_dtype = Downcast<StringImm>(op->args[3])->value; |
| std::string B_dtype = Downcast<StringImm>(op->args[4])->value; |
| std::string C_dtype = Downcast<StringImm>(op->args[5])->value; |
| std::string a_ref = this->PrintExpr(op->args[6]); |
| std::string a_offset = this->PrintExpr(op->args[7]); |
| std::string b_ref = this->PrintExpr(op->args[8]); |
| std::string b_offset = this->PrintExpr(op->args[9]); |
| std::string c_ref = this->PrintExpr(op->args[10]); |
| std::string c_offset = this->PrintExpr(op->args[11]); |
| std::string metadata = this->PrintExpr(op->args[12]); |
| std::string metadata_offset = this->PrintExpr(op->args[13]); |
| std::string sparse_selector = this->PrintExpr(op->args[14]); |
| bool saturate = Downcast<Bool>(op->args[15])->value; |
| std::string asm_code = PrintMMAAssembly( |
| shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, |
| c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); |
| this->stream << asm_code; |
| } else if (op->op.same_as(builtin::ptx_ldmatrix())) { |
| // arg 0: whether the matrix is loaded in column major format or not. |
| // arg 1: number of matrices to load. |
| // arg 2: The data type in the matrix, .b16 is the only accepted data type. |
| // arg 3: pointer to local buffer. |
| // arg 4: The offset of the element to store in the local buffer. |
| // arg 5: pointer to the shared memory buffer to load. |
| // arg 6: The offset of the start element of the row to load in shared memory. |
| ICHECK_EQ(op->args.size(), 7U); |
| bool trans = Downcast<Bool>(op->args[0])->value; |
| int num = Downcast<Integer>(op->args[1])->value; |
| std::string type = Downcast<StringImm>(op->args[2])->value; |
| std::string local_ptr = this->PrintExpr(op->args[3]); |
| std::string local_elem_offset = this->PrintExpr(op->args[4]); |
| std::string smem_ptr = this->PrintExpr(op->args[5]); |
| if (trans && op->dtype.bits() == 8) { |
| // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an |
| // int8 matrix. |
| std::string smem_stride = this->PrintExpr(op->args[6]); |
| ICHECK(num == 4); |
| os << "for (int i = 0; i < 16; ++i) {\n"; |
| os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr |
| << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + |
| "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n"; |
| os << "}\n"; |
| } else { |
| std::string smem_elem_offset = this->PrintExpr(op->args[6]); |
| need_cast_smem_ptr_to_int_ = true; |
| this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, |
| smem_ptr, smem_elem_offset); |
| } |
| } else if (op->op.same_as(builtin::mma_store())) { |
| int m = Downcast<Integer>(op->args[0])->value; |
| int n = Downcast<Integer>(op->args[1])->value; |
| std::string dst = this->PrintExpr(op->args[2]); |
| std::string src = this->PrintExpr(op->args[3]); |
| std::string src_offset = this->PrintExpr(op->args[4]); |
| PrimExpr stride = op->args[5]; |
| |
| ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; |
| |
| // Each thread in a warp holds a certain number of elements of an MMA output. |
| // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements |
| // in its registers. So conceptually, a warp memory is organized as a 32x8 block. |
| // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. |
| |
| // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map |
| // to determine the output location for each 8 element. |
| |
| const auto* index_map_func = |
| runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); |
| ICHECK(index_map_func); |
| |
| arith::Analyzer analyzer; |
| auto inverse_index_map = |
| IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, &analyzer); |
| auto indices_16x16 = inverse_index_map->final_indices; |
| |
| // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. |
| // FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them |
| // to the plain ones here. |
| class LowerFloorDivMod : public ExprMutator { |
| public: |
| PrimExpr VisitExpr_(const FloorDivNode* op) { |
| return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); |
| } |
| PrimExpr VisitExpr_(const FloorModNode* op) { |
| return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); |
| } |
| }; |
| |
| auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); |
| |
| var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; |
| var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; |
| |
| os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; |
| os << dst << "[" + this->PrintExpr(dst_ind) + "]" |
| << " = " << src << "[" << src_offset << " + local_id];\n"; |
| os << "}\n"; |
| |
| } else if (op->op.same_as(builtin::mma_fill())) { |
| std::string num_elem = this->PrintExpr(op->args[0]); |
| std::string dst = this->PrintExpr(op->args[1]); |
| std::string dst_offset = this->PrintExpr(op->args[2]); |
| |
| os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; |
| os << dst << "[" << dst_offset << " + i] = 0.0;"; |
| os << "}\n"; |
| } else if (op->op.same_as(builtin::ptx_cp_async())) { |
| std::string dst = this->PrintExpr(op->args[0]); |
| std::string dst_offset = this->PrintExpr(op->args[1]); |
| std::string src = this->PrintExpr(op->args[2]); |
| std::string src_offset = this->PrintExpr(op->args[3]); |
| std::string size = this->PrintExpr(op->args[4]); |
| need_cast_smem_ptr_to_int_ = true; |
| // use size of argument list to indicate whether or not to use predicated cp.async |
| if (op->args.size() == 5) { |
| this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); |
| } else { |
| this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, |
| this->PrintExpr(op->args[5])); |
| } |
| } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { |
| need_cast_smem_ptr_to_int_ = true; |
| std::string dst = this->PrintExpr(op->args[0]); |
| std::string dst_offset = this->PrintExpr(op->args[1]); |
| std::string src = this->PrintExpr(op->args[2]); |
| std::string src_offset = this->PrintExpr(op->args[3]); |
| std::string size = this->PrintExpr(op->args[4]); |
| int barrier_id = Downcast<IntImm>(op->args[5])->value; |
| CHECK(barrier_id < barrier_count_); |
| std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; |
| this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); |
| } else if (op->op.same_as(builtin::ptx_commit_group())) { |
| this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; |
| } else if (op->op.same_as(builtin::ptx_wait_group())) { |
| int n = Downcast<IntImm>(op->args[0])->value; |
| this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n"; |
| } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { |
| need_cast_smem_ptr_to_int_ = true; |
| int barrier_id = Downcast<IntImm>(op->args[0])->value; |
| CHECK(barrier_id < barrier_count_); |
| std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; |
| this->stream << PrintCpAsyncBarrierAsm(barrier); |
| } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { |
| need_cast_smem_ptr_to_int_ = true; |
| int barrier_id = Downcast<IntImm>(op->args[0])->value; |
| CHECK(barrier_id < barrier_count_); |
| std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; |
| std::string thread_count = this->PrintExpr(op->args[1]); |
| this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); |
| } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { |
| need_cast_smem_ptr_to_int_ = true; |
| int barrier_id = Downcast<IntImm>(op->args[0])->value; |
| CHECK(barrier_id < barrier_count_); |
| std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; |
| this->stream << PrintArriveBarrierAsm(barrier); |
| } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { |
| need_cast_smem_ptr_to_int_ = true; |
| int barrier_id = Downcast<IntImm>(op->args[0])->value; |
| CHECK(barrier_id < barrier_count_); |
| std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; |
| std::string byte_count = this->PrintExpr(op->args[1]); |
| this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count); |
| } else if (op->op.same_as(builtin::ptx_wait_barrier())) { |
| need_cast_smem_ptr_to_int_ = true; |
| int barrier_id = Downcast<IntImm>(op->args[0])->value; |
| CHECK(barrier_id < barrier_count_); |
| std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; |
| this->stream << PrintWaitBarrierAsm(barrier); |
| } else if (op->op.same_as(builtin::create_barriers())) { |
| CHECK_EQ(barrier_count_, -1); |
| int barrier_count = Downcast<IntImm>(op->args[0])->value; |
| // pad barrier alignment to avoid runtime alignment errors |
| CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0); |
| int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t); |
| if (barrier_count % barrier_alignment_count != 0) { |
| barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count; |
| } |
| barrier_count_ = barrier_count; |
| this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") uint64_t " |
| << barrier_name_ << "[" << barrier_count << "];\n"; |
| this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << barrier_name_ |
| << "[i] = 0; }\n"; |
| } else if (op->op.same_as(builtin::ptx_ldg32())) { |
| /* |
| asm volatile ( |
| "{.reg .pred p;\n" |
| " setp.ne.b32 p, %2, 0;\n" |
| // " @p ld.global.nc.f32 %0, [%1];}\n"t |
| " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n" |
| : "=f"(reg) |
| : "l"(addr), "r"((int)guard) |
| ); |
| */ |
| |
| // get local |
| std::string reg = this->PrintExpr(op->args[0]); |
| // get guard |
| std::string guard = this->PrintExpr(op->args[1]); |
| const BufferLoadNode* addr_buffer = op->args[2].as<BufferLoadNode>(); |
| std::string global_addr = this->PrintExpr(addr_buffer->indices[0]); |
| std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data); |
| std::string local_addr = this->PrintExpr(op->args[3]); |
| this->stream << "asm volatile (\n"; |
| this->stream << "\"{.reg .pred p;\\n\"\n"; |
| this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n"; |
| this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; |
| this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n"; |
| // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; |
| stream << ": \"=f\"(" << reg << "[" << local_addr << "]" |
| << ")\n"; |
| stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" |
| << guard << ")\n"; |
| stream << ");\n"; |
| } else { |
| CodeGenC::VisitExpr_(op, os); |
| } |
| } |
| |
| void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { |
| if (op->attr_key == tir::attr::fragment_shape) { |
| const VarNode* buffer = op->node.as<VarNode>(); |
| const StringImmNode* shape_str = op->value.as<StringImmNode>(); |
| fragment_shapes[buffer] = shape_str->value; |
| } else if (op->attr_key == tir::attr::fragment_layout) { |
| const VarNode* buffer = op->node.as<VarNode>(); |
| const StringImmNode* layout_str = op->value.as<StringImmNode>(); |
| fragment_layouts[buffer] = layout_str->value; |
| } else if (op->attr_key == tir::attr::async_commit_queue_scope) { |
| const IntImmNode* queue_id = op->value.as<IntImmNode>(); |
| ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; |
| this->VisitStmt(op->body); |
| auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); |
| this->VisitExpr(commit_group, this->stream); |
| return; |
| } else if (op->attr_key == tir::attr::async_wait_queue_scope) { |
| auto wait_attrs = GetAsyncWaitAttributes(op); |
| auto queue_id = wait_attrs.first.as<IntImmNode>(); |
| ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; |
| auto wait_cnt = wait_attrs.second; |
| auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); |
| this->VisitExpr(wait_group, this->stream); |
| auto inner = op->body.as<AttrStmtNode>(); |
| ICHECK(inner); |
| this->VisitStmt(inner->body); |
| return; |
| } |
| CodeGenC::VisitStmt_(op); |
| } |
| |
| void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { |
| ICHECK(!is_zero(op->condition)); |
| std::string vid = AllocVarID(op->buffer_var.get()); |
| |
| this->PrintIndent(); |
| std::string scope = GetPtrStorageScope(op->buffer_var); |
| const VarNode* buffer = op->buffer_var.as<VarNode>(); |
| if (scope.find("wmma.") == 0) { |
| if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { |
| ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || |
| op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || |
| op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || |
| op->dtype == DataType::BFloat(16)) |
| << "Matrix_a and matrix_b only support half or char or unsigned char " |
| << "or uint4 or int4 or int1 type for now"; |
| } else { |
| ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || |
| op->dtype == DataType::Int(32)) |
| << "Accumulator only support half, float and int type for now"; |
| } |
| PrintWmmaScope(scope, op->dtype, buffer, stream); |
| } else { |
| PrintStorageScope(scope, stream); |
| PrintType(op->dtype, stream); |
| } |
| |
| if (scope == "shared.dyn") { |
| stream << ' ' << vid << "[];\n"; |
| } else { |
| size_t constant_size = op->ConstantAllocationSize(); |
| ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; |
| |
| if (scope.find("wmma.") == 0) { |
| constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); |
| } |
| if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || |
| op->dtype == DataType::Int(1)) && |
| scope == "shared") { |
| constant_size = constant_size / (32 / op->dtype.bits()); |
| } |
| stream << ' ' << vid << '[' << constant_size << "];\n"; |
| } |
| |
| RegisterHandleType(op->buffer_var.get(), op->dtype); |
| this->PrintStmt(op->body); |
| } |
| |
| void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { |
| if (is_const_int(op->value)) return; |
| const CallNode* call = op->value.as<CallNode>(); |
| if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { |
| PrintIndent(); |
| stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; |
| PrintIndent(); |
| stream << "if (threadIdx.x == 0) {\n"; |
| PrintIndent(); |
| stream << " " << vid_global_barrier_expect_ << " = 0;\n"; |
| PrintIndent(); |
| stream << "}\n"; |
| } else { |
| CodeGenC::VisitStmt_(op); |
| } |
| } |
| |
| void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { |
| CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; |
| PrintVecConstructor(op->dtype, os); |
| os << "("; |
| for (int i = 0; i < op->lanes; i++) { |
| os << "(" << PrintExpr(op->base) << ")" |
| << "+(" << PrintExpr(op->stride) << "*" << i << ")"; |
| if (i != op->lanes - 1) os << ", "; |
| } |
| os << ")"; |
| } |
| |
| void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) |
| if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { |
| // make_int8x4 |
| const int64_t* p = as_const_int(op->value); |
| ICHECK(p); |
| int64_t v = *p & 0xFF; |
| v = (v << 24) | (v << 16) | (v << 8) | v; |
| if (op->dtype.is_uint()) { |
| os << "(uint)" << v; |
| } else { |
| os << "(int)" << v; |
| } |
| return; |
| } |
| |
| if (op->dtype.is_float16()) { |
| std::string v = PrintExpr(op->value); |
| PrintVecConstructor(op->dtype, os); |
| os << '('; |
| for (int i = 0; i < op->lanes / 2; ++i) { |
| if (i != 0) os << ", "; |
| os << "__pack_half2(" << v << ", " << v << ")"; |
| } |
| os << ')'; |
| return; |
| } |
| |
| if (op->dtype.is_bfloat16()) { |
| std::string v = PrintExpr(op->value); |
| PrintVecConstructor(op->dtype, os); |
| os << '('; |
| for (int i = 0; i < op->lanes / 2; ++i) { |
| if (i != 0) os << ", "; |
| os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; |
| } |
| os << ')'; |
| return; |
| } |
| |
| if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { |
| bool fail = false; |
| const int64_t* p = as_const_int(op->value); |
| ICHECK(p); |
| int64_t v = *p & 0xF; |
| |
| if (op->lanes == 4) { |
| v = (v << 12) | (v << 8) | (v << 4) | v; |
| if (op->dtype.is_uint()) { |
| os << "(uint16_t)" << v; |
| } else { |
| os << "(int16_t)" << v; |
| } |
| } else { |
| v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; |
| if (op->lanes == 8) { |
| if (op->dtype.is_uint()) { |
| os << "(uint)" << v; |
| } else { |
| os << "(int)" << v; |
| } |
| } else if (op->lanes == 16 || op->lanes == 32) { |
| PrintVecConstructor(op->dtype, os); |
| os << '('; |
| for (int i = 0; i < op->lanes / 8; ++i) { |
| if (i != 0) os << ", "; |
| if (op->dtype.is_uint()) { |
| os << "(uint)" << v; |
| } else { |
| os << "(int)" << v; |
| } |
| } |
| os << ')'; |
| } else { |
| fail = true; |
| } |
| } |
| |
| if (!fail) { |
| return; |
| } |
| } |
| |
| std::string v = PrintExpr(op->value); |
| PrintVecConstructor(op->dtype, os); |
| os << '('; |
| for (int i = 0; i < op->lanes; ++i) { |
| if (i != 0) os << ", "; |
| os << v; |
| } |
| os << ')'; |
| } |
| |
| void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { |
| // Non-vector cases. |
| if (!op->dtype.is_vector()) { |
| CodeGenC::VisitExpr_(op, os); |
| return; |
| } |
| |
| // Codegen vector condition case by serializing the select op. |
| ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && |
| op->dtype.lanes() == op->condition.dtype().lanes()); |
| |
| std::string r_var = name_supply_->FreshName("_"); |
| this->PrintIndent(); |
| this->PrintType(op->dtype, stream); |
| stream << ' ' << r_var << ";\n"; |
| { |
| std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); |
| std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); |
| std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); |
| |
| // The condition is stored as an ushort vector. |
| int lanes = op->dtype.lanes(); |
| DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes); |
| |
| for (int i = 0; i < lanes; ++i) { |
| std::ostringstream item; |
| item << "(bool("; |
| PrintVecElemLoad(c_var, memory_ty, i, item); |
| item << ")?"; |
| PrintVecElemLoad(t_var, op->dtype, i, item); |
| item << ':'; |
| PrintVecElemLoad(f_var, op->dtype, i, item); |
| item << ')'; |
| PrintVecElemStore(r_var, op->dtype, i, item.str()); |
| } |
| } |
| os << r_var; |
| } |
| |
| inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) |
| // Type code is kBFloat |
| if (op->dtype.is_bfloat16()) { |
| os << "__float2bfloat16_rn"; |
| os << '(' << std::scientific << op->value << 'f' << ')'; |
| return; |
| } |
| // Type code is kFloat |
| switch (op->dtype.bits()) { |
| case 64: |
| case 32: { |
| std::ostringstream temp; |
| if (std::isinf(op->value)) { |
| if (op->value < 0) { |
| temp << "-"; |
| } |
| temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); |
| p->need_math_constants_h_ = true; |
| } else if (std::isnan(op->value)) { |
| temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); |
| p->need_math_constants_h_ = true; |
| } else { |
| temp << std::scientific << op->value; |
| if (op->dtype.bits() == 32) temp << 'f'; |
| } |
| p->MarkConst(temp.str()); |
| os << temp.str(); |
| break; |
| } |
| case 16: { |
| os << "__float2half_rn" << '('; |
| FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); |
| PrintConst(const_f32.get(), os, p); |
| os << ')'; |
| break; |
| } |
| default: |
| LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; |
| } |
| } |
| |
| void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) |
| PrintConst(op, os, this); |
| } |
| |
| void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, |
| std::ostream& os) { |
| std::stringstream type; |
| PrintType(t, type); |
| ICHECK(fragment_shapes.count(variable)) |
| << "Cannot find shape of the wmma fragment " << variable->name_hint; |
| std::string shape_str = fragment_shapes.at(variable); |
| if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { |
| type.str(std::string()); |
| if (t.is_int()) { |
| if (t.bits() == 4) { |
| type << "nvcuda::wmma::experimental::precision::s4"; |
| } else if (t.bits() == 1) { |
| type << "nvcuda::wmma::experimental::precision::b1"; |
| } else { |
| LOG(FATAL) << "Unhandled interger type for wmma fragment!"; |
| } |
| } else if (t.is_uint()) { |
| if (t.bits() == 4) { |
| type << "nvcuda::wmma::experimental::precision::u4"; |
| } else { |
| LOG(FATAL) << "Unhandled interger type for wmma fragment!"; |
| } |
| } |
| } |
| if (scope == "wmma.matrix_a") { |
| need_mma_h_ = true; |
| std::string layout_str = fragment_layouts[variable]; |
| ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; |
| os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str() |
| << ", nvcuda::wmma::" << layout_str << ">"; |
| } else if (scope == "wmma.matrix_b") { |
| need_mma_h_ = true; |
| std::string layout_str = fragment_layouts[variable]; |
| ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; |
| os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str() |
| << ", nvcuda::wmma::" << layout_str << ">"; |
| } else if (scope == "wmma.accumulator") { |
| need_mma_h_ = true; |
| os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str() |
| << ">"; |
| } |
| } |
| |
| int stoi(const std::string& str) { |
| try { |
| return std::stoi(str); |
| } catch (std::invalid_argument& e) { |
| LOG(FATAL) << "Cannot convert \"" << str << "\" to int"; |
| throw; |
| } |
| } |
| |
| int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, |
| int32_t size) { |
| ICHECK(fragment_shapes.count(variable)) |
| << "Cannot find shape of the wmma fragment " << variable->name_hint; |
| std::string shape_str = fragment_shapes.at(variable); |
| std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope); |
| if (dim.first * dim.second != 0) |
| return size / dim.first / dim.second; |
| else |
| return 0; |
| } |
| |
| void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, |
| std::ostream& os) { |
| // Cast away volatile qualifier for fp16 types. That is, only loads and |
| // stores are volatile. The loaded objects are not marked as volatile. |
| // |
| if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { |
| os << "("; |
| PrintType(op->dtype, os); |
| os << ")(" << value << ")"; |
| } else { |
| os << value; |
| } |
| } |
| |
| void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, |
| std::ostream& os) { |
| ICHECK_GT(t.lanes(), 1); |
| if (t.bits() == 8 && (t.is_int() || t.is_uint())) { |
| if (!(t.lanes() == 2 || t.lanes() == 3)) { |
| if (i != 0) { |
| os << "|"; |
| } |
| os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; |
| return; |
| } |
| } |
| |
| if (t.is_float16()) { |
| if (i == 0) { |
| PrintVecConstructor(t, os); |
| os << '('; |
| } |
| if (i % 2 == 0) { |
| os << "__pack_half2(" << value; |
| } else { |
| os << "," << value << ")"; |
| if (i != t.lanes() - 1) { |
| os << ","; |
| } else { |
| os << ")"; |
| } |
| } |
| return; |
| } |
| |
| if (t.is_bfloat16()) { |
| if (i == 0) { |
| PrintVecConstructor(t, os); |
| os << '('; |
| } |
| if (i % 2 == 0) { |
| os << "__pack_bfloat162(" << value; |
| } else { |
| os << "," << value << ")"; |
| if (i != t.lanes() - 1) { |
| os << ","; |
| } else { |
| os << ")"; |
| } |
| } |
| return; |
| } |
| |
| if (i == 0) { |
| PrintVecConstructor(t, os); |
| os << "("; |
| } |
| os << value; |
| if (i != t.lanes() - 1) { |
| os << ","; |
| } else { |
| os << ")"; |
| } |
| return; |
| } |
| |
| } // namespace codegen |
| } // namespace tvm |